Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import os | |
| MODEL_DIR = "./mistral-7b-brvm-finetuned" | |
| FULL_MODEL_DIR = "./mistral-7b-brvm-full-finetuned" | |
| # Fonction d’entraînement LoRA | |
| def train_model(): | |
| os.system("python finetune.py") | |
| return "✅ Entraînement LoRA terminé ! Modèle sauvegardé dans " + MODEL_DIR | |
| # Fonction d’entraînement Full Fine-tune | |
| def train_model_full(): | |
| os.system("python finetune_full.py") | |
| return "✅ Full fine-tune terminé ! Modèle sauvegardé dans " + FULL_MODEL_DIR | |
| # Chargement du modèle (par défaut LoRA, sinon base) | |
| def load_model(): | |
| if os.path.exists(FULL_MODEL_DIR): | |
| model_name = FULL_MODEL_DIR | |
| elif os.path.exists(MODEL_DIR): | |
| model_name = MODEL_DIR | |
| else: | |
| model_name = "mistralai/Mistral-7B-Instruct-v0.3" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| trust_remote_code=True, | |
| load_in_8bit=True, # <<<<< important | |
| ) | |
| return pipeline("text-generation", model=model, tokenizer=tokenizer) | |
| # Charger pipeline | |
| pipe = load_model() | |
| # Fonction de test | |
| def chat(prompt): | |
| outputs = pipe(prompt, max_new_tokens=200, do_sample=True, temperature=0.7, top_p=0.9) | |
| return outputs[0]["generated_text"] | |
| # Interface Gradio | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🐟 BRVM Finetuner (Mistral-7B)") | |
| with gr.Tab("🚀 Entraînement"): | |
| with gr.Row(): | |
| train_btn = gr.Button("Lancer LoRA Fine-tuning") | |
| train_full_btn = gr.Button("Lancer Full Fine-tuning") | |
| train_output = gr.Textbox(label="Logs") | |
| train_btn.click(fn=train_model, outputs=train_output) | |
| train_full_btn.click(fn=train_model_full, outputs=train_output) | |
| with gr.Tab("💬 Tester le modèle"): | |
| input_text = gr.Textbox(label="Votre question :", placeholder="Posez une question...") | |
| output_text = gr.Textbox(label="Réponse du modèle") | |
| submit_btn = gr.Button("Envoyer") | |
| submit_btn.click(fn=chat, inputs=input_text, outputs=output_text) | |
| demo.launch() | |