File size: 4,250 Bytes
c431d32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# modeling_creativity_iti.py
"""
Auto-apply creativity ITI wrapper for LLaMA 3.1 8B
"""

import torch
import pickle
import json
from pathlib import Path
from transformers import LlamaForCausalLM
from huggingface_hub import hf_hub_download

class CreativityITILlamaForCausalLM(LlamaForCausalLM):
    """LLaMA with automatic creativity ITI application"""
    
    def __init__(self, config):
        super().__init__(config)
        
        try:
            # Get model name from config
            model_name = getattr(config, "_name_or_path", "")
            
            # Download ITI files
            print(f"Loading Creativity ITI components...")
            
            top_heads_path = hf_hub_download(
                repo_id=model_name,
                filename="iti_top_heads.pkl",
                repo_type="model"
            )
            
            directions_path = hf_hub_download(
                repo_id=model_name,
                filename="iti_directions.pkl",
                repo_type="model"
            )
            
            config_path = hf_hub_download(
                repo_id=model_name,
                filename="iti_config.json",
                repo_type="model"
            )
            
            # Load files
            with open(top_heads_path, 'rb') as f:
                self.top_heads = pickle.load(f)
            
            with open(directions_path, 'rb') as f:
                self.directions = pickle.load(f)
            
            with open(config_path, 'r') as f:
                iti_config = json.load(f)
                self.alpha = iti_config['alpha']
            
            # Model dimensions
            self.num_heads = config.num_attention_heads
            self.head_dim = config.hidden_size // self.num_heads
            
            # Register hooks
            self._register_iti_hooks()
            print(f"✓ Creativity ITI active: α={self.alpha}, {len(self.top_heads)} heads")
            
        except Exception as e:
            print(f"Warning: Could not load ITI: {e}")
            self.top_heads = []
            self.directions = {}
            self.alpha = 0
    
    def _register_iti_hooks(self):
        """Register ITI intervention hooks"""
        if not self.top_heads:
            return
            
        heads_by_layer = {}
        for head_info in self.top_heads:
            layer = head_info['layer']
            head = head_info['head']
            if layer not in heads_by_layer:
                heads_by_layer[layer] = []
            heads_by_layer[layer].append(head)
        
        for layer_idx, head_indices in heads_by_layer.items():
            def make_hook(layer_idx, head_indices):
                def hook_fn(module, input, output):
                    if isinstance(output, tuple):
                        hidden_states = output[0]
                    else:
                        hidden_states = output
                    
                    batch_size, seq_len, hidden_size = hidden_states.shape
                    
                    hidden_reshaped = hidden_states.view(
                        batch_size, seq_len, self.num_heads, self.head_dim
                    )
                    
                    for head_idx in head_indices:
                        if (layer_idx, head_idx) in self.directions:
                            direction = torch.tensor(
                                self.directions[(layer_idx, head_idx)],
                                dtype=hidden_reshaped.dtype,
                                device=hidden_reshaped.device
                            )
                            
                            hidden_reshaped[:, -1, head_idx, :] += self.alpha * direction
                    
                    hidden_states = hidden_reshaped.view(batch_size, seq_len, hidden_size)
                    
                    if isinstance(output, tuple):
                        return (hidden_states,) + output[1:]
                    else:
                        return hidden_states
                
                return hook_fn
            
            hook = make_hook(layer_idx, head_indices)
            self.model.layers[layer_idx].self_attn.o_proj.register_forward_hook(hook)