Spaces:
Running
Running
First commit
Browse files- app.py +84 -0
- datasets/rg_masks.py +326 -0
- models/layers.py +86 -0
- models/tiramisu.py +121 -0
- requirements.txt +0 -0
app.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import numpy as np
|
| 4 |
+
from datasets.rg_masks import get_transforms
|
| 5 |
+
from models import tiramisu
|
| 6 |
+
from torchvision.transforms.functional import to_pil_image
|
| 7 |
+
import torch
|
| 8 |
+
from astropy.io import fits
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_fits(path):
|
| 12 |
+
array = fits.getdata(path).astype(np.float32)
|
| 13 |
+
array = np.expand_dims(array, 2)
|
| 14 |
+
return array
|
| 15 |
+
|
| 16 |
+
def load_image(path):
|
| 17 |
+
image = Image.open(path)
|
| 18 |
+
array = np.array(image)
|
| 19 |
+
array = np.expand_dims(array[:,:,0], 2)
|
| 20 |
+
|
| 21 |
+
return array
|
| 22 |
+
|
| 23 |
+
def load_weights(model, fpath, device="cuda"):
|
| 24 |
+
print("loading weights '{}'".format(fpath))
|
| 25 |
+
weights = torch.load(fpath, map_location=torch.device(device))
|
| 26 |
+
model.load_state_dict(weights['state_dict'])
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Function to apply color overlay to the input image based on the segmentation mask
|
| 30 |
+
def apply_color_overlay(input_image, segmentation_mask, alpha=0.5):
|
| 31 |
+
r = (segmentation_mask == 1).float()
|
| 32 |
+
g = (segmentation_mask == 2).float()
|
| 33 |
+
b = (segmentation_mask == 3).float()
|
| 34 |
+
overlay = torch.cat([r, g, b], dim=0)
|
| 35 |
+
overlay = to_pil_image(overlay)
|
| 36 |
+
output = Image.blend(input_image, overlay, alpha=alpha)
|
| 37 |
+
return output
|
| 38 |
+
|
| 39 |
+
# Streamlit app
|
| 40 |
+
def main():
|
| 41 |
+
st.title("Tiramisu for semantic segmentation of radio astronomy images")
|
| 42 |
+
st.write("Upload an image and see the segmentation result!")
|
| 43 |
+
|
| 44 |
+
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "fits"])
|
| 45 |
+
|
| 46 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 47 |
+
|
| 48 |
+
model = tiramisu.FCDenseNet67(n_classes=4).to(device)
|
| 49 |
+
load_weights(model, "weights/real.pth")
|
| 50 |
+
model.eval()
|
| 51 |
+
|
| 52 |
+
st.markdown(
|
| 53 |
+
"""
|
| 54 |
+
Category Legend:
|
| 55 |
+
- :blue[Extended]
|
| 56 |
+
- :green[Compact]
|
| 57 |
+
- :red[Spurious]
|
| 58 |
+
"""
|
| 59 |
+
)
|
| 60 |
+
if uploaded_image is not None:
|
| 61 |
+
# Load the uploaded image
|
| 62 |
+
if uploaded_image.name.endswith(".fits"):
|
| 63 |
+
input_array = load_fits(uploaded_image)
|
| 64 |
+
else:
|
| 65 |
+
input_array = load_image(uploaded_image)
|
| 66 |
+
|
| 67 |
+
input_array = input_array.transpose(2,0,1)
|
| 68 |
+
transforms = get_transforms(input_array.shape[1])
|
| 69 |
+
image = transforms(input_array)
|
| 70 |
+
image = image.to(device)
|
| 71 |
+
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
output = model(image)
|
| 74 |
+
preds = output.argmax(1)
|
| 75 |
+
|
| 76 |
+
pil_image = to_pil_image(image[0])
|
| 77 |
+
# Apply color overlay to the input image
|
| 78 |
+
segmented_image = apply_color_overlay(pil_image, preds)
|
| 79 |
+
|
| 80 |
+
# Display the input image and the segmented output
|
| 81 |
+
st.image([pil_image, segmented_image], caption=["Input Image", "Segmented Output"], width=300)
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
main()
|
datasets/rg_masks.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import math
|
| 3 |
+
import random
|
| 4 |
+
import warnings
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import torch.utils.data
|
| 11 |
+
import torchvision.transforms as T
|
| 12 |
+
import torchvision.transforms.functional as TF
|
| 13 |
+
from astropy.io import fits
|
| 14 |
+
from astropy.io.fits.verify import VerifyWarning
|
| 15 |
+
from einops import rearrange
|
| 16 |
+
from torch.utils.data import Dataset
|
| 17 |
+
from torchvision.transforms.functional import to_pil_image
|
| 18 |
+
from torchvision.utils import make_grid, save_image
|
| 19 |
+
|
| 20 |
+
warnings.simplefilter('ignore', category=VerifyWarning)
|
| 21 |
+
import warnings
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
from astropy.stats import sigma_clip
|
| 26 |
+
from astropy.visualization import ZScaleInterval
|
| 27 |
+
from torch.utils.data import DataLoader
|
| 28 |
+
|
| 29 |
+
warnings.simplefilter('ignore', category=VerifyWarning)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
CLASSES = ['background', 'spurious', 'compact', 'extended']
|
| 33 |
+
COLORS = [[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_transforms(img_size):
|
| 37 |
+
return T.Compose([
|
| 38 |
+
RemoveNaNs(),
|
| 39 |
+
ZScale(),
|
| 40 |
+
SigmaClip(),
|
| 41 |
+
ToTensor(),
|
| 42 |
+
torch.nn.Tanh(),
|
| 43 |
+
MinMaxNormalize(),
|
| 44 |
+
Unsqueeze(),
|
| 45 |
+
T.Resize((img_size, img_size)),
|
| 46 |
+
RepeatChannels((3))
|
| 47 |
+
])
|
| 48 |
+
|
| 49 |
+
class RemoveNaNs(object):
|
| 50 |
+
def __init__(self):
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
def __call__(self, img):
|
| 54 |
+
img[np.isnan(img)] = 0
|
| 55 |
+
return img
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ZScale(object):
|
| 59 |
+
def __init__(self, contrast=0.15):
|
| 60 |
+
self.contrast = contrast
|
| 61 |
+
|
| 62 |
+
def __call__(self, img):
|
| 63 |
+
interval = ZScaleInterval(contrast=self.contrast)
|
| 64 |
+
min, max = interval.get_limits(img)
|
| 65 |
+
|
| 66 |
+
img = (img - min) / (max - min)
|
| 67 |
+
return img
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class SigmaClip(object):
|
| 71 |
+
def __init__(self, sigma=3, masked=True):
|
| 72 |
+
self.sigma = sigma
|
| 73 |
+
self.masked = masked
|
| 74 |
+
|
| 75 |
+
def __call__(self, img):
|
| 76 |
+
img = sigma_clip(img, sigma=self.sigma, masked=self.masked)
|
| 77 |
+
return img
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class MinMaxNormalize(object):
|
| 81 |
+
def __init__(self):
|
| 82 |
+
pass
|
| 83 |
+
|
| 84 |
+
def __call__(self, img):
|
| 85 |
+
img = (img - img.min()) / (img.max() - img.min())
|
| 86 |
+
return img
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class ToTensor(object):
|
| 90 |
+
def __init__(self):
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
def __call__(self, img):
|
| 94 |
+
return torch.tensor(img, dtype=torch.float32)
|
| 95 |
+
|
| 96 |
+
class RepeatChannels(object):
|
| 97 |
+
def __init__(self, ch):
|
| 98 |
+
self.ch = ch
|
| 99 |
+
|
| 100 |
+
def __call__(self, img):
|
| 101 |
+
return img.repeat(1, self.ch, 1, 1)
|
| 102 |
+
|
| 103 |
+
class FromNumpy(object):
|
| 104 |
+
def __init__(self):
|
| 105 |
+
pass
|
| 106 |
+
|
| 107 |
+
def __call__(self, img):
|
| 108 |
+
return torch.from_numpy(img.astype(np.float32)).type(torch.float32)
|
| 109 |
+
|
| 110 |
+
class Unsqueeze(object):
|
| 111 |
+
def __init__(self):
|
| 112 |
+
pass
|
| 113 |
+
|
| 114 |
+
def __call__(self, img):
|
| 115 |
+
return img.unsqueeze(0)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def mask_to_rgb(mask):
|
| 119 |
+
rgb_mask = torch.zeros_like(mask, device=mask.device).repeat(1, 3, 1, 1)
|
| 120 |
+
for i, c in enumerate(COLORS):
|
| 121 |
+
color_mask = torch.tensor(c, device=mask.device).unsqueeze(
|
| 122 |
+
1).unsqueeze(2) * (mask == i)
|
| 123 |
+
rgb_mask += color_mask
|
| 124 |
+
return rgb_mask
|
| 125 |
+
|
| 126 |
+
def get_data_loader(dataset, batch_size, split="train"):
|
| 127 |
+
batch_size = batch_size
|
| 128 |
+
workers = min(8, batch_size)
|
| 129 |
+
is_train = split == "train"
|
| 130 |
+
return DataLoader(dataset, shuffle=is_train, batch_size=batch_size,
|
| 131 |
+
num_workers=workers, persistent_workers=True,
|
| 132 |
+
drop_last=is_train
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def rgb_to_tensor(mask):
|
| 136 |
+
r,g,b = mask
|
| 137 |
+
r *= 1
|
| 138 |
+
g *= 2
|
| 139 |
+
b *= 3
|
| 140 |
+
mask, _ = torch.max(torch.stack([r,g,b]), dim=0, keepdim=True)
|
| 141 |
+
return mask
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def rand_horizontal_flip(img, mask):
|
| 145 |
+
if random.random() < 0.5:
|
| 146 |
+
img = TF.hflip(img)
|
| 147 |
+
mask = TF.hflip(mask)
|
| 148 |
+
return img, mask
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class RGDataset(Dataset):
|
| 152 |
+
def __init__(self, data_dir, img_paths, img_size=128):
|
| 153 |
+
super().__init__()
|
| 154 |
+
data_dir = Path(data_dir)
|
| 155 |
+
with open(img_paths) as f:
|
| 156 |
+
self.img_paths = f.read().splitlines()
|
| 157 |
+
self.img_paths = [data_dir / p for p in self.img_paths]
|
| 158 |
+
|
| 159 |
+
self.transforms = T.Compose([
|
| 160 |
+
RemoveNaNs(),
|
| 161 |
+
ZScale(),
|
| 162 |
+
SigmaClip(),
|
| 163 |
+
ToTensor(),
|
| 164 |
+
torch.nn.Tanh(),
|
| 165 |
+
MinMaxNormalize(),
|
| 166 |
+
# T.Resize((img_size),
|
| 167 |
+
# interpolation=T.InterpolationMode.NEAREST),
|
| 168 |
+
Unsqueeze(),
|
| 169 |
+
T.Resize((img_size, img_size)),
|
| 170 |
+
|
| 171 |
+
RepeatChannels((3))
|
| 172 |
+
])
|
| 173 |
+
self.img_size = img_size
|
| 174 |
+
|
| 175 |
+
self.mask_transforms = T.Compose([
|
| 176 |
+
FromNumpy(),
|
| 177 |
+
Unsqueeze(),
|
| 178 |
+
T.Resize((img_size, img_size),
|
| 179 |
+
interpolation=T.InterpolationMode.NEAREST),
|
| 180 |
+
])
|
| 181 |
+
|
| 182 |
+
def get_mask(self, img_path, type):
|
| 183 |
+
assert type in ["real", "synthetic"], f"Type {type} not supported"
|
| 184 |
+
if type == "real":
|
| 185 |
+
ann_path = str(img_path).replace(
|
| 186 |
+
'imgs', 'masks').replace('.fits', '.json')
|
| 187 |
+
ann_dir = Path(ann_path).parent
|
| 188 |
+
ann_path = ann_dir / f'mask_{ann_path.split("/")[-1]}'
|
| 189 |
+
with open(ann_path) as j:
|
| 190 |
+
mask_info = json.load(j)
|
| 191 |
+
|
| 192 |
+
masks = []
|
| 193 |
+
|
| 194 |
+
for obj in mask_info['objs']:
|
| 195 |
+
seg_path = ann_dir / obj['mask']
|
| 196 |
+
|
| 197 |
+
mask = fits.getdata(seg_path)
|
| 198 |
+
|
| 199 |
+
mask = self.mask_transforms(mask.astype(np.float32))
|
| 200 |
+
masks.append(mask)
|
| 201 |
+
mask, _ = torch.max(torch.stack(masks), dim=0)
|
| 202 |
+
|
| 203 |
+
elif type == "synthetic":
|
| 204 |
+
mask_path = str(img_path).replace("gen_fits", "cond_fits")
|
| 205 |
+
mask = fits.getdata(mask_path)
|
| 206 |
+
mask = self.mask_transforms(mask)
|
| 207 |
+
mask = mask.squeeze()
|
| 208 |
+
if mask.shape[0] == 3:
|
| 209 |
+
mask = rgb_to_tensor(mask)
|
| 210 |
+
return mask
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def __len__(self):
|
| 214 |
+
return len(self.img_paths)
|
| 215 |
+
|
| 216 |
+
def __getitem__(self, idx):
|
| 217 |
+
image_path = self.img_paths[idx]
|
| 218 |
+
img = fits.getdata(image_path)
|
| 219 |
+
img = self.transforms(img)
|
| 220 |
+
|
| 221 |
+
if "synthetic" in str(image_path):
|
| 222 |
+
mask = self.get_mask(image_path, type='synthetic')
|
| 223 |
+
else:
|
| 224 |
+
mask = self.get_mask(image_path, type='real')
|
| 225 |
+
|
| 226 |
+
# ann_path = str(image_path).replace(
|
| 227 |
+
# 'imgs', 'masks').replace('.fits', '.json')
|
| 228 |
+
# ann_dir = Path(ann_path).parent
|
| 229 |
+
# ann_path = ann_dir / f'mask_{ann_path.split("/")[-1]}'
|
| 230 |
+
# with open(ann_path) as j:
|
| 231 |
+
# mask_info = json.load(j)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
# masks = []
|
| 235 |
+
|
| 236 |
+
# for obj in mask_info['objs']:
|
| 237 |
+
# seg_path = ann_dir / obj['mask']
|
| 238 |
+
|
| 239 |
+
# mask = fits.getdata(seg_path)
|
| 240 |
+
|
| 241 |
+
# mask = self.mask_transforms(mask.astype(np.float32))
|
| 242 |
+
# masks.append(mask)
|
| 243 |
+
|
| 244 |
+
# if 'bkg' in str(image_path):
|
| 245 |
+
# mask = torch.zeros_like(img)
|
| 246 |
+
# masks.append(mask)
|
| 247 |
+
|
| 248 |
+
# mask, _ = torch.max(torch.stack(masks), dim=0)
|
| 249 |
+
mask = mask.long()
|
| 250 |
+
return img.squeeze(), mask.squeeze()
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class SyntheticRGDataset(Dataset):
|
| 254 |
+
def __init__(self, data_dir, img_paths, img_size=128):
|
| 255 |
+
super().__init__()
|
| 256 |
+
data_dir = Path(data_dir)
|
| 257 |
+
with open(img_paths) as f:
|
| 258 |
+
self.img_paths = f.read().splitlines()
|
| 259 |
+
self.img_paths = [data_dir / p for p in self.img_paths]
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
self.transforms = T.Compose([
|
| 264 |
+
RemoveNaNs(),
|
| 265 |
+
ZScale(),
|
| 266 |
+
SigmaClip(),
|
| 267 |
+
ToTensor(),
|
| 268 |
+
torch.nn.Tanh(),
|
| 269 |
+
MinMaxNormalize(),
|
| 270 |
+
# T.Resize((img_size),
|
| 271 |
+
# interpolation=T.InterpolationMode.NEAREST),
|
| 272 |
+
Unsqueeze(),
|
| 273 |
+
T.Resize((img_size, img_size)),
|
| 274 |
+
|
| 275 |
+
RepeatChannels((3))
|
| 276 |
+
])
|
| 277 |
+
self.img_size = img_size
|
| 278 |
+
|
| 279 |
+
self.mask_transforms = T.Compose([
|
| 280 |
+
FromNumpy(),
|
| 281 |
+
Unsqueeze(),
|
| 282 |
+
T.Resize((img_size, img_size),
|
| 283 |
+
interpolation=T.InterpolationMode.NEAREST),
|
| 284 |
+
])
|
| 285 |
+
|
| 286 |
+
def __len__(self):
|
| 287 |
+
return len(self.img_paths)
|
| 288 |
+
|
| 289 |
+
def __getitem__(self, idx):
|
| 290 |
+
image_path = self.img_paths[idx]
|
| 291 |
+
img = fits.getdata(image_path)
|
| 292 |
+
img = self.transforms(img)
|
| 293 |
+
img = img.squeeze()
|
| 294 |
+
|
| 295 |
+
mask_path = str(image_path).replace("gen_fits", "cond_fits")
|
| 296 |
+
mask = fits.getdata(mask_path)
|
| 297 |
+
mask = self.mask_transforms(mask)
|
| 298 |
+
|
| 299 |
+
img, mask = rand_horizontal_flip(img, mask)
|
| 300 |
+
|
| 301 |
+
mask = mask.squeeze().long()
|
| 302 |
+
return img, mask
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
if __name__ == '__main__':
|
| 306 |
+
rgtrain = SyntheticRGDataset('data/rg-dataset/data',
|
| 307 |
+
'data/rg-dataset/val_w_bg.txt')
|
| 308 |
+
batch = next(iter(rgtrain))
|
| 309 |
+
image, mask, masked_image = batch
|
| 310 |
+
to_pil_image(image).save('image.png')
|
| 311 |
+
rgb_mask = mask_to_rgb(mask)[0]
|
| 312 |
+
to_pil_image(rgb_mask).save('mask.png')
|
| 313 |
+
to_pil_image(masked_image[0]).save('masked.png')
|
| 314 |
+
|
| 315 |
+
bs = 256
|
| 316 |
+
|
| 317 |
+
loader = torch.utils.data.DataLoader(
|
| 318 |
+
rgtrain, batch_size=bs, shuffle=False, num_workers=16)
|
| 319 |
+
for i, batch in enumerate(loader):
|
| 320 |
+
image, mask, masked_image = batch
|
| 321 |
+
rgb_mask = mask_to_rgb(mask)
|
| 322 |
+
nrow = int(math.sqrt(bs))
|
| 323 |
+
# nrow = bs // 2
|
| 324 |
+
grid = make_grid(rgb_mask, nrow=nrow, padding=0)
|
| 325 |
+
save_image(grid, f'mask_{nrow}x{nrow}.png')
|
| 326 |
+
break
|
models/layers.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class DenseLayer(nn.Sequential):
|
| 6 |
+
def __init__(self, in_channels, growth_rate):
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.add_module('norm', nn.BatchNorm2d(in_channels))
|
| 9 |
+
self.add_module('relu', nn.ReLU(True))
|
| 10 |
+
self.add_module('conv', nn.Conv2d(in_channels, growth_rate, kernel_size=3,
|
| 11 |
+
stride=1, padding=1, bias=True))
|
| 12 |
+
self.add_module('drop', nn.Dropout2d(0.2))
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
return super().forward(x)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DenseBlock(nn.Module):
|
| 19 |
+
def __init__(self, in_channels, growth_rate, n_layers, upsample=False):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.upsample = upsample
|
| 22 |
+
self.layers = nn.ModuleList([DenseLayer(
|
| 23 |
+
in_channels + i*growth_rate, growth_rate)
|
| 24 |
+
for i in range(n_layers)])
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
if self.upsample:
|
| 28 |
+
new_features = []
|
| 29 |
+
# we pass all previous activations into each dense layer normally
|
| 30 |
+
# But we only store each dense layer's output in the new_features array
|
| 31 |
+
for layer in self.layers:
|
| 32 |
+
out = layer(x)
|
| 33 |
+
x = torch.cat([x, out], 1)
|
| 34 |
+
new_features.append(out)
|
| 35 |
+
return torch.cat(new_features, 1)
|
| 36 |
+
else:
|
| 37 |
+
for layer in self.layers:
|
| 38 |
+
out = layer(x)
|
| 39 |
+
x = torch.cat([x, out], 1) # 1 = channel axis
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class TransitionDown(nn.Sequential):
|
| 44 |
+
def __init__(self, in_channels):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.add_module('norm', nn.BatchNorm2d(num_features=in_channels))
|
| 47 |
+
self.add_module('relu', nn.ReLU(inplace=True))
|
| 48 |
+
self.add_module('conv', nn.Conv2d(in_channels, in_channels,
|
| 49 |
+
kernel_size=1, stride=1,
|
| 50 |
+
padding=0, bias=True))
|
| 51 |
+
self.add_module('drop', nn.Dropout2d(0.2))
|
| 52 |
+
self.add_module('maxpool', nn.MaxPool2d(2))
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
return super().forward(x)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class TransitionUp(nn.Module):
|
| 59 |
+
def __init__(self, in_channels, out_channels):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.convTrans = nn.ConvTranspose2d(
|
| 62 |
+
in_channels=in_channels, out_channels=out_channels,
|
| 63 |
+
kernel_size=3, stride=2, padding=0, bias=True)
|
| 64 |
+
|
| 65 |
+
def forward(self, x, skip):
|
| 66 |
+
out = self.convTrans(x)
|
| 67 |
+
out = center_crop(out, skip.size(2), skip.size(3))
|
| 68 |
+
out = torch.cat([out, skip], 1)
|
| 69 |
+
return out
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class Bottleneck(nn.Sequential):
|
| 73 |
+
def __init__(self, in_channels, growth_rate, n_layers):
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.add_module('bottleneck', DenseBlock(
|
| 76 |
+
in_channels, growth_rate, n_layers, upsample=True))
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
return super().forward(x)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def center_crop(layer, max_height, max_width):
|
| 83 |
+
_, _, h, w = layer.size()
|
| 84 |
+
xy1 = (w - max_width) // 2
|
| 85 |
+
xy2 = (h - max_height) // 2
|
| 86 |
+
return layer[:, :, xy2:(xy2 + max_height), xy1:(xy1 + max_width)]
|
models/tiramisu.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from .layers import *
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class FCDenseNet(nn.Module):
|
| 8 |
+
def __init__(self, in_channels=3, down_blocks=(5, 5, 5, 5, 5),
|
| 9 |
+
up_blocks=(5, 5, 5, 5, 5), bottleneck_layers=5,
|
| 10 |
+
growth_rate=16, out_chans_first_conv=48, n_classes=12):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.down_blocks = down_blocks
|
| 13 |
+
self.up_blocks = up_blocks
|
| 14 |
+
cur_channels_count = 0
|
| 15 |
+
skip_connection_channel_counts = []
|
| 16 |
+
|
| 17 |
+
## First Convolution ##
|
| 18 |
+
|
| 19 |
+
self.add_module('firstconv', nn.Conv2d(in_channels=in_channels,
|
| 20 |
+
out_channels=out_chans_first_conv, kernel_size=3,
|
| 21 |
+
stride=1, padding=1, bias=True))
|
| 22 |
+
cur_channels_count = out_chans_first_conv
|
| 23 |
+
|
| 24 |
+
#####################
|
| 25 |
+
# Downsampling path #
|
| 26 |
+
#####################
|
| 27 |
+
|
| 28 |
+
self.denseBlocksDown = nn.ModuleList([])
|
| 29 |
+
self.transDownBlocks = nn.ModuleList([])
|
| 30 |
+
for i in range(len(down_blocks)):
|
| 31 |
+
self.denseBlocksDown.append(
|
| 32 |
+
DenseBlock(cur_channels_count, growth_rate, down_blocks[i]))
|
| 33 |
+
cur_channels_count += (growth_rate*down_blocks[i])
|
| 34 |
+
skip_connection_channel_counts.insert(0, cur_channels_count)
|
| 35 |
+
self.transDownBlocks.append(TransitionDown(cur_channels_count))
|
| 36 |
+
|
| 37 |
+
#####################
|
| 38 |
+
# Bottleneck #
|
| 39 |
+
#####################
|
| 40 |
+
|
| 41 |
+
self.add_module('bottleneck', Bottleneck(cur_channels_count,
|
| 42 |
+
growth_rate, bottleneck_layers))
|
| 43 |
+
prev_block_channels = growth_rate*bottleneck_layers
|
| 44 |
+
cur_channels_count += prev_block_channels
|
| 45 |
+
|
| 46 |
+
#######################
|
| 47 |
+
# Upsampling path #
|
| 48 |
+
#######################
|
| 49 |
+
|
| 50 |
+
self.transUpBlocks = nn.ModuleList([])
|
| 51 |
+
self.denseBlocksUp = nn.ModuleList([])
|
| 52 |
+
for i in range(len(up_blocks)-1):
|
| 53 |
+
self.transUpBlocks.append(TransitionUp(
|
| 54 |
+
prev_block_channels, prev_block_channels))
|
| 55 |
+
cur_channels_count = prev_block_channels + \
|
| 56 |
+
skip_connection_channel_counts[i]
|
| 57 |
+
|
| 58 |
+
self.denseBlocksUp.append(DenseBlock(
|
| 59 |
+
cur_channels_count, growth_rate, up_blocks[i],
|
| 60 |
+
upsample=True))
|
| 61 |
+
prev_block_channels = growth_rate*up_blocks[i]
|
| 62 |
+
cur_channels_count += prev_block_channels
|
| 63 |
+
|
| 64 |
+
## Final DenseBlock ##
|
| 65 |
+
|
| 66 |
+
self.transUpBlocks.append(TransitionUp(
|
| 67 |
+
prev_block_channels, prev_block_channels))
|
| 68 |
+
cur_channels_count = prev_block_channels + \
|
| 69 |
+
skip_connection_channel_counts[-1]
|
| 70 |
+
|
| 71 |
+
self.denseBlocksUp.append(DenseBlock(
|
| 72 |
+
cur_channels_count, growth_rate, up_blocks[-1],
|
| 73 |
+
upsample=False))
|
| 74 |
+
cur_channels_count += growth_rate*up_blocks[-1]
|
| 75 |
+
|
| 76 |
+
## Softmax ##
|
| 77 |
+
|
| 78 |
+
self.finalConv = nn.Conv2d(in_channels=cur_channels_count,
|
| 79 |
+
out_channels=n_classes, kernel_size=1, stride=1,
|
| 80 |
+
padding=0, bias=True)
|
| 81 |
+
self.softmax = nn.LogSoftmax(dim=1)
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
out = self.firstconv(x)
|
| 85 |
+
|
| 86 |
+
skip_connections = []
|
| 87 |
+
for i in range(len(self.down_blocks)):
|
| 88 |
+
out = self.denseBlocksDown[i](out)
|
| 89 |
+
skip_connections.append(out)
|
| 90 |
+
out = self.transDownBlocks[i](out)
|
| 91 |
+
|
| 92 |
+
out = self.bottleneck(out)
|
| 93 |
+
for i in range(len(self.up_blocks)):
|
| 94 |
+
skip = skip_connections.pop()
|
| 95 |
+
out = self.transUpBlocks[i](out, skip)
|
| 96 |
+
out = self.denseBlocksUp[i](out)
|
| 97 |
+
|
| 98 |
+
out = self.finalConv(out)
|
| 99 |
+
out = self.softmax(out)
|
| 100 |
+
return out
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def FCDenseNet57(n_classes):
|
| 104 |
+
return FCDenseNet(
|
| 105 |
+
in_channels=3, down_blocks=(4, 4, 4, 4, 4),
|
| 106 |
+
up_blocks=(4, 4, 4, 4, 4), bottleneck_layers=4,
|
| 107 |
+
growth_rate=12, out_chans_first_conv=48, n_classes=n_classes)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def FCDenseNet67(n_classes):
|
| 111 |
+
return FCDenseNet(
|
| 112 |
+
in_channels=3, down_blocks=(5, 5, 5, 5, 5),
|
| 113 |
+
up_blocks=(5, 5, 5, 5, 5), bottleneck_layers=5,
|
| 114 |
+
growth_rate=16, out_chans_first_conv=48, n_classes=n_classes)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def FCDenseNet103(n_classes):
|
| 118 |
+
return FCDenseNet(
|
| 119 |
+
in_channels=3, down_blocks=(4, 5, 7, 10, 12),
|
| 120 |
+
up_blocks=(12, 10, 7, 5, 4), bottleneck_layers=15,
|
| 121 |
+
growth_rate=16, out_chans_first_conv=48, n_classes=n_classes)
|
requirements.txt
ADDED
|
Binary file (60 Bytes). View file
|
|
|