146 lines
5.5 KiB
Python
146 lines
5.5 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
#############################################################
|
|
# File: Liif.py
|
|
# Created Date: Monday October 18th 2021
|
|
# Author: Chen Xuanhong
|
|
# Email: chenxuanhongzju@outlook.com
|
|
# Last Modified: Tuesday, 19th October 2021 10:27:09 am
|
|
# Modified By: Chen Xuanhong
|
|
# Copyright (c) 2021 Shanghai Jiao Tong University
|
|
#############################################################
|
|
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def make_coord(shape, ranges=None, flatten=True):
|
|
""" Make coordinates at grid centers.
|
|
"""
|
|
coord_seqs = []
|
|
for i, n in enumerate(shape):
|
|
print("i: %d, n: %d"%(i,n))
|
|
if ranges is None:
|
|
v0, v1 = -1, 1
|
|
else:
|
|
v0, v1 = ranges[i]
|
|
r = (v1 - v0) / (2 * n)
|
|
seq = v0 + r + (2 * r) * torch.arange(n).float()
|
|
coord_seqs.append(seq)
|
|
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
|
|
if flatten:
|
|
ret = ret.view(-1, ret.shape[-1])
|
|
return ret
|
|
|
|
class MLP(nn.Module):
|
|
|
|
def __init__(self, in_dim, out_dim, hidden_list):
|
|
super().__init__()
|
|
layers = []
|
|
lastv = in_dim
|
|
for hidden in hidden_list:
|
|
layers.append(nn.Linear(lastv, hidden))
|
|
layers.append(nn.ReLU())
|
|
lastv = hidden
|
|
layers.append(nn.Linear(lastv, out_dim))
|
|
self.layers = nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
shape = x.shape[:-1]
|
|
x = self.layers(x.view(-1, x.shape[-1]))
|
|
return x.view(*shape, -1)
|
|
|
|
class LIIF(nn.Module):
|
|
|
|
def __init__(self, mlp_in_dim, mlp_out_dim, mlp_hidden_list):
|
|
super().__init__()
|
|
|
|
imnet_in_dim = mlp_in_dim
|
|
imnet_in_dim *= 9
|
|
imnet_in_dim += 2 # attach coord
|
|
imnet_in_dim += 2
|
|
self.imnet = MLP(imnet_in_dim, mlp_out_dim, mlp_hidden_list).cuda()
|
|
|
|
def gen_coord(self, in_shape, output_size):
|
|
|
|
self.vx_lst = [-1, 1]
|
|
self.vy_lst = [-1, 1]
|
|
eps_shift = 1e-6
|
|
self.image_size=output_size
|
|
|
|
# field radius (global: [-1, 1])
|
|
rx = 2 / in_shape[-2] / 2
|
|
ry = 2 / in_shape[-1] / 2
|
|
|
|
coord = make_coord(output_size,flatten=False) \
|
|
.expand(in_shape[0],output_size[0],output_size[1],2) \
|
|
.view(in_shape[0],output_size[0]*output_size[1],2)
|
|
|
|
cell = torch.ones_like(coord)
|
|
cell[:, :, 0] *= 2 / coord.shape[-2]
|
|
cell[:, :, 1] *= 2 / coord.shape[-1]
|
|
|
|
feat_coord = make_coord(in_shape[-2:], flatten=False) \
|
|
.permute(2, 0, 1) \
|
|
.unsqueeze(0).expand(in_shape[0], 2, *in_shape[-2:])
|
|
|
|
areas = []
|
|
|
|
self.rel_coord = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
|
|
self.rel_cell = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
|
|
self.coord_ = torch.zeros((2,2,in_shape[0],output_size[0]*output_size[1],2))
|
|
for vx in self.vx_lst:
|
|
for vy in self.vy_lst:
|
|
self.coord_[(vx+1)//2,(vy+1)//2,:, :, :] = coord.clone()
|
|
self.coord_[(vx+1)//2,(vy+1)//2,:, :, 0] += vx * rx + eps_shift
|
|
self.coord_[(vx+1)//2,(vy+1)//2,:, :, 1] += vy * ry + eps_shift
|
|
self.coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
|
|
q_coord = F.grid_sample(
|
|
feat_coord, self.coord_[(vx+1)//2,(vy+1)//2,:, :, :].flip(-1).unsqueeze(1),
|
|
mode='nearest', align_corners=False)[:, :, 0, :] \
|
|
.permute(0, 2, 1)
|
|
self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, :] = coord - q_coord
|
|
self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2]
|
|
self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1]
|
|
|
|
self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, :] = cell.clone()
|
|
self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 0] *= in_shape[-2]
|
|
self.rel_cell[(vx+1)//2,(vy+1)//2,:, :, 1] *= in_shape[-1]
|
|
area = torch.abs(self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 0] * self.rel_coord[(vx+1)//2,(vy+1)//2,:, :, 1])
|
|
areas.append(area + 1e-9)
|
|
tot_area = torch.stack(areas).sum(dim=0)
|
|
t = areas[0]; areas[0] = areas[3]; areas[3] = t
|
|
t = areas[1]; areas[1] = areas[2]; areas[2] = t
|
|
self.area_weights = []
|
|
for item in areas:
|
|
self.area_weights.append((item / tot_area).unsqueeze(-1).cuda())
|
|
|
|
self.rel_coord = self.rel_coord.cuda()
|
|
self.rel_cell = self.rel_cell.cuda()
|
|
self.coord_ = self.coord_.cuda()
|
|
|
|
def forward(self, feat):
|
|
# B K*K*Cin H W
|
|
feat = F.unfold(feat, 3, padding=1).view(
|
|
feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])
|
|
|
|
preds = []
|
|
for vx in [0,1]:
|
|
for vy in [0,1]:
|
|
q_feat = F.grid_sample(
|
|
feat, self.coord_[vx,vy,:,:,:].flip(-1).unsqueeze(1),
|
|
mode='nearest', align_corners=False)[:, :, 0, :] \
|
|
.permute(0, 2, 1)
|
|
inp = torch.cat([q_feat, self.rel_coord[vx,vy,:,:,:], self.rel_cell[vx,vy,:,:,:]], dim=-1)
|
|
|
|
bs, q = self.coord_[0,0,:,:,:].shape[:2]
|
|
pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1)
|
|
# print("pred shape: ",pred.shape)
|
|
preds.append(pred)
|
|
ret = 0
|
|
for pred, area in zip(preds, self.area_weights):
|
|
ret = ret + pred * area
|
|
|
|
return ret.permute(0, 2, 1).view(-1,3,self.image_size[0],self.image_size[1]) |