import subprocess from flask import Flask, request from flask_cors import CORS from transformer_lens import utils from llama_lens import HookedLlama device = utils.get_device() app = Flask(__name__) CORS(app) @app.route("/") def alive(): try: # Run nvidia-smi command and capture the output result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE) # Decode the output from bytes to string output = result.stdout.decode('utf-8') except: output = "NVIDIA-SMI ERROR" # Print the output return f"Endpoint service is alive running on {device}\nNVIDIA SMI OUTPUT:\n{output}" @app.route("/llama", methods=["POST"]) def steer(): json = request.get_json(force=True) # expected api call # chat: { # base: [ // list of messages # { # role: "user" / "assistant", # message: str # }, # ... # ], # steered: [...] # } # steer_topics: [ # { # topic: str, # direction: - / + float # }, # ... # ] chat = json["chat"] base_hist = chat["base"] steered_hist = chat["steered"] steer_topics = json["steer_topics"] base_response = model.predict(base_hist) with model.add_refusal_ablation_hooks("normal"): steered_response = model.predict(steered_hist) return { "base": base_response, "steered": steered_response, "steer_topics": steer_topics } if __name__ == "__main__": model = HookedLlama() app.run(host="0.0.0.0", port=80, debug=True)