rahul7star commited on
Commit
a8678a6
·
verified ·
1 Parent(s): be09bfa

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +78 -44
app_flash.py CHANGED
@@ -1,72 +1,107 @@
 
 
 
 
 
1
  import gradio as gr
2
- from transformers import AutoTokenizer
3
- from flashpack.integrations.transformers import FlashPackTransformersModelMixin
4
- from transformers import AutoModelForCausalLM, pipeline as hf_pipeline
5
 
6
  # ============================================================
7
- # 1️⃣ Define FlashPack-enabled model class
8
  # ============================================================
9
- class FlashPackGemmaModel(AutoModelForCausalLM, FlashPackTransformersModelMixin):
10
- """Gemma 3 model wrapped with FlashPackTransformersModelMixin"""
11
- pass
 
 
 
 
 
 
 
 
 
12
 
13
  # ============================================================
14
- # 2️⃣ Load tokenizer
15
  # ============================================================
16
- MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it"
17
- FLASHPACK_REPO = "rahul7star/FlashPack"
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  # ============================================================
22
- # 3️⃣ Load or create FlashPack model
23
  # ============================================================
24
- try:
25
- print("📂 Loading model from FlashPack repository...")
26
- model = FlashPackGemmaModel.from_pretrained_flashpack(FLASHPACK_REPO)
27
- except FileNotFoundError:
28
- print("⚠️ FlashPack model not found. Loading from HF Hub and uploading FlashPack...")
29
- model = FlashPackGemmaModel.from_pretrained(MODEL_ID)
30
- model.save_pretrained_flashpack(FLASHPACK_REPO, push_to_hub=True)
31
- print(f"✅ FlashPack model uploaded to Hugging Face Hub: {FLASHPACK_REPO}")
32
 
33
  # ============================================================
34
- # 4️⃣ Build text-generation pipeline
35
  # ============================================================
36
- pipe = hf_pipeline(
37
- "text-generation",
38
- model=model,
39
- tokenizer=tokenizer,
40
- device_map="auto"
41
- )
42
 
43
  # ============================================================
44
- # 5️⃣ Define prompt enhancement function
45
  # ============================================================
46
  def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
47
  chat_history = chat_history or []
48
 
49
- messages = [
50
- {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
51
- {"role": "user", "content": user_prompt},
52
- ]
53
 
54
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
55
 
56
- outputs = pipe(
57
- prompt,
58
- max_new_tokens=int(max_tokens),
59
- temperature=float(temperature),
60
- do_sample=True
61
- )
62
- enhanced = outputs[0]["generated_text"].strip()
63
 
 
64
  chat_history.append({"role": "user", "content": user_prompt})
65
- chat_history.append({"role": "assistant", "content": enhanced})
66
  return chat_history
67
 
68
  # ============================================================
69
- # 6️⃣ Gradio UI
70
  # ============================================================
71
  with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
72
  gr.Markdown(
@@ -90,7 +125,6 @@ with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft())
90
  send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
91
  clear_btn = gr.Button("🧹 Clear Chat")
92
 
93
- # Bind UI actions
94
  send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
95
  user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
96
  clear_btn.click(lambda: [], None, chatbot)
@@ -105,7 +139,7 @@ with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft())
105
  )
106
 
107
  # ============================================================
108
- # 7️⃣ Launch
109
  # ============================================================
110
  if __name__ == "__main__":
111
  demo.launch(show_error=True)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from flashpack import FlashPackMixin
5
+ from datasets import load_dataset
6
  import gradio as gr
7
+
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
9
 
10
  # ============================================================
11
+ # 1️⃣ Define FlashPack model
12
  # ============================================================
13
+ class GemmaTrainer(nn.Module, FlashPackMixin):
14
+ def __init__(self, input_dim=768, hidden_dim=1024, output_dim=768):
15
+ super().__init__()
16
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
17
+ self.relu = nn.ReLU()
18
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
19
+
20
+ def forward(self, x):
21
+ x = self.fc1(x)
22
+ x = self.relu(x)
23
+ x = self.fc2(x)
24
+ return x
25
 
