Abstract
This repository provides a domain-adapted Turkish legal instruction-tuned model derived from meta-llama/Llama-3.1-8B-Instruct. This model corresponds to the BF16 baseline configuration trained on 8 nodes with a global batch size of 32 as part of the “Harnessing Fully Sharded Data Parallelism v2 with Float8 Precision for Faster Training” study. The model was fine-tuned on the newmindai/EuroHPC-Legal corpus (Q/A format) to enhance reasoning across multiple Turkish legal subdomains. It delivers stable convergence, high accuracy, and consistent loss behavior, making it an important baseline for evaluating mixed-precision strategies under large-scale distributed training.
Experiment Context
This model was trained as part of our study for comparing FSDP2 with bfloat16 precision against FSDP2 with FP8 mixed precision bfp16-fp8.
We used meta-llama/Llama-3.1-8B-Instruct. The model has been loaded using torch_dtype = bfloat16 and wrapped at once, also during forward/backward passes bfloat16 has been used for computations.
from torch.distributed._composable.fsdp import fully_shard
mesh_device_type = "cuda" if use_cuda else "cpu"
mesh = DeviceMesh(mesh_device_type, list(range(world_size)))
fsdp_kwargs = {
"mesh": mesh,
"reshard_after_forward": True,
}
model = fully_shard(model, **fsdp_kwargs)
Base Model Technical Specifications
- Parameters: 8 Billion
- Architecture Family: Llama 3.1
- Maximum Position Embeddings: 131,072
- Attention Heads: 32 (
num_attention_heads) - Key-Value Heads: 8 (
num_key_value_heads) - Hidden Layers: 32 (
num_hidden_layers) - Hidden Size: 4,096 (
hidden_size) - Intermediate Size: 14,336
- Vocabulary Size: 128,256
- Precision: bfloat16
- RoPE Scaling: type
llama3, factor = 8.0 - RMS Norm Epsilon: 1e-05
- Activation: SiLU
Training Methodology
Training Configuration
- Model:
meta-llama/Llama-3.1-8B-Instruct - Sequence Length: 4,096 (
seq_len) - Epochs: 2
- Max Steps: 1,200
- Per-Device Micro Batch Size: 4
- Gradient Accumulation: 8
- GPUs: 4 (via
CUDA_VISIBLE_DEVICES=0,1,2,3) - dtype:
bf16&&fp8=false- Weights: bfloat16
- Activations: bfloat16
- Optimizer: AdamW
- Learning Rate: 2e-5
- Weight Decay: 0.01
- Betas: (0.9, 0.95)
- Epsilon: 1e-8
- LR Scheduler: Cosine; warmup = 10% (
warmup_ratio=0.1) | alsowarmup_steps=100 - Max Grad Norm: 1.0
- Gradient Checkpointing: Enabled
- Checkpointing: every 10 steps; keep last 5; select best by
eval_loss - Logging: every step to file; Weights & Biases in offline mode
- Seed: 100
- Distributed Training:
torch.distributed.run(multi-nodes, multi-GPU)- FSDP2 (Optimized Fully Sharded Data Parallel)
Setups
- Precision: Used Half-precision bfloat16 as data type and for computation.
- Hardware: HPC (EuroHPC/BSC-class) 8 nodes with 4 × NVIDIA H100 GPUs.
- Framework: PyTorch with
torchrunfor distributed training.
Dependencies
| package | Version |
|---|---|
| Transformers | 4.57.1 |
| torch | 2.9.0+cu128 |
| accelerate | 0.14.1 |
| datasets | 4.3.0 |
| huggingface-hub | 0.36.0 |
| tensorboard | 2.20.0 |
| tensorboard-data-server | 0.7.2 |
| wandb | 0.22.1 |
Job Details
| model | Job ID | Runtime (mins) | Nodes | GPUs | Node-hour | GPU-hour | micro-batch | batch-size | gradient_accumulation | total_batch_size |
|---|---|---|---|---|---|---|---|---|---|---|
| Llama-3.1-8B-Instruct-w16a16-1node | 31472940 | 51.50 | 1 | 4 | 0.858 | 3.433 | 2 | 2 | 4 | 32 |
| Llama-3.1-8B-Instruct-w16a8-1node | 31473092 | 47.25 | 1 | 4 | 0.788 | 3.151 | 2 | 2 | 4 | 32 |
| Llama-3.1-8B-Instruct-w16a16-4nodes | 31478433 | 31.75 | 4 | 4 | 2.117 | 8.467 | 4 | 4 | 8 | 512 |
| Llama-3.1-8B-Instruct-w16a8-4nodes | 31478468 | 39.75 | 4 | 4 | 2.650 | 10.600 | 4 | 4 | 8 | 512 |
| Llama-3.1-8B-Instruct-w16a16-8nodes | 31476914 | 22.00 | 8 | 4 | 2.933 | 11.733 | 4 | 4 | 8 | 1024 |
| Llama-3.1-8B-Instruct-w16a8-8nodes | 31476844 | 23.50 | 8 | 4 | 3.133 | 12.533 | 4 | 4 | 8 | 1024 |
Computational Infrastructure
- Platform: HPC
- GPUs: NVIDIA H100 (32)
All 6-models trained on(1Node,4Noes,8Nodes with both bfp16-fp8 && bfp16 configurations)
| Model | Batch Size | Max Loss (train) | Min Loss (train) | Avg Loss (train) | ± Std (train) | Final Loss (train) | Max Loss (val) | Min Loss (val) | Avg Loss (val) | ± Std (val) | Final Loss (val) | Total Step | Best Step |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Llama-3.1-8B-Instruct-w16a16-1node | 8 | 3.1235 | 0.7203 | 0.9750 | 0.3344 | 0.7612 | 1.9113 | 0.8907 | 0.9831 | 0.1897 | 0.8907 | 312 | — |
| Llama-3.1-8B-Instruct-w16a8-1node | 8 | 3.1661 | 0.7261 | 0.9804 | 0.3374 | 0.7672 | 1.9230 | 0.8948 | 0.9867 | 0.1906 | 0.8951 | 312 | — |
| Llama-3.1-8B-Instruct-w16a16-4nodes | 32 | 3.2452 | 0.7414 | 0.9665 | 0.4844 | 0.7504 | 1.0538 | 0.8382 | 0.8844 | 0.0725 | 0.8382 | 70 | — |
| Llama-3.1-8B-Instruct-w16a8-4nodes | 32 | 3.2840 | 0.7478 | 0.9748 | 0.4905 | 0.7581 | 1.0701 | 0.8430 | 0.8922 | 0.0764 | 0.8430 | 70 | — |
| Llama-3.1-8B-Instruct-w16a16-8nodes | 32 | 3.2311 | 0.8448 | 1.1856 | 0.6434 | 0.8448 | 1.0257 | 0.8977 | 0.9460 | 0.0568 | 0.8977 | 35 | — |
| Llama-3.1-8B-Instruct-w16a8-8nodes | 32 | 3.3003 | 0.8473 | 1.1866 | 0.6481 | 0.8473 | 1.0203 | 0.8992 | 0.9445 | 0.0539 | 0.8992 | 35 | — |
Implementation
Usage
Note: the final model has been saved in bfloat16 format. For inference, load the model in bfloat16 or float16 as shown below:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "newmindai/Llama-3.1-8B-Instruct-w16a16-8nodes"
dtype = torch.bfloat16
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="auto"
)
prompt = "Soru: Kişisel Verilerin Korunması Kanunu uyarınca hangi durumlarda açık rıza aranmaz? Cevap:"
inputs = tok(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=256,
do_sample=False
)
print(tok.decode(out[0], skip_special_tokens=True))
Ethical Considerations and Disclaimers
- Research & development purposes only; not a substitute for professional legal counsel.
- Users must ensure compliance with data protection and sector regulations.
- Potential biases may exist in domain data and model outputs.
Model & Data Card Metadata
- Total Parameters: 8,030,261,248
- Serialized Size (approx.): 16,060,522,496 bytes
- Config precision: bfloat16
- RoPE: llama3 scaling, factor 8.0
References and Citations
Base Model
@misc{meta_llama31_8b_instruct,
title={Llama 3.1 8B Instruct},
author={Meta AI},
year={2024},
howpublished={\url{https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct}}
}
Training Dataset
@misc{euro_hpc_legal,
title={EuroHPC-Legal},
author={newmindai},
year={2025},
howpublished={\url{https://huggingface.co/datasets/newmindai/EuroHPC-Legal}}
}
- Downloads last month
- 6
Model tree for newmindai/Llama-3.1-8B-Instruct-w16a16-8nodes
Base model
meta-llama/Llama-3.1-8B



