llama-lens-endpoint / handler.py
henryz2004
no streaming should be ok
bfde0bc
import sys
import os
#
# try:
# import datasets
# except:
#
# import os
# import subprocess
# # subprocess.run(['pip','install','--user','-r','/opt/huggingface/model/requirements.txt'])
import json
from gemma_lens import HookedGemma, proj_steering_hook, non_proj_steering_hook
from gemma_sae import feature_dict
def getdirectorycontentsasstring(root_dir):
"""
Traverses the directory tree starting from root_dir and returns
a formatted string containing all directories and files.
Parameters:
root_dir (str): The root directory to start traversal.
Returns:
str: A string representation of the directory tree.
"""
lines = []
for dirpath, dirnames, filenames in os.walk(root_dir):
# Calculate the indentation based on directory depth
depth = dirpath.replace(root_dir, '').count(os.sep)
indent = ' ' * depth
# Append the current directory
lines.append(f"{indent}Directory: {os.path.basename(dirpath) or dirpath}")
# Append subdirectories
for dirname in dirnames:
lines.append(f"{indent} Sub-directory: {dirname}")
# Append files
for filename in filenames:
lines.append(f"{indent} File: {filename}")
# Add a blank line for readability
lines.append('')
# Join all lines into a single string
return '\n'.join(lines)
class EndpointHandler():
def __init__(self, path="."):
self.model = HookedGemma(path)
self.path = path
def stream(self, data):
inputs = data.get("inputs")
if inputs == "healthcheck":
yield {
"response": "Alive",
"features": []
}
elif inputs == "ls":
yield {
"response": getdirectorycontentsasstring(self.path),
"features": []
}
else:
chat = inputs
steer_inputs = data.get("steer_inputs", [])
temperature = data.get("temperature", 1)
max_new_tokens = min(512, data.get("max_new_tokens", 256))
if len(steer_inputs) > 0:
features = []
hook_fns = []
dirs = []
strengths = []
for steer_input in steer_inputs:
feature = steer_input["feature"].lower()
strength = steer_input["strength"]
if feature == "refusal":
if strength < 0:
hook_fn = proj_steering_hook
dir_ = self.model.refusal_ablation_dir
else:
hook_fn = non_proj_steering_hook
dir_ = self.model.refusal_amplify_dir
elif feature in feature_dict:
hook_fn = non_proj_steering_hook
dir_ = feature_dict[feature.lower()]
else:
print("WARN: UNKNOWN STEERING FEATURE:", feature)
continue
features.append(feature)
hook_fns.append(hook_fn)
dirs.append(dir_)
strengths.append(strength)
with self.model.add_hooks(hook_fns, dirs, strengths):
for tok_dec in self.model.predict(chat, temperature=temperature, max_new_tokens=max_new_tokens):
yield {"token": tok_dec, "features": features}
else:
for tok_dec in self.model.predict(chat, temperature=temperature, max_new_tokens=max_new_tokens):
yield {"token": tok_dec, "features": []}
def __call__(self, data):
print(data)
inputs = data.get("inputs")
if inputs == "healthcheck":
return {
"response": "Alive",
"features": []
}
elif inputs == "ls":
return {
"response": getdirectorycontentsasstring(self.path),
"features": []
}
else:
chat = inputs
steer_inputs = data.get("steer_inputs", [])
temperature = data.get("temperature", 1)
max_new_tokens = min(512, data.get("max_new_tokens", 256))
if len(steer_inputs) > 0:
features = []
hook_fns = []
dirs = []
strengths = []
for steer_input in steer_inputs:
feature = steer_input["feature"].lower()
strength = steer_input["strength"]
if feature == "refusal":
if strength < 0:
hook_fn = proj_steering_hook
dir_ = self.model.refusal_ablation_dir
else:
hook_fn = non_proj_steering_hook
dir_ = self.model.refusal_amplify_dir
elif feature in feature_dict:
hook_fn = non_proj_steering_hook
dir_ = feature_dict[feature.lower()]
else:
print("WARN: UNKNOWN STEERING FEATURE:", feature)
continue
features.append(feature)
hook_fns.append(hook_fn)
dirs.append(dir_)
strengths.append(strength)
with self.model.add_hooks(hook_fns, dirs, strengths):
response = []
for tok_dec in self.model.predict(chat, temperature=temperature, max_new_tokens=max_new_tokens):
response.append(tok_dec)
return {
"response": "".join(response),
"features": features
}
else:
response = []
for tok_dec in self.model.predict(chat, temperature=temperature, max_new_tokens=max_new_tokens):
response.append(tok_dec)
return {
"response": "".join(response),
"features": []
}