Spaces:
Runtime error
Runtime error
| from collections import OrderedDict | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import torch.nn.functional as F | |
| import torchvision.models as models | |
| ''' | |
| # -------------------------------------------- | |
| # Advanced nn.Sequential | |
| # https://github.com/xinntao/BasicSR | |
| # -------------------------------------------- | |
| ''' | |
| def sequential(*args): | |
| """Advanced nn.Sequential. | |
| Args: | |
| nn.Sequential, nn.Module | |
| Returns: | |
| nn.Sequential | |
| """ | |
| if len(args) == 1: | |
| if isinstance(args[0], OrderedDict): | |
| raise NotImplementedError('sequential does not support OrderedDict input.') | |
| return args[0] # No sequential is needed. | |
| modules = [] | |
| for module in args: | |
| if isinstance(module, nn.Sequential): | |
| for submodule in module.children(): | |
| modules.append(submodule) | |
| elif isinstance(module, nn.Module): | |
| modules.append(module) | |
| return nn.Sequential(*modules) | |
| # -------------------------------------------- | |
| # return nn.Sequantial of (Conv + BN + ReLU) | |
| # -------------------------------------------- | |
| def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CBR', negative_slope=0.2): | |
| L = [] | |
| for t in mode: | |
| if t == 'C': | |
| L.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)) | |
| elif t == 'T': | |
| L.append(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)) | |
| elif t == 'B': | |
| L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True)) | |
| elif t == 'I': | |
| L.append(nn.InstanceNorm2d(out_channels, affine=True)) | |
| elif t == 'R': | |
| L.append(nn.ReLU(inplace=True)) | |
| elif t == 'r': | |
| L.append(nn.ReLU(inplace=False)) | |
| elif t == 'L': | |
| L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True)) | |
| elif t == 'l': | |
| L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False)) | |
| elif t == '2': | |
| L.append(nn.PixelShuffle(upscale_factor=2)) | |
| elif t == '3': | |
| L.append(nn.PixelShuffle(upscale_factor=3)) | |
| elif t == '4': | |
| L.append(nn.PixelShuffle(upscale_factor=4)) | |
| elif t == 'U': | |
| L.append(nn.Upsample(scale_factor=2, mode='nearest')) | |
| elif t == 'u': | |
| L.append(nn.Upsample(scale_factor=3, mode='nearest')) | |
| elif t == 'v': | |
| L.append(nn.Upsample(scale_factor=4, mode='nearest')) | |
| elif t == 'M': | |
| L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0)) | |
| elif t == 'A': | |
| L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0)) | |
| else: | |
| raise NotImplementedError('Undefined type: '.format(t)) | |
| return sequential(*L) | |
| # -------------------------------------------- | |
| # Res Block: x + conv(relu(conv(x))) | |
| # -------------------------------------------- | |
| class ResBlock(nn.Module): | |
| def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', negative_slope=0.2): | |
| super(ResBlock, self).__init__() | |
| assert in_channels == out_channels, 'Only support in_channels==out_channels.' | |
| if mode[0] in ['R', 'L']: | |
| mode = mode[0].lower() + mode[1:] | |
| self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) | |
| def forward(self, x): | |
| res = self.res(x) | |
| return x + res | |
| # -------------------------------------------- | |
| # conv + subp (+ relu) | |
| # -------------------------------------------- | |
| def upsample_pixelshuffle(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2): | |
| assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' | |
| up1 = conv(in_channels, out_channels * (int(mode[0]) ** 2), kernel_size, stride, padding, bias, mode='C'+mode, negative_slope=negative_slope) | |
| return up1 | |
| # -------------------------------------------- | |
| # nearest_upsample + conv (+ R) | |
| # -------------------------------------------- | |
| def upsample_upconv(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2): | |
| assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR' | |
| if mode[0] == '2': | |
| uc = 'UC' | |
| elif mode[0] == '3': | |
| uc = 'uC' | |
| elif mode[0] == '4': | |
| uc = 'vC' | |
| mode = mode.replace(mode[0], uc) | |
| up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode, negative_slope=negative_slope) | |
| return up1 | |
| # -------------------------------------------- | |
| # convTranspose (+ relu) | |
| # -------------------------------------------- | |
| def upsample_convtranspose(in_channels=64, out_channels=3, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', negative_slope=0.2): | |
| assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' | |
| kernel_size = int(mode[0]) | |
| stride = int(mode[0]) | |
| mode = mode.replace(mode[0], 'T') | |
| up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) | |
| return up1 | |
| ''' | |
| # -------------------------------------------- | |
| # Downsampler | |
| # Kai Zhang, https://github.com/cszn/KAIR | |
| # -------------------------------------------- | |
| # downsample_strideconv | |
| # downsample_maxpool | |
| # downsample_avgpool | |
| # -------------------------------------------- | |
| ''' | |
| # -------------------------------------------- | |
| # strideconv (+ relu) | |
| # -------------------------------------------- | |
| def downsample_strideconv(in_channels=64, out_channels=64, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', negative_slope=0.2): | |
| assert len(mode)<4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' | |
| kernel_size = int(mode[0]) | |
| stride = int(mode[0]) | |
| mode = mode.replace(mode[0], 'C') | |
| down1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) | |
| return down1 | |
| # -------------------------------------------- | |
| # maxpooling + conv (+ relu) | |
| # -------------------------------------------- | |
| def downsample_maxpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0, bias=True, mode='2R', negative_slope=0.2): | |
| assert len(mode)<4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.' | |
| kernel_size_pool = int(mode[0]) | |
| stride_pool = int(mode[0]) | |
| mode = mode.replace(mode[0], 'MC') | |
| pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope) | |
| pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope) | |
| return sequential(pool, pool_tail) | |
| # -------------------------------------------- | |
| # averagepooling + conv (+ relu) | |
| # -------------------------------------------- | |
| def downsample_avgpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', negative_slope=0.2): | |
| assert len(mode)<4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.' | |
| kernel_size_pool = int(mode[0]) | |
| stride_pool = int(mode[0]) | |
| mode = mode.replace(mode[0], 'AC') | |
| pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope) | |
| pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope) | |
| return sequential(pool, pool_tail) | |
| class QFAttention(nn.Module): | |
| def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', negative_slope=0.2): | |
| super(QFAttention, self).__init__() | |
| assert in_channels == out_channels, 'Only support in_channels==out_channels.' | |
| if mode[0] in ['R', 'L']: | |
| mode = mode[0].lower() + mode[1:] | |
| self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) | |
| def forward(self, x, gamma, beta): | |
| gamma = gamma.unsqueeze(-1).unsqueeze(-1) | |
| beta = beta.unsqueeze(-1).unsqueeze(-1) | |
| res = (gamma)*self.res(x) + beta | |
| return x + res | |
| class FBCNN(nn.Module): | |
| def __init__(self, in_nc=3, out_nc=3, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv', | |
| upsample_mode='convtranspose'): | |
| super(FBCNN, self).__init__() | |
| self.m_head = conv(in_nc, nc[0], bias=True, mode='C') | |
| self.nb = nb | |
| self.nc = nc | |
| # downsample | |
| if downsample_mode == 'avgpool': | |
| downsample_block = downsample_avgpool | |
| elif downsample_mode == 'maxpool': | |
| downsample_block = downsample_maxpool | |
| elif downsample_mode == 'strideconv': | |
| downsample_block = downsample_strideconv | |
| else: | |
| raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode)) | |
| self.m_down1 = sequential( | |
| *[ResBlock(nc[0], nc[0], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)], | |
| downsample_block(nc[0], nc[1], bias=True, mode='2')) | |
| self.m_down2 = sequential( | |
| *[ResBlock(nc[1], nc[1], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)], | |
| downsample_block(nc[1], nc[2], bias=True, mode='2')) | |
| self.m_down3 = sequential( | |
| *[ResBlock(nc[2], nc[2], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)], | |
| downsample_block(nc[2], nc[3], bias=True, mode='2')) | |
| self.m_body_encoder = sequential( | |
| *[ResBlock(nc[3], nc[3], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]) | |
| self.m_body_decoder = sequential( | |
| *[ResBlock(nc[3], nc[3], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]) | |
| # upsample | |
| if upsample_mode == 'upconv': | |
| upsample_block = upsample_upconv | |
| elif upsample_mode == 'pixelshuffle': | |
| upsample_block = upsample_pixelshuffle | |
| elif upsample_mode == 'convtranspose': | |
| upsample_block = upsample_convtranspose | |
| else: | |
| raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) | |
| self.m_up3 = nn.ModuleList([upsample_block(nc[3], nc[2], bias=True, mode='2'), | |
| *[QFAttention(nc[2], nc[2], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]]) | |
| self.m_up2 = nn.ModuleList([upsample_block(nc[2], nc[1], bias=True, mode='2'), | |
| *[QFAttention(nc[1], nc[1], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]]) | |
| self.m_up1 = nn.ModuleList([upsample_block(nc[1], nc[0], bias=True, mode='2'), | |
| *[QFAttention(nc[0], nc[0], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)]]) | |
| self.m_tail = conv(nc[0], out_nc, bias=True, mode='C') | |
| self.qf_pred = sequential(*[ResBlock(nc[3], nc[3], bias=True, mode='C' + act_mode + 'C') for _ in range(nb)], | |
| torch.nn.AdaptiveAvgPool2d((1,1)), | |
| torch.nn.Flatten(), | |
| torch.nn.Linear(512, 512), | |
| nn.ReLU(), | |
| torch.nn.Linear(512, 512), | |
| nn.ReLU(), | |
| torch.nn.Linear(512, 1), | |
| nn.Sigmoid() | |
| ) | |
| self.qf_embed = sequential(torch.nn.Linear(1, 512), | |
| nn.ReLU(), | |
| torch.nn.Linear(512, 512), | |
| nn.ReLU(), | |
| torch.nn.Linear(512, 512), | |
| nn.ReLU() | |
| ) | |
| self.to_gamma_3 = sequential(torch.nn.Linear(512, nc[2]),nn.Sigmoid()) | |
| self.to_beta_3 = sequential(torch.nn.Linear(512, nc[2]),nn.Tanh()) | |
| self.to_gamma_2 = sequential(torch.nn.Linear(512, nc[1]),nn.Sigmoid()) | |
| self.to_beta_2 = sequential(torch.nn.Linear(512, nc[1]),nn.Tanh()) | |
| self.to_gamma_1 = sequential(torch.nn.Linear(512, nc[0]),nn.Sigmoid()) | |
| self.to_beta_1 = sequential(torch.nn.Linear(512, nc[0]),nn.Tanh()) | |
| def forward(self, x, qf_input=None): | |
| h, w = x.size()[-2:] | |
| paddingBottom = int(np.ceil(h / 8) * 8 - h) | |
| paddingRight = int(np.ceil(w / 8) * 8 - w) | |
| x = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x) | |
| x1 = self.m_head(x) | |
| x2 = self.m_down1(x1) | |
| x3 = self.m_down2(x2) | |
| x4 = self.m_down3(x3) | |
| x = self.m_body_encoder(x4) | |
| qf = self.qf_pred(x) | |
| x = self.m_body_decoder(x) | |
| qf_embedding = self.qf_embed(qf_input) if qf_input is not None else self.qf_embed(qf) | |
| gamma_3 = self.to_gamma_3(qf_embedding) | |
| beta_3 = self.to_beta_3(qf_embedding) | |
| gamma_2 = self.to_gamma_2(qf_embedding) | |
| beta_2 = self.to_beta_2(qf_embedding) | |
| gamma_1 = self.to_gamma_1(qf_embedding) | |
| beta_1 = self.to_beta_1(qf_embedding) | |
| x = x + x4 | |
| x = self.m_up3[0](x) | |
| for i in range(self.nb): | |
| x = self.m_up3[i+1](x, gamma_3,beta_3) | |
| x = x + x3 | |
| x = self.m_up2[0](x) | |
| for i in range(self.nb): | |
| x = self.m_up2[i+1](x, gamma_2, beta_2) | |
| x = x + x2 | |
| x = self.m_up1[0](x) | |
| for i in range(self.nb): | |
| x = self.m_up1[i+1](x, gamma_1, beta_1) | |
| x = x + x1 | |
| x = self.m_tail(x) | |
| x = x[..., :h, :w] | |
| return x, qf | |
| if __name__ == "__main__": | |
| x = torch.randn(1, 3, 96, 96)#.cuda()#.to(torch.device('cuda')) | |
| fbar=FBAR() | |
| y,qf = fbar(x) | |
| print(y.shape,qf.shape) | |