DisLLM Split GPT-2 Model

This repository contains a split GPT-2 model trained using federated learning with LoRA fine-tuning.

Model Architecture

  • Base Model: GPT-2 Small (124M parameters)
  • Fine-tuning Method: LoRA (Low-Rank Adaptation)
  • Trainable Parameters: ~1.7M
  • Split Configuration:
    • First Part (Client): 4 transformer blocks
    • Second Part (Server): 8 transformer blocks

Files

  • central_trained_first_part_20251210_162256.pth: Client-side model (first 4 layers)
  • central_trained_second_part_20251210_162256.pth: Server-side model (remaining 8 layers)

Training Details

  • Dataset: WikiText-2
  • Training Method: Federated Learning with Split Learning
  • Context Length: 1024 tokens
  • Batch Size: 2
  • Learning Rate: 1e-6

Usage

import torch
from transformers import GPT2Config

# Load the model parts
first_part = torch.load('central_trained_first_part_20251210_162256.pth')
second_part = torch.load('central_trained_second_part_20251210_162256.pth')

# Use with the DisLLM architecture
# (Requires FirstPartModel and SecondPartModel class definitions)

Performance

Training improves perplexity from ~45 to ~30-35 across train/val/test sets.

Citation

If you use this model, please cite the original DisLLM work and GPT-2 paper.

License

This model inherits the license from the GPT-2 model and training code.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support