Files
SimSwapPlus/components/Liif.py
T
chenxuanhong 3783ef0e75 init
2022-01-10 15:03:58 +08:00

146 lines
5.5 KiB
Python

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: Liif.py
# Created Date: Monday October 18th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Tuesday, 19th October 2021 10:27:09 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
import torch
import torch.nn as nn
import torch.nn.functional as F
def make_coord(shape, ranges=None, flatten=True):
""" Make coordinates at grid centers.
"""
coord_seqs = []
for i, n in enumerate(shape):
print("i: %d, n: %d"%(i,n))
if ranges is None:
v0, v1 = -1, 1
else:
v0, v1 = ranges[i]
r = (v1 - v0) / (2 * n)
seq = v0 + r + (2 * r) * torch.arange(n).float()
coord_seqs.append(seq)
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
if flatten:
ret = ret.view(-1, ret.shape[-1])
return ret
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_list):
super().__init__()
layers = []
lastv = in_dim
for hidden in hidden_list:
layers.append(nn.Linear(lastv, hidden))
layers.append(nn.ReLU())
lastv = hidden
layers.append(nn.Linear(lastv, out_dim))
self.layers = nn.Sequential(*layers)
def forward(self, x):
shape = x.shape[:-1]
x = self.layers(x.view(-1, x.shape[-1]))
return x.view(*shape, -1)
class LIIF(nn.Module):
def __init__(self, mlp_in_dim, mlp_out_dim, mlp_hidden_list):
super().__init__()
imnet_in_dim = mlp_in_dim
imnet_in_dim *= 9
imnet_in_dim += 2 # attach coord
imnet_in_dim += 2
self.imnet = MLP(imnet_in_dim, mlp_out_dim, mlp_hidden_list).cuda()
def gen_coord(self, in_shape, output_size):
self.vx_lst = [-1, 1]
self.vy_lst = [-1, 1]
eps_shift = 1e-6
self.image_size=output_size
# field radius (global: [-1, 1])
rx = 2 / in_shape[-2] / 2
ry = 2 / in_shape[-1] / 2
coord = make_coord(output_size,flatten=False) \
.expand(in_shape[0],output_size[0],output_size[1],2) \
.view(in_shape[0],output_size[0]*output_size[1],2)
cell = torch.ones_like(coord)
cell[:, :, 0] *= 2 / coord.shape[-2]
cell[:, :, 1] *= 2 / coord.shape[-1]
feat_coord = make_coord(in_shape[-2:], flatten=False) \
.permute(2, 0, 1) \
.unsqueeze(0).expand(in_shape[0], 2, *in_shape[-2:])
areas = []
self.rel_coord = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
self.rel_cell = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
self.coord_ = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
for vx in self.vx_lst:
for vy in self.vy_lst:
self.coord_[(vx+1)//2,(vy+1)//2,:, :, :] = coord.clone()
self.coord_[(vx+1)//2,(vy+1)//2,:, :, 0] += vx * rx + eps_shift
self.coord_[(vx+1)//2,(vy+1)//2,:, :, 1] += vy * ry + eps_shift
self.coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
q_coord = F.grid_sample(
feat_coord, self.coord_[(vx+1)//2,(vy+1)//2,:, :, :].flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, :] = coord - q_coord
self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2]
self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1]
self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, :] = cell.clone()
self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2]
self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1]
area = torch.abs(self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] * self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1])
areas.append(area + 1e-9)
tot_area = torch.stack(areas).sum(dim=0)
t = areas[0]; areas[0] = areas[3]; areas[3] = t
t = areas[1]; areas[1] = areas[2]; areas[2] = t
self.area_weights = []
for item in areas:
self.area_weights.append((item / tot_area).unsqueeze(-1).cuda())
self.rel_coord = self.rel_coord.cuda()
self.rel_cell = self.rel_cell.cuda()
self.coord_ = self.coord_.cuda()
def forward(self, feat):
# B K*K*Cin H W
feat = F.unfold(feat, 3, padding=1).view(
feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])
preds = []
for vx in [0,1]:
for vy in [0,1]:
q_feat = F.grid_sample(
feat, self.coord_[vx,vy,:,:,:].flip(-1).unsqueeze(1),
mode='nearest', align_corners=False)[:, :, 0, :] \
.permute(0, 2, 1)
inp = torch.cat([q_feat, self.rel_coord[vx,vy,:,:,:], self.rel_cell[vx,vy,:,:,:]], dim=-1)
bs, q = self.coord_[0,0,:,:,:].shape[:2]
pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1)
# print("pred shape: ",pred.shape)
preds.append(pred)
ret = 0
for pred, area in zip(preds, self.area_weights):
ret = ret + pred * area
return ret.permute(0, 2, 1).view(-1,3,self.image_size[0],self.image_size[1])