DisLLM Split GPT-2 Model
This repository contains a split GPT-2 model trained using federated learning with LoRA fine-tuning.
Model Architecture
- Base Model: GPT-2 Small (124M parameters)
- Fine-tuning Method: LoRA (Low-Rank Adaptation)
- Trainable Parameters: ~1.7M
- Split Configuration:
- First Part (Client): 4 transformer blocks
- Second Part (Server): 8 transformer blocks
Files
central_trained_first_part_20251210_162256.pth: Client-side model (first 4 layers)central_trained_second_part_20251210_162256.pth: Server-side model (remaining 8 layers)
Training Details
- Dataset: WikiText-2
- Training Method: Federated Learning with Split Learning
- Context Length: 1024 tokens
- Batch Size: 2
- Learning Rate: 1e-6
Usage
import torch
from transformers import GPT2Config
# Load the model parts
first_part = torch.load('central_trained_first_part_20251210_162256.pth')
second_part = torch.load('central_trained_second_part_20251210_162256.pth')
# Use with the DisLLM architecture
# (Requires FirstPartModel and SecondPartModel class definitions)
Performance
Training improves perplexity from ~45 to ~30-35 across train/val/test sets.
Citation
If you use this model, please cite the original DisLLM work and GPT-2 paper.
License
This model inherits the license from the GPT-2 model and training code.
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support