Token Importance Classifier

This model is a single-layer attention-based classifier trained to predict token importance in sequences.

Model Details

  • Architecture: Single-layer self-attention network with RoPE positional embeddings
  • Base Model: meta-llama/Llama-3.1-8B
  • Hidden Dimension: 4096
  • Number of Heads: 32
  • Max Sequence Length: 131072

Training Configuration

data:
  max_seq_len: 131072
  path: /root/workspace/data_generation/data/sample_output.jsonl
  tokenizer_path: meta-llama/Llama-3.1-8B
  valid_split: 0.1
final_metrics:
  accuracy: 0.8365938756296772
  f1: 0.9094284550391643
  precision: 0.8365938756296772
  recall: 1.0
huggingface:
  private: false
  push_to_hub: true
  repo_id: Slicky325/token-selector-model
model:
  base_model_dir: meta-llama/Llama-3.1-8B
  dropout: 0.1
  hidden_dim: 4096
  max_seq_len: 131072
  num_heads: 32
  rope_theta: 500000
  save_embeddings: false
  save_path: models/selector.pt
  train_embeddings: false
  use_positional: true
system:
  device: cuda
  num_workers: 2
training:
  batch_size: 4
  epochs: 1
  grad_clip: 1.0
  learning_rate: 0.001
  seed: 42
  weight_decay: 0.0

Validation Metrics

  • Accuracy: 0.8365938756296772
  • Precision: 0.8365938756296772
  • Recall: 1.0
  • F1 Score: 0.9094284550391643

Usage

import torch
from pathlib import Path

# Load the checkpoint
checkpoint = torch.load('selector.pt')
model_state = checkpoint['model_state_dict']
config = checkpoint['config']

# Initialize your model architecture and load the weights
# model.load_state_dict(model_state)

Citation

If you use this model in your research, please cite appropriately.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support