Spaces:
Build error
Build error
| from collections import OrderedDict | |
| from spiga.data.loaders.dl_config import DatabaseStruct | |
| MODELS_URL = { | |
| "wflw": "https://drive.google.com/uc?export=download&confirm=yes&id=1h0qA5ysKorpeDNRXe9oYkVcVe8UYyzP7", | |
| "300wpublic": "https://drive.google.com/uc?export=download&confirm=yes&id=1YrbScfMzrAAWMJQYgxdLZ9l57nmTdpQC", | |
| "300wprivate": "https://drive.google.com/uc?export=download&confirm=yes&id=1fYv-Ie7n14eTD0ROxJYcn6SXZY5QU9SM", | |
| "merlrav": "https://drive.google.com/uc?export=download&confirm=yes&id=1GKS1x0tpsTVivPZUk_yrSiMhwEAcAkg6", | |
| "cofw68": "https://drive.google.com/uc?export=download&confirm=yes&id=1fYv-Ie7n14eTD0ROxJYcn6SXZY5QU9SM", | |
| } | |
| class ModelConfig(object): | |
| def __init__(self, dataset_name=None, load_model_url=True): | |
| # Model configuration | |
| self.model_weights = None | |
| self.model_weights_path = "./" | |
| self.load_model_url = load_model_url | |
| self.model_weights_url = None | |
| # Pretreatment | |
| self.focal_ratio = 1.5 # Camera matrix focal length ratio. | |
| self.target_dist = 1.6 # Target distance zoom in/out around face. | |
| self.image_size = (256, 256) | |
| # Outputs | |
| self.ftmap_size = (64, 64) | |
| # Dataset | |
| self.dataset = None | |
| if dataset_name is not None: | |
| self.update_with_dataset(dataset_name) | |
| def update_with_dataset(self, dataset_name): | |
| config_dict = { | |
| "dataset": DatabaseStruct(dataset_name), | |
| "model_weights": "spiga_%s.pt" % dataset_name, | |
| } | |
| if dataset_name == "cofw68": # Test only | |
| config_dict["model_weights"] = "spiga_300wprivate.pt" | |
| if self.load_model_url: | |
| config_dict["model_weights_url"] = MODELS_URL[dataset_name] | |
| self.update(config_dict) | |
| def update(self, params_dict): | |
| state_dict = self.state_dict() | |
| for k, v in params_dict.items(): | |
| if k in state_dict or hasattr(self, k): | |
| setattr(self, k, v) | |
| else: | |
| raise Warning("Unknown option: {}: {}".format(k, v)) | |
| def state_dict(self): | |
| state_dict = OrderedDict() | |
| for k in self.__dict__.keys(): | |
| if not k.startswith("_"): | |
| state_dict[k] = getattr(self, k) | |
| return state_dict |