File size: 2,299 Bytes
513cce0
9d1b236
513cce0
 
 
 
ab8689d
513cce0
ab8689d
513cce0
ab8689d
 
513cce0
ab8689d
 
 
 
 
 
513cce0
ab8689d
 
 
 
 
 
 
513cce0
 
 
ab8689d
513cce0
 
97f8596
9d1b236
ab8689d
2bbbd42
ab8689d
513cce0
205396b
ab8689d
513cce0
 
 
205396b
513cce0
 
 
9d1b236
513cce0
ab8689d
 
 
513cce0
 
ab8689d
9d1b236
513cce0
 
 
 
 
9d1b236
ab8689d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import os

MODEL_DIR = "./mistral-7b-brvm-finetuned"
FULL_MODEL_DIR = "./mistral-7b-brvm-full-finetuned"

# Fonction d’entraînement LoRA
def train_model():
    os.system("python finetune.py")
    return "✅ Entraînement LoRA terminé ! Modèle sauvegardé dans " + MODEL_DIR

# Fonction d’entraînement Full Fine-tune
def train_model_full():
    os.system("python finetune_full.py")
    return "✅ Full fine-tune terminé ! Modèle sauvegardé dans " + FULL_MODEL_DIR

# Chargement du modèle (par défaut LoRA, sinon base)
def load_model():
    if os.path.exists(FULL_MODEL_DIR):
        model_name = FULL_MODEL_DIR
    elif os.path.exists(MODEL_DIR):
        model_name = MODEL_DIR
    else:
        model_name = "mistralai/Mistral-7B-Instruct-v0.3"

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        trust_remote_code=True,
        load_in_8bit=True,   # <<<<< important
    )
    return pipeline("text-generation", model=model, tokenizer=tokenizer)

# Charger pipeline
pipe = load_model()

# Fonction de test
def chat(prompt):
    outputs = pipe(prompt, max_new_tokens=200, do_sample=True, temperature=0.7, top_p=0.9)
    return outputs[0]["generated_text"]

# Interface Gradio
with gr.Blocks() as demo:
    gr.Markdown("# 🐟 BRVM Finetuner (Mistral-7B)")

    with gr.Tab("🚀 Entraînement"):
        with gr.Row():
            train_btn = gr.Button("Lancer LoRA Fine-tuning")
            train_full_btn = gr.Button("Lancer Full Fine-tuning")
        train_output = gr.Textbox(label="Logs")
        train_btn.click(fn=train_model, outputs=train_output)
        train_full_btn.click(fn=train_model_full, outputs=train_output)

    with gr.Tab("💬 Tester le modèle"):
        input_text = gr.Textbox(label="Votre question :", placeholder="Posez une question...")
        output_text = gr.Textbox(label="Réponse du modèle")
        submit_btn = gr.Button("Envoyer")
        submit_btn.click(fn=chat, inputs=input_text, outputs=output_text)

demo.launch()