Update utils.py
Browse files
utils.py
CHANGED
|
@@ -174,7 +174,7 @@ def save_attention_maps(attn_maps, tokenizer, prompts, base_dir='attn_maps', unc
|
|
| 174 |
|
| 175 |
|
| 176 |
total_attn_map /= total_attn_map_number
|
| 177 |
-
final_attn_map =
|
| 178 |
for batch, (attn_map, tokens) in enumerate(zip(total_attn_map, total_tokens)):
|
| 179 |
batch_dir = os.path.join(base_dir, f'batch-{batch}')
|
| 180 |
if not os.path.exists(batch_dir):
|
|
@@ -198,6 +198,6 @@ def save_attention_maps(attn_maps, tokenizer, prompts, base_dir='attn_maps', unc
|
|
| 198 |
token = '-' + token + '-'
|
| 199 |
|
| 200 |
|
| 201 |
-
final_attn_map
|
| 202 |
|
| 203 |
return final_attn_map
|
|
|
|
| 174 |
|
| 175 |
|
| 176 |
total_attn_map /= total_attn_map_number
|
| 177 |
+
final_attn_map = []
|
| 178 |
for batch, (attn_map, tokens) in enumerate(zip(total_attn_map, total_tokens)):
|
| 179 |
batch_dir = os.path.join(base_dir, f'batch-{batch}')
|
| 180 |
if not os.path.exists(batch_dir):
|
|
|
|
| 198 |
token = '-' + token + '-'
|
| 199 |
|
| 200 |
|
| 201 |
+
final_attn_map.append((to_pil(a.to(torch.float32)), f'{i}-{token}'))
|
| 202 |
|
| 203 |
return final_attn_map
|