llama-lens-endpoint / tlens_gemma_steering.py
henryz2004
transfers gemma steering code
c54d733
import os
import torch
from prometheus_client.decorator import contextmanager
from tqdm import tqdm
import plotly.express as px
from datasets import load_dataset
from transformer_lens import HookedTransformer, utils
from functools import partial
from sae_lens import SAE
from contextlib import contextmanager
device = "cuda"
from sae_lens import SAE # pip install sae-lens
sae, cfg_dict, sparsity = SAE.from_pretrained(
release = "gemma-scope-2b-pt-res-canonical",
sae_id = "layer_20/width_16k/canonical",
device=device
)
sae_10, _, _ = SAE.from_pretrained(
release = "gemma-scope-2b-pt-res-canonical",
sae_id = "layer_10/width_16k/canonical",
device=device
)
sae_4, _, _ = SAE.from_pretrained(
release = "gemma-scope-2b-pt-res-canonical",
sae_id = "layer_4/width_16k/canonical",
device=device
)
model = HookedTransformer.from_pretrained_no_processing(
model_name="google/gemma-2-2b-it",
device=device,
dtype=torch.bfloat16,
default_padding_side="left"
)
layer = 20
sae.eval()
feature_dict = {
"dog": {
"sae": sae,
"index": 12082
},
"harry potter4": {
"sae": sae_4,
"index": 12445
},
"harry potter10": {
"sae": sae_10,
"index": 6520
}
}
def sae_hook(activation, hook, subject, strength):
feature = feature_dict[subject]
steering_vector = feature["sae"].W_dec[feature["index"]] * strength
return activation + steering_vector
@contextmanager
def steering(subject, strength):
layers = list(range(model.cfg.n_layers))
for layer in layers:
model.add_hook(
utils.get_act_name('resid_pre', layer),
partial(sae_hook, subject=subject, strength=strength)
)
yield
model.reset_hooks()
batched_chat = [
[
{"role": "user",
"content": "What book is Hermione from?"}
]
]
tokens = model.tokenizer.apply_chat_template(
batched_chat,
padding=True,
tokenize=True,
return_tensors="pt"
)
print(tokens)
for i in range(2):
if i == 0:
print("steering")
with steering(subject="harry potter10", strength=-5):
with torch.set_grad_enabled(False):
batch_output = model.generate(tokens, max_new_tokens=256)
response_tokens = []
for prompt, combined in zip(tokens, batch_output):
response = combined[len(prompt):]
response_tokens.append(response)
responses = model.tokenizer.batch_decode(response_tokens, skip_special_tokens=True)
else:
print("no steering")
with torch.set_grad_enabled(False):
batch_output = model.generate(tokens, max_new_tokens=256)
response_tokens = []
for prompt, combined in zip(tokens, batch_output):
response = combined[len(prompt):]
response_tokens.append(response)
responses = model.tokenizer.batch_decode(response_tokens, skip_special_tokens=True)
print(responses[0])