Spaces:
Sleeping
Sleeping
| import spaces | |
| import gradio as gr | |
| import torch | |
| import random | |
| import os | |
| from typing import List, Tuple | |
| from config_generator import generate_complete_game | |
| from dataset import get_processor, joint_speaker_input, joint_listener_input, get_index_to_token | |
| from models import get_model | |
| css=""" | |
| .radio-group .wrap { | |
| display: grid; | |
| grid-template-columns: repeat(5, 1fr); | |
| grid-template-rows: repeat(5, 1fr); | |
| width: 100%; | |
| height: 100% | |
| } | |
| """ | |
| def initialize_game() -> List[List[str]]: | |
| context_dicts = [generate_complete_game() for _ in range(4)] | |
| roles = ["listener"] * 3 + ["speaker"] * 3 + ["listener"] * 3 + ["speaker"] * 3 | |
| speaker_images = [] | |
| listener_images = [] | |
| targets = [] | |
| for context_dict in context_dicts: | |
| for i in range(3): | |
| speaker_images.append(context_dict["speaker_context"]) | |
| listener_images.append(context_dict["listener_context"]) | |
| targets.append(context_dict["targets"][i]) | |
| return list(zip(speaker_images, listener_images, targets, roles)) | |
| def get_model_response( | |
| model, adapter_name, processor, index_to_token, role: str, | |
| image_paths: List[str], user_message: str = "", target_image: str = "" | |
| ) -> str: | |
| if role == "speaker": | |
| img_dir = "tangram_pngs" | |
| print("Starting processing") | |
| input_tokens, attn_mask, images, image_attn_mask, label = joint_speaker_input( | |
| processor, image_paths, target_image, model.get_listener().device | |
| ) | |
| image_paths = [image_paths] | |
| print("Starting inference") | |
| captions = get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths, | |
| processor, img_dir, index_to_token, adapter_name) | |
| print("Done") | |
| response = captions[0] | |
| else: # listener | |
| print("Starting processing") | |
| images, l_input_tokens, l_attn_mask, l_image_attn_mask, s_input_tokens, s_attn_mask, \ | |
| s_image_attn_mask, s_target_mask, s_target_label = joint_listener_input( | |
| processor, image_paths, user_message, model.get_listener().device | |
| ) | |
| print("Starting inference") | |
| response = get_listener_response( | |
| model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token, | |
| s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths, adapter_name | |
| ) | |
| print("Done") | |
| return response | |
| def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths, processor, img_dir, index_to_token, adapter_name): | |
| if model.model.active_adapter != adapter_name: | |
| model.model.set_adapter(adapter_name) | |
| model = model.cuda() | |
| with torch.no_grad(): | |
| captions, _, _, _, _ = model.generate( | |
| images.cuda(), input_tokens.cuda(), attn_mask.cuda(), image_attn_mask.cuda(), label.cuda(), | |
| image_paths, processor, img_dir, index_to_token, | |
| max_steps=30, sampling_type="nucleus", temperature=0.7, | |
| top_k=50, top_p=1, repetition_penalty=1, num_samples=5 | |
| ) | |
| return captions | |
| def get_listener_response(model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token, | |
| s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths, adapter_name): | |
| if model.model.active_adapter != adapter_name: | |
| model.model.set_adapter(adapter_name) | |
| model = model.cuda() | |
| with torch.no_grad(): | |
| _, _, joint_log_probs = model.comprehension_side([ | |
| images.cuda(), l_input_tokens.cuda(), l_attn_mask.cuda(), l_image_attn_mask.cuda(), index_to_token, | |
| s_input_tokens.cuda(), s_attn_mask.cuda(), s_image_attn_mask.cuda(), s_target_mask.cuda(), s_target_label.cuda(), | |
| ]) | |
| target_idx = joint_log_probs[0].argmax().item() | |
| response = image_paths[target_idx] | |
| return response | |
| def initialize_interaction(model_iteration): | |
| # initialize the overall history | |
| new_history = { | |
| 'adapter_name' : 'initial' if model_iteration == "Initial System" else "final", | |
| 'image_role_pairs' : initialize_game(), | |
| 'conversation' : [], | |
| 'turn' : 0, | |
| 'num_correct' : 0, | |
| } | |
| # Initialize the first turn (always a listener) | |
| turn = new_history['turn'] | |
| image_role_pairs = new_history['image_role_pairs'] | |
| speaker_image, listener_image, target_image, _ = image_role_pairs[turn] | |
| target_idx = speaker_image.index(target_image) | |
| new_history['conversation'].extend([ | |
| f"TURN: {turn + 1}/12", | |
| f"Generate a description for the target image. Your target is Image {target_idx + 1}" | |
| ]) | |
| return new_history | |
| def progress_game(user_message, model, processor, index_to_token, current_state): | |
| # First get the game state | |
| turn = current_state['turn'] | |
| image_role_pairs = current_state['image_role_pairs'] | |
| speaker_image, listener_image, target_image, model_role = image_role_pairs[turn] | |
| human_role = "Speaker" if model_role == "listener" else "Listener" | |
| # Next, move on with current turn | |
| if model_role == "listener": | |
| human_context = speaker_image | |
| model_context = listener_image | |
| # If model is a listener, the human must have sent a message | |
| current_state['conversation'].append(f"You: {user_message}") | |
| model_message = get_model_response( | |
| model, current_state['adapter_name'], processor, index_to_token, model_role, | |
| model_context, user_message=user_message | |
| ) | |
| model_idx = human_context.index(model_message) | |
| target_idx = human_context.index(target_image) | |
| if int(model_idx) == int(target_idx): | |
| current_state['conversation'].append("The model guessed correctly!\n") | |
| current_state['num_correct'] += 1 | |
| else: | |
| current_state['conversation'].append(f"The model guessed incorrectly.\n") | |
| else: | |
| human_context = listener_image | |
| model_context = speaker_image | |
| # If model is a speaker, the human must have made a guess | |
| target_idx = human_context.index(target_image) | |
| current_state['conversation'][-1] += f"{user_message}" | |
| if int(user_message) == target_idx + 1: | |
| current_state['conversation'].append("Correct!\n") | |
| current_state['num_correct'] += 1 | |
| else: | |
| current_state['conversation'].append(f"Incorrect!\n") | |
| # We move on to the next turn | |
| current_state['turn'] += 1 | |
| acc_message = f"{current_state['num_correct']}/{current_state['turn']}" | |
| turn_message = f"{current_state['turn'] + 1}/12" | |
| if current_state['turn'] == len(image_role_pairs): | |
| current_state['conversation'].append('The game is over!') | |
| return human_context, current_state['conversation'], human_role, turn_message, acc_message, {} | |
| speaker_image, listener_image, target_image, model_role = image_role_pairs[current_state['turn']] | |
| human_role = "Listener" if model_role == "speaker" else "Speaker" | |
| if model_role == "speaker": | |
| human_context = listener_image | |
| model_context = speaker_image | |
| current_state['conversation'].extend([ | |
| f"TURN: {current_state['turn'] + 1}/12", | |
| f"Guess the target image given the speaker's description. ", | |
| ]) | |
| model_message = get_model_response(model, current_state['adapter_name'], processor, index_to_token, | |
| model_role, model_context, target_image=target_image) | |
| current_state['conversation'].append(f"Model: {model_message}") | |
| current_state['conversation'].append("You: The target is Image ") | |
| else: | |
| human_context = speaker_image | |
| model_context = listener_image | |
| target_idx = human_context.index(target_image) | |
| current_state['conversation'].extend([ | |
| f"TURN: {current_state['turn'] + 1}/12", | |
| f"Generate a description for the target image. Your target is Image {target_idx + 1}", | |
| ]) | |
| return human_context, current_state['conversation'], human_role, turn_message, acc_message, current_state | |
| def get_current_images(current_history): | |
| turn = current_history['turn'] | |
| image_role_pairs = current_history['image_role_pairs'] | |
| speaker_image, listener_image, target_image, model_role = image_role_pairs[turn] | |
| human_context = listener_image if model_role == "speaker" else speaker_image | |
| return human_context | |
| def get_human_role(current_history): | |
| turn = current_history['turn'] | |
| image_role_pairs = current_history['image_role_pairs'] | |
| speaker_image, listener_image, target_image, model_role = image_role_pairs[turn] | |
| return "Listener" if model_role == "speaker" else "Speaker" | |
| def create_app(): | |
| with gr.Blocks(css=css) as app: | |
| game_history = gr.State(value={}) | |
| gr.Markdown("# Tangram Reference Game") | |
| gr.Markdown( | |
| '### You will be playing a sequence of reference games against a model. To start a game, first select whether ' +\ | |
| 'you wish to play against our initial trained model ("Initial System") or our model at the end of deployment ("Final System") ' +\ | |
| 'and press the "Start Game" button. There will be 12 rounds of reference games. You will take on a "listener" or a "speaker" role at each round.' | |
| ) | |
| gr.Markdown( | |
| '### In the speaker role, you will be assigned a target image. Your goal will be to describe this image (via a message in the textbox) ' +\ | |
| 'so that your partner can guess what it is.' | |
| ) | |
| gr.Markdown( | |
| '### In the listener role, you will be given a description. Your goal will be ' +\ | |
| 'to select the image that the description best describes (by clicking on the relevant button).' | |
| ) | |
| gr.Markdown( | |
| '### Press "Send" to submit your action in either role and make the game proceed.' | |
| ) | |
| with gr.Row(): | |
| model_iteration = gr.Radio(["Initial System", "Final System"], label="Model Iteration") | |
| start_btn = gr.Button("Start Game") | |
| with gr.Row(): | |
| current_role = gr.Textbox(label="YOUR ROLE") | |
| current_turn = gr.Textbox(label="TURN") | |
| accuracy = gr.Textbox(label="FINAL ACCURACY") | |
| with gr.Row(): | |
| image_output = gr.Gallery( | |
| label="CONTEXT", show_label=False, elem_id="gallery", | |
| columns=5, rows=2, object_fit="contain", height="250px", | |
| allow_preview=False, container=True | |
| ) | |
| with gr.Row(): | |
| conversation_output = gr.Textbox(label="Interaction History") | |
| with gr.Column(): | |
| user_input = gr.Textbox(label="Your Message as Speaker", interactive=False) | |
| radio_buttons = gr.Radio( | |
| label="Your Guess as Listener", | |
| elem_classes="radio-group", | |
| choices=list(range(1, 11)), | |
| interactive=False, | |
| ) | |
| send_btn = gr.Button("Send", interactive=False) | |
| model = get_model() | |
| processor = get_processor() | |
| index_to_token = get_index_to_token() | |
| def start_interaction(model_iteration): | |
| # Initialize the interaction | |
| if model_iteration is None: | |
| return [], "Please select a model iteration.", "", "", "", gr.update(interactive=False), \ | |
| gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), {} | |
| current_history = initialize_interaction(model_iteration) | |
| # Unpack the relevant items | |
| images = get_current_images(current_history) | |
| conversation = current_history["conversation"] | |
| role = get_human_role(current_history) | |
| human_listener = role == "Listener" | |
| current_turn = current_history['turn'] + 1 | |
| turn_msg = f"{current_turn}/12" | |
| acc_msg = "0/0" | |
| return [(f"tangram_pngs/{img}", f"Image {i+1}") for i, img in enumerate(images)], "\n".join(conversation), role, turn_msg, acc_msg, \ | |
| gr.update(interactive=not human_listener), gr.update(interactive=human_listener), gr.update(interactive=True), gr.update(interactive=False), current_history | |
| def send_message(message, radio_choice, current_state): | |
| nonlocal model | |
| nonlocal processor | |
| nonlocal index_to_token | |
| # Game ended | |
| if current_state['turn'] == len(current_state['image_role_pairs']): | |
| return [], conversation_output.value, current_role.value, current_turn.value, accuracy.value, gr.update(interactive=False), \ | |
| gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True, value=None), {} | |
| # Regular game progress | |
| user_output = message if radio_choice is None else radio_choice | |
| images, conversation, role, turn, acc_message, current_state = progress_game(user_output, model, processor, index_to_token, current_state) | |
| human_listener = role == "Listener" | |
| return [(f"tangram_pngs/{img}", f"Image {i+1}") for i, img in enumerate(images)], "\n".join(conversation), role, turn, \ | |
| acc_message, gr.update(interactive=not human_listener, value=""), gr.update(interactive=human_listener, value=None), \ | |
| gr.update(interactive=True), gr.update(interactive=False), current_state | |
| start_btn.click( | |
| start_interaction, | |
| inputs=[model_iteration], | |
| outputs=[ | |
| image_output, conversation_output, current_role, current_turn, accuracy, | |
| user_input, radio_buttons, send_btn, model_iteration, game_history], | |
| queue=False | |
| ) | |
| send_btn.click( | |
| send_message, | |
| inputs=[user_input, radio_buttons, game_history], | |
| outputs=[image_output, conversation_output, current_role, current_turn, accuracy, user_input, | |
| radio_buttons, send_btn, model_iteration, game_history], | |
| queue=True | |
| ) | |
| return app | |
| app = create_app() | |
| app.queue() | |
| app.launch() | |