Medical Reasoning Model β Qwen2.5-0.5B (LoRA/QLoRA Fine-tuned)
A compact medical-reasoning model fine-tuned from Qwen/Qwen2.5-0.5B-Instruct for step-by-step clinical reasoning. It is optimized to break down cases, propose differentials, and justify answers.
Note: Outputs may include intermediate reasoning. Do not treat responses as medical advice.
π§ Model Summary
- Base: Qwen2.5-0.5B-Instruct
- Architecture: Decoder-only transformer
- Finetuning: LoRA/QLoRA on medical reasoning SFT data
- Primary language: English
- Intended tasks: clinical vignette reasoning, justification, treatment option comparison, dosage math (simple)
- Not for: real-world diagnosis or treatment; no protected-health-information handling
β Intended Use
- Educational exploration of clinical reasoning
- Generating explanations for practice questions
- Drafting step-by-step rationales for study aids
Out-of-scope / High-risk
Do not use for clinical decisions, triage, or patient-specific recommendations.
π Data
- Training set:
FreedomIntelligence/medical-o1-reasoning-SFT(instruction-style, reasoning-focused) - Preprocessing:
- Deduplication & light cleaning
- Prompt formatting into
{system, instruction, input}style - Truncation/padding to max sequence length (π§
L=1024by default)
π οΈ Training Procedure
Strategy
- Method: LoRA / QLoRA (4-bit NF4) to fit small GPUs while retaining quality
- Reasoning: adapters reduce trainable params β faster, cheaper, less overfitting risk
Environment (example)
- GPU: π§ 1Γ NVIDIA RTX 4060 (8 GB) / L4 / A10G
- Frameworks:
transformers,peft,bitsandbytes,accelerate - Precision:
bf16(fallbackfp16if needed) - Attention: Flash-Attention 2 when available
Key Hyperparameters (fill if you used different)
- LoRA config:
r=8,alpha=16,lora_dropout=0.1,target_modules=["q_proj","v_proj","o_proj"] - Quantization (QLoRA):
load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_use_double_quant=True - Optimizer:
AdamW8bit(bitsandbytes) oradamw_torch_fused - LR & schedule:
lr=2e-4, cosine decay,warmup_ratio=0.06 - Regularization: weight decay
0.01, label smoothing0.05, gradient clipping1.0 - Batching:
per_device_train_batch_size=4,gradient_accumulation_steps=8β effective batch 32 - Sequence length: π§
max_seq_length=1024 - Epochs: π§
2β3(use early stopping on val loss) - Gradient checkpointing: enabled (reduce activation memory)
- Eval/Save: every π§
500steps, keep best on val loss - Checkpoint offload: local recent N + sync to S3 (callback)
Reproducibility
seed=42- Deterministic ops where practical (note: may reduce throughput)
β±οΈ Training Time & Throughput (example on RTX 4060)
- Throughput (train): ~15k tokens/s (depends on L, micro-batch, kernels)
- Steps/epoch:
ceil(num_samples / effective_batch)
e.g., π§100k / 32 β 3125 - Epoch time (L=1024): ~1.9β2.3 h (incl. overhead)
Your numbers will vary; measure 200 steps after warmup for accuracy.
π§ͺ Evaluation
- Metrics: perplexity (val), exact match / ROUGE-L on held-out prompts
- Human spot-checks: correctness of differentials, contraindications, and justification clarity
- Known failure modes:
- Confident but incorrect rationale (hallucinations)
- Outdated guidelines
- Arithmetic slips on edge cases
Add your concrete scores here once computed.
- Val perplexity: π§
β¦ - Exact match / ROUGE-L: π§
β¦
π Safety & Ethics
- The model may produce incorrect or harmful medical content.
- No PHI was used; training data are public/anon sources.
- Add a disclaimer in any downstream app; keep a human-in-the-loop.
π How to Use
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
repo = "YOUR_ORG/medical-qwen25-0_5b-lora" # replace with your repo
tok = AutoTokenizer.from_pretrained(repo, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
repo,
torch_dtype=torch.bfloat16,
device_map="auto"
)
prompt = """You are a clinician. Reason step by step.
Patient: 56F with chest pain radiating to jaw, diaphoresis...
Question: Most likely diagnosis and initial management?"""
inputs = tok(prompt, return_tensors="pt").to(model.device)
out = model.generate(**inputs, max_new_tokens=256, temperature=0.2, top_p=0.9)
print(tok.decode(out[0], skip_special_tokens=True))
#Training Script snippet with lora Fine Tunning
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
import torch
base = "Qwen/Qwen2.5-0.5B-Instruct"
tok = AutoTokenizer.from_pretrained(base)
model = AutoModelForCausalLM.from_pretrained(
base,
load_in_4bit=True,
torch_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
device_map="auto",
low_cpu_mem_usage=True,
)
lora = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1,
target_modules=["q_proj","v_proj","o_proj"])
model = get_peft_model(model, lora)
model.gradient_checkpointing_enable()
model.config.use_cache = False # training only
args = TrainingArguments(
output_dir="ckpts",
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.06,
weight_decay=0.01,
num_train_epochs=3,
bf16=True,
save_steps=500,
evaluation_strategy="steps",
eval_steps=500,
logging_steps=50,
save_safetensors=True,
load_best_model_at_end=True,
metric_for_best_model="loss",
gradient_checkpointing=True,
)
# ... Trainer(train_ds, val_ds, callbacks for S3 sync)
- Downloads last month
- 21