import gc import os import torch import torch.nn as nn import torch.optim as optim import tempfile import gradio as gr from datasets import load_dataset from transformers import AutoTokenizer, AutoModel from flashpack import FlashPackMixin from huggingface_hub import Repository from typing import Tuple # ============================================================ # ๐Ÿ–ฅ Device setup (CPU-only safe) # ============================================================ device = torch.device("cpu") torch.set_num_threads(4) print(f"๐Ÿ”ง Using device: {device} (CPU-only mode)") # prompt_enhancer_flashpack_cpu_publish_v2.py import gc import os import tempfile from typing import Tuple import torch import torch.nn as nn import torch.optim as optim from datasets import load_dataset from transformers import AutoTokenizer, AutoModel from flashpack import FlashPackMixin from huggingface_hub import Repository device = torch.device("cpu") torch.set_num_threads(4) print(f"๐Ÿ”ง Using device: {device} (CPU-only mode)") # ============================================================ # 1๏ธโƒฃ Define improved FlashPack model # ============================================================ class GemmaTrainer(nn.Module, FlashPackMixin): def __init__(self, input_dim: int, hidden_dim: int = 1024, output_dim: int = 768): super().__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, output_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc1(x) x = self.relu(x) x = self.fc2(x) x = self.relu(x) x = self.fc3(x) return x # ============================================================ # 2๏ธโƒฃ Encoder with mean+max pooling # ============================================================ def build_encoder(model_name="gpt2", max_length: int = 128): tokenizer = AutoTokenizer.from_pretrained(model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token embed_model = AutoModel.from_pretrained(model_name).to(device) embed_model.eval() @torch.no_grad() def encode(prompt: str) -> torch.Tensor: inputs = tokenizer( prompt, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length ).to(device) last_hidden = embed_model(**inputs).last_hidden_state mean_pool = last_hidden.mean(dim=1) max_pool, _ = last_hidden.max(dim=1) return torch.cat([mean_pool, max_pool], dim=1).cpu() return tokenizer, embed_model, encode # ============================================================ # 3๏ธโƒฃ Push FlashPack model to HF # ============================================================ def push_flashpack_model_to_hf(model, hf_repo: str): logs = [] with tempfile.TemporaryDirectory() as tmp_dir: logs.append(f"๐Ÿ“‚ Using temporary directory: {tmp_dir}") repo = Repository(local_dir=tmp_dir, clone_from=hf_repo, use_auth_token=True) logs.append(f"๐ŸŒ Hugging Face repo cloned to: {tmp_dir}") pack_path = os.path.join(tmp_dir, "model.flashpack") logs.append(f"๐Ÿ’พ Saving model to: {pack_path}") model.save_flashpack(pack_path, target_dtype=torch.float32) logs.append("โœ… Model saved successfully.") readme_path = os.path.join(tmp_dir, "README.md") with open(readme_path, "w") as f: f.write("# FlashPack Model\nThis repo contains a FlashPack model.") logs.append("๐Ÿ“„ README.md added.") logs.append("๐Ÿš€ Pushing repo to Hugging Face Hub...") repo.push_to_hub() logs.append(f"โœ… Model successfully pushed to: {hf_repo}") return logs # ============================================================ # 4๏ธโƒฃ Train FlashPack model # ============================================================ def train_flashpack_model( dataset_name: str = "rahul7star/prompt-enhancer-dataset", max_encode: int = 1000, hidden_dim: int = 1024, push_to_hub: bool = True, hf_repo: str = "rahul7star/FlashPack" ) -> Tuple[GemmaTrainer, object, object, object, torch.Tensor]: print("๐Ÿ“ฆ Loading dataset...") dataset = load_dataset(dataset_name, split="train") limit = min(max_encode, len(dataset)) dataset = dataset.select(range(limit)) print(f"โšก Using {len(dataset)} prompts for training (max {max_encode})") # Build encoder tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=128) # Encode prompts short_list, long_list = [], [] for i, item in enumerate(dataset): short_list.append(encode_fn(item["short_prompt"])) long_list.append(encode_fn(item["long_prompt"])) if (i+1) % 50 == 0 or (i+1) == len(dataset): print(f" โ†’ Encoded {i+1}/{limit} prompts") gc.collect() short_embeddings = torch.vstack(short_list) long_embeddings = torch.vstack(long_list) print(f"โœ… Finished encoding {short_embeddings.shape[0]} prompts") input_dim = short_embeddings.shape[1] output_dim = long_embeddings.shape[1] # Build model model = GemmaTrainer(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim).to(device) # Loss & optimizer criterion = nn.CosineSimilarity(dim=1) optimizer = optim.Adam(model.parameters(), lr=1e-3) max_epochs = 50 batch_size = 32 n = short_embeddings.shape[0] print("๐Ÿš€ Training model...") for epoch in range(max_epochs): model.train() epoch_loss = 0.0 perm = torch.randperm(n) for start in range(0, n, batch_size): idx = perm[start:start+batch_size] inputs = short_embeddings[idx].to(device) targets = long_embeddings[idx].to(device) optimizer.zero_grad() outputs = model(inputs) loss = 1 - criterion(outputs, targets).mean() # Cosine similarity loss loss.backward() optimizer.step() epoch_loss += loss.item() * inputs.size(0) epoch_loss /= n if epoch % 5 == 0 or epoch == max_epochs-1: print(f"Epoch {epoch+1}/{max_epochs}, Loss={epoch_loss:.6f}") print("โœ… Training finished!") # Push to HF logs = [] if push_to_hub: logs = push_flashpack_model_to_hf(model, hf_repo) for log in logs: print(log) return model, dataset, embed_model, tokenizer, long_embeddings # ============================================================ # 5๏ธโƒฃ Load FlashPack model (train if missing) # ============================================================ def get_flashpack_model(hf_repo="rahul7star/FlashPack"): try: print(f"๐Ÿ” Attempting to load FlashPack model from {hf_repo}") model = GemmaTrainer.from_flashpack(hf_repo) model.eval() print("โœ… Loaded model successfully from HF") tokenizer, embed_model, encode_fn = build_encoder("gpt2", max_length=32) return model, tokenizer, embed_model except Exception as e: print(f"โš ๏ธ Load failed: {e}") print("โฌ Training a new FlashPack model locally...") model, dataset, embed_model, tokenizer, long_embeddings = train_flashpack_model() print("๐Ÿ“ค Pushing trained model to HF...") push_flashpack_model_to_hf(model, hf_repo) return model, tokenizer, embed_model, dataset, long_embeddings # ============================================================ # 6๏ธโƒฃ Load or train # ============================================================ try: model, tokenizer, embed_model, dataset, long_embeddings = get_flashpack_model() except Exception as e: raise SystemExit(f"โŒ Failed to load or train FlashPack model: {e}") # ============================================================ # 7๏ธโƒฃ Inference helpers # ============================================================ @torch.no_grad() def encode_for_inference(prompt: str) -> torch.Tensor: inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding="max_length", max_length=32).to(device) return embed_model(**inputs).last_hidden_state.mean(dim=1).cpu() def enhance_prompt(user_prompt: str, temperature: float, max_tokens: int, chat_history): chat_history = chat_history or [] short_emb = encode_for_inference(user_prompt) mapped = model(short_emb.to(device)).cpu() sims = (long_embeddings @ mapped.t()).squeeze(1) long_norms = long_embeddings.norm(dim=1) mapped_norm = mapped.norm() sims = sims / (long_norms * (mapped_norm + 1e-12)) best_idx = int(sims.argmax().item()) enhanced_prompt = dataset[best_idx]["long_prompt"] chat_history.append({"role": "user", "content": user_prompt}) chat_history.append({"role": "assistant", "content": enhanced_prompt}) return chat_history # ============================================================ # 8๏ธโƒฃ Gradio UI # ============================================================ with gr.Blocks(title="Prompt Enhancer โ€“ FlashPack (CPU)", theme=gr.themes.Soft()) as demo: gr.Markdown( """ # โœจ Prompt Enhancer (FlashPack mapper) Enter a short prompt, and the model will **expand it with details and creative context**. (CPU-only mode.) """ ) with gr.Row(): chatbot = gr.Chatbot(height=400, label="Enhanced Prompts", type="messages") with gr.Column(scale=1): user_prompt = gr.Textbox(placeholder="Enter a short prompt...", label="Your Prompt", lines=3) temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Temperature") max_tokens = gr.Slider(32, 256, value=128, step=16, label="Max Tokens") send_btn = gr.Button("๐Ÿš€ Enhance Prompt", variant="primary") clear_btn = gr.Button("๐Ÿงน Clear Chat") send_btn.click(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) user_prompt.submit(enhance_prompt, [user_prompt, temperature, max_tokens, chatbot], chatbot) clear_btn.click(lambda: [], None, chatbot) # ============================================================ # 9๏ธโƒฃ Launch # ============================================================ # ๐Ÿ Launch app # ============================================================ if __name__ == "__main__": demo.launch(show_error=True)