lamekemal commited on
Commit
ab8689d
·
verified ·
1 Parent(s): 82bf711

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -15
app.py CHANGED
@@ -4,32 +4,40 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  import os
5
 
6
  MODEL_DIR = "./mistral-7b-brvm-finetuned"
 
7
 
8
- # Fonction d’entraînement (appelle ton script de fine-tuning)
9
  def train_model():
10
- os.system("python finetune.py") # tu mets ton code d'entraînement dans finetune.py
11
- return "✅ Entraînement terminé ! Le modèle est sauvegardé dans " + MODEL_DIR
12
 
13
- # Chargement du modèle (fine-tuné si dispo, sinon base)
 
 
 
 
 
14
  def load_model():
15
- model_name = MODEL_DIR if os.path.exists(MODEL_DIR) else "mistralai/Mistral-7B-Instruct-v0.3"
 
 
 
 
 
 
16
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
17
  model = AutoModelForCausalLM.from_pretrained(
18
  model_name,
19
- device_map="auto", # Accelerate gère la répartition CPU/GPU
20
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
21
  trust_remote_code=True,
22
  )
 
23
 
24
- # Ne PAS passer device quand on utilise accelerate
25
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
26
-
27
- return pipe
28
-
29
- # On charge le pipeline une fois au démarrage
30
  pipe = load_model()
31
 
32
- # Fonction de test du modèle
33
  def chat(prompt):
34
  outputs = pipe(prompt, max_new_tokens=200, do_sample=True, temperature=0.7, top_p=0.9)
35
  return outputs[0]["generated_text"]
@@ -39,9 +47,12 @@ with gr.Blocks() as demo:
39
  gr.Markdown("# 🐟 BRVM Finetuner (Mistral-7B)")
40
 
41
  with gr.Tab("🚀 Entraînement"):
42
- train_btn = gr.Button("Lancer l’entraînement")
 
 
43
  train_output = gr.Textbox(label="Logs")
44
  train_btn.click(fn=train_model, outputs=train_output)
 
45
 
46
  with gr.Tab("💬 Tester le modèle"):
47
  input_text = gr.Textbox(label="Votre question :", placeholder="Posez une question...")
@@ -49,4 +60,4 @@ with gr.Blocks() as demo:
49
  submit_btn = gr.Button("Envoyer")
50
  submit_btn.click(fn=chat, inputs=input_text, outputs=output_text)
51
 
52
- demo.launch()
 
4
  import os
5
 
6
  MODEL_DIR = "./mistral-7b-brvm-finetuned"
7
+ FULL_MODEL_DIR = "./mistral-7b-brvm-full-finetuned"
8
 
9
+ # Fonction d’entraînement LoRA
10
  def train_model():
11
+ os.system("python finetune.py")
12
+ return "✅ Entraînement LoRA terminé ! Modèle sauvegardé dans " + MODEL_DIR
13
 
14
+ # Fonction d’entraînement Full Fine-tune
15
+ def train_model_full():
16
+ os.system("python finetune_full.py")
17
+ return "✅ Full fine-tune terminé ! Modèle sauvegardé dans " + FULL_MODEL_DIR
18
+
19
+ # Chargement du modèle (par défaut LoRA, sinon base)
20
  def load_model():
21
+ if os.path.exists(FULL_MODEL_DIR):
22
+ model_name = FULL_MODEL_DIR
23
+ elif os.path.exists(MODEL_DIR):
24
+ model_name = MODEL_DIR
25
+ else:
26
+ model_name = "mistralai/Mistral-7B-Instruct-v0.3"
27
+
28
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
29
  model = AutoModelForCausalLM.from_pretrained(
30
  model_name,
31
+ device_map="auto",
32
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
33
  trust_remote_code=True,
34
  )
35
+ return pipeline("text-generation", model=model, tokenizer=tokenizer)
36
 
37
+ # Charger pipeline
 
 
 
 
 
38
  pipe = load_model()
39
 
40
+ # Fonction de test
41
  def chat(prompt):
42
  outputs = pipe(prompt, max_new_tokens=200, do_sample=True, temperature=0.7, top_p=0.9)
43
  return outputs[0]["generated_text"]
 
47
  gr.Markdown("# 🐟 BRVM Finetuner (Mistral-7B)")
48
 
49
  with gr.Tab("🚀 Entraînement"):
50
+ with gr.Row():
51
+ train_btn = gr.Button("Lancer LoRA Fine-tuning")
52
+ train_full_btn = gr.Button("Lancer Full Fine-tuning")
53
  train_output = gr.Textbox(label="Logs")
54
  train_btn.click(fn=train_model, outputs=train_output)
55
+ train_full_btn.click(fn=train_model_full, outputs=train_output)
56
 
57
  with gr.Tab("💬 Tester le modèle"):
58
  input_text = gr.Textbox(label="Votre question :", placeholder="Posez une question...")
 
60
  submit_btn = gr.Button("Envoyer")
61
  submit_btn.click(fn=chat, inputs=input_text, outputs=output_text)
62
 
63
+ demo.launch()