#!/usr/bin/env python3 # -*- coding:utf-8 -*- ############################################################# # File: PatchNCE.py # Created Date: Friday January 21st 2022 # Author: Chen Xuanhong # Email: chenxuanhongzju@outlook.com # Last Modified: Monday, 24th January 2022 9:56:24 am # Modified By: Chen Xuanhong # Copyright (c) 2022 Shanghai Jiao Tong University ############################################################# import torch import torch.nn as nn from torch.nn import init import torch.nn.functional as F from packaging import version import numpy as np def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True): """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights Parameters: net (network) -- the network to be initialized init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal gain (float) -- scaling factor for normal, xavier and orthogonal. gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 Return an initialized network. """ if len(gpu_ids) > 0: assert(torch.cuda.is_available()) net.to(gpu_ids[0]) # if not amp: # net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs for non-AMP training if initialize_weights: init_weights(net, init_type, init_gain=init_gain, debug=debug) return net def init_weights(net, init_type='normal', init_gain=0.02, debug=False): """Initialize network weights. Parameters: net (network) -- network to be initialized init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal init_gain (float) -- scaling factor for normal, xavier and orthogonal. We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might work better for some applications. Feel free to try yourself. """ def init_func(m): # define the initialization function classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if debug: print(classname) if init_type == 'normal': init.normal_(m.weight.data, 0.0, init_gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=init_gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=init_gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. init.normal_(m.weight.data, 1.0, init_gain) init.constant_(m.bias.data, 0.0) net.apply(init_func) # apply the initialization function class Normalize(nn.Module): def __init__(self, power=2): super(Normalize, self).__init__() self.power = power def forward(self, x): norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) out = x.div(norm + 1e-7) return out class PatchSampleF(nn.Module): def __init__(self, use_mlp=False, init_type='normal', init_gain=0.02, nc=256, gpu_ids=[]): # potential issues: currently, we use the same patch_ids for multiple images in the batch super(PatchSampleF, self).__init__() self.l2norm = Normalize(2) self.use_mlp = use_mlp self.nc = nc # hard-coded self.mlp_init = False self.init_type = init_type self.init_gain = init_gain self.gpu_ids = gpu_ids def create_mlp(self, feats): for mlp_id, feat in enumerate(feats): input_nc = feat.shape[1] mlp = nn.Sequential(*[nn.Linear(input_nc, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]) if len(self.gpu_ids) > 0: mlp.cuda() setattr(self, 'mlp_%d' % mlp_id, mlp) init_net(self, self.init_type, self.init_gain, self.gpu_ids) self.mlp_init = True def forward(self, feats, num_patches=64, patch_ids=None): return_ids = [] return_feats = [] if self.use_mlp and not self.mlp_init: self.create_mlp(feats) for feat_id, feat in enumerate(feats): B, H, W = feat.shape[0], feat.shape[2], feat.shape[3] feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2) if num_patches > 0: if patch_ids is not None: patch_id = patch_ids[feat_id] else: # torch.randperm produces cudaErrorIllegalAddress for newer versions of PyTorch. https://github.com/taesungp/contrastive-unpaired-translation/issues/83 #patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device) patch_id = np.random.permutation(feat_reshape.shape[1]) patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device) x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # reshape(-1, x.shape[1]) else: x_sample = feat_reshape patch_id = [] if self.use_mlp: mlp = getattr(self, 'mlp_%d' % feat_id) x_sample = mlp(x_sample) return_ids.append(patch_id) x_sample = self.l2norm(x_sample) if num_patches == 0: x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W]) return_feats.append(x_sample) return return_feats, return_ids # def calculate_NCE_loss(src, tgt): # n_layers = len(self.nce_layers) # feat_q = self.netG(tgt, self.nce_layers, encode_only=True) # if self.opt.flip_equivariance and self.flipped_for_equivariance: # feat_q = [torch.flip(fq, [3]) for fq in feat_q] # feat_k = self.netG(src, self.nce_layers, encode_only=True) # feat_k_pool, sample_ids = self.netF(feat_k, self.opt.num_patches, None) # feat_q_pool, _ = self.netF(feat_q, self.opt.num_patches, sample_ids) # total_nce_loss = 0.0 # for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers): # loss = crit(f_q, f_k) * self.opt.lambda_NCE # total_nce_loss += loss.mean() # return total_nce_loss / n_layers class PatchNCELoss(nn.Module): def __init__(self,batch_size, nce_T = 0.07): super().__init__() self.nce_T = nce_T self.batch_size = batch_size self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none') self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool def forward(self, feat_q, feat_k): num_patches = feat_q.shape[0] dim = feat_q.shape[1] feat_k = feat_k.detach() # pos logit l_pos = torch.bmm( feat_q.view(num_patches, 1, -1), feat_k.view(num_patches, -1, 1)) l_pos = l_pos.view(num_patches, 1) # neg logit # Should the negatives from the other samples of a minibatch be utilized? # In CUT and FastCUT, we found that it's best to only include negatives # from the same image. Therefore, we set # --nce_includes_all_negatives_from_minibatch as False # However, for single-image translation, the minibatch consists of # crops from the "same" high-resolution image. # Therefore, we will include the negatives from the entire minibatch. # if self.opt.nce_includes_all_negatives_from_minibatch: # # reshape features as if they are all negatives of minibatch of size 1. # batch_dim_for_bmm = 1 # else: batch_dim_for_bmm = self.batch_size # reshape features to batch size feat_q = feat_q.view(batch_dim_for_bmm, -1, dim) feat_k = feat_k.view(batch_dim_for_bmm, -1, dim) npatches = feat_q.size(1) l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1)) # diagonal entries are similarity between same features, and hence meaningless. # just fill the diagonal with very small number, which is exp(-10) and almost zero diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :] l_neg_curbatch.masked_fill_(diagonal, -10.0) l_neg = l_neg_curbatch.view(-1, npatches) out = torch.cat((l_pos, l_neg), dim=1) / self.nce_T loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long, device=feat_q.device)) return loss if __name__ == "__main__": batch = 16 nc = 256 num_patches = 64 netF = PatchSampleF(use_mlp=True, init_type='normal', init_gain=0.02, gpu_ids=["cuda:0"], nc=nc) feat_q = [torch.ones((batch,nc,32,32)).cuda()] # if self.opt.flip_equivariance and self.flipped_for_equivariance: # feat_q = [torch.flip(fq, [3]) for fq in feat_q] crit = PatchNCELoss(batch) feat_k = [torch.ones((batch,nc,32,32)).cuda()] feat_k_pool, sample_ids = netF(feat_k, num_patches, None) print(feat_k_pool[0].shape) feat_q_pool, _ = netF(feat_q, num_patches, sample_ids) print(feat_q_pool[0].shape) loss = crit(feat_q_pool[0], feat_k_pool[0]) print(loss)