229 lines
9.6 KiB
Python
229 lines
9.6 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
#############################################################
|
|
# File: PatchNCE.py
|
|
# Created Date: Friday January 21st 2022
|
|
# Author: Chen Xuanhong
|
|
# Email: chenxuanhongzju@outlook.com
|
|
# Last Modified: Monday, 24th January 2022 9:56:24 am
|
|
# Modified By: Chen Xuanhong
|
|
# Copyright (c) 2022 Shanghai Jiao Tong University
|
|
#############################################################
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import init
|
|
import torch.nn.functional as F
|
|
from packaging import version
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True):
|
|
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
|
Parameters:
|
|
net (network) -- the network to be initialized
|
|
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
|
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
|
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
|
|
|
Return an initialized network.
|
|
"""
|
|
if len(gpu_ids) > 0:
|
|
assert(torch.cuda.is_available())
|
|
net.to(gpu_ids[0])
|
|
# if not amp:
|
|
# net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs for non-AMP training
|
|
if initialize_weights:
|
|
init_weights(net, init_type, init_gain=init_gain, debug=debug)
|
|
return net
|
|
|
|
def init_weights(net, init_type='normal', init_gain=0.02, debug=False):
|
|
"""Initialize network weights.
|
|
|
|
Parameters:
|
|
net (network) -- network to be initialized
|
|
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
|
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
|
|
|
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
|
work better for some applications. Feel free to try yourself.
|
|
"""
|
|
def init_func(m): # define the initialization function
|
|
classname = m.__class__.__name__
|
|
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
|
if debug:
|
|
print(classname)
|
|
if init_type == 'normal':
|
|
init.normal_(m.weight.data, 0.0, init_gain)
|
|
elif init_type == 'xavier':
|
|
init.xavier_normal_(m.weight.data, gain=init_gain)
|
|
elif init_type == 'kaiming':
|
|
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
|
elif init_type == 'orthogonal':
|
|
init.orthogonal_(m.weight.data, gain=init_gain)
|
|
else:
|
|
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
|
if hasattr(m, 'bias') and m.bias is not None:
|
|
init.constant_(m.bias.data, 0.0)
|
|
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
|
init.normal_(m.weight.data, 1.0, init_gain)
|
|
init.constant_(m.bias.data, 0.0)
|
|
|
|
net.apply(init_func) # apply the initialization function <init_func>
|
|
|
|
class Normalize(nn.Module):
|
|
|
|
def __init__(self, power=2):
|
|
super(Normalize, self).__init__()
|
|
self.power = power
|
|
|
|
def forward(self, x):
|
|
norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
|
|
out = x.div(norm + 1e-7)
|
|
return out
|
|
|
|
class PatchSampleF(nn.Module):
|
|
def __init__(self, use_mlp=False, init_type='normal', init_gain=0.02, nc=256, gpu_ids=[]):
|
|
# potential issues: currently, we use the same patch_ids for multiple images in the batch
|
|
super(PatchSampleF, self).__init__()
|
|
self.l2norm = Normalize(2)
|
|
self.use_mlp = use_mlp
|
|
self.nc = nc # hard-coded
|
|
self.mlp_init = False
|
|
self.init_type = init_type
|
|
self.init_gain = init_gain
|
|
self.gpu_ids = gpu_ids
|
|
|
|
def create_mlp(self, feats):
|
|
for mlp_id, feat in enumerate(feats):
|
|
input_nc = feat.shape[1]
|
|
mlp = nn.Sequential(*[nn.Linear(input_nc, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)])
|
|
if len(self.gpu_ids) > 0:
|
|
mlp.cuda()
|
|
setattr(self, 'mlp_%d' % mlp_id, mlp)
|
|
init_net(self, self.init_type, self.init_gain, self.gpu_ids)
|
|
self.mlp_init = True
|
|
|
|
def forward(self, feats, num_patches=64, patch_ids=None):
|
|
return_ids = []
|
|
return_feats = []
|
|
if self.use_mlp and not self.mlp_init:
|
|
self.create_mlp(feats)
|
|
for feat_id, feat in enumerate(feats):
|
|
B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]
|
|
feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
|
|
if num_patches > 0:
|
|
if patch_ids is not None:
|
|
patch_id = patch_ids[feat_id]
|
|
else:
|
|
# torch.randperm produces cudaErrorIllegalAddress for newer versions of PyTorch. https://github.com/taesungp/contrastive-unpaired-translation/issues/83
|
|
#patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device)
|
|
patch_id = np.random.permutation(feat_reshape.shape[1])
|
|
patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device)
|
|
x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # reshape(-1, x.shape[1])
|
|
else:
|
|
x_sample = feat_reshape
|
|
patch_id = []
|
|
if self.use_mlp:
|
|
mlp = getattr(self, 'mlp_%d' % feat_id)
|
|
x_sample = mlp(x_sample)
|
|
return_ids.append(patch_id)
|
|
x_sample = self.l2norm(x_sample)
|
|
|
|
if num_patches == 0:
|
|
x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W])
|
|
return_feats.append(x_sample)
|
|
return return_feats, return_ids
|
|
|
|
|
|
|
|
# def calculate_NCE_loss(src, tgt):
|
|
# n_layers = len(self.nce_layers)
|
|
# feat_q = self.netG(tgt, self.nce_layers, encode_only=True)
|
|
|
|
# if self.opt.flip_equivariance and self.flipped_for_equivariance:
|
|
# feat_q = [torch.flip(fq, [3]) for fq in feat_q]
|
|
|
|
# feat_k = self.netG(src, self.nce_layers, encode_only=True)
|
|
# feat_k_pool, sample_ids = self.netF(feat_k, self.opt.num_patches, None)
|
|
# feat_q_pool, _ = self.netF(feat_q, self.opt.num_patches, sample_ids)
|
|
|
|
# total_nce_loss = 0.0
|
|
# for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers):
|
|
# loss = crit(f_q, f_k) * self.opt.lambda_NCE
|
|
# total_nce_loss += loss.mean()
|
|
|
|
# return total_nce_loss / n_layers
|
|
|
|
class PatchNCELoss(nn.Module):
|
|
def __init__(self,batch_size, nce_T = 0.07):
|
|
super().__init__()
|
|
self.nce_T = nce_T
|
|
self.batch_size = batch_size
|
|
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
|
|
self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
|
|
|
|
def forward(self, feat_q, feat_k):
|
|
num_patches = feat_q.shape[0]
|
|
dim = feat_q.shape[1]
|
|
feat_k = feat_k.detach()
|
|
|
|
# pos logit
|
|
l_pos = torch.bmm(
|
|
feat_q.view(num_patches, 1, -1), feat_k.view(num_patches, -1, 1))
|
|
l_pos = l_pos.view(num_patches, 1)
|
|
|
|
# neg logit
|
|
|
|
# Should the negatives from the other samples of a minibatch be utilized?
|
|
# In CUT and FastCUT, we found that it's best to only include negatives
|
|
# from the same image. Therefore, we set
|
|
# --nce_includes_all_negatives_from_minibatch as False
|
|
# However, for single-image translation, the minibatch consists of
|
|
# crops from the "same" high-resolution image.
|
|
# Therefore, we will include the negatives from the entire minibatch.
|
|
# if self.opt.nce_includes_all_negatives_from_minibatch:
|
|
# # reshape features as if they are all negatives of minibatch of size 1.
|
|
# batch_dim_for_bmm = 1
|
|
# else:
|
|
batch_dim_for_bmm = self.batch_size
|
|
|
|
# reshape features to batch size
|
|
feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
|
|
feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
|
|
npatches = feat_q.size(1)
|
|
l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
|
|
|
|
# diagonal entries are similarity between same features, and hence meaningless.
|
|
# just fill the diagonal with very small number, which is exp(-10) and almost zero
|
|
diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
|
|
l_neg_curbatch.masked_fill_(diagonal, -10.0)
|
|
l_neg = l_neg_curbatch.view(-1, npatches)
|
|
|
|
out = torch.cat((l_pos, l_neg), dim=1) / self.nce_T
|
|
|
|
loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
|
|
device=feat_q.device))
|
|
|
|
return loss
|
|
|
|
if __name__ == "__main__":
|
|
batch = 16
|
|
nc = 256
|
|
num_patches = 64
|
|
netF = PatchSampleF(use_mlp=True, init_type='normal', init_gain=0.02, gpu_ids=["cuda:0"], nc=nc)
|
|
|
|
feat_q = [torch.ones((batch,nc,32,32)).cuda()]
|
|
|
|
# if self.opt.flip_equivariance and self.flipped_for_equivariance:
|
|
# feat_q = [torch.flip(fq, [3]) for fq in feat_q]
|
|
crit = PatchNCELoss(batch)
|
|
feat_k = [torch.ones((batch,nc,32,32)).cuda()]
|
|
feat_k_pool, sample_ids = netF(feat_k, num_patches, None)
|
|
print(feat_k_pool[0].shape)
|
|
feat_q_pool, _ = netF(feat_q, num_patches, sample_ids)
|
|
print(feat_q_pool[0].shape)
|
|
loss = crit(feat_q_pool[0], feat_k_pool[0])
|
|
print(loss) |