rahul7star commited on
Commit
2e79f16
·
verified ·
1 Parent(s): 400b637

Update app_flash.py

Browse files

testing more encode but a verson before this works just good

Files changed (1) hide show
  1. app_flash.py +58 -20
app_flash.py CHANGED
@@ -17,27 +17,49 @@ from typing import Tuple
17
  device = torch.device("cpu")
18
  torch.set_num_threads(4)
19
  print(f"🔧 Using device: {device} (CPU-only mode)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  # ============================================================
22
- # 1️⃣ Define FlashPack model
23
  # ============================================================
24
  class GemmaTrainer(nn.Module, FlashPackMixin):
25
- def __init__(self, input_dim: int = 768, hidden_dim: int = 512, output_dim: int = 768):
26
  super().__init__()
27
  self.fc1 = nn.Linear(input_dim, hidden_dim)
28
  self.relu = nn.ReLU()
29
- self.fc2 = nn.Linear(hidden_dim, output_dim)
 
30
 
31
  def forward(self, x: torch.Tensor) -> torch.Tensor:
32
  x = self.fc1(x)
33
  x = self.relu(x)
34
  x = self.fc2(x)
 
 
35
  return x
36
 
 
37
  # ============================================================
38
- # 2️⃣ Build tokenizer + encoder
39
  # ============================================================
40
- def build_encoder(model_name="gpt2", max_length: int = 32):
41
  tokenizer = AutoTokenizer.from_pretrained(model_name)
42
  if tokenizer.pad_token is None:
43
  tokenizer.pad_token = tokenizer.eos_token
@@ -54,11 +76,14 @@ def build_encoder(model_name="gpt2", max_length: int = 32):
54
  padding="max_length",
55
  max_length=max_length
56
  ).to(device)
57
- outputs = embed_model(**inputs).last_hidden_state.mean(dim=1)
58
- return outputs.cpu()
 
 
59
 
60
  return tokenizer, embed_model, encode
61
 
 
62
  # ============================================================
63
  # 3️⃣ Push FlashPack model to HF
64
  # ============================================================
@@ -85,22 +110,28 @@ def push_flashpack_model_to_hf(model, hf_repo: str):
85
 
86
  return logs
87
 
 
88
  # ============================================================
89
  # 4️⃣ Train FlashPack model
90
  # ============================================================
91
  def train_flashpack_model(
92
- dataset_name: str = "gokaygokay/prompt-enhancer-dataset",
93
  max_encode: int = 5000,
94
- device: str = "cpu"
 
 
95
  ) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:
 
96
  print("📦 Loading dataset...")
97
  dataset = load_dataset(dataset_name, split="train")
98
  limit = min(max_encode, len(dataset))
99
  dataset = dataset.select(range(limit))
100
- print(f"⚡ Encoding {len(dataset)} prompts (max {max_encode})")
101
 
102
- tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=32)
 
103
 
 
104
  short_list, long_list = [], []
105
  for i, item in enumerate(dataset):
106
  short_list.append(encode_fn(item["short_prompt"]))
