Spaces:
Sleeping
Sleeping
| import math | |
| import os | |
| import json | |
| import pickle | |
| import random | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| import numpy as np | |
| EMPTY_DATA_PATH = "tangram_pngs/" | |
| CLIP_FOLDER = "clip_similarities" | |
| def generate_complete_game(): | |
| # First get corpus and clip model | |
| curr_corpus = get_data() | |
| clip_files = os.listdir(CLIP_FOLDER) | |
| clip_model = {} | |
| for filename in clip_files: | |
| # Get values | |
| with open(os.path.join(CLIP_FOLDER, filename), 'rb') as f: | |
| curr_similarities = pickle.load(f) | |
| # Get keys | |
| tangram_name = '-'.join(filename.split('-')[:2]) | |
| clip_model[tangram_name] = curr_similarities | |
| # Next get the pragmatic context | |
| context_dict = get_pragmatic_context(curr_corpus, clip_model) | |
| return context_dict | |
| def get_pragmatic_context(curr_corpus, clip_model): | |
| # Initialize the lists needed for generation | |
| overall_context = [] | |
| base_tangrams = [] | |
| individual_blocks = [] | |
| # Initialize the parameters for generation | |
| block_sizes = evenly_spread_values(10, 3) | |
| for i in range(3): | |
| # Sample the base tangram | |
| base_tangram = sample_similarity_block_base(curr_corpus, clip_model, overall_context) | |
| base_tangrams.append(base_tangram) | |
| # Sample the similarity block | |
| similarity_block = sample_similarity_block(curr_corpus, base_tangram, block_sizes[i], clip_model) # TODO | |
| individual_blocks.append(similarity_block) | |
| overall_context.extend(similarity_block) | |
| # Filter out the corpus | |
| curr_corpus = [tangram for tangram in curr_corpus if tangram not in overall_context] | |
| # Sample the targets at random | |
| targets = random.sample(overall_context, 3) | |
| # Construct the dictionary | |
| speaker_order = list(range(len(overall_context))) | |
| random.shuffle(speaker_order) | |
| speaker_images = [overall_context[i] for i in speaker_order] | |
| listener_order = list(range(len(overall_context))) | |
| random.shuffle(listener_order) | |
| listener_images = [overall_context[i] for i in listener_order] | |
| context_dict = { | |
| "speaker_context" : speaker_images, | |
| "listener_context" : listener_images, | |
| "targets" : targets, | |
| } | |
| return context_dict | |
| def evenly_spread_values(block_size, num_similarity_blocks): | |
| sim_block_sizes = [0 for _ in range(num_similarity_blocks)] | |
| for i in range(block_size): | |
| idx = i % num_similarity_blocks | |
| sim_block_sizes[idx] += 1 | |
| return sim_block_sizes | |
| def sample_similarity_block_base(curr_corpus, clip_model, overall_context): | |
| # Get list of candidate tangrams | |
| candidate_base_tangrams = get_candidate_base_tangrams(curr_corpus, clip_model, | |
| overall_context) | |
| base_tangram = random.sample(candidate_base_tangrams, 1)[0] | |
| return base_tangram | |
| def get_candidate_base_tangrams(curr_corpus, clip_model, overall_context): | |
| candidate_base_tangrams = [] | |
| for tangram in curr_corpus: | |
| if valid_base_tangram(overall_context, tangram, clip_model): | |
| candidate_base_tangrams.append(tangram) | |
| return candidate_base_tangrams | |
| def valid_base_tangram(overall_context, tangram, clip_model): | |
| for context_tangram in overall_context: | |
| if clip_model[context_tangram[:-4]][tangram[:-4]] > 1: | |
| return False | |
| return True | |
| def sample_similarity_block(curr_corpus, base_tangram, similarity_block_size, | |
| clip_model): | |
| # Get the most similar tangrams to the base tangram | |
| base_similarities = clip_model[base_tangram[:-4]] | |
| sorted_similarities = sorted(base_similarities.items(), reverse=True, key=lambda x: x[1]) | |
| sorted_similarities = [sim for sim in sorted_similarities if sim[0] + ".png" in curr_corpus] | |
| # Separate out the tangrams and the scores | |
| sorted_tangrams = [sim[0] + ".png" for sim in sorted_similarities] | |
| sorted_scores = [sim[1] for sim in sorted_similarities] | |
| k = similarity_block_size - 1 | |
| distribution = get_similarity_distribution(sorted_scores, 0.055) | |
| sampled_indices = sample_without_replacement(distribution, k) | |
| similarity_block = [base_tangram] + [sorted_tangrams[i] for i in sampled_indices] | |
| return similarity_block | |
| def get_similarity_distribution(scores, temperature): | |
| logits = torch.Tensor([score / temperature for score in scores]) | |
| probs = F.softmax(logits, dim=0) | |
| return probs | |
| def sample_without_replacement(distribution, K): | |
| new_distribution = torch.clone(distribution) | |
| samples = [] | |
| for i in range(K): | |
| current_sample = torch.multinomial(new_distribution, num_samples=1).item() | |
| samples.append(current_sample) | |
| new_distribution[current_sample] = 0 | |
| new_distribution = new_distribution / torch.sum(new_distribution) | |
| return samples | |
| def get_data(restricted_dataset=""): | |
| # Get the list of all paths | |
| if restricted_dataset == "": | |
| paths = os.listdir(EMPTY_DATA_PATH) | |
| else: | |
| with open(restricted_dataset, 'rb') as f: | |
| paths = pickle.load(f) | |
| paths = [path + ".svg" for path in paths] | |
| paths = [path for path in paths if ".DS_Store" not in path] | |
| random.shuffle(paths) | |
| # Remove duplicates | |
| for duplicate in ['page6-51.png', 'page6-66.png', 'page4-170.png']: | |
| if duplicate in paths: | |
| paths.remove(duplicate) | |
| return paths | |