Token Importance Classifier
This model is a single-layer attention-based classifier trained to predict token importance in sequences.
Model Details
- Architecture: Single-layer self-attention network with RoPE positional embeddings
- Base Model: meta-llama/Llama-3.1-8B
- Hidden Dimension: 4096
- Number of Heads: 32
- Max Sequence Length: 131072
Training Configuration
data:
max_seq_len: 131072
path: /root/workspace/data_generation/data/sample_output.jsonl
tokenizer_path: meta-llama/Llama-3.1-8B
valid_split: 0.1
final_metrics:
accuracy: 0.8365938756296772
f1: 0.9094284550391643
precision: 0.8365938756296772
recall: 1.0
huggingface:
private: false
push_to_hub: true
repo_id: Slicky325/token-selector-model
model:
base_model_dir: meta-llama/Llama-3.1-8B
dropout: 0.1
hidden_dim: 4096
max_seq_len: 131072
num_heads: 32
rope_theta: 500000
save_embeddings: false
save_path: models/selector.pt
train_embeddings: false
use_positional: true
system:
device: cuda
num_workers: 2
training:
batch_size: 4
epochs: 1
grad_clip: 1.0
learning_rate: 0.001
seed: 42
weight_decay: 0.0
Validation Metrics
- Accuracy: 0.8365938756296772
- Precision: 0.8365938756296772
- Recall: 1.0
- F1 Score: 0.9094284550391643
Usage
import torch
from pathlib import Path
# Load the checkpoint
checkpoint = torch.load('selector.pt')
model_state = checkpoint['model_state_dict']
config = checkpoint['config']
# Initialize your model architecture and load the weights
# model.load_state_dict(model_state)
Citation
If you use this model in your research, please cite appropriately.
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support