alielfilali01 commited on
Commit
4a44340
·
verified ·
1 Parent(s): 00f7c3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -113
app.py CHANGED
@@ -13,7 +13,6 @@ from PIL import Image
13
  def load_results():
14
  # Get the directory of the current file
15
  current_dir = os.path.dirname(os.path.abspath(__file__))
16
- # Construct the path to the JSON file (assumes file is stored in "files/aragen_v1_results.json")
17
  results_file = os.path.join(current_dir, "files", "aragen_v1_results.json")
18
  with open(results_file, "r") as f:
19
  data = json.load(f)
@@ -75,7 +74,6 @@ def generate_heatmap_image(model_entry):
75
  image = Image.open(buf).convert("RGB")
76
 
77
  # Resize the image to a reasonable fixed size for the gallery
78
- # This helps maintain consistency and prevent oversized images
79
  max_size = (800, 600)
80
  image.thumbnail(max_size, Image.Resampling.LANCZOS)
81
 
@@ -103,12 +101,18 @@ with gr.Blocks(css="""
103
  object-fit: contain !important;
104
  }
105
  """) as demo:
106
- gr.Markdown("## 3C3H Heatmap Generator")
107
- gr.Markdown("Select the models you want to compare and generate their heatmaps below.")
 
 
 
 
 
 
108
 
109
  with gr.Row():
110
  default_models = ["silma-ai/SILMA-9B-Instruct-v1.0", "google/gemma-2-9b-it"]
111
- model_dropdown = gr.Dropdown(choices=MODEL_NAMES, label="Select Model(s)", multiselect=True, value=default_models) # value=MODEL_NAMES[:3]
112
 
113
  generate_btn = gr.Button("Generate Heatmaps")
114
 
