rahul7star commited on
Commit
7c8fb46
Β·
verified Β·
1 Parent(s): ee55050

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +105 -101
app_flash.py CHANGED
@@ -1,128 +1,158 @@
1
- import gc
2
  import os
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.optim as optim
6
- import gradio as gr
7
  from datasets import load_dataset
8
- from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, pipeline
9
  from flashpack import FlashPackMixin
10
- from typing import Tuple
 
 
11
 
12
  # ============================================================
13
- # πŸ–₯ Force CPU mode
14
  # ============================================================
15
  device = torch.device("cpu")
16
  torch.set_num_threads(4)
17
- print(f"πŸ”§ Forcing device: {device} (CPU-only mode)")
 
 
 
 
18
 
19
  # ============================================================
20
- # 1️⃣ Define FlashPack model
21
  # ============================================================
22
  class GemmaTrainer(nn.Module, FlashPackMixin):
23
- def __init__(self, input_dim: int = 768, hidden_dim: int = 512, output_dim: int = 768):
24
  super().__init__()
25
  self.fc1 = nn.Linear(input_dim, hidden_dim)
26
  self.relu = nn.ReLU()
27
  self.fc2 = nn.Linear(hidden_dim, output_dim)
28
 
29
- def forward(self, x: torch.Tensor) -> torch.Tensor:
30
- x = self.fc1(x)
31
- x = self.relu(x)
32
- x = self.fc2(x)
33
- return x
34
 
35
 
36
  # ============================================================
37
- # 2️⃣ Build encoder (for embedding)
38
  # ============================================================
39
- def build_encoder(model_name="gpt2", max_length: int = 32):
40
  tokenizer = AutoTokenizer.from_pretrained(model_name)
41
  if tokenizer.pad_token is None:
42
  tokenizer.pad_token = tokenizer.eos_token
43
-
44
  embed_model = AutoModel.from_pretrained(model_name).to(device)
45
  embed_model.eval()
46
 
47
  @torch.no_grad()
48
- def encode(prompt: str) -> torch.Tensor:
49
  inputs = tokenizer(
50
- prompt,
51
  return_tensors="pt",
52
  truncation=True,
53
  padding="max_length",
54
  max_length=max_length,
55
  ).to(device)
56
- outputs = embed_model(**inputs).last_hidden_state.mean(dim=1)
57
- return outputs.cpu()
58
 
59
  return tokenizer, embed_model, encode
60
 
61
 
62
  # ============================================================
63
- # 3️⃣ Load pretrained FlashPack model (skip training)
64
  # ============================================================
65
- def load_flashpack_model(hf_repo="rahul7star/FlashPack"):
66
- model = GemmaTrainer.from_flashpack(hf_repo)
67
- tokenizer = model.tokenizer if hasattr(model, "tokenizer") else None
68
- embed_model = model.embed_model if hasattr(model, "embed_model") else None
69
- return model, tokenizer, embed_model
70
 
71
- # def load_flashpack_model(hf_repo="rahul7star/FlashPack"):
72
- # print(f"πŸ” Loading FlashPack model from: {hf_repo}")
73
-
74
- # model = GemmaTrainer.from_flashpack(hf_repo)
75
 
76
- # model.eval()
77
- # tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=32)
78
- # return model, tokenizer, embed_model
 
 
 
 
79
 
 
 
80
 
81
- # ============================================================
82
- # 4️⃣ Load Gemma text model for prompt enhancement
83
- # ============================================================
84
- MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it"
 
 
 
 
85
 
86
- tokenizer_gemma = AutoTokenizer.from_pretrained(MODEL_ID)
87
- model_gemma = AutoModelForCausalLM.from_pretrained(MODEL_ID)
 
88
 
89
- pipe_gemma = pipeline(
90
- "text-generation",
91
- model=model_gemma,
92
- tokenizer=tokenizer_gemma,
93
- device=-1, # CPU
94
- )
95
 
96
- import re
 
97
 
98
- def extract_later_part(user_prompt, generated_text):
99
- """Cleans the model output and extracts only the enhanced (later) portion."""
100
- cleaned = re.sub(r"<.*?>", "", generated_text).strip()
101
- cleaned = re.sub(r"\s+", " ", cleaned)
102
- user_prompt_clean = user_prompt.strip().lower()
103
- cleaned_lower = cleaned.lower()
104
- if cleaned_lower.startswith(user_prompt_clean):
105
- cleaned = cleaned[len(user_prompt):].strip(",. ").strip()
106
- return cleaned
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
 
