FlexiDepth-Llama-3-8B-Instruct

This model is the official implementation of the paper Adaptive Layer-skipping in Pre-trained LLMs. FlexiDepth-Llama-3-8B-Instruct is built on meta-llama/Meta-Llama-3-8B-Instruct and is trained to dynamically skip layers during inference, enabling significant speedups.

πŸ“° News: We have updated our training method for improved results! For details about the updated training method and datasets, please refer to our GitHub repository: luoxuan-cs/Flexidepth.

πŸ“š Resources

Notice that the current implementation uses transformers==4.57.0

Model Description

FlexiDepth-Llama-3-8B-Instruct is an enhanced version of the Llama-3-8B-Instruct model, incorporating the Flexidepth method to enable adaptive layer-skipping during text generation. This approach reveals unique layer allocation patterns, showing how computational demands vary across different tokens. The token depth map visualization (see below) demonstrates that summarization tasks typically require more layers than extractive question answering, while in mathematical reasoning tasks like addition, tokens on the left-hand side of equations use fewer layers than those on the right. For further insights, refer to the dataset at xuan-luo/FlexiPatterns-Llama-3-8B-Instruct.

FlexiDepth banner
  • Developed by: Xuan Luo, Weizhi Wang, Xifeng Yan
  • Model type: Causal Language Model with adaptive layer-skipping
  • Language(s) (NLP): English (en)
  • License: Apache-2.0
  • Finetuned from model: meta-llama/Meta-Llama-3-8B-Instruct

Get the number of layers used when generating different tokens

import transformers
from transformers import TextStreamer
import torch
from transformers.generation.streamers import BaseStreamer


class TokenStreamer(BaseStreamer):
    """
    Simple token streamer that prints each token with its corresponding layers used.
    
    Parameters:
        tokenizer (`AutoTokenizer`):
            The tokenizer used to decode the tokens.
        skip_prompt (`bool`, *optional*, defaults to `False`):
            Whether to skip the prompt tokens in the output. Useful for chatbots.
    """

    def __init__(self, tokenizer, skip_prompt=True):
        self.tokenizer = tokenizer
        self.skip_prompt = skip_prompt
        self.next_tokens_are_prompt = True

    def put(self, value):
        """
        Receives tokens and prints each one surrounded by brackets.
        """
        if len(value.shape) > 1 and value.shape[0] > 1:
            raise ValueError("TokenStreamer only supports batch size 1")
        elif len(value.shape) > 1:
            value = value[0]

        if self.skip_prompt and self.next_tokens_are_prompt:
            self.next_tokens_are_prompt = False
            return

        # Process each token in the received tensor
        for token_id in value.tolist():
            token_text = self.tokenizer.decode([token_id])
            print(f"={repr(token_text)}", end="\n", flush=True)

    def end(self):
        """Prints a newline at the end of generation."""
        self.next_tokens_are_prompt = True
        print()  # Print a newline at the end



# model path
model_id = "xuan-luo/FlexiDepth-Llama-3-8B-Instruct"
# tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)

pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
    trust_remote_code=True
)

messages = [
    {"role": "user", "content": \
"""
Please calcualte the sum of the eight numbers in the list: [99, 45, 12, 78, 33, 66, 21, 54]. Please solve this problem step by step.
"""},
]

terminators = [
    pipeline.tokenizer.eos_token_id,
    pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]


streamer = TokenStreamer(tokenizer)
outputs = pipeline(
    messages,
    max_new_tokens=512,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.6,
    top_p=1.0,
    streamer=streamer,
)

Evaluation

The evaluation was conducted using the lm-eval-harness framework (version 0.4.9.1). This is an update from the version used in our original paper (v0.4.8). A key change in the newer framework is the introduction of the new humaneval_instruct benchmark, which is more suitable for instruction-tuned models. We have therefore included its results below.

All evaluation scripts and detailed results are available in the evals folder of this repository.

Performance Comparison

The table below compares FlexiDepth-Llama-3-8B-Instruct against the baseline Llama-3-8B-Instruct. For our model, we report both the performance score and the average number of layers used per task, demonstrating its efficiency.

Benchmark Shots Metric FlexiDepth Score FlexiDepth Avg. Layers Llama-3 Score Llama-3 Layers
MMLU 5 acc 0.6642 28.31 0.6732 32
Hellaswag 5 acc_norm 0.7451 30.15 0.7066 32
Winogrande 5 acc 0.7545 27.65 0.7380 32
GSM8K 5 strict-match 0.7013 22.39 0.6687 32
HumanEval 0 pass@1 0.3476 22.97 0.2927 32
HumanEval-Instruct 0 pass@1 0.6098 22.18 0.5976 32
CoQA 0 f1 0.7878 25.17 0.7816 32

These results show that FlexiDepth-Llama-3-8B-Instruct maintains comparable or superior performance across most benchmarks while significantly reducing the average number of layers used for inference.

Model Card Authors

Xuan Luo, Weizhi Wang, Xifeng Yan

Model Card Contact

For questions or inquiries, please contact [email protected].

Downloads last month
52
Safetensors
Model size
8B params
Tensor type
BF16
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for xuan-luo/FlexiDepth-Llama-3-8B-Instruct

Finetuned
(792)
this model

Dataset used to train xuan-luo/FlexiDepth-Llama-3-8B-Instruct