|
|
import os |
|
|
import torch |
|
|
from prometheus_client.decorator import contextmanager |
|
|
from tqdm import tqdm |
|
|
import plotly.express as px |
|
|
from datasets import load_dataset |
|
|
from transformer_lens import HookedTransformer, utils |
|
|
from functools import partial |
|
|
from sae_lens import SAE |
|
|
from contextlib import contextmanager |
|
|
device = "cuda" |
|
|
|
|
|
from sae_lens import SAE |
|
|
|
|
|
sae, cfg_dict, sparsity = SAE.from_pretrained( |
|
|
release = "gemma-scope-2b-pt-res-canonical", |
|
|
sae_id = "layer_20/width_16k/canonical", |
|
|
device=device |
|
|
) |
|
|
|
|
|
sae_10, _, _ = SAE.from_pretrained( |
|
|
release = "gemma-scope-2b-pt-res-canonical", |
|
|
sae_id = "layer_10/width_16k/canonical", |
|
|
device=device |
|
|
) |
|
|
|
|
|
sae_4, _, _ = SAE.from_pretrained( |
|
|
release = "gemma-scope-2b-pt-res-canonical", |
|
|
sae_id = "layer_4/width_16k/canonical", |
|
|
device=device |
|
|
) |
|
|
|
|
|
model = HookedTransformer.from_pretrained_no_processing( |
|
|
model_name="google/gemma-2-2b-it", |
|
|
device=device, |
|
|
dtype=torch.bfloat16, |
|
|
default_padding_side="left" |
|
|
) |
|
|
layer = 20 |
|
|
sae.eval() |
|
|
|
|
|
feature_dict = { |
|
|
"dog": { |
|
|
"sae": sae, |
|
|
"index": 12082 |
|
|
}, |
|
|
"harry potter4": { |
|
|
"sae": sae_4, |
|
|
"index": 12445 |
|
|
}, |
|
|
"harry potter10": { |
|
|
"sae": sae_10, |
|
|
"index": 6520 |
|
|
} |
|
|
} |
|
|
|
|
|
def sae_hook(activation, hook, subject, strength): |
|
|
feature = feature_dict[subject] |
|
|
steering_vector = feature["sae"].W_dec[feature["index"]] * strength |
|
|
return activation + steering_vector |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def steering(subject, strength): |
|
|
layers = list(range(model.cfg.n_layers)) |
|
|
for layer in layers: |
|
|
model.add_hook( |
|
|
utils.get_act_name('resid_pre', layer), |
|
|
partial(sae_hook, subject=subject, strength=strength) |
|
|
) |
|
|
|
|
|
yield |
|
|
|
|
|
model.reset_hooks() |
|
|
|
|
|
|
|
|
batched_chat = [ |
|
|
[ |
|
|
{"role": "user", |
|
|
"content": "What book is Hermione from?"} |
|
|
] |
|
|
] |
|
|
|
|
|
tokens = model.tokenizer.apply_chat_template( |
|
|
batched_chat, |
|
|
padding=True, |
|
|
tokenize=True, |
|
|
return_tensors="pt" |
|
|
) |
|
|
print(tokens) |
|
|
|
|
|
for i in range(2): |
|
|
if i == 0: |
|
|
print("steering") |
|
|
with steering(subject="harry potter10", strength=-5): |
|
|
with torch.set_grad_enabled(False): |
|
|
batch_output = model.generate(tokens, max_new_tokens=256) |
|
|
response_tokens = [] |
|
|
for prompt, combined in zip(tokens, batch_output): |
|
|
response = combined[len(prompt):] |
|
|
response_tokens.append(response) |
|
|
|
|
|
responses = model.tokenizer.batch_decode(response_tokens, skip_special_tokens=True) |
|
|
|
|
|
else: |
|
|
print("no steering") |
|
|
with torch.set_grad_enabled(False): |
|
|
batch_output = model.generate(tokens, max_new_tokens=256) |
|
|
response_tokens = [] |
|
|
for prompt, combined in zip(tokens, batch_output): |
|
|
response = combined[len(prompt):] |
|
|
response_tokens.append(response) |
|
|
|
|
|
responses = model.tokenizer.batch_decode(response_tokens, skip_special_tokens=True) |
|
|
|
|
|
print(responses[0]) |