@@ -122,112 +126,4 @@ with gr.Blocks(css="""
122
 
123
  generate_btn.click(fn=generate_heatmaps, inputs=model_dropdown, outputs=gallery)
124
 
125
- # Launch the Gradio app
126
  demo.launch()
127
-
128
-
129
- # import gradio as gr
130
- # import json
131
- # import os
132
- # import numpy as np
133
- # import matplotlib.pyplot as plt
134
- # import seaborn as sns
135
- # from io import BytesIO
136
- # from PIL import Image
137
-
138
- # # -------------------------------
139
- # # 1. Load Results from Local File
140
- # # -------------------------------
141
- # def load_results():
142
- # # Get the directory of the current file
143
- # current_dir = os.path.dirname(os.path.abspath(__file__))
144
- # # Construct the path to the JSON file (assumes file is stored in "files/aragen_v1_results.json")
145
- # results_file = os.path.join(current_dir, "files", "aragen_v1_results.json")
146
- # with open(results_file, "r") as f:
147
- # data = json.load(f)
148
- # # Filter out any non-model entries (e.g., timestamp entries)
149
- # model_data = [entry for entry in data if "Meta" in entry]
150
- # return model_data
151
-
152
- # # Load the JSON data once when the app starts
153
- # DATA = load_results()
154
-
155
- # # Extract model names for the dropdown from the JSON "Meta" field
156
- # def get_model_names(data):
157
- # model_names = [entry["Meta"]["Model Name"] for entry in data]
158
- # return model_names
159
-
160
- # MODEL_NAMES = get_model_names(DATA)
161
-
162
- # # -------------------------------
163
- # # 2. Define Metrics and Heatmap Generation Functions
164
- # # -------------------------------
165
- # # Define the six metrics in the desired order.
166
- # METRICS = ["Correctness", "Completeness", "Conciseness", "Helpfulness", "Honesty", "Harmlessness"]
167
-
168
- # def generate_heatmap_image(model_entry):
169
- # """
170
- # For a given model entry, extract the six metrics and compute a 6x6 similarity matrix
171
- # using the definition: similarity = 1 - |v_i - v_j|, then return the heatmap as a PIL image.
172
- # """
173
- # scores = model_entry["claude-3.5-sonnet Scores"]["3C3H Scores"]
174
- # # Create a vector with the metrics in the defined order.
175
- # v = np.array([scores[m] for m in METRICS])
176
- # # Compute the 6x6 similarity matrix.
177
- # matrix = 1 - np.abs(np.subtract.outer(v, v))
178
- # # Create a mask for the upper triangle (keeping the diagonal visible).
179
- # mask = np.triu(np.ones_like(matrix, dtype=bool), k=1)
180
-
181
- # plt.figure(figsize=(6, 5))
182
- # sns.heatmap(matrix,
183
- # mask=mask,
184
- # annot=True,
185
- # fmt=".2f",
186
- # cmap="viridis",
187
- # xticklabels=METRICS,
188
- # yticklabels=METRICS,
189
- # cbar_kws={"label": "Similarity"})
190
- # plt.title(f"Confusion Matrix for Model: {model_entry['Meta']['Model Name']}")
191
- # plt.xlabel("Metrics")
192
- # plt.ylabel("Metrics")
193
- # plt.tight_layout()
194
-
195
- # # Save the plot to a bytes buffer.
196
- # buf = BytesIO()
197
- # plt.savefig(buf, format="png")
198
- # plt.close()
199
- # buf.seek(0)
200
- # # Convert the buffer into a PIL Image.
201
- # image = Image.open(buf).convert("RGB")
202
- # return image
203
-
204
- # def generate_heatmaps(selected_model_names):
205
- # """
206
- # Filter the global DATA for entries matching the selected model names,
207
- # generate a heatmap for each, and return a list of PIL images.
208
- # """
209
- # filtered_entries = [entry for entry in DATA if entry["Meta"]["Model Name"] in selected_model_names]
210
- # images = []
211
- # for entry in filtered_entries:
212
- # img = generate_heatmap_image(entry)
213
- # images.append(img)
214
- # return images
215
-
216
- # # -------------------------------
217
- # # 3. Build the Gradio Interface
218
- # # -------------------------------
219
- # with gr.Blocks() as demo:
220
- # gr.Markdown("## 3C3H Heatmap Generator")
221
- # gr.Markdown("Select the models you want to compare and generate their heatmaps below.")
222
-
223
- # with gr.Row():
224
- # model_dropdown = gr.Dropdown(choices=MODEL_NAMES, label="Select Model(s)", multiselect=True, value=MODEL_NAMES[:3])
225
-
226
- # generate_btn = gr.Button("Generate Heatmaps")
227
- # # Use the 'columns' parameter to set a grid layout in the gallery.
228
- # gallery = gr.Gallery(label="Heatmaps", columns=2)
229
-
230
- # generate_btn.click(fn=generate_heatmaps, inputs=model_dropdown, outputs=gallery)
231
-
232
- # # Launch the Gradio app
233
- # demo.launch()
 
13
  def load_results():
14
  # Get the directory of the current file
15
  current_dir = os.path.dirname(os.path.abspath(__file__))
 
16
  results_file = os.path.join(current_dir, "files", "aragen_v1_results.json")
17
  with open(results_file, "r") as f:
18
  data = json.load(f)
 
74
  image = Image.open(buf).convert("RGB")
75
 
76
  # Resize the image to a reasonable fixed size for the gallery
 
77
  max_size = (800, 600)
78
  image.thumbnail(max_size, Image.Resampling.LANCZOS)
79
 
 
101
  object-fit: contain !important;
102
  }
103
  """) as demo:
104
+ gr.Markdown("""
105
+ <center>
106
+ <br></br>
107
+ <h1>3C3H Heatmap Generator</h1>
108
+ <br></br>
109
+ </center>
110
+ """)
111
+ gr.Markdown("<center>Select the models you want to compare and generate their heatmaps below.</center>")
112
 
113
  with gr.Row():
114
  default_models = ["silma-ai/SILMA-9B-Instruct-v1.0", "google/gemma-2-9b-it"]
115
+ model_dropdown = gr.Dropdown(choices=MODEL_NAMES, label="Select Model(s)", multiselect=True, value=default_models)
116
 
117
  generate_btn = gr.Button("Generate Heatmaps")
118
 
 
126
 
127
  generate_btn.click(fn=generate_heatmaps, inputs=model_dropdown, outputs=gallery)
128
 
 
129
  demo.launch()