| import torch |
| import torch.nn as nn |
| from enum import Enum |
|
|
| import base64 |
| import json |
| from io import BytesIO |
| from PIL import Image |
| import requests |
| import re |
| from copy import deepcopy |
|
|
| class ImageType(Enum): |
| REAL_UP_L = 0 |
| REAL_UP_R = 1 |
| REAL_DOWN_R = 2 |
| REAL_DOWN_L = 3 |
| FAKE = 4 |
|
|
|
|
| def crop_image_part(image: torch.Tensor, |
| part: ImageType) -> torch.Tensor: |
| size = image.shape[2] // 2 |
|
|
| if part == ImageType.REAL_UP_L: |
| return image[:, :, :size, :size] |
|
|
| elif part == ImageType.REAL_UP_R: |
| return image[:, :, :size, size:] |
|
|
| elif part == ImageType.REAL_DOWN_L: |
| return image[:, :, size:, :size] |
|
|
| elif part == ImageType.REAL_DOWN_R: |
| return image[:, :, size:, size:] |
|
|
| else: |
| raise ValueError('invalid part') |
|
|
|
|
| def init_weights(module: nn.Module): |
| if isinstance(module, nn.Conv2d): |
| torch.nn.init.normal_(module.weight, 0.0, 0.02) |
|
|
| if isinstance(module, nn.BatchNorm2d): |
| torch.nn.init.normal_(module.weight, 1.0, 0.02) |
| module.bias.data.fill_(0) |
|
|
| def load_image_from_local(image_path, image_resize=None): |
| image = Image.open(image_path) |
|
|
| if isinstance(image_resize, tuple): |
| image = image.resize(image_resize) |
| return image |
|
|
| def load_image_from_url(image_url, rgba_mode=False, image_resize=None, default_image=None): |
| try: |
| image = Image.open(requests.get(image_url, stream=True).raw) |
|
|
| if rgba_mode: |
| image = image.convert("RGBA") |
|
|
| if isinstance(image_resize, tuple): |
| image = image.resize(image_resize) |
|
|
| except Exception as e: |
| image = None |
| if default_image: |
| image = load_image_from_local(default_image, image_resize=image_resize) |
|
|
| return image |
|
|
| def image_to_base64(image_array): |
| buffered = BytesIO() |
| image_array.save(buffered, format="PNG") |
| image_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8") |
| return f"data:image/png;base64, {image_b64}" |
|
|
|
|
| def copy_G_params(model): |
| flatten = deepcopy(list(p.data for p in model.parameters())) |
| return flatten |
|
|
|
|
| def load_params(model, new_param): |
| for p, new_p in zip(model.parameters(), new_param): |
| p.data.copy_(new_p) |
|
|