@@ -112,21 +143,20 @@ def train_flashpack_model(
112
  short_embeddings = torch.vstack(short_list)
113
  long_embeddings = torch.vstack(long_list)
114
  print(f"✅ Finished encoding {short_embeddings.shape[0]} prompts")
 
 
115
 
116
  # Build model
117
- model = GemmaTrainer(
118
- input_dim=short_embeddings.shape[1],
119
- hidden_dim=min(512, short_embeddings.shape[1]),
120
- output_dim=long_embeddings.shape[1]
121
- ).to(device)
122
 
123
- criterion = nn.MSELoss()
 
124
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
125
- max_epochs = 20
126
  batch_size = 32
 
127
 
128
  print("🚀 Training model...")
129
- n = short_embeddings.shape[0]
130
  for epoch in range(max_epochs):
131
  model.train()
132
  epoch_loss = 0.0
@@ -138,7 +168,7 @@ def train_flashpack_model(
138
 
139
  optimizer.zero_grad()
140
  outputs = model(inputs)
141
- loss = criterion(outputs, targets)
142
  loss.backward()
143
  optimizer.step()
144
  epoch_loss += loss.item() * inputs.size(0)
@@ -148,6 +178,14 @@ def train_flashpack_model(
148
  print(f"Epoch {epoch+1}/{max_epochs}, Loss={epoch_loss:.6f}")
149
 
150
  print("✅ Training finished!")
 
 
 
 
 
 
 
 
151
  return model, dataset, embed_model, tokenizer, long_embeddings
152
 
153
  # ============================================================
 
17
  device = torch.device("cpu")
18
  torch.set_num_threads(4)
19
  print(f"🔧 Using device: {device} (CPU-only mode)")
20
+ # prompt_enhancer_flashpack_cpu_publish_v2.py
21
+ import gc
22
+ import os
23
+ import tempfile
24
+ from typing import Tuple
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.optim as optim
29
+ from datasets import load_dataset
30
+ from transformers import AutoTokenizer, AutoModel
31
+ from flashpack import FlashPackMixin
32
+ from huggingface_hub import Repository
33
+
34
+ device = torch.device("cpu")
35
+ torch.set_num_threads(4)
36
+ print(f"🔧 Using device: {device} (CPU-only mode)")
37
+
38
 
39
  # ============================================================
40
+ # 1️⃣ Define improved FlashPack model
41
  # ============================================================
42
  class GemmaTrainer(nn.Module, FlashPackMixin):
43
+ def __init__(self, input_dim: int, hidden_dim: int = 1024, output_dim: int = 768):
44
  super().__init__()
45
  self.fc1 = nn.Linear(input_dim, hidden_dim)
46
  self.relu = nn.ReLU()
47
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim)
48
+ self.fc3 = nn.Linear(hidden_dim, output_dim)
49
 
50
  def forward(self, x: torch.Tensor) -> torch.Tensor:
51
  x = self.fc1(x)
52
  x = self.relu(x)
53
  x = self.fc2(x)
54
+ x = self.relu(x)
55
+ x = self.fc3(x)
56
  return x
57
 
58
+
59
  # ============================================================
60
+ # 2️⃣ Encoder with mean+max pooling
61
  # ============================================================
62
+ def build_encoder(model_name="gpt2", max_length: int = 128):
63
  tokenizer = AutoTokenizer.from_pretrained(model_name)
64
  if tokenizer.pad_token is None:
65
  tokenizer.pad_token = tokenizer.eos_token
 
76
  padding="max_length",
77
  max_length=max_length
78
  ).to(device)
79
+ last_hidden = embed_model(**inputs).last_hidden_state
80
+ mean_pool = last_hidden.mean(dim=1)
81
+ max_pool, _ = last_hidden.max(dim=1)
82
+ return torch.cat([mean_pool, max_pool], dim=1).cpu()
83
 
84
  return tokenizer, embed_model, encode
85
 
86
+
87
  # ============================================================
88
  # 3️⃣ Push FlashPack model to HF
89
  # ============================================================
 
110
 
111
  return logs
112
 
113
+
114
  # ============================================================
115
  # 4️⃣ Train FlashPack model
116
  # ============================================================
117
  def train_flashpack_model(
118
+ dataset_name: str = "rahul7star/prompt-enhancer-dataset",
119
  max_encode: int = 5000,
120
+ hidden_dim: int = 1024,
121
+ push_to_hub: bool = True,
122
+ hf_repo: str = "rahul7star/FlashPack"
123
  ) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]:
124
+
125
  print("📦 Loading dataset...")
126
  dataset = load_dataset(dataset_name, split="train")
127
  limit = min(max_encode, len(dataset))
128
  dataset = dataset.select(range(limit))
129
+ print(f"⚡ Using {len(dataset)} prompts for training (max {max_encode})")
130
 
131
+ # Build encoder
132
+ tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128)
133
 
134
+ # Encode prompts
135
  short_list, long_list = [], []
136
  for i, item in enumerate(dataset):
137
  short_list.append(encode_fn(item["short_prompt"]))
 
143
  short_embeddings = torch.vstack(short_list)
144
  long_embeddings = torch.vstack(long_list)
145
  print(f"✅ Finished encoding {short_embeddings.shape[0]} prompts")
146
+ input_dim = short_embeddings.shape[1]
147
+ output_dim = long_embeddings.shape[1]
148
 
149
  # Build model
150
+ model = GemmaTrainer(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device)
 
 
 
 
151
 
152
+ # Loss & optimizer
153
+ criterion = nn.CosineSimilarity(dim=1)
154
  optimizer = optim.Adam(model.parameters(), lr=1e-3)
155
+ max_epochs = 30
156
  batch_size = 32
157
+ n = short_embeddings.shape[0]
158
 
159
  print("🚀 Training model...")
 
160
  for epoch in range(max_epochs):
161
  model.train()
162
  epoch_loss = 0.0
 
168
 
169
  optimizer.zero_grad()
170
  outputs = model(inputs)
171
+ loss = 1 - criterion(outputs, targets).mean() # Cosine similarity loss
172
  loss.backward()
173
  optimizer.step()
174
  epoch_loss += loss.item() * inputs.size(0)
 
178
  print(f"Epoch {epoch+1}/{max_epochs}, Loss={epoch_loss:.6f}")
179
 
180
  print("✅ Training finished!")
181
+
182
+ # Push to HF
183
+ logs = []
184
+ if push_to_hub:
185
+ logs = push_flashpack_model_to_hf(model, hf_repo)
186
+ for log in logs:
187
+ print(log)
188
+
189
  return model, dataset, embed_model, tokenizer, long_embeddings
190
 
191
  # ============================================================