rahul7star commited on
Commit
d071e42
·
verified ·
1 Parent(s): 248fe25

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +15 -14
app_flash.py CHANGED
@@ -3,47 +3,48 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
  from flashpack.integrations.transformers import FlashPackTransformersModelMixin
4
 
5
  # ============================================================
6
- # 1️⃣ Define FlashPack-enabled model class
7
  # ============================================================
8
  class FlashPackGemmaModel(AutoModelForCausalLM, FlashPackTransformersModelMixin):
9
  """AutoModelForCausalLM extended with FlashPackMixin for fast save/load"""
10
  pass
11
 
 
12
 
13
  # ============================================================
14
- # 2️⃣ Load or prepare model
15
  # ============================================================
16
- MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it"
17
-
18
  try:
19
- print("📂 Trying to load FlashPack model...")
20
  model = FlashPackGemmaModel.from_pretrained_flashpack("model_flashpack")
21
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
22
  except Exception as e:
23
- print("⚙️ FlashPack not found, loading from Hugging Face Hub...")
24
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
25
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
26
- # Save as FlashPack for faster next load
 
27
  model.save_pretrained_flashpack("model_flashpack")
28
  print("✅ Model saved as FlashPack for next startup!")
29
 
30
- # Create the Hugging Face text-generation pipeline
 
 
31
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
32
 
33
 
34
  # ============================================================
35
- # 3️⃣ Define inference logic
36
  # ============================================================
37
  def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
38
  chat_history = chat_history or []
39
 
40
- # Build messages for chat-template
41
  messages = [
42
  {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
43
  {"role": "user", "content": user_prompt},
44
  ]
45
 
46
- # Use tokenizer.apply_chat_template
47
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
48
 
49
  outputs = pipe(
@@ -63,7 +64,7 @@ def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
63
 
64
 
65
  # ============================================================
66
- # 4️⃣ Gradio Interface
67
  # ============================================================
68
  with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
69
  gr.Markdown(
@@ -103,7 +104,7 @@ with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft())
103
 
104
 
105
  # ============================================================
106
- # 5️⃣ Launch App
107
  # ============================================================
108
  if __name__ == "__main__":
109
  demo.launch(show_error=True)
 
3
  from flashpack.integrations.transformers import FlashPackTransformersModelMixin
4
 
5
  # ============================================================
6
+ # 1️⃣ FlashPack-enabled model class
7
  # ============================================================
8
  class FlashPackGemmaModel(AutoModelForCausalLM, FlashPackTransformersModelMixin):
9
  """AutoModelForCausalLM extended with FlashPackMixin for fast save/load"""
10
  pass
11
 
12
+ MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it"
13
 
14
  # ============================================================
15
+ # 2️⃣ Load model and tokenizer with FlashPack
16
  # ============================================================
 
 
17
  try:
18
+ print("📂 Trying to load model from FlashPack directory...")
19
  model = FlashPackGemmaModel.from_pretrained_flashpack("model_flashpack")
20
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
21
  except Exception as e:
22
+ print("⚙️ FlashPack model not found, loading from Hugging Face Hub...")
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
24
+ # Load Hugging Face model and wrap into FlashPack class
25
+ model = FlashPackGemmaModel.from_pretrained(MODEL_ID)
26
+ # Save for future faster loads
27
  model.save_pretrained_flashpack("model_flashpack")
28
  print("✅ Model saved as FlashPack for next startup!")
29
 
30
+ # ============================================================
31
+ # 3️⃣ Create text-generation pipeline
32
+ # ============================================================
33
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
34
 
35
 
36
  # ============================================================
37
+ # 4️⃣ Define prompt enhancement logic
38
  # ============================================================
39
  def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
40
  chat_history = chat_history or []
41
 
 
42
  messages = [
43
  {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
44
  {"role": "user", "content": user_prompt},
45
  ]
46
 
47
+ # Use chat-template
48
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
49
 
50
  outputs = pipe(
 
64
 
65
 
66
  # ============================================================
67
+ # 5️⃣ Gradio Interface
68
  # ============================================================
69
  with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
70
  gr.Markdown(
 
104
 
105
 
106
  # ============================================================
107
+ # 6️⃣ Launch App
108
  # ============================================================
109
  if __name__ == "__main__":
110
  demo.launch(show_error=True)