""" Fine-tune a quantized Qwen2.5:7b model using SFT + LoRA on expanded preventative health prompts. """ # Imports import json import time from datasets import load_dataset, Dataset from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForSeq2Seq from peft import LoraConfig, get_peft_model, TaskType import torch # ==== # General configuration SEED_DATA_FILE = "expanded_templates.json" # Created locally MODEL_NAME = "Qwen/Qwen2.5-7B" # From Hugging Face Model Hub OUTPUT_DIR = "./qwen_lora_adapter" BATCH_SIZE = 4 EPOCHS = 3 LEARNING_RATE = 2e-4 MAX_LENGTH = 512 LORA_RANK = 16 LORA_ALPHA = 32 LORA_DROPOUT = 0.05 print("🔄 Starting fine-tuning pipeline for Qwen2.5 with LoRA...") start_time = time.time() # =========================== # LOAD DATASET # =========================== print("📂 Loading dataset from:", SEED_DATA_FILE) with open(SEED_DATA_FILE, "r", encoding="utf-8") as f: data = json.load(f) print(f"✅ Loaded {len(data)} samples.") # ==== # Convert to Hugging Face Dataset dataset = Dataset.from_list([{ "prompt": entry["prompt"], "response": entry["response"], "topic": entry.get("topic", "general") } for entry in data]) print("✅ Converted to Hugging Face Dataset.") # ==== # Concatenate prompt + response for causal LM def tokenize_function(examples, tokenizer): # Takes the dictionary and formats it into a string texts = [ f"### Topic: {t}\n### Instruction:\n{p}\n\n### Response:\n{r}" for p, r, t in zip(examples["prompt"], examples["response"], examples["topic"]) ] # Tokenize the concatenated texts tokenized = tokenizer( texts, max_length=MAX_LENGTH, # Max length of tokens padding="max_length", truncation=True, return_tensors="pt" # Returns input_ids, attention_maks, and labels as matrices ) tokenized["labels"] = tokenized["input_ids"].clone() return tokenized # =========================== # LOAD TOKENIZER AND MODEL # =========================== print("🧠 Loading tokenizer and 8-bit quantized model:", MODEL_NAME) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, device_map="auto", torch_dtype=torch.bfloat16 ) print("✅ Model and tokenizer loaded successfully.") # =========================== # CONFIGURE LoRA # =========================== print("⚙️ Configuring LoRA adapters...") lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, r=LORA_RANK, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT, target_modules=["q_proj", "v_proj"] # typical for Qwen2.5 ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() print("✅ LoRA configuration complete.") # =========================== # TOKENIZE DATASET # =========================== print("🧩 Tokenizing dataset... (this might take a while)") tokenized_dataset = dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True) data_collator = DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt") print("✅ Dataset tokenized and ready for training.") # =========================== # TRAINING ARGUMENTS # =========================== print("📘 Setting up training arguments...") training_args = TrainingArguments( output_dir=OUTPUT_DIR, per_device_train_batch_size=BATCH_SIZE, gradient_accumulation_steps=2, learning_rate=LEARNING_RATE, num_train_epochs=EPOCHS, logging_steps=10, save_strategy="epoch", fp16=True, save_total_limit=3, report_to="none", # disable wandb if not set up ) print("✅ Training arguments configured.") # =========================== # TRAINING # =========================== print("🚀 Starting training...") trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, tokenizer=tokenizer, data_collator=data_collator ) train_start = time.time() trainer.train() print(f"✅ Training completed in {(time.time() - train_start)/60:.2f} minutes.") # =========================== # SAVE LoRA ADAPTER ONLY # =========================== print("💾 Saving LoRA adapter...") model.save_pretrained(OUTPUT_DIR) print(f"✅ LoRA adapter saved at: {OUTPUT_DIR}") print(f"🏁 All done! Total pipeline time: {(time.time() - start_time)/60:.2f} minutes.")