import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import os MODEL_DIR = "./mistral-7b-brvm-finetuned" FULL_MODEL_DIR = "./mistral-7b-brvm-full-finetuned" # Fonction d’entraînement LoRA def train_model(): os.system("python finetune.py") return "✅ Entraînement LoRA terminé ! Modèle sauvegardé dans " + MODEL_DIR # Fonction d’entraînement Full Fine-tune def train_model_full(): os.system("python finetune_full.py") return "✅ Full fine-tune terminé ! Modèle sauvegardé dans " + FULL_MODEL_DIR # Chargement du modèle (par défaut LoRA, sinon base) def load_model(): if os.path.exists(FULL_MODEL_DIR): model_name = FULL_MODEL_DIR elif os.path.exists(MODEL_DIR): model_name = MODEL_DIR else: model_name = "mistralai/Mistral-7B-Instruct-v0.3" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, trust_remote_code=True, load_in_8bit=True, # <<<<< important ) return pipeline("text-generation", model=model, tokenizer=tokenizer) # Charger pipeline pipe = load_model() # Fonction de test def chat(prompt): outputs = pipe(prompt, max_new_tokens=200, do_sample=True, temperature=0.7, top_p=0.9) return outputs[0]["generated_text"] # Interface Gradio with gr.Blocks() as demo: gr.Markdown("# 🐟 BRVM Finetuner (Mistral-7B)") with gr.Tab("🚀 Entraînement"): with gr.Row(): train_btn = gr.Button("Lancer LoRA Fine-tuning") train_full_btn = gr.Button("Lancer Full Fine-tuning") train_output = gr.Textbox(label="Logs") train_btn.click(fn=train_model, outputs=train_output) train_full_btn.click(fn=train_model_full, outputs=train_output) with gr.Tab("💬 Tester le modèle"): input_text = gr.Textbox(label="Votre question :", placeholder="Posez une question...") output_text = gr.Textbox(label="Réponse du modèle") submit_btn = gr.Button("Envoyer") submit_btn.click(fn=chat, inputs=input_text, outputs=output_text) demo.launch()