brvm_finetuner / app.py
lamekemal's picture
Update app.py
97f8596 verified
raw
history blame
2.3 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import os
MODEL_DIR = "./mistral-7b-brvm-finetuned"
FULL_MODEL_DIR = "./mistral-7b-brvm-full-finetuned"
# Fonction d’entraînement LoRA
def train_model():
os.system("python finetune.py")
return "✅ Entraînement LoRA terminé ! Modèle sauvegardé dans " + MODEL_DIR
# Fonction d’entraînement Full Fine-tune
def train_model_full():
os.system("python finetune_full.py")
return "✅ Full fine-tune terminé ! Modèle sauvegardé dans " + FULL_MODEL_DIR
# Chargement du modèle (par défaut LoRA, sinon base)
def load_model():
if os.path.exists(FULL_MODEL_DIR):
model_name = FULL_MODEL_DIR
elif os.path.exists(MODEL_DIR):
model_name = MODEL_DIR
else:
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True,
load_in_8bit=True, # <<<<< important
)
return pipeline("text-generation", model=model, tokenizer=tokenizer)
# Charger pipeline
pipe = load_model()
# Fonction de test
def chat(prompt):
outputs = pipe(prompt, max_new_tokens=200, do_sample=True, temperature=0.7, top_p=0.9)
return outputs[0]["generated_text"]
# Interface Gradio
with gr.Blocks() as demo:
gr.Markdown("# 🐟 BRVM Finetuner (Mistral-7B)")
with gr.Tab("🚀 Entraînement"):
with gr.Row():
train_btn = gr.Button("Lancer LoRA Fine-tuning")
train_full_btn = gr.Button("Lancer Full Fine-tuning")
train_output = gr.Textbox(label="Logs")
train_btn.click(fn=train_model, outputs=train_output)
train_full_btn.click(fn=train_model_full, outputs=train_output)
with gr.Tab("💬 Tester le modèle"):
input_text = gr.Textbox(label="Votre question :", placeholder="Posez une question...")
output_text = gr.Textbox(label="Réponse du modèle")
submit_btn = gr.Button("Envoyer")
submit_btn.click(fn=chat, inputs=input_text, outputs=output_text)
demo.launch()