rahul7star commited on
Commit
d191426
·
verified ·
1 Parent(s): 9aeedd9

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +28 -61
app_flash.py CHANGED
@@ -6,14 +6,14 @@ import torch.optim as optim
6
  from datasets import load_dataset
7
  import gradio as gr
8
  from transformers import AutoTokenizer, AutoModel
9
- from flashpack import FlashPackMixin # keep if your mixin provides save_flashpack
10
  from typing import Tuple
11
 
12
  # ============================================================
13
- # 🖥 Force CPU mode (safe for HF Spaces / Kaggle)
14
  # ============================================================
15
  device = torch.device("cpu")
16
- torch.set_num_threads(4) # reduce CPU contention in shared environments
17
  print(f"🔧 Forcing device: {device} (CPU-only mode)")
18
 
19
  # ============================================================
@@ -37,7 +37,6 @@ class GemmaTrainer(nn.Module, FlashPackMixin):
37
  # ============================================================
38
  def build_encoder(model_name="gpt2", max_length: int = 32):
39
  tokenizer = AutoTokenizer.from_pretrained(model_name)
40
- # Some GPT2 tokenizers have no pad token — set eos as pad
41
  if tokenizer.pad_token is None:
42
  tokenizer.pad_token = tokenizer.eos_token
43
 
@@ -46,10 +45,6 @@ def build_encoder(model_name="gpt2", max_length: int = 32):
46
 
47
  @torch.no_grad()
48
  def encode(prompt: str) -> torch.Tensor:
49
- """
50
- Encodes a single prompt and returns a CPU tensor of shape (1, hidden_size).
51
- Always returns a CPU tensor to avoid device juggling in downstream code.
52
- """
53
  inputs = tokenizer(
54
  prompt,
55
  return_tensors="pt",
@@ -57,8 +52,7 @@ def build_encoder(model_name="gpt2", max_length: int = 32):
57
  padding="max_length",
58
  max_length=max_length,
59
  ).to(device)
60
-
61
- outputs = embed_model(**inputs).last_hidden_state.mean(dim=1) # (1, hidden)
62
  return outputs.cpu()
63
 
64
  return tokenizer, embed_model, encode
@@ -70,10 +64,10 @@ def train_flashpack_model(
70
  dataset_name: str = "gokaygokay/prompt-enhancer-dataset",
71
  model_name: str = "gpt2",
72
  max_length: int = 32,
73
- max_encode: int = 2000, # maximum number of prompts to encode
74
  push_to_hub: bool = False,
75
  hf_repo: str = "rahul7star/FlashPack",
76
- ) -> tuple:
77
 
78
  # 1️⃣ Load dataset
79
  print("📦 Loading dataset...")
