Files
SimSwapPlus/losses/PatchNCE.py
T
XHChen0528 9efbe03d3f update
2022-03-19 22:47:08 +08:00

229 lines
9.6 KiB
Python

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: PatchNCE.py
# Created Date: Friday January 21st 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Monday, 24th January 2022 9:56:24 am
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
from packaging import version
import numpy as np
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True):
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
Parameters:
net (network) -- the network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
Return an initialized network.
"""
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
net.to(gpu_ids[0])
# if not amp:
# net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs for non-AMP training
if initialize_weights:
init_weights(net, init_type, init_gain=init_gain, debug=debug)
return net
def init_weights(net, init_type='normal', init_gain=0.02, debug=False):
"""Initialize network weights.
Parameters:
net (network) -- network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself.
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if debug:
print(classname)
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
net.apply(init_func) # apply the initialization function <init_func>
class Normalize(nn.Module):
def __init__(self, power=2):
super(Normalize, self).__init__()
self.power = power
def forward(self, x):
norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
out = x.div(norm + 1e-7)
return out
class PatchSampleF(nn.Module):
def __init__(self, use_mlp=False, init_type='normal', init_gain=0.02, nc=256, gpu_ids=[]):
# potential issues: currently, we use the same patch_ids for multiple images in the batch
super(PatchSampleF, self).__init__()
self.l2norm = Normalize(2)
self.use_mlp = use_mlp
self.nc = nc # hard-coded
self.mlp_init = False
self.init_type = init_type
self.init_gain = init_gain
self.gpu_ids = gpu_ids
def create_mlp(self, feats):
for mlp_id, feat in enumerate(feats):
input_nc = feat.shape[1]
mlp = nn.Sequential(*[nn.Linear(input_nc, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)])
if len(self.gpu_ids) > 0:
mlp.cuda()
setattr(self, 'mlp_%d' % mlp_id, mlp)
init_net(self, self.init_type, self.init_gain, self.gpu_ids)
self.mlp_init = True
def forward(self, feats, num_patches=64, patch_ids=None):
return_ids = []
return_feats = []
if self.use_mlp and not self.mlp_init:
self.create_mlp(feats)
for feat_id, feat in enumerate(feats):
B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]
feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
if num_patches > 0:
if patch_ids is not None:
patch_id = patch_ids[feat_id]
else:
# torch.randperm produces cudaErrorIllegalAddress for newer versions of PyTorch. https://github.com/taesungp/contrastive-unpaired-translation/issues/83
#patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device)
patch_id = np.random.permutation(feat_reshape.shape[1])
patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device)
x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # reshape(-1, x.shape[1])
else:
x_sample = feat_reshape
patch_id = []
if self.use_mlp:
mlp = getattr(self, 'mlp_%d' % feat_id)
x_sample = mlp(x_sample)
return_ids.append(patch_id)
x_sample = self.l2norm(x_sample)
if num_patches == 0:
x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W])
return_feats.append(x_sample)
return return_feats, return_ids
# def calculate_NCE_loss(src, tgt):
# n_layers = len(self.nce_layers)
# feat_q = self.netG(tgt, self.nce_layers, encode_only=True)
# if self.opt.flip_equivariance and self.flipped_for_equivariance:
# feat_q = [torch.flip(fq, [3]) for fq in feat_q]
# feat_k = self.netG(src, self.nce_layers, encode_only=True)
# feat_k_pool, sample_ids = self.netF(feat_k, self.opt.num_patches, None)
# feat_q_pool, _ = self.netF(feat_q, self.opt.num_patches, sample_ids)
# total_nce_loss = 0.0
# for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers):
# loss = crit(f_q, f_k) * self.opt.lambda_NCE
# total_nce_loss += loss.mean()
# return total_nce_loss / n_layers
class PatchNCELoss(nn.Module):
def __init__(self,batch_size, nce_T = 0.07):
super().__init__()
self.nce_T = nce_T
self.batch_size = batch_size
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
def forward(self, feat_q, feat_k):
num_patches = feat_q.shape[0]
dim = feat_q.shape[1]
feat_k = feat_k.detach()
# pos logit
l_pos = torch.bmm(
feat_q.view(num_patches, 1, -1), feat_k.view(num_patches, -1, 1))
l_pos = l_pos.view(num_patches, 1)
# neg logit
# Should the negatives from the other samples of a minibatch be utilized?
# In CUT and FastCUT, we found that it's best to only include negatives
# from the same image. Therefore, we set
# --nce_includes_all_negatives_from_minibatch as False
# However, for single-image translation, the minibatch consists of
# crops from the "same" high-resolution image.
# Therefore, we will include the negatives from the entire minibatch.
# if self.opt.nce_includes_all_negatives_from_minibatch:
# # reshape features as if they are all negatives of minibatch of size 1.
# batch_dim_for_bmm = 1
# else:
batch_dim_for_bmm = self.batch_size
# reshape features to batch size
feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
npatches = feat_q.size(1)
l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
# diagonal entries are similarity between same features, and hence meaningless.
# just fill the diagonal with very small number, which is exp(-10) and almost zero
diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
l_neg_curbatch.masked_fill_(diagonal, -10.0)
l_neg = l_neg_curbatch.view(-1, npatches)
out = torch.cat((l_pos, l_neg), dim=1) / self.nce_T
loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
device=feat_q.device))
return loss
if __name__ == "__main__":
batch = 16
nc = 256
num_patches = 64
netF = PatchSampleF(use_mlp=True, init_type='normal', init_gain=0.02, gpu_ids=["cuda:0"], nc=nc)
feat_q = [torch.ones((batch,nc,32,32)).cuda()]
# if self.opt.flip_equivariance and self.flipped_for_equivariance:
# feat_q = [torch.flip(fq, [3]) for fq in feat_q]
crit = PatchNCELoss(batch)
feat_k = [torch.ones((batch,nc,32,32)).cuda()]
feat_k_pool, sample_ids = netF(feat_k, num_patches, None)
print(feat_k_pool[0].shape)
feat_q_pool, _ = netF(feat_q, num_patches, sample_ids)
print(feat_q_pool[0].shape)
loss = crit(feat_q_pool[0], feat_k_pool[0])
print(loss)