26
  # ============================================================
27
+ # 2️⃣ Load dataset
28
  # ============================================================
29
+ dataset = load_dataset("gokaygokay/prompt-enhancer-dataset", split="train")
30
+
31
+ # Example: convert short_prompt and long_prompt to embeddings
32
+ from transformers import AutoTokenizer, AutoModel
33
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
34
+ embed_model = AutoModel.from_pretrained("gpt2").to(device)
35
+
36
+ def encode_prompt(prompt):
37
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding="max_length", max_length=32).to(device)
38
+ with torch.no_grad():
39
+ return embed_model(**inputs).last_hidden_state.mean(dim=1)
40
+
41
+ short_embeddings = torch.vstack([encode_prompt(p["short_prompt"]) for p in dataset])
42
+ long_embeddings = torch.vstack([encode_prompt(p["long_prompt"]) for p in dataset])
43
 
44
+ # ============================================================
45
+ # 3️⃣ Train FlashPack model
46
+ # ============================================================
47
+ model = GemmaTrainer(input_dim=short_embeddings.shape[1], output_dim=long_embeddings.shape[1]).to(device)
48
+ criterion = nn.MSELoss()
49
+ optimizer = optim.Adam(model.parameters(), lr=1e-3)
50
+
51
+ max_epochs = 1000
52
+ tolerance = 1e-4
53
+
54
+ for epoch in range(max_epochs):
55
+ optimizer.zero_grad()
56
+ outputs = model(short_embeddings)
57
+ loss = criterion(outputs, long_embeddings)
58
+ loss.backward()
59
+ optimizer.step()
60
+ if loss.item() < tolerance:
61
+ print(f"Training converged at epoch {epoch+1}")
62
+ break
63
+ if epoch % 50 == 0:
64
+ print(f"Epoch {epoch+1}, Loss: {loss.item():.6f}")
65
 
66
  # ============================================================
67
+ # 4️⃣ Save to FlashPack Hub
68
  # ============================================================
69
+ FLASHPACK_REPO = "rahul7star/FlashPack"
70
+ model.save_flashpack(FLASHPACK_REPO, target_dtype=torch.float32, push_to_hub=True)
71
+ print("✅ Model saved to FlashPack Hub!")
 
 
 
 
 
72
 
73
  # ============================================================
74
+ # 5️⃣ Load FlashPack model
75
  # ============================================================
76
+ loaded_model = model.from_flashpack(FLASHPACK_REPO)
 
 
 
 
 
77
 
78
  # ============================================================
79
+ # 6️⃣ Gradio interface
80
  # ============================================================
81
  def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
82
  chat_history = chat_history or []
83
 
84
+ # Encode short prompt
85
+ short_emb = encode_prompt(user_prompt)
 
 
86
 
87
+ # Generate expanded embedding via trained model
88
+ with torch.no_grad():
89
+ long_emb = loaded_model(short_emb)
90
 
91
+ # Decode embedding back to text (approximate via nearest training example)
92
+ # Simple approach: cosine similarity to long_embeddings
93
+ cos = nn.CosineSimilarity(dim=1)
94
+ sims = cos(long_emb.repeat(len(long_embeddings),1), long_embeddings)
95
+ best_idx = sims.argmax()
96
+ enhanced_prompt = dataset[best_idx]["long_prompt"]
 
97
 
98
+ # Update chat history
99
  chat_history.append({"role": "user", "content": user_prompt})
100
+ chat_history.append({"role": "assistant", "content": enhanced_prompt})
101
  return chat_history
102
 
103
  # ============================================================
104
+ # 7️⃣ Gradio UI
105
  # ============================================================
106
  with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
107
  gr.Markdown(
 
125
  send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
126
  clear_btn = gr.Button("🧹 Clear Chat")
127
 
 
128
  send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
129
  user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
130
  clear_btn.click(lambda: [], None, chatbot)
 
139
  )
140
 
141
  # ============================================================
142
+ # 8️⃣ Launch
143
  # ============================================================
144
  if __name__ == "__main__":
145
  demo.launch(show_error=True)