File size: 4,250 Bytes
c431d32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
# modeling_creativity_iti.py
"""
Auto-apply creativity ITI wrapper for LLaMA 3.1 8B
"""
import torch
import pickle
import json
from pathlib import Path
from transformers import LlamaForCausalLM
from huggingface_hub import hf_hub_download
class CreativityITILlamaForCausalLM(LlamaForCausalLM):
"""LLaMA with automatic creativity ITI application"""
def __init__(self, config):
super().__init__(config)
try:
# Get model name from config
model_name = getattr(config, "_name_or_path", "")
# Download ITI files
print(f"Loading Creativity ITI components...")
top_heads_path = hf_hub_download(
repo_id=model_name,
filename="iti_top_heads.pkl",
repo_type="model"
)
directions_path = hf_hub_download(
repo_id=model_name,
filename="iti_directions.pkl",
repo_type="model"
)
config_path = hf_hub_download(
repo_id=model_name,
filename="iti_config.json",
repo_type="model"
)
# Load files
with open(top_heads_path, 'rb') as f:
self.top_heads = pickle.load(f)
with open(directions_path, 'rb') as f:
self.directions = pickle.load(f)
with open(config_path, 'r') as f:
iti_config = json.load(f)
self.alpha = iti_config['alpha']
# Model dimensions
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // self.num_heads
# Register hooks
self._register_iti_hooks()
print(f"✓ Creativity ITI active: α={self.alpha}, {len(self.top_heads)} heads")
except Exception as e:
print(f"Warning: Could not load ITI: {e}")
self.top_heads = []
self.directions = {}
self.alpha = 0
def _register_iti_hooks(self):
"""Register ITI intervention hooks"""
if not self.top_heads:
return
heads_by_layer = {}
for head_info in self.top_heads:
layer = head_info['layer']
head = head_info['head']
if layer not in heads_by_layer:
heads_by_layer[layer] = []
heads_by_layer[layer].append(head)
for layer_idx, head_indices in heads_by_layer.items():
def make_hook(layer_idx, head_indices):
def hook_fn(module, input, output):
if isinstance(output, tuple):
hidden_states = output[0]
else:
hidden_states = output
batch_size, seq_len, hidden_size = hidden_states.shape
hidden_reshaped = hidden_states.view(
batch_size, seq_len, self.num_heads, self.head_dim
)
for head_idx in head_indices:
if (layer_idx, head_idx) in self.directions:
direction = torch.tensor(
self.directions[(layer_idx, head_idx)],
dtype=hidden_reshaped.dtype,
device=hidden_reshaped.device
)
hidden_reshaped[:, -1, head_idx, :] += self.alpha * direction
hidden_states = hidden_reshaped.view(batch_size, seq_len, hidden_size)
if isinstance(output, tuple):
return (hidden_states,) + output[1:]
else:
return hidden_states
return hook_fn
hook = make_hook(layer_idx, head_indices)
self.model.layers[layer_idx].self_attn.o_proj.register_forward_hook(hook)
|