Spaces:
Sleeping
Sleeping
| import json | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| from torchvision import transforms | |
| def read_json(file_name, suppress_console_info=False): | |
| with open(file_name, 'r') as f: | |
| data = json.load(f) | |
| if not suppress_console_info: | |
| print("Read from:", file_name) | |
| return data | |
| def get_file_names(data, imgs_folder, feature_folder, suppress_console_info=False): | |
| image_file_names = {} | |
| feature_pathes = {} | |
| captions = {} | |
| labels = {} | |
| lats = {} | |
| lons = {} | |
| for img in data['images']: | |
| image_name = img["image_name"] | |
| sample_id = img["sample_id"] | |
| image_id = f'{sample_id}_{image_name}' | |
| path_data = imgs_folder + f'{sample_id}/{image_name}' | |
| feature_data = feature_folder + f'{sample_id}/{image_name}.npy' | |
| # image_file_name.append(path_data) | |
| # caption.append(img["description"]) | |
| # label.append(img["labels"]) | |
| # lat.append(img["lat"]) | |
| # lon.append(img["lon"]) | |
| image_file_names[image_id] = path_data | |
| feature_pathes[image_id] = feature_data | |
| captions[image_id] = img["description"] | |
| labels[image_id] = img["labels"] | |
| lats[image_id] = img["lat"] | |
| lons[image_id] = img["lon"] | |
| return image_file_names, feature_pathes, captions, labels, lats, lons | |
| def get_data(image_file_names, captions, feature_pathes, labels, lats, lons, image_id): | |
| image_file_name = image_file_names[image_id] | |
| feature_path = feature_pathes[image_id] | |
| caption = captions[image_id] | |
| label = labels[image_id] | |
| lat = lats[image_id] | |
| lon = lons[image_id] | |
| return image_file_name, feature_path, caption, label, lat, lon | |
| def read_by_image_id(data_dir, imgs_folder, feature_folder, image_id=None): | |
| ''' | |
| return: | |
| img | |
| img_ -> transform(img) | |
| caption | |
| image_feature -> tensor | |
| label | |
| label_en -> text of labels | |
| lat | |
| lon | |
| ''' | |
| data_info = read_json(data_dir) | |
| image_file_names, image_features_path, captions, labels, lats, lons = get_file_names(data_info, imgs_folder, feature_folder) | |
| image_file_name, image_feature_path, caption, label, lat, lon = get_data(image_file_names, captions, image_features_path, labels, lats, lons, image_id) | |
| label_en = [] | |
| label131 = data_info['labels'] | |
| for lable_name in label131.keys(): | |
| label_id = label131[lable_name] | |
| for label_singel in label: | |
| if label_singel == label_id: | |
| label_en.append(lable_name) | |
| image_feature = np.load(image_feature_path) | |
| img = Image.open(image_file_name).convert('RGB') | |
| transform = transforms.Compose([ | |
| transforms.Resize(224), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) | |
| ]) | |
| if transform is not None: | |
| img_ = np.array(transform(img)) | |
| else: | |
| img_ = np.array(img) | |
| img_ = torch.from_numpy(img_.astype('float32')) | |
| return img, img_, caption, image_feature, label, label_en, lat, lon | |
| # test | |
| data_dir = '/data02/xy/dataEngine/json_data/merged_output_combined_9w_resplit.json' | |
| imgs_folder = '/data02/xy/Clip-hash//datasets/image/' | |
| feature_folder = '/data02/xy/Clip-hash/image_feature/georsclip_21_r0.9_fpn/' | |
| image_id = 'sample44_889.jpg' | |
| # img, img_, caption, image_feature, label, label_en, lat, lon = read_by_image_id(data_dir, imgs_folder, feature_folder, image_id) | |
| # print(img, img_, caption, image_feature, label, label_en, lat, lon) |