@@ -84,23 +78,17 @@ def train_flashpack_model(
84
  dataset = dataset.select(range(limit))
85
  print(f"⚡ Encoding only {len(dataset)} prompts (max limit {max_encode})")
86
 
87
- # 2️⃣ Setup tokenizer & encoder
88
- tokenizer, embed_model, encode_fn = build_encoder(model_name=model_name, max_length=max_length)
89
 
90
- # 3️⃣ Encode dataset (CPU-friendly)
91
  print("🔢 Encoding dataset into embeddings (CPU-friendly)...")
92
  short_list, long_list = [], []
93
  for i, item in enumerate(dataset):
94
  short_list.append(encode_fn(item["short_prompt"]))
95
  long_list.append(encode_fn(item["long_prompt"]))
96
 
97
- # Exit early if we hit max_encode
98
- if (i + 1) >= max_encode:
99
- print(f"⚡ Reached max encode limit: {max_encode} prompts, stopping early.")
100
- break
101
-
102
- # Progress logging
103
- if (i + 1) % 50 == 0:
104
  print(f" → Encoded {i+1}/{limit} prompts")
105
  gc.collect()
106
 
@@ -108,7 +96,7 @@ def train_flashpack_model(
108
  long_embeddings = torch.vstack(long_list)
109
  print(f"✅ Finished encoding {short_embeddings.shape[0]} prompts")
110
 
111
- # 4️⃣ Initialize and train model (same as before)
112
  model = GemmaTrainer(
113
  input_dim=short_embeddings.shape[1],
114
  hidden_dim=min(512, short_embeddings.shape[1]),
@@ -117,8 +105,7 @@ def train_flashpack_model(
117
 
118
  criterion = nn.MSELoss()
119
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
120
- max_epochs = 50
121
- tolerance = 1e-4
122
  batch_size = 32
123
 
124
  print("🚀 Training FlashPack mapper model (CPU)...")
@@ -143,26 +130,28 @@ def train_flashpack_model(
143
  if epoch % 5 == 0 or epoch == max_epochs-1:
144
  print(f"Epoch {epoch+1}/{max_epochs}, Loss={epoch_loss:.6f}")
145
 
146
- if epoch_loss < tolerance:
147
- print(f"✅ Converged at epoch {epoch+1}, Loss={epoch_loss:.6f}")
148
- break
149
-
150
  print("✅ Training finished!")
 
 
 
 
 
 
151
  return model, dataset, embed_model, tokenizer, long_embeddings
152
 
153
  # ============================================================
154
- # 4️⃣ Build everything and prepare for inference
155
  # ============================================================
156
- # For demo speed in CPU mode, you might want a subset_limit (e.g., 1000).
157
- # Set subset_limit=None to use full dataset.
158
  model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model(
159
- subset_limit=None, # change to a small int for faster testing
160
- push_to_hub=False, # toggle when you want to actually push
161
  )
162
 
163
  model.eval()
164
 
165
- # Reusable encode function for inference (returns CPU tensor)
 
 
166
  @torch.no_grad()
167
  def encode_for_inference(prompt: str) -> torch.Tensor:
168
  inputs = tokenizer(
@@ -174,22 +163,13 @@ def encode_for_inference(prompt: str) -> torch.Tensor:
174
  ).to(device)
175
  return embed_model(**inputs).last_hidden_state.mean(dim=1).cpu()
176
 
177
- # ============================================================
178
- # 5️⃣ Enhance prompt function (nearest neighbor via cosine)
179
- # ============================================================
180
  def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_history):
181
  chat_history = chat_history or []
 
 
182
 
183
- # encode user prompt (CPU tensor)
184
- short_emb = encode_for_inference(user_prompt) # (1, dim)
185
- with torch.no_grad():
186
- mapped = model(short_emb.to(device)).cpu() # (1, dim)
187
-
188
- # cosine similarity against dataset long embeddings
189
  cos = nn.CosineSimilarity(dim=1)
190
- # mapped.repeat(len(long_embeddings), 1) is heavy; do efficient matmul similarity:
191
  sims = (long_embeddings @ mapped.t()).squeeze(1)
192
- # normalize: sims / (||long|| * ||mapped||)
193
  long_norms = long_embeddings.norm(dim=1)
194
  mapped_norm = mapped.norm()
195
  sims = sims / (long_norms * (mapped_norm + 1e-12))
@@ -209,18 +189,14 @@ with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft
209
  """
210
  # ✨ Prompt Enhancer (FlashPack mapper)
211
  Enter a short prompt, and the model will **expand it with details and creative context**.
212
- (This demo runs on CPU — expect slower inference/training than GPU.)
213
  """
214
  )
215
 
216
  with gr.Row():
217
  chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
218
  with gr.Column(scale=1):
219
- user_prompt = gr.Textbox(
220
- placeholder="Enter a short prompt...",
221
- label="Your Prompt",
222
- lines=3,
223
- )
224
  temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
225
  max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens")
226
  send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
@@ -230,15 +206,6 @@ with gr.Blocks(title="Prompt Enhancer – FlashPack (CPU)", theme=gr.themes.Soft
230
  user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
231
  clear_btn.click(lambda: [], None, chatbot)
232
 
233
- gr.Markdown(
234
- """
235
- ---
236
- 💡 **Tips:**
237
- - CPU mode: training and large-batch encodes can take a while. Use `subset_limit` in the training call for quick tests.
238
- - Increase *Temperature* for more creative outputs (not used in the nearest-neighbour mapper but kept for UI parity).
239
- """
240
- )
241
-
242
  # ============================================================
243
  # 7️⃣ Launch
244
  # ============================================================
 
6
  from datasets import load_dataset
7
  import gradio as gr
8
  from transformers import AutoTokenizer, AutoModel
9
+ from flashpack import FlashPackMixin
10
  from typing import Tuple
11
 
12
  # ============================================================
13
+ # 🖥 Force CPU mode
14
  # ============================================================
15
  device = torch.device("cpu")
16
+ torch.set_num_threads(4) # reduce CPU contention
17
  print(f"🔧 Forcing device: {device} (CPU-only mode)")
18
 
19
  # ============================================================
 
37
  # ============================================================
38
  def build_encoder(model_name="gpt2", max_length: int = 32):
39
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
40
  if tokenizer.pad_token is None:
41
  tokenizer.pad_token = tokenizer.eos_token
42
 
 
45
 
46
  @torch.no_grad()
47
  def encode(prompt: str) -> torch.Tensor:
 
 
 
 
48
  inputs = tokenizer(
49
  prompt,
50
  return_tensors="pt",
 
52
  padding="max_length",
53
  max_length=max_length,
54
  ).to(device)
55
+ outputs = embed_model(**inputs).last_hidden_state.mean(dim=1)
 
56
  return outputs.cpu()
57
 
58
  return tokenizer, embed_model, encode
 
64
  dataset_name: str = "gokaygokay/prompt-enhancer-dataset",
65
  model_name: str = "gpt2",
66
  max_length: int = 32,
67
+ max_encode: int = 1000, # use smaller number for CPU
68
  push_to_hub: bool = False,
69
  hf_repo: str = "rahul7star/FlashPack",
70
+ ) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:
71
 
72
  # 1️⃣ Load dataset
73
  print("📦 Loading dataset...")
 
78
  dataset = dataset.select(range(limit))
79
  print(f"⚡ Encoding only {len(dataset)} prompts (max limit {max_encode})")
80
 
81
+ # 2️⃣ Setup encoder
82
+ tokenizer, embed_model, encode_fn = build_encoder(model_name, max_length)
83
 
84
+ # 3️⃣ Encode dataset
85
  print("🔢 Encoding dataset into embeddings (CPU-friendly)...")
86
  short_list, long_list = [], []
87
  for i, item in enumerate(dataset):
88
  short_list.append(encode_fn(item["short_prompt"]))
89
  long_list.append(encode_fn(item["long_prompt"]))
90
 
91
+ if (i + 1) % 50 == 0 or (i + 1) == len(dataset):
 
 
 
 
 
 
92
  print(f" → Encoded {i+1}/{limit} prompts")
93
  gc.collect()
94
 
 
96
  long_embeddings = torch.vstack(long_list)
97
  print(f"✅ Finished encoding {short_embeddings.shape[0]} prompts")
98
 
99
+ # 4️⃣ Initialize & train model
100
  model = GemmaTrainer(
101
  input_dim=short_embeddings.shape[1],
102
  hidden_dim=min(512, short_embeddings.shape[1]),
 
105
 
106
  criterion = nn.MSELoss()
107
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
108
+ max_epochs = 20
 
109
  batch_size = 32
110
 
111
  print("🚀 Training FlashPack mapper model (CPU)...")
 
130
  if epoch % 5 == 0 or epoch == max_epochs-1:
131
  print(f"Epoch {epoch+1}/{max_epochs}, Loss={epoch_loss:.6f}")
132
 
 
 
 
 
133
  print("✅ Training finished!")
134
+
135
+ # 5️⃣ Push to HF repo if requested
136
+ if push_to_hub:
137
+ model.save_flashpack(hf_repo, target_dtype=torch.float32, push_to_hub=True)
138
+ print(f"✅ Model pushed to HF repo: {hf_repo}")
139
+
140
  return model, dataset, embed_model, tokenizer, long_embeddings
141
 
142
  # ============================================================
143
+ # 4️⃣ Run training & load model
144
  # ============================================================
 
 
145
  model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model(
146
+ max_encode=1000, # safe CPU-friendly subset
147
+ push_to_hub=False
148
  )
149
 
150
  model.eval()
151
 
152
+ # ============================================================
153
+ # 5️⃣ Inference helpers
154
+ # ============================================================
155
  @torch.no_grad()
156
  def encode_for_inference(prompt: str) -> torch.Tensor:
157
  inputs = tokenizer(
 
163
  ).to(device)
164
  return embed_model(**inputs).last_hidden_state.mean(dim=1).cpu()
165
 
 
 
 
166
  def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_history):
167
  chat_history = chat_history or []
168
+ short_emb = encode_for_inference(user_prompt)
169
+ mapped = model(short_emb.to(device)).cpu()
170
 
 
 
 
 
 
 
171
  cos = nn.CosineSimilarity(dim=1)
 
172
  sims = (long_embeddings @ mapped.t()).squeeze(1)
 
173
  long_norms = long_embeddings.norm(dim=1)
174
  mapped_norm = mapped.norm()
175
  sims = sims / (long_norms * (mapped_norm + 1e-12))
 
189
  """
190
  # ✨ Prompt Enhancer (FlashPack mapper)
191
  Enter a short prompt, and the model will **expand it with details and creative context**.
192
+ (CPU-only mode.)
193
  """
194
  )
195
 
196
  with gr.Row():
197
  chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages")
198
  with gr.Column(scale=1):
199
+ user_prompt = gr.Textbox(placeholder="Enter a short prompt...", label="Your Prompt", lines=3)
 
 
 
 
200
  temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature")
201
  max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens")
202
  send_btn = gr.Button("🚀 Enhance Prompt", variant="primary")
 
206
  user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot)
207
  clear_btn.click(lambda: [], None, chatbot)
208
 
 
 
 
 
 
 
 
 
 
209
  # ============================================================
210
  # 7️⃣ Launch
211
  # ============================================================