|
|
import sys |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
|
from gemma_lens import HookedGemma, proj_steering_hook, non_proj_steering_hook |
|
|
from gemma_sae import feature_dict |
|
|
|
|
|
def getdirectorycontentsasstring(root_dir): |
|
|
""" |
|
|
Traverses the directory tree starting from root_dir and returns |
|
|
a formatted string containing all directories and files. |
|
|
|
|
|
Parameters: |
|
|
root_dir (str): The root directory to start traversal. |
|
|
|
|
|
Returns: |
|
|
str: A string representation of the directory tree. |
|
|
""" |
|
|
lines = [] |
|
|
for dirpath, dirnames, filenames in os.walk(root_dir): |
|
|
|
|
|
depth = dirpath.replace(root_dir, '').count(os.sep) |
|
|
indent = ' ' * depth |
|
|
|
|
|
lines.append(f"{indent}Directory: {os.path.basename(dirpath) or dirpath}") |
|
|
|
|
|
for dirname in dirnames: |
|
|
lines.append(f"{indent} Sub-directory: {dirname}") |
|
|
|
|
|
for filename in filenames: |
|
|
lines.append(f"{indent} File: {filename}") |
|
|
|
|
|
lines.append('') |
|
|
|
|
|
return '\n'.join(lines) |
|
|
|
|
|
class EndpointHandler(): |
|
|
|
|
|
def __init__(self, path="."): |
|
|
|
|
|
|
|
|
self.model = HookedGemma(path) |
|
|
self.path = path |
|
|
|
|
|
def stream(self, data): |
|
|
|
|
|
inputs = data.get("inputs") |
|
|
if inputs == "healthcheck": |
|
|
yield { |
|
|
"response": "Alive", |
|
|
"features": [] |
|
|
} |
|
|
|
|
|
elif inputs == "ls": |
|
|
yield { |
|
|
"response": getdirectorycontentsasstring(self.path), |
|
|
"features": [] |
|
|
} |
|
|
|
|
|
else: |
|
|
|
|
|
chat = inputs |
|
|
steer_inputs = data.get("steer_inputs", []) |
|
|
temperature = data.get("temperature", 1) |
|
|
max_new_tokens = min(512, data.get("max_new_tokens", 256)) |
|
|
|
|
|
if len(steer_inputs) > 0: |
|
|
|
|
|
features = [] |
|
|
hook_fns = [] |
|
|
dirs = [] |
|
|
strengths = [] |
|
|
for steer_input in steer_inputs: |
|
|
feature = steer_input["feature"].lower() |
|
|
strength = steer_input["strength"] |
|
|
|
|
|
if feature == "refusal": |
|
|
if strength < 0: |
|
|
hook_fn = proj_steering_hook |
|
|
dir_ = self.model.refusal_ablation_dir |
|
|
else: |
|
|
hook_fn = non_proj_steering_hook |
|
|
dir_ = self.model.refusal_amplify_dir |
|
|
|
|
|
elif feature in feature_dict: |
|
|
hook_fn = non_proj_steering_hook |
|
|
dir_ = feature_dict[feature.lower()] |
|
|
|
|
|
else: |
|
|
print("WARN: UNKNOWN STEERING FEATURE:", feature) |
|
|
continue |
|
|
|
|
|
features.append(feature) |
|
|
hook_fns.append(hook_fn) |
|
|
dirs.append(dir_) |
|
|
strengths.append(strength) |
|
|
|
|
|
with self.model.add_hooks(hook_fns, dirs, strengths): |
|
|
for tok_dec in self.model.predict(chat, temperature=temperature, max_new_tokens=max_new_tokens): |
|
|
yield {"token": tok_dec, "features": features} |
|
|
|
|
|
else: |
|
|
for tok_dec in self.model.predict(chat, temperature=temperature, max_new_tokens=max_new_tokens): |
|
|
yield {"token": tok_dec, "features": []} |
|
|
|
|
|
|
|
|
def __call__(self, data): |
|
|
print(data) |
|
|
|
|
|
inputs = data.get("inputs") |
|
|
if inputs == "healthcheck": |
|
|
return { |
|
|
"response": "Alive", |
|
|
"features": [] |
|
|
} |
|
|
|
|
|
elif inputs == "ls": |
|
|
return { |
|
|
"response": getdirectorycontentsasstring(self.path), |
|
|
"features": [] |
|
|
} |
|
|
else: |
|
|
|
|
|
chat = inputs |
|
|
steer_inputs = data.get("steer_inputs", []) |
|
|
temperature = data.get("temperature", 1) |
|
|
max_new_tokens = min(512, data.get("max_new_tokens", 256)) |
|
|
|
|
|
if len(steer_inputs) > 0: |
|
|
|
|
|
features = [] |
|
|
hook_fns = [] |
|
|
dirs = [] |
|
|
strengths = [] |
|
|
for steer_input in steer_inputs: |
|
|
feature = steer_input["feature"].lower() |
|
|
strength = steer_input["strength"] |
|
|
|
|
|
if feature == "refusal": |
|
|
if strength < 0: |
|
|
hook_fn = proj_steering_hook |
|
|
dir_ = self.model.refusal_ablation_dir |
|
|
else: |
|
|
hook_fn = non_proj_steering_hook |
|
|
dir_ = self.model.refusal_amplify_dir |
|
|
|
|
|
elif feature in feature_dict: |
|
|
hook_fn = non_proj_steering_hook |
|
|
dir_ = feature_dict[feature.lower()] |
|
|
|
|
|
else: |
|
|
print("WARN: UNKNOWN STEERING FEATURE:", feature) |
|
|
continue |
|
|
|
|
|
features.append(feature) |
|
|
hook_fns.append(hook_fn) |
|
|
dirs.append(dir_) |
|
|
strengths.append(strength) |
|
|
|
|
|
with self.model.add_hooks(hook_fns, dirs, strengths): |
|
|
response = [] |
|
|
for tok_dec in self.model.predict(chat, temperature=temperature, max_new_tokens=max_new_tokens): |
|
|
response.append(tok_dec) |
|
|
|
|
|
return { |
|
|
"response": "".join(response), |
|
|
"features": features |
|
|
} |
|
|
|
|
|
else: |
|
|
response = [] |
|
|
for tok_dec in self.model.predict(chat, temperature=temperature, max_new_tokens=max_new_tokens): |
|
|
response.append(tok_dec) |
|
|
return { |
|
|
"response": "".join(response), |
|
|
"features": [] |
|
|
} |
|
|
|
|
|
|