Spaces:
Sleeping
Sleeping
File size: 2,299 Bytes
513cce0 9d1b236 513cce0 ab8689d 513cce0 ab8689d 513cce0 ab8689d 513cce0 ab8689d 513cce0 ab8689d 513cce0 ab8689d 513cce0 97f8596 9d1b236 ab8689d 2bbbd42 ab8689d 513cce0 205396b ab8689d 513cce0 205396b 513cce0 9d1b236 513cce0 ab8689d 513cce0 ab8689d 9d1b236 513cce0 9d1b236 ab8689d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import os
MODEL_DIR = "./mistral-7b-brvm-finetuned"
FULL_MODEL_DIR = "./mistral-7b-brvm-full-finetuned"
# Fonction d’entraînement LoRA
def train_model():
os.system("python finetune.py")
return "✅ Entraînement LoRA terminé ! Modèle sauvegardé dans " + MODEL_DIR
# Fonction d’entraînement Full Fine-tune
def train_model_full():
os.system("python finetune_full.py")
return "✅ Full fine-tune terminé ! Modèle sauvegardé dans " + FULL_MODEL_DIR
# Chargement du modèle (par défaut LoRA, sinon base)
def load_model():
if os.path.exists(FULL_MODEL_DIR):
model_name = FULL_MODEL_DIR
elif os.path.exists(MODEL_DIR):
model_name = MODEL_DIR
else:
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True,
load_in_8bit=True, # <<<<< important
)
return pipeline("text-generation", model=model, tokenizer=tokenizer)
# Charger pipeline
pipe = load_model()
# Fonction de test
def chat(prompt):
outputs = pipe(prompt, max_new_tokens=200, do_sample=True, temperature=0.7, top_p=0.9)
return outputs[0]["generated_text"]
# Interface Gradio
with gr.Blocks() as demo:
gr.Markdown("# 🐟 BRVM Finetuner (Mistral-7B)")
with gr.Tab("🚀 Entraînement"):
with gr.Row():
train_btn = gr.Button("Lancer LoRA Fine-tuning")
train_full_btn = gr.Button("Lancer Full Fine-tuning")
train_output = gr.Textbox(label="Logs")
train_btn.click(fn=train_model, outputs=train_output)
train_full_btn.click(fn=train_model_full, outputs=train_output)
with gr.Tab("💬 Tester le modèle"):
input_text = gr.Textbox(label="Votre question :", placeholder="Posez une question...")
output_text = gr.Textbox(label="Réponse du modèle")
submit_btn = gr.Button("Envoyer")
submit_btn.click(fn=chat, inputs=input_text, outputs=output_text)
demo.launch()
|