import os import torch from prometheus_client.decorator import contextmanager from tqdm import tqdm import plotly.express as px from datasets import load_dataset from transformer_lens import HookedTransformer, utils from functools import partial from sae_lens import SAE from contextlib import contextmanager device = "cuda" from sae_lens import SAE # pip install sae-lens sae, cfg_dict, sparsity = SAE.from_pretrained( release = "gemma-scope-2b-pt-res-canonical", sae_id = "layer_20/width_16k/canonical", device=device ) sae_10, _, _ = SAE.from_pretrained( release = "gemma-scope-2b-pt-res-canonical", sae_id = "layer_10/width_16k/canonical", device=device ) sae_4, _, _ = SAE.from_pretrained( release = "gemma-scope-2b-pt-res-canonical", sae_id = "layer_4/width_16k/canonical", device=device ) model = HookedTransformer.from_pretrained_no_processing( model_name="google/gemma-2-2b-it", device=device, dtype=torch.bfloat16, default_padding_side="left" ) layer = 20 sae.eval() feature_dict = { "dog": { "sae": sae, "index": 12082 }, "harry potter4": { "sae": sae_4, "index": 12445 }, "harry potter10": { "sae": sae_10, "index": 6520 } } def sae_hook(activation, hook, subject, strength): feature = feature_dict[subject] steering_vector = feature["sae"].W_dec[feature["index"]] * strength return activation + steering_vector @contextmanager def steering(subject, strength): layers = list(range(model.cfg.n_layers)) for layer in layers: model.add_hook( utils.get_act_name('resid_pre', layer), partial(sae_hook, subject=subject, strength=strength) ) yield model.reset_hooks() batched_chat = [ [ {"role": "user", "content": "What book is Hermione from?"} ] ] tokens = model.tokenizer.apply_chat_template( batched_chat, padding=True, tokenize=True, return_tensors="pt" ) print(tokens) for i in range(2): if i == 0: print("steering") with steering(subject="harry potter10", strength=-5): with torch.set_grad_enabled(False): batch_output = model.generate(tokens, max_new_tokens=256) response_tokens = [] for prompt, combined in zip(tokens, batch_output): response = combined[len(prompt):] response_tokens.append(response) responses = model.tokenizer.batch_decode(response_tokens, skip_special_tokens=True) else: print("no steering") with torch.set_grad_enabled(False): batch_output = model.generate(tokens, max_new_tokens=256) response_tokens = [] for prompt, combined in zip(tokens, batch_output): response = combined[len(prompt):] response_tokens.append(response) responses = model.tokenizer.batch_decode(response_tokens, skip_special_tokens=True) print(responses[0])