109
  # ============================================================
110
- # 5️⃣ Load FlashPack + Dataset + Encoder
111
  # ============================================================
112
- model, tokenizer, embed_model = load_flashpack_model("rahul7star/FlashPack")
113
  dataset = load_dataset("gokaygokay/prompt-enhancer-dataset", split="train")
114
- long_embeddings = torch.vstack(
115
- [embed_model(**tokenizer(p["long_prompt"], return_tensors="pt", truncation=True, padding="max_length", max_length=32)).last_hidden_state.mean(dim=1).cpu()
116
- for p in dataset.select(range(min(500, len(dataset))))]
117
- )
118
- print("βœ… Loaded FlashPack and Gemma models.")
 
 
 
 
 
119
 
120
 
121
  # ============================================================
122
- # 6️⃣ FlashPack inference helper
123
  # ============================================================
124
  @torch.no_grad()
125
- def encode_for_inference(prompt: str) -> torch.Tensor:
126
  inputs = tokenizer(
127
  prompt,
128
  return_tensors="pt",
@@ -140,10 +170,7 @@ def enhance_prompt_flashpack(user_prompt: str, temperature: float, max_tokens: i
140
  mapped = model(short_emb.to(device)).cpu()
141
 
142
  sims = (long_embeddings @ mapped.t()).squeeze(1)
143
- long_norms = long_embeddings.norm(dim=1)
144
- mapped_norm = mapped.norm()
145
- sims = sims / (long_norms * (mapped_norm + 1e-12))
146
-
147
  best_idx = int(sims.argmax().item())
148
  enhanced_prompt = dataset[best_idx]["long_prompt"]
149
 
@@ -153,36 +180,14 @@ def enhance_prompt_flashpack(user_prompt: str, temperature: float, max_tokens: i
153
 
154
 
155
  # ============================================================
156
- # 7️⃣ Gemma prompt enhancer
157
  # ============================================================
158
- def enhance_prompt_gemma(user_prompt, temperature, max_tokens, chat_history):
159
- chat_history = chat_history or []
160
- messages = [
161
- {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
162
- {"role": "user", "content": user_prompt}
163
- ]
164
- prompt = tokenizer_gemma.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
165
- output = pipe_gemma(
166
- prompt,
167
- max_new_tokens=int(max_tokens),
168
- temperature=float(temperature),
169
- do_sample=True,
170
- )[0]["generated_text"]
171
- enhanced_text = extract_later_part(user_prompt, output)
172
- chat_history.append({"role": "user", "content": user_prompt})
173
- chat_history.append({"role": "assistant", "content": enhanced_text})
174
- return chat_history
175
-
176
-
177
- # ============================================================
178
- # 8️⃣ Gradio UI
179
- # ============================================================
180
- with gr.Blocks(title="Prompt Enhancer – FlashPack + Gemma (CPU)", theme=gr.themes.Soft()) as demo:
181
  gr.Markdown("""
182
- # ✨ Prompt Enhancer (FlashPack + Gemma)
183
- - **Gemma model**: Enhances prompts with natural language.
184
- - **FlashPack model**: Finds similar expanded prompts from dataset.
185
- - CPU-only, for reproducibility.
186
  """)
187
 
188
  with gr.Row():
@@ -191,17 +196,16 @@ with gr.Blocks(title="Prompt Enhancer – FlashPack + Gemma (CPU)", theme=gr.the
191
  user_prompt = gr.Textbox(placeholder="Enter a short prompt...", label="Your Prompt", lines=3)
192
  temperature = gr.Slider(0.1, 1.5, value=0.7, label="Temperature")
193
  max_tokens = gr.Slider(32, 512, value=256, label="Max Tokens")
194
- send_gemma = gr.Button("πŸ’¬ Enhance (Gemma)")
195
- send_flashpack = gr.Button("πŸ”— Enhance (FlashPack)")
196
  clear_btn = gr.Button("🧹 Clear Chat")
197
 
198
- send_gemma.click(enhance_prompt_gemma, [user_prompt, temperature, max_tokens, chatbot], chatbot)
199
  send_flashpack.click(enhance_prompt_flashpack, [user_prompt, temperature, max_tokens, chatbot], chatbot)
200
- user_prompt.submit(enhance_prompt_gemma, [user_prompt, temperature, max_tokens, chatbot], chatbot)
201
  clear_btn.click(lambda: [], None, chatbot)
202
 
 
203
  # ============================================================
204
- # 9️⃣ Launch
205
  # ============================================================
206
  if __name__ == "__main__":
207
  demo.launch(show_error=True)
 
 
1
  import os
2
+ import re
3
  import torch
4
  import torch.nn as nn
5
  import torch.optim as optim
6
+ from typing import Tuple
7
  from datasets import load_dataset
 
8
  from flashpack import FlashPackMixin
9
+ from huggingface_hub import HfApi, create_repo, repo_exists
10
+ import gradio as gr
11
+ from transformers import AutoTokenizer, AutoModel
12
 
13
  # ============================================================
14
+ # βš™οΈ Setup
15
  # ============================================================
16
  device = torch.device("cpu")
17
  torch.set_num_threads(4)
18
+ print(f"πŸ”§ Using device: {device} (CPU-only mode)")
19
+
20
+ HF_REPO = "rahul7star/FlashPack"
21
+ MODEL_ID = HF_REPO
22
+
23
 
24
  # ============================================================
25
+ # 🧠 Define FlashPack Trainer
26
  # ============================================================
27
  class GemmaTrainer(nn.Module, FlashPackMixin):
28
+ def __init__(self, input_dim=768, hidden_dim=512, output_dim=768):
29
  super().__init__()
30
  self.fc1 = nn.Linear(input_dim, hidden_dim)
31
  self.relu = nn.ReLU()
32
  self.fc2 = nn.Linear(hidden_dim, output_dim)
33
 
34
+ def forward(self, x):
35
+ return self.fc2(self.relu(self.fc1(x)))
 
 
 
36
 
37
 
38
  # ============================================================
39
+ # πŸ”€ Encoder Builder (GPT2 base)
40
  # ============================================================
41
+ def build_encoder(model_name="gpt2", max_length=32):
42
  tokenizer = AutoTokenizer.from_pretrained(model_name)
43
  if tokenizer.pad_token is None:
44
  tokenizer.pad_token = tokenizer.eos_token
 
45
  embed_model = AutoModel.from_pretrained(model_name).to(device)
46
  embed_model.eval()
47
 
48
  @torch.no_grad()
49
+ def encode(text: str):
50
  inputs = tokenizer(
51
+ text,
52
  return_tensors="pt",
53
  truncation=True,
54
  padding="max_length",
55
  max_length=max_length,
56
  ).to(device)
57
+ return embed_model(**inputs).last_hidden_state.mean(dim=1).cpu()
 
58
 
59
  return tokenizer, embed_model, encode
60
 
61
 
62
  # ============================================================
63
+ # 🧩 FlashPack: Train and Upload (uses Gemma only internally)
64
  # ============================================================
65
+ def train_flashpack_model(hf_repo=HF_REPO):
66
+ print(f"πŸš€ Training new FlashPack model for repo: {hf_repo}")
67
+ model = GemmaTrainer()
68
+ tokenizer, embed_model, encode = build_encoder("gpt2")
 
69
 
70
+ # Load dataset (Gemma-expanded dataset)
71
+ dataset = load_dataset("gokaygokay/prompt-enhancer-dataset", split="train")
 
 
72
 
73
+ # Compute embeddings for training (short β†’ long)
74
+ X, Y = [], []
75
+ for p in dataset.select(range(300)):
76
+ short_emb = encode(p["short_prompt"])
77
+ long_emb = encode(p["long_prompt"])
78
+ X.append(short_emb)
79
+ Y.append(long_emb)
80
 
81
+ X = torch.vstack(X)
82
+ Y = torch.vstack(Y)
83
 
84
+ optimizer = optim.Adam(model.parameters(), lr=1e-3)
85
+ for epoch in range(10):
86
+ out = model(X)
87
+ loss = nn.MSELoss()(out, Y)
88
+ optimizer.zero_grad()
89
+ loss.backward()
90
+ optimizer.step()
91
+ print(f"Epoch {epoch+1}/10 | Loss: {loss.item():.6f}")
92
 
93
+ # Save FlashPack model and push
94
+ model.to_flashpack("flashpack_model")
95
+ print("πŸ’Ύ Model saved locally. Uploading to Hugging Face...")
96
 
97
+ api = HfApi()
98
+ if not repo_exists(hf_repo):
99
+ create_repo(hf_repo, repo_type="model", exist_ok=True)
100
+ model.push_to_hub(hf_repo, commit_message="Initial FlashPack model training")
 
 
101
 
102
+ print(f"βœ… Model uploaded successfully to {hf_repo}")
103
+ return model, tokenizer, embed_model
104
 
105
+
106
+ # ============================================================
107
+ # πŸ“¦ Load FlashPack from Hub
108
+ # ============================================================
109
+ def load_flashpack_model(hf_repo=HF_REPO):
110
+ print(f"πŸ“₯ Loading FlashPack model from {hf_repo}...")
111
+ model = GemmaTrainer.from_flashpack(hf_repo)
112
+ tokenizer, embed_model, encode = build_encoder("gpt2")
113
+ print("βœ… Loaded FlashPack model successfully.")
114
+ return model, tokenizer, embed_model
115
+
116
+
117
+ # ============================================================
118
+ # ⚑ Auto Load or Train
119
+ # ============================================================
120
+ def get_flashpack_model(hf_repo=HF_REPO):
121
+ try:
122
+ api = HfApi()
123
+ if repo_exists(hf_repo):
124
+ print("βœ… Found trained model on Hub.")
125
+ return load_flashpack_model(hf_repo)
126
+ else:
127
+ print("❌ Model not found, training new one using Gemma dataset...")
128
+ return train_flashpack_model(hf_repo)
129
+ except Exception as e:
130
+ print(f"⚠️ Repo check failed: {e}. Retraining model locally.")
131
+ return train_flashpack_model(hf_repo)
132
 
133
 
134
  # ============================================================
135
+ # πŸ“š Dataset + Model
136
  # ============================================================
137
+ model, tokenizer, embed_model = get_flashpack_model()
138
  dataset = load_dataset("gokaygokay/prompt-enhancer-dataset", split="train")
139
+
140
+ long_embeddings = torch.vstack([
141
+ embed_model(**tokenizer(
142
+ p["long_prompt"], return_tensors="pt",
143
+ truncation=True, padding="max_length", max_length=32
144
+ )).last_hidden_state.mean(dim=1).cpu()
145
+ for p in dataset.select(range(min(500, len(dataset))))
146
+ ])
147
+
148
+ print("βœ… FlashPack model and embeddings loaded.")
149
 
150
 
151
  # ============================================================
152
+ # 🧠 Inference Helpers
153
  # ============================================================
154
  @torch.no_grad()
155
+ def encode_for_inference(prompt: str):
156
  inputs = tokenizer(
157
  prompt,
158
  return_tensors="pt",
 
170
  mapped = model(short_emb.to(device)).cpu()
171
 
172
  sims = (long_embeddings @ mapped.t()).squeeze(1)
173
+ sims /= (long_embeddings.norm(dim=1) * (mapped.norm() + 1e-12))
 
 
 
174
  best_idx = int(sims.argmax().item())
175
  enhanced_prompt = dataset[best_idx]["long_prompt"]
176
 
 
180
 
181
 
182
  # ============================================================
183
+ # πŸ’¬ Gradio UI
184
  # ============================================================
185
+ with gr.Blocks(title="Prompt Enhancer – FlashPack Only", theme=gr.themes.Soft()) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  gr.Markdown("""
187
+ # ✨ FlashPack Prompt Enhancer
188
+ - Uses pre-trained **FlashPack model** (`rahul7star/FlashPack`)
189
+ - Matches short prompts to enhanced long prompts using learned embeddings
190
+ - CPU-only, no Gemma dependency during inference.
191
  """)
192
 
193
  with gr.Row():
 
196
  user_prompt = gr.Textbox(placeholder="Enter a short prompt...", label="Your Prompt", lines=3)
197
  temperature = gr.Slider(0.1, 1.5, value=0.7, label="Temperature")
198
  max_tokens = gr.Slider(32, 512, value=256, label="Max Tokens")
199
+ send_flashpack = gr.Button("πŸ”— Enhance Prompt")
 
200
  clear_btn = gr.Button("🧹 Clear Chat")
201
 
 
202
  send_flashpack.click(enhance_prompt_flashpack, [user_prompt, temperature, max_tokens, chatbot], chatbot)
203
+ user_prompt.submit(enhance_prompt_flashpack, [user_prompt, temperature, max_tokens, chatbot], chatbot)
204
  clear_btn.click(lambda: [], None, chatbot)
205
 
206
+
207
  # ============================================================
208
+ # πŸš€ Launch App
209
  # ============================================================
210
  if __name__ == "__main__":
211
  demo.launch(show_error=True)