support multi-gpu
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"GUI.py": 1642351532.4558506,
|
||||
"test.py": 1642733759.8015468,
|
||||
"train.py": 1643009568.7253726,
|
||||
"test.py": 1643529962.5602193,
|
||||
"train.py": 1643397924.974299,
|
||||
"components\\Generator.py": 1642347735.351465,
|
||||
"components\\projected_discriminator.py": 1642348101.4661522,
|
||||
"components\\pg_modules\\blocks.py": 1640773190.0,
|
||||
@@ -12,7 +12,7 @@
|
||||
"components\\pg_modules\\projector.py": 1642349764.3896568,
|
||||
"data_tools\\data_loader.py": 1611123530.660446,
|
||||
"data_tools\\data_loader_condition.py": 1625411562.8217106,
|
||||
"data_tools\\data_loader_VGGFace2HQ.py": 1642349144.749807,
|
||||
"data_tools\\data_loader_VGGFace2HQ.py": 1644234949.3769877,
|
||||
"data_tools\\StyleResize.py": 1624954084.7176485,
|
||||
"data_tools\\test_dataloader_dir.py": 1634041792.6743984,
|
||||
"losses\\PerceptualLoss.py": 1615020169.668723,
|
||||
@@ -23,7 +23,7 @@
|
||||
"test_scripts\\tester_common.py": 1625369535.199175,
|
||||
"test_scripts\\tester_FastNST.py": 1634041357.607633,
|
||||
"train_scripts\\trainer_base.py": 1642396105.3868554,
|
||||
"train_scripts\\trainer_FM.py": 1642826710.3532298,
|
||||
"train_scripts\\trainer_FM.py": 1643021959.3577182,
|
||||
"train_scripts\\trainer_naiv512.py": 1642315674.9740853,
|
||||
"utilities\\checkpoint_manager.py": 1611123530.6624403,
|
||||
"utilities\\figure.py": 1611123530.6634378,
|
||||
@@ -37,11 +37,11 @@
|
||||
"utilities\\transfer_checkpoint.py": 1642397157.0163105,
|
||||
"utilities\\utilities.py": 1634019485.0783668,
|
||||
"utilities\\yaml_config.py": 1611123530.6614666,
|
||||
"train_yamls\\train_512FM.yaml": 1642412254.0831068,
|
||||
"train_yamls\\train_512FM.yaml": 1643021615.8106658,
|
||||
"train_scripts\\trainer_2layer_FM.py": 1642826548.2530458,
|
||||
"train_yamls\\train_2layer_FM.yaml": 1642411635.5534878,
|
||||
"components\\Generator_reduce.py": 1642690262.572149,
|
||||
"insightface_func\\face_detect_crop_multi.py": 1638370471.789609,
|
||||
"components\\Generator_reduce.py": 1643021243.6658802,
|
||||
"insightface_func\\face_detect_crop_multi.py": 1643796928.6362474,
|
||||
"insightface_func\\face_detect_crop_single.py": 1638370471.7967434,
|
||||
"insightface_func\\__init__.py": 1624197300.011183,
|
||||
"insightface_func\\utils\\face_align_ffhqandnewarc.py": 1638370471.850638,
|
||||
@@ -51,9 +51,58 @@
|
||||
"test_scripts\\tester_common copy.py": 1625369535.199175,
|
||||
"test_scripts\\tester_video.py": 1642734397.3307388,
|
||||
"train_scripts\\trainer_cycleloss.py": 1642580463.495596,
|
||||
"train_scripts\\trainer_GramFM.py": 1643010471.1077821,
|
||||
"train_scripts\\trainer_GramFM.py": 1643095575.2628715,
|
||||
"utilities\\ImagenetNorm.py": 1642732910.5280058,
|
||||
"utilities\\reverse2original.py": 1642733688.7976837,
|
||||
"train_yamls\\train_cycleloss.yaml": 1642577741.345273,
|
||||
"train_yamls\\train_GramFM.yaml": 1643011210.82505
|
||||
"train_yamls\\train_GramFM.yaml": 1643398791.363959,
|
||||
"train_yamls\\train_512FM_Modulation.yaml": 1643022022.3165789,
|
||||
"face_crop.py": 1643789609.1834445,
|
||||
"face_crop_video.py": 1643815024.5516832,
|
||||
"similarity.py": 1643269705.1073737,
|
||||
"train_multigpu.py": 1644296706.054128,
|
||||
"components\\arcface_decoder.py": 1643396144.2575414,
|
||||
"components\\Generator_nobias.py": 1643179001.810856,
|
||||
"data_tools\\data_loader_VGGFace2HQ_multigpu.py": 1644299401.8480241,
|
||||
"data_tools\\data_loader_VGGFace2HQ_Rec.py": 1643398754.86898,
|
||||
"test_scripts\\tester_arcface_Rec.py": 1643431261.9333818,
|
||||
"test_scripts\\tester_image.py": 1643428951.5532105,
|
||||
"torch_utils\\custom_ops.py": 1640773190.0,
|
||||
"torch_utils\\misc.py": 1640773190.0,
|
||||
"torch_utils\\persistence.py": 1640773190.0,
|
||||
"torch_utils\\training_stats.py": 1640773190.0,
|
||||
"torch_utils\\utils_spectrum.py": 1640773190.0,
|
||||
"torch_utils\\__init__.py": 1640773190.0,
|
||||
"torch_utils\\ops\\bias_act.py": 1640773190.0,
|
||||
"torch_utils\\ops\\conv2d_gradfix.py": 1640773190.0,
|
||||
"torch_utils\\ops\\conv2d_resample.py": 1640773190.0,
|
||||
"torch_utils\\ops\\filtered_lrelu.py": 1640773190.0,
|
||||
"torch_utils\\ops\\fma.py": 1640773190.0,
|
||||
"torch_utils\\ops\\grid_sample_gradfix.py": 1640773190.0,
|
||||
"torch_utils\\ops\\upfirdn2d.py": 1640773190.0,
|
||||
"torch_utils\\ops\\__init__.py": 1640773190.0,
|
||||
"train_scripts\\trainer_arcface_rec.py": 1643399647.0182135,
|
||||
"train_scripts\\trainer_multigpu_base.py": 1644131205.772292,
|
||||
"train_scripts\\trainer_multi_gpu.py": 1644301774.3077753,
|
||||
"train_yamls\\train_arcface_rec.yaml": 1643398807.3434353,
|
||||
"train_yamls\\train_multigpu.yaml": 1644301838.3615713,
|
||||
"wandb\\run-20220129_032741-340btp9k\\files\\conda-environment.yaml": 1643398065.409959,
|
||||
"wandb\\run-20220129_032741-340btp9k\\files\\config.yaml": 1643398069.2392955,
|
||||
"wandb\\run-20220129_032939-2nmaozxq\\files\\conda-environment.yaml": 1643398182.647548,
|
||||
"wandb\\run-20220129_032939-2nmaozxq\\files\\config.yaml": 1643398186.3626983,
|
||||
"wandb\\run-20220129_033051-21z19tyg\\files\\conda-environment.yaml": 1643398254.9293146,
|
||||
"wandb\\run-20220129_033051-21z19tyg\\files\\config.yaml": 1643398259.2274177,
|
||||
"wandb\\run-20220129_033202-16la4gpu\\files\\conda-environment.yaml": 1643398325.8794518,
|
||||
"wandb\\run-20220129_033202-16la4gpu\\files\\config.yaml": 1643398324.9487782,
|
||||
"wandb\\run-20220129_034327-1bmseytq\\files\\conda-environment.yaml": 1643399010.865907,
|
||||
"wandb\\run-20220129_034327-1bmseytq\\files\\config.yaml": 1643399148.0268817,
|
||||
"wandb\\run-20220129_034859-2puk6sph\\files\\conda-environment.yaml": 1643399343.3508356,
|
||||
"wandb\\run-20220129_034859-2puk6sph\\files\\config.yaml": 1643399477.881678,
|
||||
"wandb\\run-20220129_035624-3hmwgcgw\\files\\conda-environment.yaml": 1643399787.8899708,
|
||||
"wandb\\run-20220129_035624-3hmwgcgw\\files\\config.yaml": 1643426465.6088357,
|
||||
"dnnlib\\util.py": 1640773190.0,
|
||||
"dnnlib\\__init__.py": 1640773190.0,
|
||||
"components\\Generator_ori.py": 1644229508.0031855,
|
||||
"losses\\cos.py": 1644229583.4023254,
|
||||
"data_tools\\data_loader_VGGFace2HQ_multigpu1.py": 1644297868.397411
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
# Simswap++
|
||||
|
||||
## Dependencies
|
||||
- python
|
||||
- python > 3.6
|
||||
- yaml (pip install pyyaml)
|
||||
- paramiko (For ssh file transportation)
|
||||
- pytorch > 1.8
|
||||
|
||||
@@ -0,0 +1,187 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator.py
|
||||
# Created Date: Sunday January 16th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Wednesday, 26th January 2022 2:36:41 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
from audioop import bias
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
from torch.nn import functional as F
|
||||
|
||||
class InstanceNorm(nn.Module):
|
||||
def __init__(self, epsilon=1e-8):
|
||||
"""
|
||||
@notice: avoid in-place ops.
|
||||
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
|
||||
"""
|
||||
super(InstanceNorm, self).__init__()
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, x):
|
||||
x = x - torch.mean(x, (2, 3), True)
|
||||
tmp = torch.mul(x, x) # or x ** 2
|
||||
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
|
||||
return x * tmp
|
||||
|
||||
class ApplyStyle(nn.Module):
|
||||
"""
|
||||
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
|
||||
"""
|
||||
def __init__(self, latent_size, channels):
|
||||
super(ApplyStyle, self).__init__()
|
||||
self.linear = nn.Linear(latent_size, channels * 2)
|
||||
|
||||
def forward(self, x, latent):
|
||||
style = self.linear(latent) # style => [batch_size, n_channels*2]
|
||||
shape = [-1, 2, x.size(1), 1, 1]
|
||||
style = style.view(shape) # [batch_size, 2, n_channels, ...]
|
||||
#x = x * (style[:, 0] + 1.) + style[:, 1]
|
||||
x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1
|
||||
return x
|
||||
|
||||
class ResnetBlock_Adain(nn.Module):
|
||||
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
|
||||
super(ResnetBlock_Adain, self).__init__()
|
||||
|
||||
p = 0
|
||||
conv1 = []
|
||||
if padding_type == 'reflect':
|
||||
conv1 += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv1 += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p, bias=False), InstanceNorm()]
|
||||
self.conv1 = nn.Sequential(*conv1)
|
||||
self.style1 = ApplyStyle(latent_size, dim)
|
||||
self.act1 = activation
|
||||
|
||||
p = 0
|
||||
conv2 = []
|
||||
if padding_type == 'reflect':
|
||||
conv2 += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv2 += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=False), InstanceNorm()]
|
||||
self.conv2 = nn.Sequential(*conv2)
|
||||
self.style2 = ApplyStyle(latent_size, dim)
|
||||
|
||||
|
||||
def forward(self, x, dlatents_in_slice):
|
||||
y = self.conv1(x)
|
||||
y = self.style1(y, dlatents_in_slice)
|
||||
y = self.act1(y)
|
||||
y = self.conv2(y)
|
||||
y = self.style2(y, dlatents_in_slice)
|
||||
out = x + y
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
chn = kwargs["g_conv_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(128), activation)
|
||||
|
||||
self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(256), activation)
|
||||
|
||||
self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(512), activation)
|
||||
|
||||
self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(512), activation)
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
ResnetBlock_Adain(512, latent_size=chn, padding_type=padding_type, activation=activation)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
self.up4 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(512), activation
|
||||
)
|
||||
|
||||
self.up3 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(256), activation
|
||||
)
|
||||
|
||||
self.up2 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(128), activation
|
||||
)
|
||||
|
||||
self.up1 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(64), activation
|
||||
)
|
||||
|
||||
self.last_layer = nn.Sequential(nn.Conv2d(64, 3, kernel_size=3, padding=1, bias=False))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, input, id):
|
||||
x = input # 3*224*224
|
||||
skip1 = self.first_layer(x)
|
||||
skip2 = self.down1(skip1)
|
||||
skip3 = self.down2(skip2)
|
||||
skip4 = self.down3(skip3)
|
||||
res = self.down4(skip4)
|
||||
|
||||
for i in range(len(self.BottleNeck)):
|
||||
x = self.BottleNeck[i](res, id)
|
||||
|
||||
x = self.up4(x)
|
||||
x = self.up3(x)
|
||||
x = self.up2(x)
|
||||
x = self.up1(x)
|
||||
x = self.last_layer(x)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,187 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: Generator.py
|
||||
# Created Date: Sunday January 16th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 7th February 2022 6:25:07 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
from audioop import bias
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
from torch.nn import functional as F
|
||||
|
||||
class InstanceNorm(nn.Module):
|
||||
def __init__(self, epsilon=1e-8):
|
||||
"""
|
||||
@notice: avoid in-place ops.
|
||||
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
|
||||
"""
|
||||
super(InstanceNorm, self).__init__()
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, x):
|
||||
x = x - torch.mean(x, (2, 3), True)
|
||||
tmp = torch.mul(x, x) # or x ** 2
|
||||
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
|
||||
return x * tmp
|
||||
|
||||
class ApplyStyle(nn.Module):
|
||||
"""
|
||||
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
|
||||
"""
|
||||
def __init__(self, latent_size, channels):
|
||||
super(ApplyStyle, self).__init__()
|
||||
self.linear = nn.Linear(latent_size, channels * 2)
|
||||
|
||||
def forward(self, x, latent):
|
||||
style = self.linear(latent) # style => [batch_size, n_channels*2]
|
||||
shape = [-1, 2, x.size(1), 1, 1]
|
||||
style = style.view(shape) # [batch_size, 2, n_channels, ...]
|
||||
#x = x * (style[:, 0] + 1.) + style[:, 1]
|
||||
x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1
|
||||
return x
|
||||
|
||||
class ResnetBlock_Adain(nn.Module):
|
||||
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
|
||||
super(ResnetBlock_Adain, self).__init__()
|
||||
|
||||
p = 0
|
||||
conv1 = []
|
||||
if padding_type == 'reflect':
|
||||
conv1 += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv1 += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()]
|
||||
self.conv1 = nn.Sequential(*conv1)
|
||||
self.style1 = ApplyStyle(latent_size, dim)
|
||||
self.act1 = activation
|
||||
|
||||
p = 0
|
||||
conv2 = []
|
||||
if padding_type == 'reflect':
|
||||
conv2 += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv2 += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
|
||||
self.conv2 = nn.Sequential(*conv2)
|
||||
self.style2 = ApplyStyle(latent_size, dim)
|
||||
|
||||
|
||||
def forward(self, x, dlatents_in_slice):
|
||||
y = self.conv1(x)
|
||||
y = self.style1(y, dlatents_in_slice)
|
||||
y = self.act1(y)
|
||||
y = self.conv2(y)
|
||||
y = self.style2(y, dlatents_in_slice)
|
||||
out = x + y
|
||||
return out
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
chn = kwargs["g_conv_dim"]
|
||||
k_size = kwargs["g_kernel_size"]
|
||||
res_num = kwargs["res_num"]
|
||||
|
||||
padding_size= int((k_size -1)/2)
|
||||
padding_type= 'reflect'
|
||||
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(3, 64, kernel_size=7, padding=0, bias=False),
|
||||
nn.BatchNorm2d(64), activation)
|
||||
### downsample
|
||||
self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(128), activation)
|
||||
|
||||
self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(256), activation)
|
||||
|
||||
self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(512), activation)
|
||||
|
||||
self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(512), activation)
|
||||
|
||||
### resnet blocks
|
||||
BN = []
|
||||
for i in range(res_num):
|
||||
BN += [
|
||||
ResnetBlock_Adain(512, latent_size=chn, padding_type=padding_type, activation=activation)]
|
||||
self.BottleNeck = nn.Sequential(*BN)
|
||||
|
||||
self.up4 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(512), activation
|
||||
)
|
||||
|
||||
self.up3 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(256), activation
|
||||
)
|
||||
|
||||
self.up2 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(128), activation
|
||||
)
|
||||
|
||||
self.up1 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(64), activation
|
||||
)
|
||||
|
||||
self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
||||
|
||||
|
||||
# self.__weights_init__()
|
||||
|
||||
# def __weights_init__(self):
|
||||
# for layer in self.encoder:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
# for layer in self.encoder2:
|
||||
# if isinstance(layer,nn.Conv2d):
|
||||
# nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
def forward(self, input, id):
|
||||
x = input # 3*224*224
|
||||
skip1 = self.first_layer(x)
|
||||
skip2 = self.down1(skip1)
|
||||
skip3 = self.down2(skip2)
|
||||
skip4 = self.down3(skip3)
|
||||
res = self.down4(skip4)
|
||||
|
||||
for i in range(len(self.BottleNeck)):
|
||||
x = self.BottleNeck[i](res, id)
|
||||
|
||||
x = self.up4(x)
|
||||
x = self.up3(x)
|
||||
x = self.up2(x)
|
||||
x = self.up1(x)
|
||||
x = self.last_layer(x)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,64 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: arcface_decoder.py
|
||||
# Created Date: Saturday January 29th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Saturday, 29th January 2022 2:55:39 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
from torch.nn import functional as F
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
self.fc = nn.Linear(512, 7*7*512)
|
||||
|
||||
self.up4 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(512), activation
|
||||
)
|
||||
|
||||
self.up3 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(256), activation
|
||||
)
|
||||
|
||||
self.up2 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(128), activation
|
||||
)
|
||||
|
||||
self.up1 = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.BatchNorm2d(64), activation
|
||||
)
|
||||
|
||||
self.last_layer = nn.Sequential(nn.Conv2d(64, 3, kernel_size=3, padding=1))
|
||||
def forward(self, input):
|
||||
x = input #
|
||||
x = self.fc(x)
|
||||
x = x.view(x.size(0),512,7,7)
|
||||
x = self.up4(x)
|
||||
x = self.up3(x)
|
||||
x = self.up2(x)
|
||||
x = self.up1(x)
|
||||
x = self.last_layer(x)
|
||||
|
||||
return x
|
||||
@@ -8,13 +8,15 @@ from torch.utils import data
|
||||
from torchvision import transforms as T
|
||||
# from StyleResize import StyleResize
|
||||
|
||||
|
||||
class data_prefetcher():
|
||||
def __init__(self, loader):
|
||||
def __init__(self, loader, cur_gpu):
|
||||
self.loader = loader
|
||||
self.dataiter = iter(loader)
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1,3,1,1)
|
||||
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1,3,1,1)
|
||||
self.stream = torch.cuda.Stream(device=cur_gpu)
|
||||
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda(device=cur_gpu).view(1,3,1,1)
|
||||
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda(device=cur_gpu).view(1,3,1,1)
|
||||
self.cur_gpu = cur_gpu
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.mean = self.mean.half()
|
||||
@@ -30,9 +32,9 @@ class data_prefetcher():
|
||||
self.src_image1, self.src_image2 = next(self.dataiter)
|
||||
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.src_image1 = self.src_image1.cuda(non_blocking=True)
|
||||
self.src_image1 = self.src_image1.cuda(device= self.cur_gpu, non_blocking=True)
|
||||
self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std)
|
||||
self.src_image2 = self.src_image2.cuda(non_blocking=True)
|
||||
self.src_image2 = self.src_image2.cuda(device= self.cur_gpu, non_blocking=True)
|
||||
self.src_image2 = self.src_image2.sub_(self.mean).div_(self.std)
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
@@ -41,7 +43,7 @@ class data_prefetcher():
|
||||
# self.next_input = self.next_input.float()
|
||||
# self.next_input = self.next_input.sub_(self.mean).div_(self.std)
|
||||
def next(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
torch.cuda.current_stream(device= self.cur_gpu,).wait_stream(self.stream)
|
||||
src_image1 = self.src_image1
|
||||
src_image2 = self.src_image2
|
||||
self.preload()
|
||||
@@ -102,6 +104,7 @@ class VGGFace2HQDataset(data.Dataset):
|
||||
return self.num_images
|
||||
|
||||
def GetLoader( dataset_roots,
|
||||
cur_gpu,
|
||||
batch_size=16,
|
||||
**kwargs
|
||||
):
|
||||
@@ -123,7 +126,7 @@ def GetLoader( dataset_roots,
|
||||
random_seed)
|
||||
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
|
||||
drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True)
|
||||
prefetcher = data_prefetcher(content_data_loader)
|
||||
prefetcher = data_prefetcher(content_data_loader,cur_gpu)
|
||||
return prefetcher
|
||||
|
||||
def denorm(x):
|
||||
|
||||
@@ -0,0 +1,195 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: data_loader_VGGFace2HQ copy.py
|
||||
# Created Date: Saturday January 29th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Saturday, 29th January 2022 3:39:14 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import glob
|
||||
import torch
|
||||
import random
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
from torch.utils import data
|
||||
from torchvision import transforms as T
|
||||
# from StyleResize import StyleResize
|
||||
|
||||
class data_prefetcher():
|
||||
def __init__(self, loader):
|
||||
self.loader = loader
|
||||
self.dataiter = iter(loader)
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1,3,1,1)
|
||||
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1,3,1,1)
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.mean = self.mean.half()
|
||||
# self.std = self.std.half()
|
||||
self.num_images = len(loader)
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
try:
|
||||
self.src_image1 = next(self.dataiter)
|
||||
except StopIteration:
|
||||
self.dataiter = iter(self.loader)
|
||||
self.src_image1 = next(self.dataiter)
|
||||
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.src_image1 = self.src_image1.cuda(non_blocking=True)
|
||||
self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std)
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.next_input = self.next_input.half()
|
||||
# else:
|
||||
# self.next_input = self.next_input.float()
|
||||
# self.next_input = self.next_input.sub_(self.mean).div_(self.std)
|
||||
def next(self):
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
src_image1 = self.src_image1
|
||||
self.preload()
|
||||
return src_image1
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of images."""
|
||||
return self.num_images
|
||||
|
||||
class VGGFace2HQDataset(data.Dataset):
|
||||
"""Dataset class for the Artworks dataset and content dataset."""
|
||||
|
||||
def __init__(self,
|
||||
image_dir,
|
||||
img_transform,
|
||||
subffix='jpg',
|
||||
random_seed=1234):
|
||||
"""Initialize and preprocess the VGGFace2 HQ dataset."""
|
||||
self.image_dir = image_dir
|
||||
self.img_transform = img_transform
|
||||
self.subffix = subffix
|
||||
self.dataset = []
|
||||
self.random_seed = random_seed
|
||||
self.preprocess()
|
||||
self.num_images = len(self.dataset)
|
||||
|
||||
def preprocess(self):
|
||||
"""Preprocess the VGGFace2 HQ dataset."""
|
||||
print("processing VGGFace2 HQ dataset images...")
|
||||
|
||||
temp_path = os.path.join(self.image_dir,'*/*')
|
||||
pathes = glob.glob(temp_path)
|
||||
self.dataset = []
|
||||
for dir_item in pathes:
|
||||
print("processing %s"%dir_item,end='\r')
|
||||
self.dataset.append(dir_item)
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.dataset)
|
||||
print('Finished preprocessing the VGGFace2 HQ dataset, total dirs number: %d...'%len(self.dataset))
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Return two src domain images and two dst domain images."""
|
||||
dir_tmp1 = self.dataset[index]
|
||||
|
||||
image1 = self.img_transform(Image.open(dir_tmp1))
|
||||
return image1
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of images."""
|
||||
return self.num_images
|
||||
|
||||
def GetLoader( dataset_roots,
|
||||
batch_size=16,
|
||||
**kwargs
|
||||
):
|
||||
"""Build and return a data loader."""
|
||||
|
||||
data_root = dataset_roots
|
||||
random_seed = kwargs["random_seed"]
|
||||
num_workers = kwargs["dataloader_workers"]
|
||||
|
||||
c_transforms = []
|
||||
c_transforms.append(T.Resize((112,112)))
|
||||
c_transforms.append(T.ToTensor())
|
||||
c_transforms = T.Compose(c_transforms)
|
||||
|
||||
content_dataset = VGGFace2HQDataset(
|
||||
data_root,
|
||||
c_transforms,
|
||||
"jpg",
|
||||
random_seed)
|
||||
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
|
||||
drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True)
|
||||
prefetcher = data_prefetcher(content_data_loader)
|
||||
return prefetcher
|
||||
|
||||
def denorm(x):
|
||||
out = (x + 1) / 2
|
||||
return out.clamp_(0, 1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torchvision.utils import save_image
|
||||
style_class = ["vangogh","picasso","samuel"]
|
||||
categories_names = \
|
||||
['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
|
||||
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
|
||||
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
|
||||
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
|
||||
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
|
||||
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
|
||||
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
|
||||
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
|
||||
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
|
||||
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
|
||||
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
|
||||
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
|
||||
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
|
||||
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
|
||||
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
|
||||
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
|
||||
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
|
||||
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
|
||||
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
|
||||
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
|
||||
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
|
||||
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
|
||||
'b/building_facade', 'c/cemetery']
|
||||
|
||||
s_datapath = "D:\\F_Disk\\data_set\\Art_Data\\data_art_backup"
|
||||
c_datapath = "D:\\Downloads\\data_large"
|
||||
savepath = "D:\\PatchFace\\PleaseWork\\multi-style-gan\\StyleTransfer\\dataloader_test"
|
||||
|
||||
imsize = 512
|
||||
s_datasetloader= getLoader(s_datapath,c_datapath,
|
||||
style_class, categories_names,
|
||||
crop_size=imsize, batch_size=16, num_workers=4)
|
||||
wocao = iter(s_datasetloader)
|
||||
for i in range(500):
|
||||
print("new batch")
|
||||
s_image,c_image,label = next(wocao)
|
||||
print(label)
|
||||
# print(label)
|
||||
# saved_image1 = torch.cat([denorm(image.data),denorm(hahh.data)],3)
|
||||
# save_image(denorm(image), "%s\\%d-label-%d.jpg"%(savepath,i), nrow=1, padding=1)
|
||||
pass
|
||||
# import cv2
|
||||
# import os
|
||||
# for dir_item in categories_names:
|
||||
# join_path = Path(contentdatapath,dir_item)
|
||||
# if join_path.exists():
|
||||
# print("processing %s"%dir_item,end='\r')
|
||||
# images = join_path.glob('*.%s'%("jpg"))
|
||||
# for item in images:
|
||||
# temp_path = str(item)
|
||||
# # temp = cv2.imread(temp_path)
|
||||
# temp = Image.open(temp_path)
|
||||
# if temp.layers<3:
|
||||
# print("remove broken image...")
|
||||
# print("image name:%s"%temp_path)
|
||||
# del temp
|
||||
# os.remove(item)
|
||||
@@ -0,0 +1,246 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: data_loader_VGGFace2HQ copy.py
|
||||
# Created Date: Sunday February 6th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 8th February 2022 1:50:00 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import glob
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torch.utils import data
|
||||
from torchvision import transforms as T
|
||||
# from StyleResize import StyleResize
|
||||
|
||||
class InfiniteSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
|
||||
assert len(dataset) > 0
|
||||
assert num_replicas > 0
|
||||
assert 0 <= rank < num_replicas
|
||||
assert 0 <= window_size <= 1
|
||||
super().__init__(dataset)
|
||||
self.dataset = dataset
|
||||
self.rank = rank
|
||||
self.num_replicas = num_replicas
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self.window_size = window_size
|
||||
|
||||
def __iter__(self):
|
||||
order = np.arange(len(self.dataset))
|
||||
rnd = None
|
||||
window = 0
|
||||
if self.shuffle:
|
||||
rnd = np.random.RandomState(self.seed)
|
||||
rnd.shuffle(order)
|
||||
window = int(np.rint(order.size * self.window_size))
|
||||
|
||||
idx = 0
|
||||
while True:
|
||||
i = idx % order.size
|
||||
if idx % self.num_replicas == self.rank:
|
||||
yield order[i]
|
||||
if window >= 2:
|
||||
j = (i - rnd.randint(window)) % order.size
|
||||
order[i], order[j] = order[j], order[i]
|
||||
idx += 1
|
||||
|
||||
class data_prefetcher():
|
||||
def __init__(self, loader, cur_gpu):
|
||||
self.loader = loader
|
||||
self.dataiter = iter(loader)
|
||||
self.stream = torch.cuda.Stream(device=cur_gpu)
|
||||
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda(device=cur_gpu).view(1,3,1,1)
|
||||
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda(device=cur_gpu).view(1,3,1,1)
|
||||
self.cur_gpu = cur_gpu
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.mean = self.mean.half()
|
||||
# self.std = self.std.half()
|
||||
# self.num_images = loader.__len__()
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
# try:
|
||||
self.src_image1, self.src_image2 = next(self.dataiter)
|
||||
# except StopIteration:
|
||||
# self.dataiter = iter(self.loader)
|
||||
# self.src_image1, self.src_image2 = next(self.dataiter)
|
||||
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.src_image1 = self.src_image1.cuda(device= self.cur_gpu, non_blocking=True)
|
||||
self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std)
|
||||
self.src_image2 = self.src_image2.cuda(device= self.cur_gpu, non_blocking=True)
|
||||
self.src_image2 = self.src_image2.sub_(self.mean).div_(self.std)
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.next_input = self.next_input.half()
|
||||
# else:
|
||||
# self.next_input = self.next_input.float()
|
||||
# self.next_input = self.next_input.sub_(self.mean).div_(self.std)
|
||||
def next(self):
|
||||
torch.cuda.current_stream(device= self.cur_gpu,).wait_stream(self.stream)
|
||||
src_image1 = self.src_image1
|
||||
src_image2 = self.src_image2
|
||||
self.preload()
|
||||
return src_image1, src_image2
|
||||
|
||||
# def __len__(self):
|
||||
# """Return the number of images."""
|
||||
# return self.num_images
|
||||
|
||||
class VGGFace2HQDataset(data.Dataset):
|
||||
"""Dataset class for the Artworks dataset and content dataset."""
|
||||
|
||||
def __init__(self,
|
||||
image_dir,
|
||||
img_transform,
|
||||
subffix='jpg',
|
||||
random_seed=1234):
|
||||
"""Initialize and preprocess the VGGFace2 HQ dataset."""
|
||||
self.image_dir = image_dir
|
||||
self.img_transform = img_transform
|
||||
self.subffix = subffix
|
||||
self.dataset = []
|
||||
self.random_seed = random_seed
|
||||
self.preprocess()
|
||||
self.num_images = len(self.dataset)
|
||||
|
||||
def preprocess(self):
|
||||
"""Preprocess the VGGFace2 HQ dataset."""
|
||||
print("processing VGGFace2 HQ dataset images...")
|
||||
|
||||
temp_path = os.path.join(self.image_dir,'*/')
|
||||
pathes = glob.glob(temp_path)
|
||||
self.dataset = []
|
||||
for dir_item in pathes:
|
||||
join_path = glob.glob(os.path.join(dir_item,'*.jpg'))
|
||||
print("processing %s"%dir_item,end='\r')
|
||||
temp_list = []
|
||||
for item in join_path:
|
||||
temp_list.append(item)
|
||||
self.dataset.append(temp_list)
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.dataset)
|
||||
print('Finished preprocessing the VGGFace2 HQ dataset, total dirs number: %d...'%len(self.dataset))
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Return two src domain images and two dst domain images."""
|
||||
dir_tmp1 = self.dataset[index]
|
||||
dir_tmp1_len = len(dir_tmp1)
|
||||
|
||||
filename1 = dir_tmp1[random.randint(0,dir_tmp1_len-1)]
|
||||
filename2 = dir_tmp1[random.randint(0,dir_tmp1_len-1)]
|
||||
image1 = self.img_transform(Image.open(filename1))
|
||||
image2 = self.img_transform(Image.open(filename2))
|
||||
return image1, image2
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of images."""
|
||||
return self.num_images
|
||||
|
||||
def GetLoader( dataset_roots,
|
||||
rank,
|
||||
num_gpus,
|
||||
batch_size=16,
|
||||
**kwargs
|
||||
):
|
||||
"""Build and return a data loader."""
|
||||
|
||||
data_root = dataset_roots
|
||||
random_seed = kwargs["random_seed"]
|
||||
num_workers = kwargs["dataloader_workers"]
|
||||
|
||||
c_transforms = []
|
||||
|
||||
c_transforms.append(T.ToTensor())
|
||||
c_transforms = T.Compose(c_transforms)
|
||||
|
||||
content_dataset = VGGFace2HQDataset(
|
||||
data_root,
|
||||
c_transforms,
|
||||
"jpg",
|
||||
random_seed)
|
||||
device = torch.device('cuda', rank)
|
||||
sampler = InfiniteSampler(dataset=content_dataset, rank=rank, num_replicas=num_gpus, seed=random_seed)
|
||||
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
|
||||
drop_last=False,shuffle=False,num_workers=num_workers,pin_memory=True, sampler=sampler)
|
||||
# content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
|
||||
# drop_last=False,shuffle=True,num_workers=num_workers,pin_memory=True)
|
||||
prefetcher = data_prefetcher(content_data_loader,device)
|
||||
return prefetcher
|
||||
|
||||
def denorm(x):
|
||||
out = (x + 1) / 2
|
||||
return out.clamp_(0, 1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torchvision.utils import save_image
|
||||
style_class = ["vangogh","picasso","samuel"]
|
||||
categories_names = \
|
||||
['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
|
||||
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
|
||||
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
|
||||
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
|
||||
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
|
||||
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
|
||||
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
|
||||
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
|
||||
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
|
||||
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
|
||||
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
|
||||
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
|
||||
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
|
||||
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
|
||||
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
|
||||
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
|
||||
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
|
||||
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
|
||||
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
|
||||
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
|
||||
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
|
||||
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
|
||||
'b/building_facade', 'c/cemetery']
|
||||
|
||||
s_datapath = "D:\\F_Disk\\data_set\\Art_Data\\data_art_backup"
|
||||
c_datapath = "D:\\Downloads\\data_large"
|
||||
savepath = "D:\\PatchFace\\PleaseWork\\multi-style-gan\\StyleTransfer\\dataloader_test"
|
||||
|
||||
imsize = 512
|
||||
s_datasetloader= getLoader(s_datapath,c_datapath,
|
||||
style_class, categories_names,
|
||||
crop_size=imsize, batch_size=16, num_workers=4)
|
||||
wocao = iter(s_datasetloader)
|
||||
for i in range(500):
|
||||
print("new batch")
|
||||
s_image,c_image,label = next(wocao)
|
||||
print(label)
|
||||
# print(label)
|
||||
# saved_image1 = torch.cat([denorm(image.data),denorm(hahh.data)],3)
|
||||
# save_image(denorm(image), "%s\\%d-label-%d.jpg"%(savepath,i), nrow=1, padding=1)
|
||||
pass
|
||||
# import cv2
|
||||
# import os
|
||||
# for dir_item in categories_names:
|
||||
# join_path = Path(contentdatapath,dir_item)
|
||||
# if join_path.exists():
|
||||
# print("processing %s"%dir_item,end='\r')
|
||||
# images = join_path.glob('*.%s'%("jpg"))
|
||||
# for item in images:
|
||||
# temp_path = str(item)
|
||||
# # temp = cv2.imread(temp_path)
|
||||
# temp = Image.open(temp_path)
|
||||
# if temp.layers<3:
|
||||
# print("remove broken image...")
|
||||
# print("image name:%s"%temp_path)
|
||||
# del temp
|
||||
# os.remove(item)
|
||||
@@ -0,0 +1,246 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: data_loader_VGGFace2HQ copy.py
|
||||
# Created Date: Sunday February 6th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 8th February 2022 1:24:27 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import glob
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from torch.utils import data
|
||||
from torchvision import transforms as T
|
||||
# from StyleResize import StyleResize
|
||||
|
||||
class InfiniteSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
|
||||
assert len(dataset) > 0
|
||||
assert num_replicas > 0
|
||||
assert 0 <= rank < num_replicas
|
||||
assert 0 <= window_size <= 1
|
||||
super().__init__(dataset)
|
||||
self.dataset = dataset
|
||||
self.rank = rank
|
||||
self.num_replicas = num_replicas
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self.window_size = window_size
|
||||
|
||||
def __iter__(self):
|
||||
order = np.arange(len(self.dataset))
|
||||
rnd = None
|
||||
window = 0
|
||||
if self.shuffle:
|
||||
rnd = np.random.RandomState(self.seed)
|
||||
rnd.shuffle(order)
|
||||
window = int(np.rint(order.size * self.window_size))
|
||||
|
||||
idx = 0
|
||||
while True:
|
||||
i = idx % order.size
|
||||
if idx % self.num_replicas == self.rank:
|
||||
yield order[i]
|
||||
if window >= 2:
|
||||
j = (i - rnd.randint(window)) % order.size
|
||||
order[i], order[j] = order[j], order[i]
|
||||
idx += 1
|
||||
|
||||
class data_prefetcher():
|
||||
def __init__(self, loader, cur_gpu):
|
||||
self.loader = loader
|
||||
self.dataiter = iter(loader)
|
||||
self.stream = torch.cuda.Stream(device=cur_gpu)
|
||||
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda(device=cur_gpu).view(1,3,1,1)
|
||||
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda(device=cur_gpu).view(1,3,1,1)
|
||||
self.cur_gpu = cur_gpu
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.mean = self.mean.half()
|
||||
# self.std = self.std.half()
|
||||
# self.num_images = loader.__len__()
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
# try:
|
||||
self.src_image1, self.src_image2 = next(self.dataiter)
|
||||
# except StopIteration:
|
||||
# self.dataiter = iter(self.loader)
|
||||
# self.src_image1, self.src_image2 = next(self.dataiter)
|
||||
|
||||
with torch.cuda.stream(self.stream):
|
||||
self.src_image1 = self.src_image1.cuda(device= self.cur_gpu, non_blocking=True)
|
||||
self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std)
|
||||
self.src_image2 = self.src_image2.cuda(device= self.cur_gpu, non_blocking=True)
|
||||
self.src_image2 = self.src_image2.sub_(self.mean).div_(self.std)
|
||||
# With Amp, it isn't necessary to manually convert data to half.
|
||||
# if args.fp16:
|
||||
# self.next_input = self.next_input.half()
|
||||
# else:
|
||||
# self.next_input = self.next_input.float()
|
||||
# self.next_input = self.next_input.sub_(self.mean).div_(self.std)
|
||||
def next(self):
|
||||
torch.cuda.current_stream(device= self.cur_gpu,).wait_stream(self.stream)
|
||||
src_image1 = self.src_image1
|
||||
src_image2 = self.src_image2
|
||||
self.preload()
|
||||
return src_image1, src_image2
|
||||
|
||||
# def __len__(self):
|
||||
# """Return the number of images."""
|
||||
# return self.num_images
|
||||
|
||||
class VGGFace2HQDataset(data.Dataset):
|
||||
"""Dataset class for the Artworks dataset and content dataset."""
|
||||
|
||||
def __init__(self,
|
||||
image_dir,
|
||||
img_transform,
|
||||
subffix='jpg',
|
||||
random_seed=1234):
|
||||
"""Initialize and preprocess the VGGFace2 HQ dataset."""
|
||||
self.image_dir = image_dir
|
||||
self.img_transform = img_transform
|
||||
self.subffix = subffix
|
||||
self.dataset = []
|
||||
self.random_seed = random_seed
|
||||
self.preprocess()
|
||||
self.num_images = len(self.dataset)
|
||||
|
||||
def preprocess(self):
|
||||
"""Preprocess the VGGFace2 HQ dataset."""
|
||||
print("processing VGGFace2 HQ dataset images...")
|
||||
|
||||
temp_path = os.path.join(self.image_dir,'*/')
|
||||
pathes = glob.glob(temp_path)
|
||||
self.dataset = []
|
||||
for dir_item in pathes:
|
||||
join_path = glob.glob(os.path.join(dir_item,'*.jpg'))
|
||||
print("processing %s"%dir_item,end='\r')
|
||||
temp_list = []
|
||||
for item in join_path:
|
||||
temp_list.append(item)
|
||||
self.dataset.append(temp_list)
|
||||
random.seed(self.random_seed)
|
||||
random.shuffle(self.dataset)
|
||||
print('Finished preprocessing the VGGFace2 HQ dataset, total dirs number: %d...'%len(self.dataset))
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Return two src domain images and two dst domain images."""
|
||||
dir_tmp1 = self.dataset[index]
|
||||
dir_tmp1_len = len(dir_tmp1)
|
||||
|
||||
filename1 = dir_tmp1[random.randint(0,dir_tmp1_len-1)]
|
||||
filename2 = dir_tmp1[random.randint(0,dir_tmp1_len-1)]
|
||||
image1 = self.img_transform(Image.open(filename1))
|
||||
image2 = self.img_transform(Image.open(filename2))
|
||||
return image1, image2
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of images."""
|
||||
return self.num_images
|
||||
|
||||
def GetLoader( dataset_roots,
|
||||
rank,
|
||||
num_gpus,
|
||||
batch_size=16,
|
||||
**kwargs
|
||||
):
|
||||
"""Build and return a data loader."""
|
||||
|
||||
data_root = dataset_roots
|
||||
random_seed = kwargs["random_seed"]
|
||||
num_workers = kwargs["dataloader_workers"]
|
||||
|
||||
c_transforms = []
|
||||
|
||||
c_transforms.append(T.ToTensor())
|
||||
c_transforms = T.Compose(c_transforms)
|
||||
|
||||
content_dataset = VGGFace2HQDataset(
|
||||
data_root,
|
||||
c_transforms,
|
||||
"jpg",
|
||||
random_seed)
|
||||
device = torch.device('cuda', rank)
|
||||
# sampler = InfiniteSampler(dataset=content_dataset, rank=rank, num_replicas=num_gpus, seed=random_seed)
|
||||
# content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
|
||||
# drop_last=False,shuffle=False,num_workers=num_workers,pin_memory=True, sampler=sampler)
|
||||
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
|
||||
drop_last=False,shuffle=True,num_workers=num_workers,pin_memory=True)
|
||||
# prefetcher = data_prefetcher(content_data_loader,device)
|
||||
return content_data_loader
|
||||
|
||||
def denorm(x):
|
||||
out = (x + 1) / 2
|
||||
return out.clamp_(0, 1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torchvision.utils import save_image
|
||||
style_class = ["vangogh","picasso","samuel"]
|
||||
categories_names = \
|
||||
['a/abbey', 'a/arch', 'a/amphitheater', 'a/aqueduct', 'a/arena/rodeo', 'a/athletic_field/outdoor',
|
||||
'b/badlands', 'b/balcony/exterior', 'b/bamboo_forest', 'b/barn', 'b/barndoor', 'b/baseball_field',
|
||||
'b/basilica', 'b/bayou', 'b/beach', 'b/beach_house', 'b/beer_garden', 'b/boardwalk', 'b/boathouse',
|
||||
'b/botanical_garden', 'b/bullring', 'b/butte', 'c/cabin/outdoor', 'c/campsite', 'c/campus',
|
||||
'c/canal/natural', 'c/canal/urban', 'c/canyon', 'c/castle', 'c/church/outdoor', 'c/chalet',
|
||||
'c/cliff', 'c/coast', 'c/corn_field', 'c/corral', 'c/cottage', 'c/courtyard', 'c/crevasse',
|
||||
'd/dam', 'd/desert/vegetation', 'd/desert_road', 'd/doorway/outdoor', 'f/farm', 'f/fairway',
|
||||
'f/field/cultivated', 'f/field/wild', 'f/field_road', 'f/fishpond', 'f/florist_shop/indoor',
|
||||
'f/forest/broadleaf', 'f/forest_path', 'f/forest_road', 'f/formal_garden', 'g/gazebo/exterior',
|
||||
'g/glacier', 'g/golf_course', 'g/greenhouse/indoor', 'g/greenhouse/outdoor', 'g/grotto', 'g/gorge',
|
||||
'h/hayfield', 'h/herb_garden', 'h/hot_spring', 'h/house', 'h/hunting_lodge/outdoor', 'i/ice_floe',
|
||||
'i/ice_shelf', 'i/iceberg', 'i/inn/outdoor', 'i/islet', 'j/japanese_garden', 'k/kasbah',
|
||||
'k/kennel/outdoor', 'l/lagoon', 'l/lake/natural', 'l/lawn', 'l/library/outdoor', 'l/lighthouse',
|
||||
'm/mansion', 'm/marsh', 'm/mausoleum', 'm/moat/water', 'm/mosque/outdoor', 'm/mountain',
|
||||
'm/mountain_path', 'm/mountain_snowy', 'o/oast_house', 'o/ocean', 'o/orchard', 'p/park',
|
||||
'p/pasture', 'p/pavilion', 'p/picnic_area', 'p/pier', 'p/pond', 'r/raft', 'r/railroad_track',
|
||||
'r/rainforest', 'r/rice_paddy', 'r/river', 'r/rock_arch', 'r/roof_garden', 'r/rope_bridge',
|
||||
'r/ruin', 's/schoolhouse', 's/sky', 's/snowfield', 's/swamp', 's/swimming_hole',
|
||||
's/synagogue/outdoor', 't/temple/asia', 't/topiary_garden', 't/tree_farm', 't/tree_house',
|
||||
'u/underwater/ocean_deep', 'u/utility_room', 'v/valley', 'v/vegetable_garden', 'v/viaduct',
|
||||
'v/village', 'v/vineyard', 'v/volcano', 'w/waterfall', 'w/watering_hole', 'w/wave',
|
||||
'w/wheat_field', 'z/zen_garden', 'a/alcove', 'a/apartment-building/outdoor', 'a/artists_loft',
|
||||
'b/building_facade', 'c/cemetery']
|
||||
|
||||
s_datapath = "D:\\F_Disk\\data_set\\Art_Data\\data_art_backup"
|
||||
c_datapath = "D:\\Downloads\\data_large"
|
||||
savepath = "D:\\PatchFace\\PleaseWork\\multi-style-gan\\StyleTransfer\\dataloader_test"
|
||||
|
||||
imsize = 512
|
||||
s_datasetloader= getLoader(s_datapath,c_datapath,
|
||||
style_class, categories_names,
|
||||
crop_size=imsize, batch_size=16, num_workers=4)
|
||||
wocao = iter(s_datasetloader)
|
||||
for i in range(500):
|
||||
print("new batch")
|
||||
s_image,c_image,label = next(wocao)
|
||||
print(label)
|
||||
# print(label)
|
||||
# saved_image1 = torch.cat([denorm(image.data),denorm(hahh.data)],3)
|
||||
# save_image(denorm(image), "%s\\%d-label-%d.jpg"%(savepath,i), nrow=1, padding=1)
|
||||
pass
|
||||
# import cv2
|
||||
# import os
|
||||
# for dir_item in categories_names:
|
||||
# join_path = Path(contentdatapath,dir_item)
|
||||
# if join_path.exists():
|
||||
# print("processing %s"%dir_item,end='\r')
|
||||
# images = join_path.glob('*.%s'%("jpg"))
|
||||
# for item in images:
|
||||
# temp_path = str(item)
|
||||
# # temp = cv2.imread(temp_path)
|
||||
# temp = Image.open(temp_path)
|
||||
# if temp.layers<3:
|
||||
# print("remove broken image...")
|
||||
# print("image name:%s"%temp_path)
|
||||
# del temp
|
||||
# os.remove(item)
|
||||
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
from .util import EasyDict, make_cache_dir_path
|
||||
+491
@@ -0,0 +1,491 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Miscellaneous utility classes and functions."""
|
||||
|
||||
import ctypes
|
||||
import fnmatch
|
||||
import importlib
|
||||
import inspect
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import types
|
||||
import io
|
||||
import pickle
|
||||
import re
|
||||
import requests
|
||||
import html
|
||||
import hashlib
|
||||
import glob
|
||||
import tempfile
|
||||
import urllib
|
||||
import urllib.request
|
||||
import uuid
|
||||
|
||||
from distutils.util import strtobool
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
|
||||
# Util classes
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
class EasyDict(dict):
|
||||
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self[name]
|
||||
except KeyError:
|
||||
raise AttributeError(name)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
self[name] = value
|
||||
|
||||
def __delattr__(self, name: str) -> None:
|
||||
del self[name]
|
||||
|
||||
|
||||
class Logger(object):
|
||||
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
||||
|
||||
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
|
||||
self.file = None
|
||||
|
||||
if file_name is not None:
|
||||
self.file = open(file_name, file_mode)
|
||||
|
||||
self.should_flush = should_flush
|
||||
self.stdout = sys.stdout
|
||||
self.stderr = sys.stderr
|
||||
|
||||
sys.stdout = self
|
||||
sys.stderr = self
|
||||
|
||||
def __enter__(self) -> "Logger":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def write(self, text: Union[str, bytes]) -> None:
|
||||
"""Write text to stdout (and a file) and optionally flush."""
|
||||
if isinstance(text, bytes):
|
||||
text = text.decode()
|
||||
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
||||
return
|
||||
|
||||
if self.file is not None:
|
||||
self.file.write(text)
|
||||
|
||||
self.stdout.write(text)
|
||||
|
||||
if self.should_flush:
|
||||
self.flush()
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Flush written text to both stdout and a file, if open."""
|
||||
if self.file is not None:
|
||||
self.file.flush()
|
||||
|
||||
self.stdout.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
||||
self.flush()
|
||||
|
||||
# if using multiple loggers, prevent closing in wrong order
|
||||
if sys.stdout is self:
|
||||
sys.stdout = self.stdout
|
||||
if sys.stderr is self:
|
||||
sys.stderr = self.stderr
|
||||
|
||||
if self.file is not None:
|
||||
self.file.close()
|
||||
self.file = None
|
||||
|
||||
|
||||
# Cache directories
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
_dnnlib_cache_dir = None
|
||||
|
||||
def set_cache_dir(path: str) -> None:
|
||||
global _dnnlib_cache_dir
|
||||
_dnnlib_cache_dir = path
|
||||
|
||||
def make_cache_dir_path(*paths: str) -> str:
|
||||
if _dnnlib_cache_dir is not None:
|
||||
return os.path.join(_dnnlib_cache_dir, *paths)
|
||||
if 'DNNLIB_CACHE_DIR' in os.environ:
|
||||
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
||||
if 'HOME' in os.environ:
|
||||
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
||||
if 'USERPROFILE' in os.environ:
|
||||
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
||||
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
||||
|
||||
# Small util functions
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def format_time(seconds: Union[int, float]) -> str:
|
||||
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
||||
s = int(np.rint(seconds))
|
||||
|
||||
if s < 60:
|
||||
return "{0}s".format(s)
|
||||
elif s < 60 * 60:
|
||||
return "{0}m {1:02}s".format(s // 60, s % 60)
|
||||
elif s < 24 * 60 * 60:
|
||||
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
||||
else:
|
||||
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
||||
|
||||
|
||||
def format_time_brief(seconds: Union[int, float]) -> str:
|
||||
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
||||
s = int(np.rint(seconds))
|
||||
|
||||
if s < 60:
|
||||
return "{0}s".format(s)
|
||||
elif s < 60 * 60:
|
||||
return "{0}m {1:02}s".format(s // 60, s % 60)
|
||||
elif s < 24 * 60 * 60:
|
||||
return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
|
||||
else:
|
||||
return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
|
||||
|
||||
|
||||
def ask_yes_no(question: str) -> bool:
|
||||
"""Ask the user the question until the user inputs a valid answer."""
|
||||
while True:
|
||||
try:
|
||||
print("{0} [y/n]".format(question))
|
||||
return strtobool(input().lower())
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def tuple_product(t: Tuple) -> Any:
|
||||
"""Calculate the product of the tuple elements."""
|
||||
result = 1
|
||||
|
||||
for v in t:
|
||||
result *= v
|
||||
|
||||
return result
|
||||
|
||||
|
||||
_str_to_ctype = {
|
||||
"uint8": ctypes.c_ubyte,
|
||||
"uint16": ctypes.c_uint16,
|
||||
"uint32": ctypes.c_uint32,
|
||||
"uint64": ctypes.c_uint64,
|
||||
"int8": ctypes.c_byte,
|
||||
"int16": ctypes.c_int16,
|
||||
"int32": ctypes.c_int32,
|
||||
"int64": ctypes.c_int64,
|
||||
"float32": ctypes.c_float,
|
||||
"float64": ctypes.c_double
|
||||
}
|
||||
|
||||
|
||||
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
||||
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
||||
type_str = None
|
||||
|
||||
if isinstance(type_obj, str):
|
||||
type_str = type_obj
|
||||
elif hasattr(type_obj, "__name__"):
|
||||
type_str = type_obj.__name__
|
||||
elif hasattr(type_obj, "name"):
|
||||
type_str = type_obj.name
|
||||
else:
|
||||
raise RuntimeError("Cannot infer type name from input")
|
||||
|
||||
assert type_str in _str_to_ctype.keys()
|
||||
|
||||
my_dtype = np.dtype(type_str)
|
||||
my_ctype = _str_to_ctype[type_str]
|
||||
|
||||
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
||||
|
||||
return my_dtype, my_ctype
|
||||
|
||||
|
||||
def is_pickleable(obj: Any) -> bool:
|
||||
try:
|
||||
with io.BytesIO() as stream:
|
||||
pickle.dump(obj, stream)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
# Functionality to import modules/objects by name, and call functions by name
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
||||
"""Searches for the underlying module behind the name to some python object.
|
||||
Returns the module and the object name (original name with module part removed)."""
|
||||
|
||||
# allow convenience shorthands, substitute them by full names
|
||||
obj_name = re.sub("^np.", "numpy.", obj_name)
|
||||
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
||||
|
||||
# list alternatives for (module_name, local_obj_name)
|
||||
parts = obj_name.split(".")
|
||||
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
||||
|
||||
# try each alternative in turn
|
||||
for module_name, local_obj_name in name_pairs:
|
||||
try:
|
||||
module = importlib.import_module(module_name) # may raise ImportError
|
||||
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
||||
return module, local_obj_name
|
||||
except:
|
||||
pass
|
||||
|
||||
# maybe some of the modules themselves contain errors?
|
||||
for module_name, _local_obj_name in name_pairs:
|
||||
try:
|
||||
importlib.import_module(module_name) # may raise ImportError
|
||||
except ImportError:
|
||||
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
||||
raise
|
||||
|
||||
# maybe the requested attribute is missing?
|
||||
for module_name, local_obj_name in name_pairs:
|
||||
try:
|
||||
module = importlib.import_module(module_name) # may raise ImportError
|
||||
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# we are out of luck, but we have no idea why
|
||||
raise ImportError(obj_name)
|
||||
|
||||
|
||||
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
||||
"""Traverses the object name and returns the last (rightmost) python object."""
|
||||
if obj_name == '':
|
||||
return module
|
||||
obj = module
|
||||
for part in obj_name.split("."):
|
||||
obj = getattr(obj, part)
|
||||
return obj
|
||||
|
||||
|
||||
def get_obj_by_name(name: str) -> Any:
|
||||
"""Finds the python object with the given name."""
|
||||
module, obj_name = get_module_from_obj_name(name)
|
||||
return get_obj_from_module(module, obj_name)
|
||||
|
||||
|
||||
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
||||
"""Finds the python object with the given name and calls it as a function."""
|
||||
assert func_name is not None
|
||||
func_obj = get_obj_by_name(func_name)
|
||||
assert callable(func_obj)
|
||||
return func_obj(*args, **kwargs)
|
||||
|
||||
|
||||
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
|
||||
"""Finds the python class with the given name and constructs it with the given arguments."""
|
||||
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
||||
|
||||
|
||||
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
||||
"""Get the directory path of the module containing the given object name."""
|
||||
module, _ = get_module_from_obj_name(obj_name)
|
||||
return os.path.dirname(inspect.getfile(module))
|
||||
|
||||
|
||||
def is_top_level_function(obj: Any) -> bool:
|
||||
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
||||
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
||||
|
||||
|
||||
def get_top_level_function_name(obj: Any) -> str:
|
||||
"""Return the fully-qualified name of a top-level function."""
|
||||
assert is_top_level_function(obj)
|
||||
module = obj.__module__
|
||||
if module == '__main__':
|
||||
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
|
||||
return module + "." + obj.__name__
|
||||
|
||||
|
||||
# File system helpers
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
||||
"""List all files recursively in a given directory while ignoring given file and directory names.
|
||||
Returns list of tuples containing both absolute and relative paths."""
|
||||
assert os.path.isdir(dir_path)
|
||||
base_name = os.path.basename(os.path.normpath(dir_path))
|
||||
|
||||
if ignores is None:
|
||||
ignores = []
|
||||
|
||||
result = []
|
||||
|
||||
for root, dirs, files in os.walk(dir_path, topdown=True):
|
||||
for ignore_ in ignores:
|
||||
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
||||
|
||||
# dirs need to be edited in-place
|
||||
for d in dirs_to_remove:
|
||||
dirs.remove(d)
|
||||
|
||||
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
||||
|
||||
absolute_paths = [os.path.join(root, f) for f in files]
|
||||
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
||||
|
||||
if add_base_to_relative:
|
||||
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
||||
|
||||
assert len(absolute_paths) == len(relative_paths)
|
||||
result += zip(absolute_paths, relative_paths)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
||||
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
||||
Will create all necessary directories."""
|
||||
for file in files:
|
||||
target_dir_name = os.path.dirname(file[1])
|
||||
|
||||
# will create all intermediate-level directories
|
||||
if not os.path.exists(target_dir_name):
|
||||
os.makedirs(target_dir_name)
|
||||
|
||||
shutil.copyfile(file[0], file[1])
|
||||
|
||||
|
||||
# URL helpers
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
||||
"""Determine whether the given object is a valid URL string."""
|
||||
if not isinstance(obj, str) or not "://" in obj:
|
||||
return False
|
||||
if allow_file_urls and obj.startswith('file://'):
|
||||
return True
|
||||
try:
|
||||
res = requests.compat.urlparse(obj)
|
||||
if not res.scheme or not res.netloc or not "." in res.netloc:
|
||||
return False
|
||||
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
||||
if not res.scheme or not res.netloc or not "." in res.netloc:
|
||||
return False
|
||||
except:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
|
||||
"""Download the given URL and return a binary-mode file object to access the data."""
|
||||
assert num_attempts >= 1
|
||||
assert not (return_filename and (not cache))
|
||||
|
||||
# Doesn't look like an URL scheme so interpret it as a local filename.
|
||||
if not re.match('^[a-z]+://', url):
|
||||
return url if return_filename else open(url, "rb")
|
||||
|
||||
# Handle file URLs. This code handles unusual file:// patterns that
|
||||
# arise on Windows:
|
||||
#
|
||||
# file:///c:/foo.txt
|
||||
#
|
||||
# which would translate to a local '/c:/foo.txt' filename that's
|
||||
# invalid. Drop the forward slash for such pathnames.
|
||||
#
|
||||
# If you touch this code path, you should test it on both Linux and
|
||||
# Windows.
|
||||
#
|
||||
# Some internet resources suggest using urllib.request.url2pathname() but
|
||||
# but that converts forward slashes to backslashes and this causes
|
||||
# its own set of problems.
|
||||
if url.startswith('file://'):
|
||||
filename = urllib.parse.urlparse(url).path
|
||||
if re.match(r'^/[a-zA-Z]:', filename):
|
||||
filename = filename[1:]
|
||||
return filename if return_filename else open(filename, "rb")
|
||||
|
||||
assert is_url(url)
|
||||
|
||||
# Lookup from cache.
|
||||
if cache_dir is None:
|
||||
cache_dir = make_cache_dir_path('downloads')
|
||||
|
||||
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
||||
if cache:
|
||||
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
||||
if len(cache_files) == 1:
|
||||
filename = cache_files[0]
|
||||
return filename if return_filename else open(filename, "rb")
|
||||
|
||||
# Download.
|
||||
url_name = None
|
||||
url_data = None
|
||||
with requests.Session() as session:
|
||||
if verbose:
|
||||
print("Downloading %s ..." % url, end="", flush=True)
|
||||
for attempts_left in reversed(range(num_attempts)):
|
||||
try:
|
||||
with session.get(url) as res:
|
||||
res.raise_for_status()
|
||||
if len(res.content) == 0:
|
||||
raise IOError("No data received")
|
||||
|
||||
if len(res.content) < 8192:
|
||||
content_str = res.content.decode("utf-8")
|
||||
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
||||
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
||||
if len(links) == 1:
|
||||
url = requests.compat.urljoin(url, links[0])
|
||||
raise IOError("Google Drive virus checker nag")
|
||||
if "Google Drive - Quota exceeded" in content_str:
|
||||
raise IOError("Google Drive download quota exceeded -- please try again later")
|
||||
|
||||
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
||||
url_name = match[1] if match else url
|
||||
url_data = res.content
|
||||
if verbose:
|
||||
print(" done")
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except:
|
||||
if not attempts_left:
|
||||
if verbose:
|
||||
print(" failed")
|
||||
raise
|
||||
if verbose:
|
||||
print(".", end="", flush=True)
|
||||
|
||||
# Save to cache.
|
||||
if cache:
|
||||
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
||||
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
||||
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
with open(temp_file, "wb") as f:
|
||||
f.write(url_data)
|
||||
os.replace(temp_file, cache_file) # atomic
|
||||
if return_filename:
|
||||
return cache_file
|
||||
|
||||
# Return data as file object.
|
||||
assert not return_filename
|
||||
return io.BytesIO(url_data)
|
||||
+275
@@ -0,0 +1,275 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: face_crop.py
|
||||
# Created Date: Tuesday February 1st 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Wednesday, 2nd February 2022 4:13:28 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import sys
|
||||
import glob
|
||||
import json
|
||||
import tkinter
|
||||
from tkinter.filedialog import askdirectory
|
||||
|
||||
import threading
|
||||
import tkinter as tk
|
||||
import tkinter.ttk as ttk
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from insightface_func.face_detect_crop_multi import Face_detect_crop
|
||||
|
||||
class TextRedirector(object):
|
||||
def __init__(self, widget, tag="stdout"):
|
||||
self.widget = widget
|
||||
self.tag = tag
|
||||
|
||||
def write(self, str):
|
||||
self.widget.configure(state="normal")
|
||||
self.widget.insert("end", str, (self.tag,))
|
||||
self.widget.configure(state="disabled")
|
||||
self.widget.see(tk.END)
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
#############################################################
|
||||
# Main Class
|
||||
#############################################################
|
||||
|
||||
class Application(tk.Frame):
|
||||
|
||||
|
||||
def __init__(self, master=None):
|
||||
tk.Frame.__init__(self, master,bg='black')
|
||||
# self.font_size = 16
|
||||
self.font_list = ("Times New Roman",14)
|
||||
self.padx = 5
|
||||
self.pady = 5
|
||||
self.window_init()
|
||||
|
||||
def __label_text__(self, usr, root):
|
||||
return "User Name: %s\nWorkspace: %s"%(usr, root)
|
||||
|
||||
def window_init(self):
|
||||
cwd = os.getcwd()
|
||||
self.master.title('Face Crop - %s'%cwd)
|
||||
# self.master.iconbitmap('./utilities/_logo.ico')
|
||||
self.master.geometry("{}x{}".format(640, 600))
|
||||
|
||||
font_list = self.font_list
|
||||
|
||||
#################################################################################################
|
||||
list_frame = tk.Frame(self.master)
|
||||
list_frame.pack(fill="both", padx=5,pady=5)
|
||||
list_frame.columnconfigure(0, weight=1)
|
||||
list_frame.columnconfigure(1, weight=1)
|
||||
list_frame.columnconfigure(2, weight=1)
|
||||
|
||||
self.img_path = tkinter.StringVar()
|
||||
|
||||
tk.Label(list_frame, text="Image/Video Path:",font=font_list,justify="left")\
|
||||
.grid(row=0,column=0,sticky=tk.EW)
|
||||
|
||||
tk.Entry(list_frame, textvariable= self.img_path, font=font_list)\
|
||||
.grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
|
||||
tk.Button(list_frame, text = "Select Path", font=font_list,
|
||||
command = self.Select, bg='#F4A460', fg='#F5F5F5')\
|
||||
.grid(row=0,column=2,sticky=tk.EW)
|
||||
#################################################################################################
|
||||
list_frame1 = tk.Frame(self.master)
|
||||
list_frame1.pack(fill="both", padx=5,pady=5)
|
||||
list_frame1.columnconfigure(0, weight=1)
|
||||
list_frame1.columnconfigure(1, weight=1)
|
||||
list_frame1.columnconfigure(2, weight=1)
|
||||
|
||||
self.save_path = tkinter.StringVar()
|
||||
|
||||
tk.Label(list_frame1, text="Target Path:",font=font_list,justify="left")\
|
||||
.grid(row=0,column=0,sticky=tk.EW)
|
||||
|
||||
tk.Entry(list_frame1, textvariable= self.save_path, font=font_list)\
|
||||
.grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
|
||||
tk.Button(list_frame1, text = "Select Path", font=font_list,
|
||||
command = self.Select_Target, bg='#F4A460', fg='#F5F5F5')\
|
||||
.grid(row=0,column=2,sticky=tk.EW)
|
||||
|
||||
#################################################################################################
|
||||
label_frame = tk.Frame(self.master)
|
||||
label_frame.pack(fill="both", padx=5,pady=5)
|
||||
label_frame.columnconfigure(0, weight=1)
|
||||
label_frame.columnconfigure(1, weight=1)
|
||||
label_frame.columnconfigure(2, weight=1)
|
||||
|
||||
tk.Label(label_frame, text="Crop Size:",font=font_list,justify="left")\
|
||||
.grid(row=0,column=0,sticky=tk.EW)
|
||||
|
||||
tk.Label(label_frame, text="Align Mode:",font=font_list,justify="left")\
|
||||
.grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
tk.Label(label_frame, text="Target Format:",font=font_list,justify="left")\
|
||||
.grid(row=0,column=2,sticky=tk.EW)
|
||||
|
||||
#################################################################################################
|
||||
|
||||
test_frame = tk.Frame(self.master)
|
||||
test_frame.pack(fill="both", padx=5,pady=5)
|
||||
test_frame.columnconfigure(0, weight=1)
|
||||
test_frame.columnconfigure(1, weight=1)
|
||||
test_frame.columnconfigure(2, weight=1)
|
||||
|
||||
self.test_var = tkinter.StringVar()
|
||||
|
||||
self.test_com = ttk.Combobox(test_frame, textvariable=self.test_var)
|
||||
self.test_com.grid(row=0,column=0,sticky=tk.EW)
|
||||
self.test_com["value"] = [256,512,768,1024]
|
||||
self.test_com.current(1)
|
||||
|
||||
self.align_var = tkinter.StringVar()
|
||||
self.align_com = ttk.Combobox(test_frame, textvariable=self.align_var)
|
||||
self.align_com.grid(row=0,column=1,sticky=tk.EW)
|
||||
self.align_com["value"] = ["VGGFace","ffhq"]
|
||||
self.align_com.current(0)
|
||||
|
||||
self.format_var = tkinter.StringVar()
|
||||
|
||||
self.format_com = ttk.Combobox(test_frame, textvariable=self.format_var)
|
||||
self.format_com.grid(row=0,column=2,sticky=tk.EW)
|
||||
self.format_com["value"] = ["png","jpg"]
|
||||
self.format_com.current(0)
|
||||
|
||||
|
||||
|
||||
#################################################################################################
|
||||
scale_frame = tk.Frame(self.master)
|
||||
scale_frame.pack(fill="both", padx=5,pady=5)
|
||||
scale_frame.columnconfigure(0, weight=2)
|
||||
label_frame.columnconfigure(1, weight=1)
|
||||
# label_frame.columnconfigure(2, weight=1)
|
||||
|
||||
tk.Label(scale_frame, text="Min Size:",font=font_list,justify="left")\
|
||||
.grid(row=0,column=0,sticky=tk.EW)
|
||||
self.min_scale = tkinter.StringVar()
|
||||
tk.Scale(scale_frame, from_=0.5, to=2.0, length=500, orient=tk.HORIZONTAL, variable= self.min_scale,\
|
||||
font=font_list, resolution=0.1).grid(row=0,column=1,sticky=tk.EW)
|
||||
|
||||
#################################################################################################
|
||||
test_frame1 = tk.Frame(self.master)
|
||||
test_frame1.pack(fill="both", padx=5,pady=5)
|
||||
test_frame1.columnconfigure(0, weight=1)
|
||||
# test_frame1.columnconfigure(1, weight=1)
|
||||
|
||||
test_update_button = tk.Button(test_frame1, text = "Crop",
|
||||
font=font_list, command = self.Crop, bg='#F4A460', fg='#F5F5F5')
|
||||
test_update_button.grid(row=0,column=0,sticky=tk.EW)
|
||||
|
||||
|
||||
|
||||
#################################################################################################
|
||||
|
||||
text = tk.Text(self.master, wrap="word")
|
||||
text.pack(fill="both",expand="yes", padx=5,pady=5)
|
||||
|
||||
|
||||
sys.stdout = TextRedirector(text, "stdout")
|
||||
|
||||
self.init_algorithm()
|
||||
self.master.protocol("WM_DELETE_WINDOW", self.on_closing)
|
||||
|
||||
def init_algorithm(self):
|
||||
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
|
||||
|
||||
# def __scaning_logs__(self):
|
||||
def Select(self):
|
||||
thread_update = threading.Thread(target=self.select_task)
|
||||
thread_update.start()
|
||||
|
||||
def select_task(self):
|
||||
path = askdirectory()
|
||||
print("Selected source directory: %s"%path)
|
||||
self.img_path.set(path)
|
||||
|
||||
def Select_Target(self):
|
||||
thread_update = threading.Thread(target=self.select_target_task)
|
||||
thread_update.start()
|
||||
|
||||
def select_target_task(self):
|
||||
path = askdirectory()
|
||||
print("Selected target directory: %s"%path)
|
||||
self.save_path.set(path)
|
||||
|
||||
def Crop(self):
|
||||
thread_update = threading.Thread(target=self.crop_task)
|
||||
thread_update.start()
|
||||
|
||||
def crop_task(self):
|
||||
mode = self.align_com.get()
|
||||
crop_size = int(self.test_com.get())
|
||||
|
||||
path = self.img_path.get()
|
||||
tg_path = self.save_path.get()
|
||||
tg_format = self.format_com.get()
|
||||
min_scale = float(self.min_scale.get())
|
||||
blur_t = 100.0
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
self.detect.prepare(ctx_id = 0, det_thresh=0.6,\
|
||||
det_size=(640,640),mode = mode,crop_size=crop_size,ratio=min_scale)
|
||||
if path and tg_path:
|
||||
imgs_list = []
|
||||
if os.path.isdir(path):
|
||||
print("Input a dir....")
|
||||
imgs = glob.glob(os.path.join(path,"*"))
|
||||
for item in imgs:
|
||||
imgs_list.append(item)
|
||||
# print(imgs_list)
|
||||
index = 0
|
||||
for img in imgs_list:
|
||||
print(img)
|
||||
attr_img_ori= cv2.imread(img)
|
||||
try:
|
||||
attr_img_align_crop, _ = self.detect.get(attr_img_ori)
|
||||
sub_index = 0
|
||||
for face_i in attr_img_align_crop:
|
||||
imageVar = cv2.Laplacian(face_i, cv2.CV_64F).var()
|
||||
f_path =os.path.join(tg_path, str(index).zfill(6)+"_%d.%s"%(sub_index,tg_format))
|
||||
if imageVar < blur_t:
|
||||
print("Over blurry image!")
|
||||
continue
|
||||
# face_i = cv2.putText(face_i, '%.1f'%imageVar,(50, 50), font, 0.8, (15, 9, 255), 2)
|
||||
cv2.imwrite(f_path,face_i)
|
||||
sub_index += 1
|
||||
index += 1
|
||||
except:
|
||||
print("Detect no face!")
|
||||
continue
|
||||
else:
|
||||
print("Input an image....")
|
||||
imgs_list.append(path)
|
||||
print("Process finished!")
|
||||
else:
|
||||
print("Pathes are invalid!")
|
||||
|
||||
def on_closing(self):
|
||||
|
||||
# self.__save_config__()
|
||||
self.master.destroy()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = Application()
|
||||
app.mainloop()
|
||||
@@ -0,0 +1,92 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: face_crop.py
|
||||
# Created Date: Tuesday February 1st 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Wednesday, 2nd February 2022 11:17:04 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import sys
|
||||
import glob
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from insightface_func.face_detect_crop_multi import Face_detect_crop
|
||||
|
||||
|
||||
def getParameters():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-p', '--save_path', type=str, default="./output/",
|
||||
help="The root path for saving cropped images")
|
||||
parser.add_argument('-v', '--video', type=str, default="G:\\4K\\05.mp4",
|
||||
help="The path for input video")
|
||||
parser.add_argument('-c', '--crop_size', type=int, default=512,
|
||||
help="expected image resolution")
|
||||
parser.add_argument('-s', '--min_scale', type=float, default=0.7,
|
||||
help="tolerance range for the size of the captured face image")
|
||||
parser.add_argument('-m', '--mode', type=str, default="none",
|
||||
choices=['ffhq', 'none'],help="none:VGG crop, ffhq:FFHQ crop")
|
||||
parser.add_argument('-f', '--format', type=str, default="png",
|
||||
choices=['jpg', 'png'],help="target file format")
|
||||
parser.add_argument('-i', '--interval', type=int, default=20,
|
||||
help="number of frames interval")
|
||||
parser.add_argument('-b', '--blur', type=float, default=10.0,
|
||||
help="blur degree")
|
||||
return parser.parse_args()
|
||||
|
||||
def main(config):
|
||||
mode = config.mode
|
||||
crop_size = config.crop_size
|
||||
video = config.video
|
||||
tg_path = config.save_path
|
||||
tg_format = config.format
|
||||
min_scale = config.min_scale
|
||||
blur_t = config.blur
|
||||
interval = config.interval
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
detect.prepare(ctx_id = 0, det_thresh=0.6,\
|
||||
det_size=(640,640),mode = mode,crop_size=crop_size,ratio=min_scale)
|
||||
video_path = os.path.basename(video)
|
||||
video_basename = os.path.splitext(video_path)[0]
|
||||
save_path = os.path.join(tg_path,video_basename)
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
cap = cv2.VideoCapture(video)
|
||||
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
frame_index = 0
|
||||
# for frame_index in tqdm(range(0,frame_count,interval)):
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if ret==True:
|
||||
|
||||
img_align_crop = detect.get(frame)
|
||||
if img_align_crop:
|
||||
img_align_crop = img_align_crop[0]
|
||||
sub_index = 0
|
||||
for face_i in img_align_crop:
|
||||
imageVar = cv2.Laplacian(face_i, cv2.CV_64F).var()
|
||||
f_path =os.path.join(save_path, str(frame_index).zfill(6)+"_%d.%s"%(sub_index,tg_format))
|
||||
if imageVar < blur_t:
|
||||
print("Over blurry image!")
|
||||
continue
|
||||
# face_i = cv2.putText(face_i, '%.1f'%imageVar,(50, 50), font, 0.8, (15, 9, 255), 2)
|
||||
cv2.imwrite(f_path,face_i)
|
||||
sub_index += 1
|
||||
# else:
|
||||
# print("Detect no face!")
|
||||
frame_index += 1
|
||||
cap.release()
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = getParameters()
|
||||
main(config)
|
||||
@@ -0,0 +1,17 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: cos.py
|
||||
# Created Date: Monday February 7th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 7th February 2022 6:26:23 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
|
||||
def cosin_metric(x1, x2):
|
||||
#return np.dot(x1, x2) / (np.linalg.norm(x1) * np.linalg.norm(x2))
|
||||
return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1))
|
||||
BIN
Binary file not shown.
@@ -0,0 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: similarity.py
|
||||
# Created Date: Thursday January 27th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Thursday, 27th January 2022 3:48:25 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 21st January 2022 10:55:59 am
|
||||
# Last Modified: Sunday, 30th January 2022 4:05:17 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -30,22 +30,22 @@ def getParameters():
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# general settings
|
||||
parser.add_argument('-v', '--version', type=str, default='2layerFM',
|
||||
parser.add_argument('-v', '--version', type=str, default='GramFM',
|
||||
help="version name for train, test, finetune")
|
||||
|
||||
parser.add_argument('-c', '--cuda', type=int, default=0) # >0 if it is set as -1, program will use CPU
|
||||
parser.add_argument('-s', '--checkpoint_step', type=int, default=310000,
|
||||
parser.add_argument('-s', '--checkpoint_step', type=int, default=480000,
|
||||
help="checkpoint epoch for test phase or finetune phase")
|
||||
|
||||
# test
|
||||
parser.add_argument('-t', '--test_script_name', type=str, default='video')
|
||||
parser.add_argument('-t', '--test_script_name', type=str, default='image')
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=1)
|
||||
parser.add_argument('-n', '--node_name', type=str, default='localhost',
|
||||
choices=['localhost', '4card','8card','new4card'])
|
||||
|
||||
|
||||
parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\dlrb2.jpeg')
|
||||
parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\G2010.mp4',
|
||||
parser.add_argument('-i', '--id_imgs', type=str, default='G:\\swap_data\\ID\\dlrb2.jpeg')
|
||||
parser.add_argument('-a', '--attr_files', type=str, default='G:\\swap_data\\ID',
|
||||
help="file path for attribute images or video")
|
||||
|
||||
parser.add_argument('--use_specified_data', action='store_true')
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: tester_commonn.py
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Saturday, 29th January 2022 12:41:01 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import glob
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
from torchvision.utils import save_image
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from insightface_func.face_detect_crop_single import Face_detect_crop
|
||||
|
||||
class Tester(object):
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
self.transformer_Arcface = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
|
||||
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
|
||||
if self.config["cuda"] >=0:
|
||||
self.imagenet_std = self.imagenet_std .cuda()
|
||||
self.imagenet_mean = self.imagenet_mean.cuda()
|
||||
|
||||
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
model_config = self.config["model_configs"]
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# TODO replace below lines to define the model framework
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
self.network = self.network.eval()
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Model structure:")
|
||||
self.reporter.writeModel(self.network.__str__())
|
||||
|
||||
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
|
||||
self.arcface = arcface1['model'].module
|
||||
self.arcface.eval()
|
||||
self.arcface.requires_grad_(False)
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.network = self.network.cuda()
|
||||
self.arcface = self.arcface.cuda()
|
||||
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.network.load_state_dict(torch.load(model_path))
|
||||
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
|
||||
|
||||
def test(self):
|
||||
|
||||
save_dir = self.config["test_samples_path"]
|
||||
ckp_step = self.config["checkpoint_step"]
|
||||
version = self.config["version"]
|
||||
attr_files = self.config["attr_files"]
|
||||
self.arcface_ckpt= self.config["arcface_ckpt"]
|
||||
imgs_list = []
|
||||
if os.path.isdir(attr_files):
|
||||
print("Input a dir....")
|
||||
imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True)
|
||||
for item in imgs:
|
||||
imgs_list.append(item)
|
||||
print(imgs_list)
|
||||
else:
|
||||
print("Input an image....")
|
||||
imgs_list.append(attr_files)
|
||||
img_num = len(imgs_list)
|
||||
|
||||
|
||||
# models
|
||||
self.__init_framework__()
|
||||
|
||||
mode = None
|
||||
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode)
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
print('Start =================================== test...')
|
||||
start_time = time.time()
|
||||
self.network.eval()
|
||||
index = 0
|
||||
with torch.no_grad():
|
||||
for img in imgs_list[1:]:
|
||||
print(img)
|
||||
attr_img_ori= cv2.imread(img)
|
||||
# try:
|
||||
# attr_img_align_crop, mat = self.detect.get(attr_img_ori,512)
|
||||
# except:
|
||||
# print("No face detected!")
|
||||
# continue
|
||||
# attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(attr_img_align_crop[0],cv2.COLOR_BGR2RGB))
|
||||
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(attr_img_ori,cv2.COLOR_BGR2RGB))
|
||||
attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda()
|
||||
|
||||
attr_img_arc= F.interpolate(attr_img,size=(112,112), mode='bicubic')
|
||||
attr_id = self.arcface(attr_img_arc)
|
||||
results = self.network(attr_id)
|
||||
|
||||
results = results * self.imagenet_std + self.imagenet_mean
|
||||
results = results.clamp_(0, 1)
|
||||
attr = attr_img_arc * self.imagenet_std + self.imagenet_mean
|
||||
results = torch.concat((attr, results), dim=2)
|
||||
if index == 0:
|
||||
final_img = results
|
||||
else:
|
||||
final_img = torch.concat((final_img, results), dim=0)
|
||||
index += 1
|
||||
save_filename = os.path.join(save_dir, "ckp_%s_v_%s.png"%(ckp_step, version))
|
||||
mark = 0
|
||||
while(True):
|
||||
if os.path.exists(save_filename):
|
||||
save_filename = os.path.join(save_dir, "ckp_%s_v_%s_%d.png"%(ckp_step, version,mark))
|
||||
mark += 1
|
||||
else:
|
||||
break
|
||||
save_image(final_img, save_filename, nrow=img_num//8)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
print("Elapsed [{}]".format(elapsed))
|
||||
@@ -0,0 +1,207 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: tester_commonn.py
|
||||
# Created Date: Saturday July 3rd 2021
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Saturday, 29th January 2022 12:02:31 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2021 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import glob
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from insightface_func.face_detect_crop_single import Face_detect_crop
|
||||
|
||||
class Tester(object):
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
self.transformer_Arcface = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(3,1,1)
|
||||
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(3,1,1)
|
||||
|
||||
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
model_config = self.config["model_configs"]
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# TODO replace below lines to define the model framework
|
||||
self.network = gen_class(**model_config["g_model"]["module_params"])
|
||||
self.network = self.network.eval()
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Model structure:")
|
||||
self.reporter.writeModel(self.network.__str__())
|
||||
|
||||
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
|
||||
self.arcface = arcface1['model'].module
|
||||
self.arcface.eval()
|
||||
self.arcface.requires_grad_(False)
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.network = self.network.cuda()
|
||||
self.arcface = self.arcface.cuda()
|
||||
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.network.load_state_dict(torch.load(model_path))
|
||||
print('loaded trained backbone model step {}...!'.format(self.config["checkpoint_step"]))
|
||||
|
||||
def test(self):
|
||||
|
||||
save_dir = self.config["test_samples_path"]
|
||||
ckp_step = self.config["checkpoint_step"]
|
||||
version = self.config["version"]
|
||||
id_imgs = self.config["id_imgs"]
|
||||
attr_files = self.config["attr_files"]
|
||||
self.arcface_ckpt= self.config["arcface_ckpt"]
|
||||
imgs_list = []
|
||||
if os.path.isdir(attr_files):
|
||||
print("Input a dir....")
|
||||
imgs = glob.glob(os.path.join(attr_files,"**"), recursive=True)
|
||||
for item in imgs:
|
||||
imgs_list.append(item)
|
||||
print(imgs_list)
|
||||
else:
|
||||
print("Input an image....")
|
||||
imgs_list.append(attr_files)
|
||||
id_basename = os.path.basename(id_imgs)
|
||||
id_basename = os.path.splitext(os.path.basename(id_imgs))[0]
|
||||
|
||||
# models
|
||||
self.__init_framework__()
|
||||
|
||||
mode = None
|
||||
self.detect = Face_detect_crop(name='antelope', root='./insightface_func/models')
|
||||
self.detect.prepare(ctx_id = 0, det_thresh=0.6, det_size=(640,640),mode = mode)
|
||||
|
||||
id_img = cv2.imread(id_imgs)
|
||||
id_img_align_crop, _ = self.detect.get(id_img,512)
|
||||
id_img_align_crop_pil = Image.fromarray(cv2.cvtColor(id_img_align_crop[0],cv2.COLOR_BGR2RGB))
|
||||
id_img = self.transformer_Arcface(id_img_align_crop_pil)
|
||||
id_img = id_img.unsqueeze(0).cuda()
|
||||
|
||||
#create latent id
|
||||
id_img = F.interpolate(id_img,size=(112,112), mode='bicubic')
|
||||
latend_id = self.arcface(id_img)
|
||||
latend_id = F.normalize(latend_id, p=2, dim=1)
|
||||
cos_loss = torch.nn.CosineSimilarity()
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to test at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
print('Start =================================== test...')
|
||||
start_time = time.time()
|
||||
self.network.eval()
|
||||
with torch.no_grad():
|
||||
for img in imgs_list:
|
||||
print(img)
|
||||
attr_img_ori= cv2.imread(img)
|
||||
try:
|
||||
attr_img_align_crop, mat = self.detect.get(attr_img_ori,512)
|
||||
except:
|
||||
continue
|
||||
attr_img_align_crop_pil = Image.fromarray(cv2.cvtColor(attr_img_align_crop[0],cv2.COLOR_BGR2RGB))
|
||||
attr_img = self.transformer_Arcface(attr_img_align_crop_pil).unsqueeze(0).cuda()
|
||||
|
||||
attr_img_arc = F.interpolate(attr_img,size=(112,112), mode='bicubic')
|
||||
# cv2.imwrite(os.path.join("./swap_results", "id_%s.png"%(id_basename)),id_img_align_crop[0])
|
||||
attr_id = self.arcface(attr_img_arc)
|
||||
attr_id = F.normalize(attr_id, p=2, dim=1)
|
||||
cos_dis = 1 - cos_loss(latend_id, attr_id)
|
||||
|
||||
mat = mat[0]
|
||||
results = self.network(attr_img, latend_id)
|
||||
|
||||
results_arc = F.interpolate(results,size=(112,112), mode='bicubic')
|
||||
results_arc = self.arcface(results_arc)
|
||||
results_arc = F.normalize(results_arc, p=2, dim=1)
|
||||
results_cos_dis = 1 - cos_loss(latend_id, results_arc)
|
||||
|
||||
results = results * self.imagenet_std + self.imagenet_mean
|
||||
results = results.cpu().permute(0,2,3,1)[0,...]
|
||||
results = results.numpy()
|
||||
results = np.clip(results,0.0,1.0)
|
||||
img_white = np.full((512,512), 255, dtype=float)
|
||||
|
||||
# inverse the Affine transformation matrix
|
||||
mat_rev = np.zeros([2,3])
|
||||
div1 = mat[0][0]*mat[1][1]-mat[0][1]*mat[1][0]
|
||||
mat_rev[0][0] = mat[1][1]/div1
|
||||
mat_rev[0][1] = -mat[0][1]/div1
|
||||
mat_rev[0][2] = -(mat[0][2]*mat[1][1]-mat[0][1]*mat[1][2])/div1
|
||||
div2 = mat[0][1]*mat[1][0]-mat[0][0]*mat[1][1]
|
||||
mat_rev[1][0] = mat[1][0]/div2
|
||||
mat_rev[1][1] = -mat[0][0]/div2
|
||||
mat_rev[1][2] = -(mat[0][2]*mat[1][0]-mat[0][0]*mat[1][2])/div2
|
||||
|
||||
orisize = (attr_img_ori.shape[1], attr_img_ori.shape[0])
|
||||
|
||||
target_image = cv2.warpAffine(results, mat_rev, orisize)
|
||||
|
||||
img_white = cv2.warpAffine(img_white, mat_rev, orisize)
|
||||
|
||||
|
||||
img_white[img_white>20] =255
|
||||
|
||||
img_mask = img_white
|
||||
|
||||
kernel = np.ones((40,40),np.uint8)
|
||||
img_mask = cv2.erode(img_mask,kernel,iterations = 1)
|
||||
kernel_size = (20, 20)
|
||||
blur_size = tuple(2*i+1 for i in kernel_size)
|
||||
img_mask = cv2.GaussianBlur(img_mask, blur_size, 0)
|
||||
|
||||
img_mask /= 255
|
||||
|
||||
img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1])
|
||||
|
||||
target_image = np.array(target_image, dtype=np.float)[..., ::-1] * 255
|
||||
|
||||
img1 = np.array(attr_img_ori, dtype=np.float)
|
||||
img1 = img_mask * target_image + (1-img_mask) * img1
|
||||
final_img = img1.astype(np.uint8)
|
||||
attr_basename = os.path.splitext(os.path.basename(img))[0]
|
||||
final_img = cv2.putText(final_img, 'id dis=%.4f'%results_cos_dis, (50, 50), font, 0.8, (15, 9, 255), 2)
|
||||
final_img = cv2.putText(final_img, 'id--attr dis=%.4f'%cos_dis, (50, 80), font, 0.8, (15, 9, 255), 2)
|
||||
save_filename = os.path.join(save_dir,
|
||||
"id_%s--attr_%s_ckp_%s_v_%s.png"%(id_basename,
|
||||
attr_basename,ckp_step,version))
|
||||
|
||||
cv2.imwrite(save_filename, final_img)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
print("Elapsed [{}]".format(elapsed))
|
||||
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
# empty
|
||||
@@ -0,0 +1,157 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
import glob
|
||||
import hashlib
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import uuid
|
||||
|
||||
import torch
|
||||
import torch.utils.cpp_extension
|
||||
from torch.utils.file_baton import FileBaton
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Global options.
|
||||
|
||||
verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Internal helper funcs.
|
||||
|
||||
def _find_compiler_bindir():
|
||||
patterns = [
|
||||
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
||||
'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
||||
'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
||||
'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
|
||||
]
|
||||
for pattern in patterns:
|
||||
matches = sorted(glob.glob(pattern))
|
||||
if len(matches):
|
||||
return matches[-1]
|
||||
return None
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _get_mangled_gpu_name():
|
||||
name = torch.cuda.get_device_name().lower()
|
||||
out = []
|
||||
for c in name:
|
||||
if re.match('[a-z0-9_-]+', c):
|
||||
out.append(c)
|
||||
else:
|
||||
out.append('-')
|
||||
return ''.join(out)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Main entry point for compiling and loading C++/CUDA plugins.
|
||||
|
||||
_cached_plugins = dict()
|
||||
|
||||
def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
|
||||
assert verbosity in ['none', 'brief', 'full']
|
||||
if headers is None:
|
||||
headers = []
|
||||
if source_dir is not None:
|
||||
sources = [os.path.join(source_dir, fname) for fname in sources]
|
||||
headers = [os.path.join(source_dir, fname) for fname in headers]
|
||||
|
||||
# Already cached?
|
||||
if module_name in _cached_plugins:
|
||||
return _cached_plugins[module_name]
|
||||
|
||||
# Print status.
|
||||
if verbosity == 'full':
|
||||
print(f'Setting up PyTorch plugin "{module_name}"...')
|
||||
elif verbosity == 'brief':
|
||||
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
|
||||
verbose_build = (verbosity == 'full')
|
||||
|
||||
# Compile and load.
|
||||
try: # pylint: disable=too-many-nested-blocks
|
||||
# Make sure we can find the necessary compiler binaries.
|
||||
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
|
||||
compiler_bindir = _find_compiler_bindir()
|
||||
if compiler_bindir is None:
|
||||
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
|
||||
os.environ['PATH'] += ';' + compiler_bindir
|
||||
|
||||
# Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
|
||||
# break the build or unnecessarily restrict what's available to nvcc.
|
||||
# Unset it to let nvcc decide based on what's available on the
|
||||
# machine.
|
||||
os.environ['TORCH_CUDA_ARCH_LIST'] = ''
|
||||
|
||||
# Incremental build md5sum trickery. Copies all the input source files
|
||||
# into a cached build directory under a combined md5 digest of the input
|
||||
# source files. Copying is done only if the combined digest has changed.
|
||||
# This keeps input file timestamps and filenames the same as in previous
|
||||
# extension builds, allowing for fast incremental rebuilds.
|
||||
#
|
||||
# This optimization is done only in case all the source files reside in
|
||||
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
|
||||
# environment variable is set (we take this as a signal that the user
|
||||
# actually cares about this.)
|
||||
#
|
||||
# EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
|
||||
# around the *.cu dependency bug in ninja config.
|
||||
#
|
||||
all_source_files = sorted(sources + headers)
|
||||
all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
|
||||
if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
|
||||
|
||||
# Compute combined hash digest for all source files.
|
||||
hash_md5 = hashlib.md5()
|
||||
for src in all_source_files:
|
||||
with open(src, 'rb') as f:
|
||||
hash_md5.update(f.read())
|
||||
|
||||
# Select cached build directory name.
|
||||
source_digest = hash_md5.hexdigest()
|
||||
build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
|
||||
cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
|
||||
|
||||
if not os.path.isdir(cached_build_dir):
|
||||
tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
|
||||
os.makedirs(tmpdir)
|
||||
for src in all_source_files:
|
||||
shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
|
||||
try:
|
||||
os.replace(tmpdir, cached_build_dir) # atomic
|
||||
except OSError:
|
||||
# source directory already exists, delete tmpdir and its contents.
|
||||
shutil.rmtree(tmpdir)
|
||||
if not os.path.isdir(cached_build_dir): raise
|
||||
|
||||
# Compile.
|
||||
cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
|
||||
torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
|
||||
verbose=verbose_build, sources=cached_sources, **build_kwargs)
|
||||
else:
|
||||
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
|
||||
|
||||
# Load.
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
except:
|
||||
if verbosity == 'brief':
|
||||
print('Failed!')
|
||||
raise
|
||||
|
||||
# Print status and add to cache dict.
|
||||
if verbosity == 'full':
|
||||
print(f'Done setting up PyTorch plugin "{module_name}".')
|
||||
elif verbosity == 'brief':
|
||||
print('Done.')
|
||||
_cached_plugins[module_name] = module
|
||||
return module
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,272 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
import re
|
||||
import contextlib
|
||||
import numpy as np
|
||||
import torch
|
||||
import warnings
|
||||
import dnnlib
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
|
||||
# same constant is used multiple times.
|
||||
|
||||
_constant_cache = dict()
|
||||
|
||||
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
||||
value = np.asarray(value)
|
||||
if shape is not None:
|
||||
shape = tuple(shape)
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
if device is None:
|
||||
device = torch.device('cpu')
|
||||
if memory_format is None:
|
||||
memory_format = torch.contiguous_format
|
||||
|
||||
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
||||
tensor = _constant_cache.get(key, None)
|
||||
if tensor is None:
|
||||
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
||||
if shape is not None:
|
||||
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
||||
tensor = tensor.contiguous(memory_format=memory_format)
|
||||
_constant_cache[key] = tensor
|
||||
return tensor
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Replace NaN/Inf with specified numerical values.
|
||||
|
||||
try:
|
||||
nan_to_num = torch.nan_to_num # 1.8.0a0
|
||||
except AttributeError:
|
||||
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
|
||||
assert isinstance(input, torch.Tensor)
|
||||
if posinf is None:
|
||||
posinf = torch.finfo(input.dtype).max
|
||||
if neginf is None:
|
||||
neginf = torch.finfo(input.dtype).min
|
||||
assert nan == 0
|
||||
return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Symbolic assert.
|
||||
|
||||
try:
|
||||
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
|
||||
except AttributeError:
|
||||
symbolic_assert = torch.Assert # 1.7.0
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Context manager to temporarily suppress known warnings in torch.jit.trace().
|
||||
# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_tracer_warnings():
|
||||
flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
|
||||
warnings.filters.insert(0, flt)
|
||||
yield
|
||||
warnings.filters.remove(flt)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Assert that the shape of a tensor matches the given list of integers.
|
||||
# None indicates that the size of a dimension is allowed to vary.
|
||||
# Performs symbolic assertion when used in torch.jit.trace().
|
||||
|
||||
def assert_shape(tensor, ref_shape):
|
||||
if tensor.ndim != len(ref_shape):
|
||||
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
|
||||
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
||||
if ref_size is None:
|
||||
pass
|
||||
elif isinstance(ref_size, torch.Tensor):
|
||||
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
||||
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
|
||||
elif isinstance(size, torch.Tensor):
|
||||
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
||||
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
|
||||
elif size != ref_size:
|
||||
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Function decorator that calls torch.autograd.profiler.record_function().
|
||||
|
||||
def profiled_function(fn):
|
||||
def decorator(*args, **kwargs):
|
||||
with torch.autograd.profiler.record_function(fn.__name__):
|
||||
return fn(*args, **kwargs)
|
||||
decorator.__name__ = fn.__name__
|
||||
return decorator
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Sampler for torch.utils.data.DataLoader that loops over the dataset
|
||||
# indefinitely, shuffling items as it goes.
|
||||
|
||||
class InfiniteSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
|
||||
assert len(dataset) > 0
|
||||
assert num_replicas > 0
|
||||
assert 0 <= rank < num_replicas
|
||||
assert 0 <= window_size <= 1
|
||||
super().__init__(dataset)
|
||||
self.dataset = dataset
|
||||
self.rank = rank
|
||||
self.num_replicas = num_replicas
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self.window_size = window_size
|
||||
|
||||
def __iter__(self):
|
||||
order = np.arange(len(self.dataset))
|
||||
rnd = None
|
||||
window = 0
|
||||
if self.shuffle:
|
||||
rnd = np.random.RandomState(self.seed)
|
||||
rnd.shuffle(order)
|
||||
window = int(np.rint(order.size * self.window_size))
|
||||
|
||||
idx = 0
|
||||
while True:
|
||||
i = idx % order.size
|
||||
if idx % self.num_replicas == self.rank:
|
||||
yield order[i]
|
||||
if window >= 2:
|
||||
j = (i - rnd.randint(window)) % order.size
|
||||
order[i], order[j] = order[j], order[i]
|
||||
idx += 1
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Utilities for operating with torch.nn.Module parameters and buffers.
|
||||
|
||||
def params_and_buffers(module):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
return list(module.parameters()) + list(module.buffers())
|
||||
|
||||
def named_params_and_buffers(module):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
return list(module.named_parameters()) + list(module.named_buffers())
|
||||
|
||||
def copy_params_and_buffers(src_module, dst_module, require_all=False):
|
||||
assert isinstance(src_module, torch.nn.Module)
|
||||
assert isinstance(dst_module, torch.nn.Module)
|
||||
src_tensors = dict(named_params_and_buffers(src_module))
|
||||
for name, tensor in named_params_and_buffers(dst_module):
|
||||
assert (name in src_tensors) or (not require_all)
|
||||
if name in src_tensors:
|
||||
tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Context manager for easily enabling/disabling DistributedDataParallel
|
||||
# synchronization.
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ddp_sync(module, sync):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
|
||||
yield
|
||||
else:
|
||||
with module.no_sync():
|
||||
yield
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Check DistributedDataParallel consistency across processes.
|
||||
|
||||
def check_ddp_consistency(module, ignore_regex=None):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
for name, tensor in named_params_and_buffers(module):
|
||||
fullname = type(module).__name__ + '.' + name
|
||||
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
|
||||
continue
|
||||
tensor = tensor.detach()
|
||||
if tensor.is_floating_point():
|
||||
tensor = nan_to_num(tensor)
|
||||
other = tensor.clone()
|
||||
torch.distributed.broadcast(tensor=other, src=0)
|
||||
assert (tensor == other).all(), fullname
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Print summary table of module hierarchy.
|
||||
|
||||
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
assert not isinstance(module, torch.jit.ScriptModule)
|
||||
assert isinstance(inputs, (tuple, list))
|
||||
|
||||
# Register hooks.
|
||||
entries = []
|
||||
nesting = [0]
|
||||
def pre_hook(_mod, _inputs):
|
||||
nesting[0] += 1
|
||||
def post_hook(mod, _inputs, outputs):
|
||||
nesting[0] -= 1
|
||||
if nesting[0] <= max_nesting:
|
||||
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
|
||||
outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
|
||||
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
|
||||
hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
|
||||
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
|
||||
|
||||
# Run module.
|
||||
outputs = module(*inputs)
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
# Identify unique outputs, parameters, and buffers.
|
||||
tensors_seen = set()
|
||||
for e in entries:
|
||||
e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
|
||||
e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
|
||||
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
|
||||
tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
|
||||
|
||||
# Filter out redundant entries.
|
||||
if skip_redundant:
|
||||
entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
|
||||
|
||||
# Construct table.
|
||||
rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
|
||||
rows += [['---'] * len(rows[0])]
|
||||
param_total = 0
|
||||
buffer_total = 0
|
||||
submodule_names = {mod: name for name, mod in module.named_modules()}
|
||||
for e in entries:
|
||||
name = '<top-level>' if e.mod is module else submodule_names[e.mod]
|
||||
param_size = sum(t.numel() for t in e.unique_params)
|
||||
buffer_size = sum(t.numel() for t in e.unique_buffers)
|
||||
output_shapes = [str(list(t.shape)) for t in e.outputs]
|
||||
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
|
||||
rows += [[
|
||||
name + (':0' if len(e.outputs) >= 2 else ''),
|
||||
str(param_size) if param_size else '-',
|
||||
str(buffer_size) if buffer_size else '-',
|
||||
(output_shapes + ['-'])[0],
|
||||
(output_dtypes + ['-'])[0],
|
||||
]]
|
||||
for idx in range(1, len(e.outputs)):
|
||||
rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
|
||||
param_total += param_size
|
||||
buffer_total += buffer_size
|
||||
rows += [['---'] * len(rows[0])]
|
||||
rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
|
||||
|
||||
# Print table.
|
||||
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
|
||||
print()
|
||||
for row in rows:
|
||||
print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
|
||||
print()
|
||||
return outputs
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
# Added by Katja
|
||||
import os
|
||||
|
||||
def get_ckpt_path(run_dir):
|
||||
return os.path.join(run_dir, f'network-snapshot.pkl')
|
||||
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
# empty
|
||||
@@ -0,0 +1,99 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "bias_act.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
static bool has_same_layout(torch::Tensor x, torch::Tensor y)
|
||||
{
|
||||
if (x.dim() != y.dim())
|
||||
return false;
|
||||
for (int64_t i = 0; i < x.dim(); i++)
|
||||
{
|
||||
if (x.size(i) != y.size(i))
|
||||
return false;
|
||||
if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
|
||||
{
|
||||
// Validate arguments.
|
||||
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
||||
TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
|
||||
TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
|
||||
TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
|
||||
TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
|
||||
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
||||
TORCH_CHECK(b.dim() == 1, "b must have rank 1");
|
||||
TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
|
||||
TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
|
||||
TORCH_CHECK(grad >= 0, "grad must be non-negative");
|
||||
|
||||
// Validate layout.
|
||||
TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
|
||||
TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
|
||||
TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
|
||||
TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
|
||||
TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
|
||||
|
||||
// Create output tensor.
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
torch::Tensor y = torch::empty_like(x);
|
||||
TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
|
||||
|
||||
// Initialize CUDA kernel parameters.
|
||||
bias_act_kernel_params p;
|
||||
p.x = x.data_ptr();
|
||||
p.b = (b.numel()) ? b.data_ptr() : NULL;
|
||||
p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
|
||||
p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
|
||||
p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
|
||||
p.y = y.data_ptr();
|
||||
p.grad = grad;
|
||||
p.act = act;
|
||||
p.alpha = alpha;
|
||||
p.gain = gain;
|
||||
p.clamp = clamp;
|
||||
p.sizeX = (int)x.numel();
|
||||
p.sizeB = (int)b.numel();
|
||||
p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
|
||||
|
||||
// Choose CUDA kernel.
|
||||
void* kernel;
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
||||
{
|
||||
kernel = choose_bias_act_kernel<scalar_t>(p);
|
||||
});
|
||||
TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
|
||||
|
||||
// Launch CUDA kernel.
|
||||
p.loopX = 4;
|
||||
int blockSize = 4 * 32;
|
||||
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
|
||||
void* args[] = {&p};
|
||||
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
||||
return y;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("bias_act", &bias_act);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
@@ -0,0 +1,173 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <c10/util/Half.h>
|
||||
#include "bias_act.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Helpers.
|
||||
|
||||
template <class T> struct InternalType;
|
||||
template <> struct InternalType<double> { typedef double scalar_t; };
|
||||
template <> struct InternalType<float> { typedef float scalar_t; };
|
||||
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel.
|
||||
|
||||
template <class T, int A>
|
||||
__global__ void bias_act_kernel(bias_act_kernel_params p)
|
||||
{
|
||||
typedef typename InternalType<T>::scalar_t scalar_t;
|
||||
int G = p.grad;
|
||||
scalar_t alpha = (scalar_t)p.alpha;
|
||||
scalar_t gain = (scalar_t)p.gain;
|
||||
scalar_t clamp = (scalar_t)p.clamp;
|
||||
scalar_t one = (scalar_t)1;
|
||||
scalar_t two = (scalar_t)2;
|
||||
scalar_t expRange = (scalar_t)80;
|
||||
scalar_t halfExpRange = (scalar_t)40;
|
||||
scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
|
||||
scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
|
||||
|
||||
// Loop over elements.
|
||||
int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
|
||||
for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
|
||||
{
|
||||
// Load.
|
||||
scalar_t x = (scalar_t)((const T*)p.x)[xi];
|
||||
scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
|
||||
scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
|
||||
scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
|
||||
scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
|
||||
scalar_t yy = (gain != 0) ? yref / gain : 0;
|
||||
scalar_t y = 0;
|
||||
|
||||
// Apply bias.
|
||||
((G == 0) ? x : xref) += b;
|
||||
|
||||
// linear
|
||||
if (A == 1)
|
||||
{
|
||||
if (G == 0) y = x;
|
||||
if (G == 1) y = x;
|
||||
}
|
||||
|
||||
// relu
|
||||
if (A == 2)
|
||||
{
|
||||
if (G == 0) y = (x > 0) ? x : 0;
|
||||
if (G == 1) y = (yy > 0) ? x : 0;
|
||||
}
|
||||
|
||||
// lrelu
|
||||
if (A == 3)
|
||||
{
|
||||
if (G == 0) y = (x > 0) ? x : x * alpha;
|
||||
if (G == 1) y = (yy > 0) ? x : x * alpha;
|
||||
}
|
||||
|
||||
// tanh
|
||||
if (A == 4)
|
||||
{
|
||||
if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
|
||||
if (G == 1) y = x * (one - yy * yy);
|
||||
if (G == 2) y = x * (one - yy * yy) * (-two * yy);
|
||||
}
|
||||
|
||||
// sigmoid
|
||||
if (A == 5)
|
||||
{
|
||||
if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
|
||||
if (G == 1) y = x * yy * (one - yy);
|
||||
if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
|
||||
}
|
||||
|
||||
// elu
|
||||
if (A == 6)
|
||||
{
|
||||
if (G == 0) y = (x >= 0) ? x : exp(x) - one;
|
||||
if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
|
||||
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
|
||||
}
|
||||
|
||||
// selu
|
||||
if (A == 7)
|
||||
{
|
||||
if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
|
||||
if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
|
||||
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
|
||||
}
|
||||
|
||||
// softplus
|
||||
if (A == 8)
|
||||
{
|
||||
if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
|
||||
if (G == 1) y = x * (one - exp(-yy));
|
||||
if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
|
||||
}
|
||||
|
||||
// swish
|
||||
if (A == 9)
|
||||
{
|
||||
if (G == 0)
|
||||
y = (x < -expRange) ? 0 : x / (exp(-x) + one);
|
||||
else
|
||||
{
|
||||
scalar_t c = exp(xref);
|
||||
scalar_t d = c + one;
|
||||
if (G == 1)
|
||||
y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
|
||||
else
|
||||
y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
|
||||
yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
|
||||
}
|
||||
}
|
||||
|
||||
// Apply gain.
|
||||
y *= gain * dy;
|
||||
|
||||
// Clamp.
|
||||
if (clamp >= 0)
|
||||
{
|
||||
if (G == 0)
|
||||
y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
|
||||
else
|
||||
y = (yref > -clamp & yref < clamp) ? y : 0;
|
||||
}
|
||||
|
||||
// Store.
|
||||
((T*)p.y)[xi] = (T)y;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
|
||||
{
|
||||
if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
|
||||
if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
|
||||
if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
|
||||
if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
|
||||
if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
|
||||
if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
|
||||
if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
|
||||
if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
|
||||
if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
|
||||
return NULL;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Template specializations.
|
||||
|
||||
template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
|
||||
template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
|
||||
template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel parameters.
|
||||
|
||||
struct bias_act_kernel_params
|
||||
{
|
||||
const void* x; // [sizeX]
|
||||
const void* b; // [sizeB] or NULL
|
||||
const void* xref; // [sizeX] or NULL
|
||||
const void* yref; // [sizeX] or NULL
|
||||
const void* dy; // [sizeX] or NULL
|
||||
void* y; // [sizeX]
|
||||
|
||||
int grad;
|
||||
int act;
|
||||
float alpha;
|
||||
float gain;
|
||||
float clamp;
|
||||
|
||||
int sizeX;
|
||||
int sizeB;
|
||||
int stepB;
|
||||
int loopX;
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
@@ -0,0 +1,209 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Custom PyTorch ops for efficient bias and activation."""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import dnnlib
|
||||
|
||||
from .. import custom_ops
|
||||
from .. import misc
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
activation_funcs = {
|
||||
'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
|
||||
'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
|
||||
'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
|
||||
'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
|
||||
'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
|
||||
'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
|
||||
'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
|
||||
'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
|
||||
'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
|
||||
}
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_plugin = None
|
||||
_null_tensor = torch.empty([0])
|
||||
|
||||
def _init():
|
||||
global _plugin
|
||||
if _plugin is None:
|
||||
_plugin = custom_ops.get_plugin(
|
||||
module_name='bias_act_plugin',
|
||||
sources=['bias_act.cpp', 'bias_act.cu'],
|
||||
headers=['bias_act.h'],
|
||||
source_dir=os.path.dirname(__file__),
|
||||
extra_cuda_cflags=['--use_fast_math'],
|
||||
)
|
||||
return True
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
|
||||
r"""Fused bias and activation function.
|
||||
|
||||
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
|
||||
and scales the result by `gain`. Each of the steps is optional. In most cases,
|
||||
the fused op is considerably more efficient than performing the same calculation
|
||||
using standard PyTorch ops. It supports first and second order gradients,
|
||||
but not third order gradients.
|
||||
|
||||
Args:
|
||||
x: Input activation tensor. Can be of any shape.
|
||||
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
||||
as `x`. The shape must be known, and it must match the dimension of `x`
|
||||
corresponding to `dim`.
|
||||
dim: The dimension in `x` corresponding to the elements of `b`.
|
||||
The value of `dim` is ignored if `b` is not specified.
|
||||
act: Name of the activation function to evaluate, or `"linear"` to disable.
|
||||
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
|
||||
See `activation_funcs` for a full list. `None` is not allowed.
|
||||
alpha: Shape parameter for the activation function, or `None` to use the default.
|
||||
gain: Scaling factor for the output tensor, or `None` to use default.
|
||||
See `activation_funcs` for the default scaling of each activation function.
|
||||
If unsure, consider specifying 1.
|
||||
clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
|
||||
the clamping (default).
|
||||
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
||||
|
||||
Returns:
|
||||
Tensor of the same shape and datatype as `x`.
|
||||
"""
|
||||
assert isinstance(x, torch.Tensor)
|
||||
assert impl in ['ref', 'cuda']
|
||||
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
||||
return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
|
||||
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
||||
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.
|
||||
"""
|
||||
assert isinstance(x, torch.Tensor)
|
||||
assert clamp is None or clamp >= 0
|
||||
spec = activation_funcs[act]
|
||||
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
||||
gain = float(gain if gain is not None else spec.def_gain)
|
||||
clamp = float(clamp if clamp is not None else -1)
|
||||
|
||||
# Add bias.
|
||||
if b is not None:
|
||||
assert isinstance(b, torch.Tensor) and b.ndim == 1
|
||||
assert 0 <= dim < x.ndim
|
||||
assert b.shape[0] == x.shape[dim]
|
||||
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
|
||||
|
||||
# Evaluate activation function.
|
||||
alpha = float(alpha)
|
||||
x = spec.func(x, alpha=alpha)
|
||||
|
||||
# Scale by gain.
|
||||
gain = float(gain)
|
||||
if gain != 1:
|
||||
x = x * gain
|
||||
|
||||
# Clamp.
|
||||
if clamp >= 0:
|
||||
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_bias_act_cuda_cache = dict()
|
||||
|
||||
def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
||||
"""Fast CUDA implementation of `bias_act()` using custom ops.
|
||||
"""
|
||||
# Parse arguments.
|
||||
assert clamp is None or clamp >= 0
|
||||
spec = activation_funcs[act]
|
||||
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
||||
gain = float(gain if gain is not None else spec.def_gain)
|
||||
clamp = float(clamp if clamp is not None else -1)
|
||||
|
||||
# Lookup from cache.
|
||||
key = (dim, act, alpha, gain, clamp)
|
||||
if key in _bias_act_cuda_cache:
|
||||
return _bias_act_cuda_cache[key]
|
||||
|
||||
# Forward op.
|
||||
class BiasActCuda(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, b): # pylint: disable=arguments-differ
|
||||
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
|
||||
x = x.contiguous(memory_format=ctx.memory_format)
|
||||
b = b.contiguous() if b is not None else _null_tensor
|
||||
y = x
|
||||
if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
|
||||
y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
|
||||
ctx.save_for_backward(
|
||||
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
||||
b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
||||
y if 'y' in spec.ref else _null_tensor)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy): # pylint: disable=arguments-differ
|
||||
dy = dy.contiguous(memory_format=ctx.memory_format)
|
||||
x, b, y = ctx.saved_tensors
|
||||
dx = None
|
||||
db = None
|
||||
|
||||
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
||||
dx = dy
|
||||
if act != 'linear' or gain != 1 or clamp >= 0:
|
||||
dx = BiasActCudaGrad.apply(dy, x, b, y)
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
db = dx.sum([i for i in range(dx.ndim) if i != dim])
|
||||
|
||||
return dx, db
|
||||
|
||||
# Backward op.
|
||||
class BiasActCudaGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
|
||||
ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
|
||||
dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
|
||||
ctx.save_for_backward(
|
||||
dy if spec.has_2nd_grad else _null_tensor,
|
||||
x, b, y)
|
||||
return dx
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, d_dx): # pylint: disable=arguments-differ
|
||||
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
|
||||
dy, x, b, y = ctx.saved_tensors
|
||||
d_dy = None
|
||||
d_x = None
|
||||
d_b = None
|
||||
d_y = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
|
||||
|
||||
if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
|
||||
d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
|
||||
|
||||
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
|
||||
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
|
||||
|
||||
return d_dy, d_x, d_b, d_y
|
||||
|
||||
# Add to cache.
|
||||
_bias_act_cuda_cache[key] = BiasActCuda
|
||||
return BiasActCuda
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,198 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Custom replacement for `torch.nn.functional.conv2d` that supports
|
||||
arbitrarily high order gradients with zero performance penalty."""
|
||||
|
||||
import contextlib
|
||||
import torch
|
||||
|
||||
# pylint: disable=redefined-builtin
|
||||
# pylint: disable=arguments-differ
|
||||
# pylint: disable=protected-access
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
enabled = False # Enable the custom op by setting this to true.
|
||||
weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
|
||||
|
||||
@contextlib.contextmanager
|
||||
def no_weight_gradients(disable=True):
|
||||
global weight_gradients_disabled
|
||||
old = weight_gradients_disabled
|
||||
if disable:
|
||||
weight_gradients_disabled = True
|
||||
yield
|
||||
weight_gradients_disabled = old
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
if _should_use_custom_op(input):
|
||||
return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
|
||||
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
|
||||
def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
||||
if _should_use_custom_op(input):
|
||||
return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
|
||||
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _should_use_custom_op(input):
|
||||
assert isinstance(input, torch.Tensor)
|
||||
if (not enabled) or (not torch.backends.cudnn.enabled):
|
||||
return False
|
||||
if input.device.type != 'cuda':
|
||||
return False
|
||||
return True
|
||||
|
||||
def _tuple_of_ints(xs, ndim):
|
||||
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
||||
assert len(xs) == ndim
|
||||
assert all(isinstance(x, int) for x in xs)
|
||||
return xs
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_conv2d_gradfix_cache = dict()
|
||||
_null_tensor = torch.empty([0])
|
||||
|
||||
def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
|
||||
# Parse arguments.
|
||||
ndim = 2
|
||||
weight_shape = tuple(weight_shape)
|
||||
stride = _tuple_of_ints(stride, ndim)
|
||||
padding = _tuple_of_ints(padding, ndim)
|
||||
output_padding = _tuple_of_ints(output_padding, ndim)
|
||||
dilation = _tuple_of_ints(dilation, ndim)
|
||||
|
||||
# Lookup from cache.
|
||||
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
||||
if key in _conv2d_gradfix_cache:
|
||||
return _conv2d_gradfix_cache[key]
|
||||
|
||||
# Validate arguments.
|
||||
assert groups >= 1
|
||||
assert len(weight_shape) == ndim + 2
|
||||
assert all(stride[i] >= 1 for i in range(ndim))
|
||||
assert all(padding[i] >= 0 for i in range(ndim))
|
||||
assert all(dilation[i] >= 0 for i in range(ndim))
|
||||
if not transpose:
|
||||
assert all(output_padding[i] == 0 for i in range(ndim))
|
||||
else: # transpose
|
||||
assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
|
||||
|
||||
# Helpers.
|
||||
common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
def calc_output_padding(input_shape, output_shape):
|
||||
if transpose:
|
||||
return [0, 0]
|
||||
return [
|
||||
input_shape[i + 2]
|
||||
- (output_shape[i + 2] - 1) * stride[i]
|
||||
- (1 - 2 * padding[i])
|
||||
- dilation[i] * (weight_shape[i + 2] - 1)
|
||||
for i in range(ndim)
|
||||
]
|
||||
|
||||
# Forward & backward.
|
||||
class Conv2d(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, bias):
|
||||
assert weight.shape == weight_shape
|
||||
ctx.save_for_backward(
|
||||
input if weight.requires_grad else _null_tensor,
|
||||
weight if input.requires_grad else _null_tensor,
|
||||
)
|
||||
ctx.input_shape = input.shape
|
||||
|
||||
# Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
|
||||
if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
|
||||
a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
|
||||
b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
|
||||
c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
|
||||
c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
|
||||
c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
|
||||
return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
|
||||
|
||||
# General case => cuDNN.
|
||||
if transpose:
|
||||
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
|
||||
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
input_shape = ctx.input_shape
|
||||
grad_input = None
|
||||
grad_weight = None
|
||||
grad_bias = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
|
||||
op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
|
||||
grad_input = op.apply(grad_output, weight, None)
|
||||
assert grad_input.shape == input_shape
|
||||
|
||||
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
||||
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
||||
assert grad_weight.shape == weight_shape
|
||||
|
||||
if ctx.needs_input_grad[2]:
|
||||
grad_bias = grad_output.sum([0, 2, 3])
|
||||
|
||||
return grad_input, grad_weight, grad_bias
|
||||
|
||||
# Gradient with respect to the weights.
|
||||
class Conv2dGradWeight(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, input):
|
||||
ctx.save_for_backward(
|
||||
grad_output if input.requires_grad else _null_tensor,
|
||||
input if grad_output.requires_grad else _null_tensor,
|
||||
)
|
||||
ctx.grad_output_shape = grad_output.shape
|
||||
ctx.input_shape = input.shape
|
||||
|
||||
# Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
|
||||
if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
|
||||
a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
|
||||
b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
|
||||
c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
|
||||
return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
|
||||
|
||||
# General case => cuDNN.
|
||||
name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight'
|
||||
flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
|
||||
return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad2_grad_weight):
|
||||
grad_output, input = ctx.saved_tensors
|
||||
grad_output_shape = ctx.grad_output_shape
|
||||
input_shape = ctx.input_shape
|
||||
grad2_grad_output = None
|
||||
grad2_input = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
|
||||
assert grad2_grad_output.shape == grad_output_shape
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
|
||||
op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
|
||||
grad2_input = op.apply(grad_output, grad2_grad_weight, None)
|
||||
assert grad2_input.shape == input_shape
|
||||
|
||||
return grad2_grad_output, grad2_input
|
||||
|
||||
_conv2d_gradfix_cache[key] = Conv2d
|
||||
return Conv2d
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,143 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""2D convolution with optional up/downsampling."""
|
||||
|
||||
import torch
|
||||
|
||||
from .. import misc
|
||||
from . import conv2d_gradfix
|
||||
from . import upfirdn2d
|
||||
from .upfirdn2d import _parse_padding
|
||||
from .upfirdn2d import _get_filter_size
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _get_weight_shape(w):
|
||||
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
||||
shape = [int(sz) for sz in w.shape]
|
||||
misc.assert_shape(w, shape)
|
||||
return shape
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
|
||||
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
|
||||
"""
|
||||
_out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
|
||||
|
||||
# Flip weight if requested.
|
||||
# Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
|
||||
if not flip_weight and (kw > 1 or kh > 1):
|
||||
w = w.flip([2, 3])
|
||||
|
||||
# Execute using conv2d_gradfix.
|
||||
op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
|
||||
return op(x, w, stride=stride, padding=padding, groups=groups)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
|
||||
r"""2D convolution with optional up/downsampling.
|
||||
|
||||
Padding is performed only once at the beginning, not between the operations.
|
||||
|
||||
Args:
|
||||
x: Input tensor of shape
|
||||
`[batch_size, in_channels, in_height, in_width]`.
|
||||
w: Weight tensor of shape
|
||||
`[out_channels, in_channels//groups, kernel_height, kernel_width]`.
|
||||
f: Low-pass filter for up/downsampling. Must be prepared beforehand by
|
||||
calling upfirdn2d.setup_filter(). None = identity (default).
|
||||
up: Integer upsampling factor (default: 1).
|
||||
down: Integer downsampling factor (default: 1).
|
||||
padding: Padding with respect to the upsampled image. Can be a single number
|
||||
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
groups: Split input channels into N groups (default: 1).
|
||||
flip_weight: False = convolution, True = correlation (default: True).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
# Validate arguments.
|
||||
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
|
||||
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
|
||||
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
|
||||
assert isinstance(up, int) and (up >= 1)
|
||||
assert isinstance(down, int) and (down >= 1)
|
||||
assert isinstance(groups, int) and (groups >= 1)
|
||||
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
||||
fw, fh = _get_filter_size(f)
|
||||
px0, px1, py0, py1 = _parse_padding(padding)
|
||||
|
||||
# Adjust padding to account for up/downsampling.
|
||||
if up > 1:
|
||||
px0 += (fw + up - 1) // 2
|
||||
px1 += (fw - up) // 2
|
||||
py0 += (fh + up - 1) // 2
|
||||
py1 += (fh - up) // 2
|
||||
if down > 1:
|
||||
px0 += (fw - down + 1) // 2
|
||||
px1 += (fw - down) // 2
|
||||
py0 += (fh - down + 1) // 2
|
||||
py1 += (fh - down) // 2
|
||||
|
||||
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
|
||||
if kw == 1 and kh == 1 and (down > 1 and up == 1):
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
||||
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
||||
return x
|
||||
|
||||
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
|
||||
if kw == 1 and kh == 1 and (up > 1 and down == 1):
|
||||
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
||||
return x
|
||||
|
||||
# Fast path: downsampling only => use strided convolution.
|
||||
if down > 1 and up == 1:
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
||||
x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
|
||||
return x
|
||||
|
||||
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
|
||||
if up > 1:
|
||||
if groups == 1:
|
||||
w = w.transpose(0, 1)
|
||||
else:
|
||||
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
|
||||
w = w.transpose(1, 2)
|
||||
w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
|
||||
px0 -= kw - 1
|
||||
px1 -= kw - up
|
||||
py0 -= kh - 1
|
||||
py1 -= kh - up
|
||||
pxt = max(min(-px0, -px1), 0)
|
||||
pyt = max(min(-py0, -py1), 0)
|
||||
x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
|
||||
if down > 1:
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
||||
return x
|
||||
|
||||
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
|
||||
if up == 1 and down == 1:
|
||||
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
|
||||
return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
|
||||
|
||||
# Fallback: Generic reference implementation.
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
||||
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
||||
if down > 1:
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,300 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "filtered_lrelu.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
static std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu(
|
||||
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
|
||||
int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
|
||||
{
|
||||
// Set CUDA device.
|
||||
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
|
||||
// Validate arguments.
|
||||
TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
|
||||
TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
|
||||
TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
|
||||
TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
|
||||
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
||||
TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
|
||||
TORCH_CHECK(x.numel() > 0, "x is empty");
|
||||
TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
|
||||
TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
|
||||
TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
|
||||
TORCH_CHECK(fu.numel() > 0, "fu is empty");
|
||||
TORCH_CHECK(fd.numel() > 0, "fd is empty");
|
||||
TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
|
||||
TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
|
||||
|
||||
// Figure out how much shared memory is available on the device.
|
||||
int maxSharedBytes = 0;
|
||||
AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));
|
||||
int sharedKB = maxSharedBytes >> 10;
|
||||
|
||||
// Populate enough launch parameters to check if a CUDA kernel exists.
|
||||
filtered_lrelu_kernel_params p;
|
||||
p.up = up;
|
||||
p.down = down;
|
||||
p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
|
||||
p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
|
||||
filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel<float, int32_t, false, false>(p, sharedKB);
|
||||
if (!test_spec.exec)
|
||||
{
|
||||
// No kernel found - return empty tensors and indicate missing kernel with return code of -1.
|
||||
return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
|
||||
}
|
||||
|
||||
// Input/output element size.
|
||||
int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
|
||||
|
||||
// Input sizes.
|
||||
int64_t xw = (int)x.size(3);
|
||||
int64_t xh = (int)x.size(2);
|
||||
int64_t fut_w = (int)fu.size(-1) - 1;
|
||||
int64_t fut_h = (int)fu.size(0) - 1;
|
||||
int64_t fdt_w = (int)fd.size(-1) - 1;
|
||||
int64_t fdt_h = (int)fd.size(0) - 1;
|
||||
|
||||
// Logical size of upsampled buffer.
|
||||
int64_t cw = xw * up + (px0 + px1) - fut_w;
|
||||
int64_t ch = xh * up + (py0 + py1) - fut_h;
|
||||
TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
|
||||
TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
|
||||
|
||||
// Compute output size and allocate.
|
||||
int64_t yw = (cw - fdt_w + (down - 1)) / down;
|
||||
int64_t yh = (ch - fdt_h + (down - 1)) / down;
|
||||
TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
|
||||
TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
|
||||
torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
|
||||
|
||||
// Allocate sign tensor.
|
||||
torch::Tensor so;
|
||||
torch::Tensor s = si;
|
||||
bool readSigns = !!s.numel();
|
||||
int64_t sw_active = 0; // Active width of sign tensor.
|
||||
if (writeSigns)
|
||||
{
|
||||
sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
|
||||
int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
|
||||
int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
|
||||
TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
|
||||
s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
|
||||
}
|
||||
else if (readSigns)
|
||||
sw_active = s.size(3) << 2;
|
||||
|
||||
// Validate sign tensor if in use.
|
||||
if (readSigns || writeSigns)
|
||||
{
|
||||
TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
|
||||
TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
|
||||
TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
|
||||
TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
|
||||
TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
|
||||
TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
|
||||
}
|
||||
|
||||
// Populate rest of CUDA kernel parameters.
|
||||
p.x = x.data_ptr();
|
||||
p.y = y.data_ptr();
|
||||
p.b = b.data_ptr();
|
||||
p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
|
||||
p.fu = fu.data_ptr<float>();
|
||||
p.fd = fd.data_ptr<float>();
|
||||
p.pad0 = make_int2(px0, py0);
|
||||
p.gain = gain;
|
||||
p.slope = slope;
|
||||
p.clamp = clamp;
|
||||
p.flip = (flip_filters) ? 1 : 0;
|
||||
p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
||||
p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
|
||||
p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
|
||||
p.sOfs = make_int2(sx, sy);
|
||||
p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
|
||||
|
||||
// x, y, b strides are in bytes.
|
||||
p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
|
||||
p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
|
||||
p.bStride = sz * b.stride(0);
|
||||
|
||||
// fu, fd strides are in elements.
|
||||
p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
|
||||
p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
|
||||
|
||||
// Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
|
||||
bool index64b = false;
|
||||
if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
|
||||
if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
|
||||
if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
|
||||
if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
|
||||
if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
|
||||
if (s.numel() > INT_MAX) index64b = true;
|
||||
|
||||
// Choose CUDA kernel.
|
||||
filtered_lrelu_kernel_spec spec = { 0 };
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
|
||||
{
|
||||
if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
|
||||
{
|
||||
// Choose kernel based on index type, datatype and sign read/write modes.
|
||||
if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, true, false>(p, sharedKB);
|
||||
else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, true >(p, sharedKB);
|
||||
else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, false>(p, sharedKB);
|
||||
else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, true, false>(p, sharedKB);
|
||||
else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, true >(p, sharedKB);
|
||||
else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, false>(p, sharedKB);
|
||||
}
|
||||
});
|
||||
TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
|
||||
|
||||
// Launch CUDA kernel.
|
||||
void* args[] = {&p};
|
||||
int bx = spec.numWarps * 32;
|
||||
int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
|
||||
int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
|
||||
int gz = p.yShape.z * p.yShape.w;
|
||||
|
||||
// Repeat multiple horizontal tiles in a CTA?
|
||||
if (spec.xrep)
|
||||
{
|
||||
p.tilesXrep = spec.xrep;
|
||||
p.tilesXdim = gx;
|
||||
|
||||
gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
|
||||
std::swap(gx, gy);
|
||||
}
|
||||
else
|
||||
{
|
||||
p.tilesXrep = 0;
|
||||
p.tilesXdim = 0;
|
||||
}
|
||||
|
||||
// Launch filter setup kernel.
|
||||
AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
|
||||
|
||||
// Copy kernels to constant memory.
|
||||
if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<true, false>(at::cuda::getCurrentCUDAStream())));
|
||||
else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters<false, true >(at::cuda::getCurrentCUDAStream())));
|
||||
else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<false, false>(at::cuda::getCurrentCUDAStream())));
|
||||
|
||||
// Set cache and shared memory configurations for main kernel.
|
||||
AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
|
||||
if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
|
||||
AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
|
||||
AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
|
||||
|
||||
// Launch main kernel.
|
||||
const int maxSubGz = 65535; // CUDA maximum for block z dimension.
|
||||
for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
|
||||
{
|
||||
p.blockZofs = zofs;
|
||||
int subGz = std::min(maxSubGz, gz - zofs);
|
||||
AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
|
||||
}
|
||||
|
||||
// Done.
|
||||
return std::make_tuple(y, so, 0);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
|
||||
{
|
||||
// Set CUDA device.
|
||||
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
|
||||
// Validate arguments.
|
||||
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
||||
TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
|
||||
TORCH_CHECK(x.numel() > 0, "x is empty");
|
||||
TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
|
||||
|
||||
// Output signs if we don't have sign input.
|
||||
torch::Tensor so;
|
||||
torch::Tensor s = si;
|
||||
bool readSigns = !!s.numel();
|
||||
if (writeSigns)
|
||||
{
|
||||
int64_t sw = x.size(3);
|
||||
sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
|
||||
s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
|
||||
}
|
||||
|
||||
// Validate sign tensor if in use.
|
||||
if (readSigns || writeSigns)
|
||||
{
|
||||
TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
|
||||
TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
|
||||
TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
|
||||
TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
|
||||
TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
|
||||
TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
|
||||
}
|
||||
|
||||
// Initialize CUDA kernel parameters.
|
||||
filtered_lrelu_act_kernel_params p;
|
||||
p.x = x.data_ptr();
|
||||
p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
|
||||
p.gain = gain;
|
||||
p.slope = slope;
|
||||
p.clamp = clamp;
|
||||
p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
||||
p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
|
||||
p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
|
||||
p.sOfs = make_int2(sx, sy);
|
||||
|
||||
// Choose CUDA kernel.
|
||||
void* func = 0;
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
|
||||
{
|
||||
if (writeSigns)
|
||||
func = choose_filtered_lrelu_act_kernel<scalar_t, true, false>();
|
||||
else if (readSigns)
|
||||
func = choose_filtered_lrelu_act_kernel<scalar_t, false, true>();
|
||||
else
|
||||
func = choose_filtered_lrelu_act_kernel<scalar_t, false, false>();
|
||||
});
|
||||
TORCH_CHECK(func, "internal error - CUDA kernel not found");
|
||||
|
||||
// Launch CUDA kernel.
|
||||
void* args[] = {&p};
|
||||
int bx = 128; // 4 warps per block.
|
||||
|
||||
// Logical size of launch = writeSigns ? p.s : p.x
|
||||
uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
|
||||
uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
|
||||
uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
|
||||
gx = (gx - 1) / bx + 1;
|
||||
|
||||
// Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
|
||||
const uint32_t gmax = 65535;
|
||||
gy = std::min(gy, gmax);
|
||||
gz = std::min(gz, gmax);
|
||||
|
||||
// Launch.
|
||||
AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
|
||||
return so;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
|
||||
m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,90 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel parameters.
|
||||
|
||||
struct filtered_lrelu_kernel_params
|
||||
{
|
||||
// These parameters decide which kernel to use.
|
||||
int up; // upsampling ratio (1, 2, 4)
|
||||
int down; // downsampling ratio (1, 2, 4)
|
||||
int2 fuShape; // [size, 1] | [size, size]
|
||||
int2 fdShape; // [size, 1] | [size, size]
|
||||
|
||||
int _dummy; // Alignment.
|
||||
|
||||
// Rest of the parameters.
|
||||
const void* x; // Input tensor.
|
||||
void* y; // Output tensor.
|
||||
const void* b; // Bias tensor.
|
||||
unsigned char* s; // Sign tensor in/out. NULL if unused.
|
||||
const float* fu; // Upsampling filter.
|
||||
const float* fd; // Downsampling filter.
|
||||
|
||||
int2 pad0; // Left/top padding.
|
||||
float gain; // Additional gain factor.
|
||||
float slope; // Leaky ReLU slope on negative side.
|
||||
float clamp; // Clamp after nonlinearity.
|
||||
int flip; // Filter kernel flip for gradient computation.
|
||||
|
||||
int tilesXdim; // Original number of horizontal output tiles.
|
||||
int tilesXrep; // Number of horizontal tiles per CTA.
|
||||
int blockZofs; // Block z offset to support large minibatch, channel dimensions.
|
||||
|
||||
int4 xShape; // [width, height, channel, batch]
|
||||
int4 yShape; // [width, height, channel, batch]
|
||||
int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
|
||||
int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
|
||||
int swLimit; // Active width of sign tensor in bytes.
|
||||
|
||||
longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
|
||||
longlong4 yStride; //
|
||||
int64_t bStride; //
|
||||
longlong3 fuStride; //
|
||||
longlong3 fdStride; //
|
||||
};
|
||||
|
||||
struct filtered_lrelu_act_kernel_params
|
||||
{
|
||||
void* x; // Input/output, modified in-place.
|
||||
unsigned char* s; // Sign tensor in/out. NULL if unused.
|
||||
|
||||
float gain; // Additional gain factor.
|
||||
float slope; // Leaky ReLU slope on negative side.
|
||||
float clamp; // Clamp after nonlinearity.
|
||||
|
||||
int4 xShape; // [width, height, channel, batch]
|
||||
longlong4 xStride; // Input/output tensor strides, same order as in shape.
|
||||
int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
|
||||
int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel specialization.
|
||||
|
||||
struct filtered_lrelu_kernel_spec
|
||||
{
|
||||
void* setup; // Function for filter kernel setup.
|
||||
void* exec; // Function for main operation.
|
||||
int2 tileOut; // Width/height of launch tile.
|
||||
int numWarps; // Number of warps per thread block, determines launch block size.
|
||||
int xrep; // For processing multiple horizontal tiles per thread block.
|
||||
int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void);
|
||||
template <bool signWrite, bool signRead> cudaError_t copy_filters(cudaStream_t stream);
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
@@ -0,0 +1,275 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
from .. import custom_ops
|
||||
from .. import misc
|
||||
from . import upfirdn2d
|
||||
from . import bias_act
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_plugin = None
|
||||
|
||||
def _init():
|
||||
global _plugin
|
||||
if _plugin is None:
|
||||
_plugin = custom_ops.get_plugin(
|
||||
module_name='filtered_lrelu_plugin',
|
||||
sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'],
|
||||
headers=['filtered_lrelu.h', 'filtered_lrelu.cu'],
|
||||
source_dir=os.path.dirname(__file__),
|
||||
extra_cuda_cflags=['--use_fast_math'],
|
||||
)
|
||||
return True
|
||||
|
||||
def _get_filter_size(f):
|
||||
if f is None:
|
||||
return 1, 1
|
||||
assert isinstance(f, torch.Tensor)
|
||||
assert 1 <= f.ndim <= 2
|
||||
return f.shape[-1], f.shape[0] # width, height
|
||||
|
||||
def _parse_padding(padding):
|
||||
if isinstance(padding, int):
|
||||
padding = [padding, padding]
|
||||
assert isinstance(padding, (list, tuple))
|
||||
assert all(isinstance(x, (int, np.integer)) for x in padding)
|
||||
padding = [int(x) for x in padding]
|
||||
if len(padding) == 2:
|
||||
px, py = padding
|
||||
padding = [px, px, py, py]
|
||||
px0, px1, py0, py1 = padding
|
||||
return px0, px1, py0, py1
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'):
|
||||
r"""Filtered leaky ReLU for a batch of 2D images.
|
||||
|
||||
Performs the following sequence of operations for each channel:
|
||||
|
||||
1. Add channel-specific bias if provided (`b`).
|
||||
|
||||
2. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
||||
|
||||
3. Pad the image with the specified number of zeros on each side (`padding`).
|
||||
Negative padding corresponds to cropping the image.
|
||||
|
||||
4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it
|
||||
so that the footprint of all output pixels lies within the input image.
|
||||
|
||||
5. Multiply each value by the provided gain factor (`gain`).
|
||||
|
||||
6. Apply leaky ReLU activation function to each value.
|
||||
|
||||
7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided.
|
||||
|
||||
8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking
|
||||
it so that the footprint of all output pixels lies within the input image.
|
||||
|
||||
9. Downsample the image by keeping every Nth pixel (`down`).
|
||||
|
||||
The fused op is considerably more efficient than performing the same calculation
|
||||
using standard PyTorch ops. It supports gradients of arbitrary order.
|
||||
|
||||
Args:
|
||||
x: Float32/float16/float64 input tensor of the shape
|
||||
`[batch_size, num_channels, in_height, in_width]`.
|
||||
fu: Float32 upsampling FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
fd: Float32 downsampling FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
||||
as `x`. The length of vector must must match the channel dimension of `x`.
|
||||
up: Integer upsampling factor (default: 1).
|
||||
down: Integer downsampling factor. (default: 1).
|
||||
padding: Padding with respect to the upsampled image. Can be a single number
|
||||
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
gain: Overall scaling factor for signal magnitude (default: sqrt(2)).
|
||||
slope: Slope on the negative side of leaky ReLU (default: 0.2).
|
||||
clamp: Maximum magnitude for leaky ReLU output (default: None).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
assert isinstance(x, torch.Tensor)
|
||||
assert impl in ['ref', 'cuda']
|
||||
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
||||
return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)
|
||||
return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
|
||||
"""Slow and memory-inefficient reference implementation of `filtered_lrelu()` using
|
||||
existing `upfirdn2n()` and `bias_act()` ops.
|
||||
"""
|
||||
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
||||
fu_w, fu_h = _get_filter_size(fu)
|
||||
fd_w, fd_h = _get_filter_size(fd)
|
||||
if b is not None:
|
||||
assert isinstance(b, torch.Tensor) and b.dtype == x.dtype
|
||||
misc.assert_shape(b, [x.shape[1]])
|
||||
assert isinstance(up, int) and up >= 1
|
||||
assert isinstance(down, int) and down >= 1
|
||||
px0, px1, py0, py1 = _parse_padding(padding)
|
||||
assert gain == float(gain) and gain > 0
|
||||
assert slope == float(slope) and slope >= 0
|
||||
assert clamp is None or (clamp == float(clamp) and clamp >= 0)
|
||||
|
||||
# Calculate output size.
|
||||
batch_size, channels, in_h, in_w = x.shape
|
||||
in_dtype = x.dtype
|
||||
out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down
|
||||
out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down
|
||||
|
||||
# Compute using existing ops.
|
||||
x = bias_act.bias_act(x=x, b=b) # Apply bias.
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
|
||||
x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.
|
||||
x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample.
|
||||
|
||||
# Check output shape & dtype.
|
||||
misc.assert_shape(x, [batch_size, channels, out_h, out_w])
|
||||
assert x.dtype == in_dtype
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_filtered_lrelu_cuda_cache = dict()
|
||||
|
||||
def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
|
||||
"""Fast CUDA implementation of `filtered_lrelu()` using custom ops.
|
||||
"""
|
||||
assert isinstance(up, int) and up >= 1
|
||||
assert isinstance(down, int) and down >= 1
|
||||
px0, px1, py0, py1 = _parse_padding(padding)
|
||||
assert gain == float(gain) and gain > 0
|
||||
gain = float(gain)
|
||||
assert slope == float(slope) and slope >= 0
|
||||
slope = float(slope)
|
||||
assert clamp is None or (clamp == float(clamp) and clamp >= 0)
|
||||
clamp = float(clamp if clamp is not None else 'inf')
|
||||
|
||||
# Lookup from cache.
|
||||
key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)
|
||||
if key in _filtered_lrelu_cuda_cache:
|
||||
return _filtered_lrelu_cuda_cache[key]
|
||||
|
||||
# Forward op.
|
||||
class FilteredLReluCuda(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ
|
||||
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
||||
|
||||
# Replace empty up/downsample kernels with full 1x1 kernels (faster than separable).
|
||||
if fu is None:
|
||||
fu = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
||||
if fd is None:
|
||||
fd = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
||||
assert 1 <= fu.ndim <= 2
|
||||
assert 1 <= fd.ndim <= 2
|
||||
|
||||
# Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1.
|
||||
if up == 1 and fu.ndim == 1 and fu.shape[0] == 1:
|
||||
fu = fu.square()[None]
|
||||
if down == 1 and fd.ndim == 1 and fd.shape[0] == 1:
|
||||
fd = fd.square()[None]
|
||||
|
||||
# Missing sign input tensor.
|
||||
if si is None:
|
||||
si = torch.empty([0])
|
||||
|
||||
# Missing bias tensor.
|
||||
if b is None:
|
||||
b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device)
|
||||
|
||||
# Construct internal sign tensor only if gradients are needed.
|
||||
write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad)
|
||||
|
||||
# Warn if input storage strides are not in decreasing order due to e.g. channels-last layout.
|
||||
x = x.contiguous()
|
||||
strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1]
|
||||
if any(a < b for a, b in zip(strides[:-1], strides[1:])):
|
||||
warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning)
|
||||
|
||||
# Call C++/Cuda plugin if datatype is supported.
|
||||
if x.dtype in [torch.float16, torch.float32]:
|
||||
if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device):
|
||||
warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning)
|
||||
y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs)
|
||||
else:
|
||||
return_code = -1
|
||||
|
||||
# No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because
|
||||
# only the bit-packed sign tensor is retained for gradient computation.
|
||||
if return_code < 0:
|
||||
warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning)
|
||||
|
||||
y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias.
|
||||
y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
|
||||
so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place.
|
||||
y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample.
|
||||
|
||||
# Prepare for gradient computation.
|
||||
ctx.save_for_backward(fu, fd, (si if si.numel() else so))
|
||||
ctx.x_shape = x.shape
|
||||
ctx.y_shape = y.shape
|
||||
ctx.s_ofs = sx, sy
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy): # pylint: disable=arguments-differ
|
||||
fu, fd, si = ctx.saved_tensors
|
||||
_, _, xh, xw = ctx.x_shape
|
||||
_, _, yh, yw = ctx.y_shape
|
||||
sx, sy = ctx.s_ofs
|
||||
dx = None # 0
|
||||
dfu = None; assert not ctx.needs_input_grad[1]
|
||||
dfd = None; assert not ctx.needs_input_grad[2]
|
||||
db = None # 3
|
||||
dsi = None; assert not ctx.needs_input_grad[4]
|
||||
dsx = None; assert not ctx.needs_input_grad[5]
|
||||
dsy = None; assert not ctx.needs_input_grad[6]
|
||||
|
||||
if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:
|
||||
pp = [
|
||||
(fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0,
|
||||
xw * up - yw * down + px0 - (up - 1),
|
||||
(fu.shape[0] - 1) + (fd.shape[0] - 1) - py0,
|
||||
xh * up - yh * down + py0 - (up - 1),
|
||||
]
|
||||
gg = gain * (up ** 2) / (down ** 2)
|
||||
ff = (not flip_filter)
|
||||
sx = sx - (fu.shape[-1] - 1) + px0
|
||||
sy = sy - (fu.shape[0] - 1) + py0
|
||||
dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy)
|
||||
|
||||
if ctx.needs_input_grad[3]:
|
||||
db = dx.sum([0, 2, 3])
|
||||
|
||||
return dx, dfu, dfd, db, dsi, dsx, dsy
|
||||
|
||||
# Add to cache.
|
||||
_filtered_lrelu_cuda_cache[key] = FilteredLReluCuda
|
||||
return FilteredLReluCuda
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,27 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include "filtered_lrelu.cu"
|
||||
|
||||
// Template/kernel specializations for no signs mode (no gradients required).
|
||||
|
||||
// Full op, 32-bit indexing.
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
|
||||
// Full op, 64-bit indexing.
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
|
||||
// Activation/signs only for generic variant. 64-bit indexing.
|
||||
template void* choose_filtered_lrelu_act_kernel<c10::Half, false, false>(void);
|
||||
template void* choose_filtered_lrelu_act_kernel<float, false, false>(void);
|
||||
template void* choose_filtered_lrelu_act_kernel<double, false, false>(void);
|
||||
|
||||
// Copy filters to constant memory.
|
||||
template cudaError_t copy_filters<false, false>(cudaStream_t stream);
|
||||
@@ -0,0 +1,27 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include "filtered_lrelu.cu"
|
||||
|
||||
// Template/kernel specializations for sign read mode.
|
||||
|
||||
// Full op, 32-bit indexing.
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
|
||||
// Full op, 64-bit indexing.
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
|
||||
// Activation/signs only for generic variant. 64-bit indexing.
|
||||
template void* choose_filtered_lrelu_act_kernel<c10::Half, false, true>(void);
|
||||
template void* choose_filtered_lrelu_act_kernel<float, false, true>(void);
|
||||
template void* choose_filtered_lrelu_act_kernel<double, false, true>(void);
|
||||
|
||||
// Copy filters to constant memory.
|
||||
template cudaError_t copy_filters<false, true>(cudaStream_t stream);
|
||||
@@ -0,0 +1,27 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include "filtered_lrelu.cu"
|
||||
|
||||
// Template/kernel specializations for sign write mode.
|
||||
|
||||
// Full op, 32-bit indexing.
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
|
||||
// Full op, 64-bit indexing.
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
||||
|
||||
// Activation/signs only for generic variant. 64-bit indexing.
|
||||
template void* choose_filtered_lrelu_act_kernel<c10::Half, true, false>(void);
|
||||
template void* choose_filtered_lrelu_act_kernel<float, true, false>(void);
|
||||
template void* choose_filtered_lrelu_act_kernel<double, true, false>(void);
|
||||
|
||||
// Copy filters to constant memory.
|
||||
template cudaError_t copy_filters<true, false>(cudaStream_t stream);
|
||||
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
|
||||
|
||||
import torch
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def fma(a, b, c): # => a * b + c
|
||||
return _FusedMultiplyAdd.apply(a, b, c)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
|
||||
@staticmethod
|
||||
def forward(ctx, a, b, c): # pylint: disable=arguments-differ
|
||||
out = torch.addcmul(c, a, b)
|
||||
ctx.save_for_backward(a, b)
|
||||
ctx.c_shape = c.shape
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dout): # pylint: disable=arguments-differ
|
||||
a, b = ctx.saved_tensors
|
||||
c_shape = ctx.c_shape
|
||||
da = None
|
||||
db = None
|
||||
dc = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
da = _unbroadcast(dout * b, a.shape)
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
db = _unbroadcast(dout * a, b.shape)
|
||||
|
||||
if ctx.needs_input_grad[2]:
|
||||
dc = _unbroadcast(dout, c_shape)
|
||||
|
||||
return da, db, dc
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _unbroadcast(x, shape):
|
||||
extra_dims = x.ndim - len(shape)
|
||||
assert extra_dims >= 0
|
||||
dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
|
||||
if len(dim):
|
||||
x = x.sum(dim=dim, keepdim=True)
|
||||
if extra_dims:
|
||||
x = x.reshape(-1, *x.shape[extra_dims+1:])
|
||||
assert x.shape == shape
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Custom replacement for `torch.nn.functional.grid_sample` that
|
||||
supports arbitrarily high order gradients between the input and output.
|
||||
Only works on 2D images and assumes
|
||||
`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
|
||||
|
||||
import torch
|
||||
|
||||
# pylint: disable=redefined-builtin
|
||||
# pylint: disable=arguments-differ
|
||||
# pylint: disable=protected-access
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
enabled = False # Enable the custom op by setting this to true.
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def grid_sample(input, grid):
|
||||
if _should_use_custom_op():
|
||||
return _GridSample2dForward.apply(input, grid)
|
||||
return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _should_use_custom_op():
|
||||
return enabled
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class _GridSample2dForward(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, grid):
|
||||
assert input.ndim == 4
|
||||
assert grid.ndim == 4
|
||||
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
||||
ctx.save_for_backward(input, grid)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, grid = ctx.saved_tensors
|
||||
grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
|
||||
return grad_input, grad_grid
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class _GridSample2dBackward(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, input, grid):
|
||||
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
|
||||
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
|
||||
ctx.save_for_backward(grid)
|
||||
return grad_input, grad_grid
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad2_grad_input, grad2_grad_grid):
|
||||
_ = grad2_grad_grid # unused
|
||||
grid, = ctx.saved_tensors
|
||||
grad2_grad_output = None
|
||||
grad2_input = None
|
||||
grad2_grid = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
|
||||
|
||||
assert not ctx.needs_input_grad[2]
|
||||
return grad2_grad_output, grad2_input, grad2_grid
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,107 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "upfirdn2d.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
|
||||
{
|
||||
// Validate arguments.
|
||||
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
||||
TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
|
||||
TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
|
||||
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
||||
TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
|
||||
TORCH_CHECK(x.numel() > 0, "x has zero size");
|
||||
TORCH_CHECK(f.numel() > 0, "f has zero size");
|
||||
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
||||
TORCH_CHECK(f.dim() == 2, "f must be rank 2");
|
||||
TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
|
||||
TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
|
||||
TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
|
||||
TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
|
||||
|
||||
// Create output tensor.
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
|
||||
int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
|
||||
TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
|
||||
torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
|
||||
TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
|
||||
TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
|
||||
|
||||
// Initialize CUDA kernel parameters.
|
||||
upfirdn2d_kernel_params p;
|
||||
p.x = x.data_ptr();
|
||||
p.f = f.data_ptr<float>();
|
||||
p.y = y.data_ptr();
|
||||
p.up = make_int2(upx, upy);
|
||||
p.down = make_int2(downx, downy);
|
||||
p.pad0 = make_int2(padx0, pady0);
|
||||
p.flip = (flip) ? 1 : 0;
|
||||
p.gain = gain;
|
||||
p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
||||
p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
|
||||
p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
|
||||
p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
|
||||
p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
|
||||
p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
|
||||
p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
|
||||
p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
|
||||
|
||||
// Choose CUDA kernel.
|
||||
upfirdn2d_kernel_spec spec;
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
||||
{
|
||||
spec = choose_upfirdn2d_kernel<scalar_t>(p);
|
||||
});
|
||||
|
||||
// Set looping options.
|
||||
p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
|
||||
p.loopMinor = spec.loopMinor;
|
||||
p.loopX = spec.loopX;
|
||||
p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
|
||||
p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
|
||||
|
||||
// Compute grid size.
|
||||
dim3 blockSize, gridSize;
|
||||
if (spec.tileOutW < 0) // large
|
||||
{
|
||||
blockSize = dim3(4, 32, 1);
|
||||
gridSize = dim3(
|
||||
((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
|
||||
(p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
|
||||
p.launchMajor);
|
||||
}
|
||||
else // small
|
||||
{
|
||||
blockSize = dim3(256, 1, 1);
|
||||
gridSize = dim3(
|
||||
((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
|
||||
(p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
|
||||
p.launchMajor);
|
||||
}
|
||||
|
||||
// Launch CUDA kernel.
|
||||
void* args[] = {&p};
|
||||
AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
||||
return y;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("upfirdn2d", &upfirdn2d);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
@@ -0,0 +1,384 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <c10/util/Half.h>
|
||||
#include "upfirdn2d.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Helpers.
|
||||
|
||||
template <class T> struct InternalType;
|
||||
template <> struct InternalType<double> { typedef double scalar_t; };
|
||||
template <> struct InternalType<float> { typedef float scalar_t; };
|
||||
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
||||
|
||||
static __device__ __forceinline__ int floor_div(int a, int b)
|
||||
{
|
||||
int t = 1 - a / b;
|
||||
return (a + t * b) / b - t;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Generic CUDA implementation for large filters.
|
||||
|
||||
template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
|
||||
{
|
||||
typedef typename InternalType<T>::scalar_t scalar_t;
|
||||
|
||||
// Calculate thread index.
|
||||
int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int outY = minorBase / p.launchMinor;
|
||||
minorBase -= outY * p.launchMinor;
|
||||
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
|
||||
int majorBase = blockIdx.z * p.loopMajor;
|
||||
if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
|
||||
return;
|
||||
|
||||
// Setup Y receptive field.
|
||||
int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
|
||||
int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
|
||||
int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
|
||||
int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
|
||||
if (p.flip)
|
||||
filterY = p.filterSize.y - 1 - filterY;
|
||||
|
||||
// Loop over major, minor, and X.
|
||||
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
||||
for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
|
||||
{
|
||||
int nc = major * p.sizeMinor + minor;
|
||||
int n = nc / p.inSize.z;
|
||||
int c = nc - n * p.inSize.z;
|
||||
for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
|
||||
{
|
||||
// Setup X receptive field.
|
||||
int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
|
||||
int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
|
||||
int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
|
||||
int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
|
||||
if (p.flip)
|
||||
filterX = p.filterSize.x - 1 - filterX;
|
||||
|
||||
// Initialize pointers.
|
||||
const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
||||
const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
|
||||
int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
|
||||
int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
|
||||
|
||||
// Inner loop.
|
||||
scalar_t v = 0;
|
||||
for (int y = 0; y < h; y++)
|
||||
{
|
||||
for (int x = 0; x < w; x++)
|
||||
{
|
||||
v += (scalar_t)(*xp) * (scalar_t)(*fp);
|
||||
xp += p.inStride.x;
|
||||
fp += filterStepX;
|
||||
}
|
||||
xp += p.inStride.y - w * p.inStride.x;
|
||||
fp += filterStepY - w * filterStepX;
|
||||
}
|
||||
|
||||
// Store result.
|
||||
v *= p.gain;
|
||||
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Specialized CUDA implementation for small filters.
|
||||
|
||||
template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
|
||||
static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
|
||||
{
|
||||
typedef typename InternalType<T>::scalar_t scalar_t;
|
||||
const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
|
||||
const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
|
||||
__shared__ volatile scalar_t sf[filterH][filterW];
|
||||
__shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
|
||||
|
||||
// Calculate tile index.
|
||||
int minorBase = blockIdx.x;
|
||||
int tileOutY = minorBase / p.launchMinor;
|
||||
minorBase -= tileOutY * p.launchMinor;
|
||||
minorBase *= loopMinor;
|
||||
tileOutY *= tileOutH;
|
||||
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
|
||||
int majorBase = blockIdx.z * p.loopMajor;
|
||||
if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
|
||||
return;
|
||||
|
||||
// Load filter (flipped).
|
||||
for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
|
||||
{
|
||||
int fy = tapIdx / filterW;
|
||||
int fx = tapIdx - fy * filterW;
|
||||
scalar_t v = 0;
|
||||
if (fx < p.filterSize.x & fy < p.filterSize.y)
|
||||
{
|
||||
int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
|
||||
int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
|
||||
v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
|
||||
}
|
||||
sf[fy][fx] = v;
|
||||
}
|
||||
|
||||
// Loop over major and X.
|
||||
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
||||
{
|
||||
int baseNC = major * p.sizeMinor + minorBase;
|
||||
int n = baseNC / p.inSize.z;
|
||||
int baseC = baseNC - n * p.inSize.z;
|
||||
for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
|
||||
{
|
||||
// Load input pixels.
|
||||
int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
|
||||
int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
|
||||
int tileInX = floor_div(tileMidX, upx);
|
||||
int tileInY = floor_div(tileMidY, upy);
|
||||
__syncthreads();
|
||||
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
|
||||
{
|
||||
int relC = inIdx;
|
||||
int relInX = relC / loopMinor;
|
||||
int relInY = relInX / tileInW;
|
||||
relC -= relInX * loopMinor;
|
||||
relInX -= relInY * tileInW;
|
||||
int c = baseC + relC;
|
||||
int inX = tileInX + relInX;
|
||||
int inY = tileInY + relInY;
|
||||
scalar_t v = 0;
|
||||
if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
|
||||
v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
||||
sx[relInY][relInX][relC] = v;
|
||||
}
|
||||
|
||||
// Loop over output pixels.
|
||||
__syncthreads();
|
||||
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
|
||||
{
|
||||
int relC = outIdx;
|
||||
int relOutX = relC / loopMinor;
|
||||
int relOutY = relOutX / tileOutW;
|
||||
relC -= relOutX * loopMinor;
|
||||
relOutX -= relOutY * tileOutW;
|
||||
int c = baseC + relC;
|
||||
int outX = tileOutX + relOutX;
|
||||
int outY = tileOutY + relOutY;
|
||||
|
||||
// Setup receptive field.
|
||||
int midX = tileMidX + relOutX * downx;
|
||||
int midY = tileMidY + relOutY * downy;
|
||||
int inX = floor_div(midX, upx);
|
||||
int inY = floor_div(midY, upy);
|
||||
int relInX = inX - tileInX;
|
||||
int relInY = inY - tileInY;
|
||||
int filterX = (inX + 1) * upx - midX - 1; // flipped
|
||||
int filterY = (inY + 1) * upy - midY - 1; // flipped
|
||||
|
||||
// Inner loop.
|
||||
if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
|
||||
{
|
||||
scalar_t v = 0;
|
||||
#pragma unroll
|
||||
for (int y = 0; y < filterH / upy; y++)
|
||||
#pragma unroll
|
||||
for (int x = 0; x < filterW / upx; x++)
|
||||
v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
|
||||
v *= p.gain;
|
||||
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
|
||||
{
|
||||
int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
|
||||
upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
|
||||
if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
|
||||
|
||||
// No up/downsampling.
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
|
||||
if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
||||
if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
|
||||
if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
||||
if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
||||
}
|
||||
|
||||
// 2x upsampling.
|
||||
if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
|
||||
if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
||||
if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
|
||||
if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
||||
if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
|
||||
}
|
||||
if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
||||
if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
||||
if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
||||
}
|
||||
if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
||||
}
|
||||
|
||||
// 2x downsampling.
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 32,16,1>, 32,16,1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 32,16,1>, 32,16,1, 1};
|
||||
if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
|
||||
if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
|
||||
if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
|
||||
if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 16,16,1>, 16,16,1, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 16,16,1>, 16,16,1, 1};
|
||||
if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
|
||||
if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
|
||||
if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
|
||||
if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
|
||||
}
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
|
||||
if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
|
||||
if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
|
||||
if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
|
||||
if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
|
||||
}
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
|
||||
}
|
||||
|
||||
// 4x upsampling.
|
||||
if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 64,32,1>, 64,32,1, 1};
|
||||
if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 64,32,1>, 64,32,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 32,32,1>, 32,32,1, 1};
|
||||
if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 32,32,1>, 32,32,1, 1};
|
||||
}
|
||||
if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,8,1>, 128,8,1, 1};
|
||||
if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,8,1>, 128,8,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,1,16>, 128,1,16, 1};
|
||||
if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,1,16>, 128,1,16, 1};
|
||||
}
|
||||
if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 32,32,1>, 32,32,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 32,32,1>, 32,32,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 1,128,16>, 1,128,16, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 1,128,16>, 1,128,16, 1};
|
||||
}
|
||||
|
||||
// 4x downsampling (inefficient).
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,8,1>, 32,8,1, 1};
|
||||
if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,8,1>, 32,8,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,1,8>, 32,1,8, 1};
|
||||
if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,1,8>, 32,1,8, 1};
|
||||
}
|
||||
if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4)
|
||||
{
|
||||
// contiguous
|
||||
if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 32,8,1>, 32,8,1, 1};
|
||||
if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 32,8,1>, 32,8,1, 1};
|
||||
// channels_last
|
||||
if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 1,32,8>, 1,32,8, 1};
|
||||
if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 1,32,8>, 1,32,8, 1};
|
||||
}
|
||||
return spec;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Template specializations.
|
||||
|
||||
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
|
||||
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
|
||||
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
@@ -0,0 +1,59 @@
|
||||
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
//
|
||||
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
// and proprietary rights in and to this software, related documentation
|
||||
// and any modifications thereto. Any use, reproduction, disclosure or
|
||||
// distribution of this software and related documentation without an express
|
||||
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel parameters.
|
||||
|
||||
struct upfirdn2d_kernel_params
|
||||
{
|
||||
const void* x;
|
||||
const float* f;
|
||||
void* y;
|
||||
|
||||
int2 up;
|
||||
int2 down;
|
||||
int2 pad0;
|
||||
int flip;
|
||||
float gain;
|
||||
|
||||
int4 inSize; // [width, height, channel, batch]
|
||||
int4 inStride;
|
||||
int2 filterSize; // [width, height]
|
||||
int2 filterStride;
|
||||
int4 outSize; // [width, height, channel, batch]
|
||||
int4 outStride;
|
||||
int sizeMinor;
|
||||
int sizeMajor;
|
||||
|
||||
int loopMinor;
|
||||
int loopMajor;
|
||||
int loopX;
|
||||
int launchMinor;
|
||||
int launchMajor;
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel specialization.
|
||||
|
||||
struct upfirdn2d_kernel_spec
|
||||
{
|
||||
void* kernel;
|
||||
int tileOutW;
|
||||
int tileOutH;
|
||||
int loopMinor;
|
||||
int loopX;
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
@@ -0,0 +1,389 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Custom PyTorch ops for efficient resampling of 2D images."""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .. import custom_ops
|
||||
from .. import misc
|
||||
from . import conv2d_gradfix
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_plugin = None
|
||||
|
||||
def _init():
|
||||
global _plugin
|
||||
if _plugin is None:
|
||||
_plugin = custom_ops.get_plugin(
|
||||
module_name='upfirdn2d_plugin',
|
||||
sources=['upfirdn2d.cpp', 'upfirdn2d.cu'],
|
||||
headers=['upfirdn2d.h'],
|
||||
source_dir=os.path.dirname(__file__),
|
||||
extra_cuda_cflags=['--use_fast_math'],
|
||||
)
|
||||
return True
|
||||
|
||||
def _parse_scaling(scaling):
|
||||
if isinstance(scaling, int):
|
||||
scaling = [scaling, scaling]
|
||||
assert isinstance(scaling, (list, tuple))
|
||||
assert all(isinstance(x, int) for x in scaling)
|
||||
sx, sy = scaling
|
||||
assert sx >= 1 and sy >= 1
|
||||
return sx, sy
|
||||
|
||||
def _parse_padding(padding):
|
||||
if isinstance(padding, int):
|
||||
padding = [padding, padding]
|
||||
assert isinstance(padding, (list, tuple))
|
||||
assert all(isinstance(x, int) for x in padding)
|
||||
if len(padding) == 2:
|
||||
padx, pady = padding
|
||||
padding = [padx, padx, pady, pady]
|
||||
padx0, padx1, pady0, pady1 = padding
|
||||
return padx0, padx1, pady0, pady1
|
||||
|
||||
def _get_filter_size(f):
|
||||
if f is None:
|
||||
return 1, 1
|
||||
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
||||
fw = f.shape[-1]
|
||||
fh = f.shape[0]
|
||||
with misc.suppress_tracer_warnings():
|
||||
fw = int(fw)
|
||||
fh = int(fh)
|
||||
misc.assert_shape(f, [fh, fw][:f.ndim])
|
||||
assert fw >= 1 and fh >= 1
|
||||
return fw, fh
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
|
||||
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
|
||||
|
||||
Args:
|
||||
f: Torch tensor, numpy array, or python list of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable),
|
||||
`[]` (impulse), or
|
||||
`None` (identity).
|
||||
device: Result device (default: cpu).
|
||||
normalize: Normalize the filter so that it retains the magnitude
|
||||
for constant input signal (DC)? (default: True).
|
||||
flip_filter: Flip the filter? (default: False).
|
||||
gain: Overall scaling factor for signal magnitude (default: 1).
|
||||
separable: Return a separable filter? (default: select automatically).
|
||||
|
||||
Returns:
|
||||
Float32 tensor of the shape
|
||||
`[filter_height, filter_width]` (non-separable) or
|
||||
`[filter_taps]` (separable).
|
||||
"""
|
||||
# Validate.
|
||||
if f is None:
|
||||
f = 1
|
||||
f = torch.as_tensor(f, dtype=torch.float32)
|
||||
assert f.ndim in [0, 1, 2]
|
||||
assert f.numel() > 0
|
||||
if f.ndim == 0:
|
||||
f = f[np.newaxis]
|
||||
|
||||
# Separable?
|
||||
if separable is None:
|
||||
separable = (f.ndim == 1 and f.numel() >= 8)
|
||||
if f.ndim == 1 and not separable:
|
||||
f = f.ger(f)
|
||||
assert f.ndim == (1 if separable else 2)
|
||||
|
||||
# Apply normalize, flip, gain, and device.
|
||||
if normalize:
|
||||
f /= f.sum()
|
||||
if flip_filter:
|
||||
f = f.flip(list(range(f.ndim)))
|
||||
f = f * (gain ** (f.ndim / 2))
|
||||
f = f.to(device=device)
|
||||
return f
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
||||
r"""Pad, upsample, filter, and downsample a batch of 2D images.
|
||||
|
||||
Performs the following sequence of operations for each channel:
|
||||
|
||||
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
||||
|
||||
2. Pad the image with the specified number of zeros on each side (`padding`).
|
||||
Negative padding corresponds to cropping the image.
|
||||
|
||||
3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
|
||||
so that the footprint of all output pixels lies within the input image.
|
||||
|
||||
4. Downsample the image by keeping every Nth pixel (`down`).
|
||||
|
||||
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
|
||||
The fused op is considerably more efficient than performing the same calculation
|
||||
using standard PyTorch ops. It supports gradients of arbitrary order.
|
||||
|
||||
Args:
|
||||
x: Float32/float64/float16 input tensor of the shape
|
||||
`[batch_size, num_channels, in_height, in_width]`.
|
||||
f: Float32 FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
up: Integer upsampling factor. Can be a single int or a list/tuple
|
||||
`[x, y]` (default: 1).
|
||||
down: Integer downsampling factor. Can be a single int or a list/tuple
|
||||
`[x, y]` (default: 1).
|
||||
padding: Padding with respect to the upsampled image. Can be a single number
|
||||
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
gain: Overall scaling factor for signal magnitude (default: 1).
|
||||
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
assert isinstance(x, torch.Tensor)
|
||||
assert impl in ['ref', 'cuda']
|
||||
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
||||
return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
|
||||
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
||||
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
|
||||
"""
|
||||
# Validate arguments.
|
||||
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
||||
if f is None:
|
||||
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
||||
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
||||
assert f.dtype == torch.float32 and not f.requires_grad
|
||||
batch_size, num_channels, in_height, in_width = x.shape
|
||||
upx, upy = _parse_scaling(up)
|
||||
downx, downy = _parse_scaling(down)
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
|
||||
# Check that upsampled buffer is not smaller than the filter.
|
||||
upW = in_width * upx + padx0 + padx1
|
||||
upH = in_height * upy + pady0 + pady1
|
||||
assert upW >= f.shape[-1] and upH >= f.shape[0]
|
||||
|
||||
# Upsample by inserting zeros.
|
||||
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
||||
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
||||
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
||||
|
||||
# Pad or crop.
|
||||
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
|
||||
x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
|
||||
|
||||
# Setup filter.
|
||||
f = f * (gain ** (f.ndim / 2))
|
||||
f = f.to(x.dtype)
|
||||
if not flip_filter:
|
||||
f = f.flip(list(range(f.ndim)))
|
||||
|
||||
# Convolve with the filter.
|
||||
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
||||
if f.ndim == 4:
|
||||
x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
|
||||
else:
|
||||
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
||||
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
||||
|
||||
# Downsample by throwing away pixels.
|
||||
x = x[:, :, ::downy, ::downx]
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_upfirdn2d_cuda_cache = dict()
|
||||
|
||||
def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
|
||||
"""Fast CUDA implementation of `upfirdn2d()` using custom ops.
|
||||
"""
|
||||
# Parse arguments.
|
||||
upx, upy = _parse_scaling(up)
|
||||
downx, downy = _parse_scaling(down)
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
|
||||
# Lookup from cache.
|
||||
key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
||||
if key in _upfirdn2d_cuda_cache:
|
||||
return _upfirdn2d_cuda_cache[key]
|
||||
|
||||
# Forward op.
|
||||
class Upfirdn2dCuda(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, f): # pylint: disable=arguments-differ
|
||||
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
||||
if f is None:
|
||||
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
||||
if f.ndim == 1 and f.shape[0] == 1:
|
||||
f = f.square().unsqueeze(0) # Convert separable-1 into full-1x1.
|
||||
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
||||
y = x
|
||||
if f.ndim == 2:
|
||||
y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
||||
else:
|
||||
y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, 1.0)
|
||||
y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, gain)
|
||||
ctx.save_for_backward(f)
|
||||
ctx.x_shape = x.shape
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy): # pylint: disable=arguments-differ
|
||||
f, = ctx.saved_tensors
|
||||
_, _, ih, iw = ctx.x_shape
|
||||
_, _, oh, ow = dy.shape
|
||||
fw, fh = _get_filter_size(f)
|
||||
p = [
|
||||
fw - padx0 - 1,
|
||||
iw * upx - ow * downx + padx0 - upx + 1,
|
||||
fh - pady0 - 1,
|
||||
ih * upy - oh * downy + pady0 - upy + 1,
|
||||
]
|
||||
dx = None
|
||||
df = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
|
||||
|
||||
assert not ctx.needs_input_grad[1]
|
||||
return dx, df
|
||||
|
||||
# Add to cache.
|
||||
_upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
|
||||
return Upfirdn2dCuda
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
||||
r"""Filter a batch of 2D images using the given 2D FIR filter.
|
||||
|
||||
By default, the result is padded so that its shape matches the input.
|
||||
User-specified padding is applied on top of that, with negative values
|
||||
indicating cropping. Pixels outside the image are assumed to be zero.
|
||||
|
||||
Args:
|
||||
x: Float32/float64/float16 input tensor of the shape
|
||||
`[batch_size, num_channels, in_height, in_width]`.
|
||||
f: Float32 FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
padding: Padding with respect to the output. Can be a single number or a
|
||||
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
gain: Overall scaling factor for signal magnitude (default: 1).
|
||||
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
fw, fh = _get_filter_size(f)
|
||||
p = [
|
||||
padx0 + fw // 2,
|
||||
padx1 + (fw - 1) // 2,
|
||||
pady0 + fh // 2,
|
||||
pady1 + (fh - 1) // 2,
|
||||
]
|
||||
return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
||||
r"""Upsample a batch of 2D images using the given 2D FIR filter.
|
||||
|
||||
By default, the result is padded so that its shape is a multiple of the input.
|
||||
User-specified padding is applied on top of that, with negative values
|
||||
indicating cropping. Pixels outside the image are assumed to be zero.
|
||||
|
||||
Args:
|
||||
x: Float32/float64/float16 input tensor of the shape
|
||||
`[batch_size, num_channels, in_height, in_width]`.
|
||||
f: Float32 FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
up: Integer upsampling factor. Can be a single int or a list/tuple
|
||||
`[x, y]` (default: 1).
|
||||
padding: Padding with respect to the output. Can be a single number or a
|
||||
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
gain: Overall scaling factor for signal magnitude (default: 1).
|
||||
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
upx, upy = _parse_scaling(up)
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
fw, fh = _get_filter_size(f)
|
||||
p = [
|
||||
padx0 + (fw + upx - 1) // 2,
|
||||
padx1 + (fw - upx) // 2,
|
||||
pady0 + (fh + upy - 1) // 2,
|
||||
pady1 + (fh - upy) // 2,
|
||||
]
|
||||
return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
||||
r"""Downsample a batch of 2D images using the given 2D FIR filter.
|
||||
|
||||
By default, the result is padded so that its shape is a fraction of the input.
|
||||
User-specified padding is applied on top of that, with negative values
|
||||
indicating cropping. Pixels outside the image are assumed to be zero.
|
||||
|
||||
Args:
|
||||
x: Float32/float64/float16 input tensor of the shape
|
||||
`[batch_size, num_channels, in_height, in_width]`.
|
||||
f: Float32 FIR filter of the shape
|
||||
`[filter_height, filter_width]` (non-separable),
|
||||
`[filter_taps]` (separable), or
|
||||
`None` (identity).
|
||||
down: Integer downsampling factor. Can be a single int or a list/tuple
|
||||
`[x, y]` (default: 1).
|
||||
padding: Padding with respect to the input. Can be a single number or a
|
||||
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
||||
(default: 0).
|
||||
flip_filter: False = convolution, True = correlation (default: False).
|
||||
gain: Overall scaling factor for signal magnitude (default: 1).
|
||||
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
||||
|
||||
Returns:
|
||||
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
||||
"""
|
||||
downx, downy = _parse_scaling(down)
|
||||
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
||||
fw, fh = _get_filter_size(f)
|
||||
p = [
|
||||
padx0 + (fw - downx + 1) // 2,
|
||||
padx1 + (fw - downx) // 2,
|
||||
pady0 + (fh - downy + 1) // 2,
|
||||
pady1 + (fh - downy) // 2,
|
||||
]
|
||||
return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,251 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Facilities for pickling Python code alongside other data.
|
||||
|
||||
The pickled code is automatically imported into a separate Python module
|
||||
during unpickling. This way, any previously exported pickles will remain
|
||||
usable even if the original code is no longer available, or if the current
|
||||
version of the code is not consistent with what was originally pickled."""
|
||||
|
||||
import sys
|
||||
import pickle
|
||||
import io
|
||||
import inspect
|
||||
import copy
|
||||
import uuid
|
||||
import types
|
||||
import dnnlib
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_version = 6 # internal version number
|
||||
_decorators = set() # {decorator_class, ...}
|
||||
_import_hooks = [] # [hook_function, ...]
|
||||
_module_to_src_dict = dict() # {module: src, ...}
|
||||
_src_to_module_dict = dict() # {src: module, ...}
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def persistent_class(orig_class):
|
||||
r"""Class decorator that extends a given class to save its source code
|
||||
when pickled.
|
||||
|
||||
Example:
|
||||
|
||||
from torch_utils import persistence
|
||||
|
||||
@persistence.persistent_class
|
||||
class MyNetwork(torch.nn.Module):
|
||||
def __init__(self, num_inputs, num_outputs):
|
||||
super().__init__()
|
||||
self.fc = MyLayer(num_inputs, num_outputs)
|
||||
...
|
||||
|
||||
@persistence.persistent_class
|
||||
class MyLayer(torch.nn.Module):
|
||||
...
|
||||
|
||||
When pickled, any instance of `MyNetwork` and `MyLayer` will save its
|
||||
source code alongside other internal state (e.g., parameters, buffers,
|
||||
and submodules). This way, any previously exported pickle will remain
|
||||
usable even if the class definitions have been modified or are no
|
||||
longer available.
|
||||
|
||||
The decorator saves the source code of the entire Python module
|
||||
containing the decorated class. It does *not* save the source code of
|
||||
any imported modules. Thus, the imported modules must be available
|
||||
during unpickling, also including `torch_utils.persistence` itself.
|
||||
|
||||
It is ok to call functions defined in the same module from the
|
||||
decorated class. However, if the decorated class depends on other
|
||||
classes defined in the same module, they must be decorated as well.
|
||||
This is illustrated in the above example in the case of `MyLayer`.
|
||||
|
||||
It is also possible to employ the decorator just-in-time before
|
||||
calling the constructor. For example:
|
||||
|
||||
cls = MyLayer
|
||||
if want_to_make_it_persistent:
|
||||
cls = persistence.persistent_class(cls)
|
||||
layer = cls(num_inputs, num_outputs)
|
||||
|
||||
As an additional feature, the decorator also keeps track of the
|
||||
arguments that were used to construct each instance of the decorated
|
||||
class. The arguments can be queried via `obj.init_args` and
|
||||
`obj.init_kwargs`, and they are automatically pickled alongside other
|
||||
object state. A typical use case is to first unpickle a previous
|
||||
instance of a persistent class, and then upgrade it to use the latest
|
||||
version of the source code:
|
||||
|
||||
with open('old_pickle.pkl', 'rb') as f:
|
||||
old_net = pickle.load(f)
|
||||
new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
|
||||
misc.copy_params_and_buffers(old_net, new_net, require_all=True)
|
||||
"""
|
||||
assert isinstance(orig_class, type)
|
||||
if is_persistent(orig_class):
|
||||
return orig_class
|
||||
|
||||
assert orig_class.__module__ in sys.modules
|
||||
orig_module = sys.modules[orig_class.__module__]
|
||||
orig_module_src = _module_to_src(orig_module)
|
||||
|
||||
class Decorator(orig_class):
|
||||
_orig_module_src = orig_module_src
|
||||
_orig_class_name = orig_class.__name__
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._init_args = copy.deepcopy(args)
|
||||
self._init_kwargs = copy.deepcopy(kwargs)
|
||||
assert orig_class.__name__ in orig_module.__dict__
|
||||
_check_pickleable(self.__reduce__())
|
||||
|
||||
@property
|
||||
def init_args(self):
|
||||
return copy.deepcopy(self._init_args)
|
||||
|
||||
@property
|
||||
def init_kwargs(self):
|
||||
return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
|
||||
|
||||
def __reduce__(self):
|
||||
fields = list(super().__reduce__())
|
||||
fields += [None] * max(3 - len(fields), 0)
|
||||
if fields[0] is not _reconstruct_persistent_obj:
|
||||
meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
|
||||
fields[0] = _reconstruct_persistent_obj # reconstruct func
|
||||
fields[1] = (meta,) # reconstruct args
|
||||
fields[2] = None # state dict
|
||||
return tuple(fields)
|
||||
|
||||
Decorator.__name__ = orig_class.__name__
|
||||
_decorators.add(Decorator)
|
||||
return Decorator
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def is_persistent(obj):
|
||||
r"""Test whether the given object or class is persistent, i.e.,
|
||||
whether it will save its source code when pickled.
|
||||
"""
|
||||
try:
|
||||
if obj in _decorators:
|
||||
return True
|
||||
except TypeError:
|
||||
pass
|
||||
return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def import_hook(hook):
|
||||
r"""Register an import hook that is called whenever a persistent object
|
||||
is being unpickled. A typical use case is to patch the pickled source
|
||||
code to avoid errors and inconsistencies when the API of some imported
|
||||
module has changed.
|
||||
|
||||
The hook should have the following signature:
|
||||
|
||||
hook(meta) -> modified meta
|
||||
|
||||
`meta` is an instance of `dnnlib.EasyDict` with the following fields:
|
||||
|
||||
type: Type of the persistent object, e.g. `'class'`.
|
||||
version: Internal version number of `torch_utils.persistence`.
|
||||
module_src Original source code of the Python module.
|
||||
class_name: Class name in the original Python module.
|
||||
state: Internal state of the object.
|
||||
|
||||
Example:
|
||||
|
||||
@persistence.import_hook
|
||||
def wreck_my_network(meta):
|
||||
if meta.class_name == 'MyNetwork':
|
||||
print('MyNetwork is being imported. I will wreck it!')
|
||||
meta.module_src = meta.module_src.replace("True", "False")
|
||||
return meta
|
||||
"""
|
||||
assert callable(hook)
|
||||
_import_hooks.append(hook)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _reconstruct_persistent_obj(meta):
|
||||
r"""Hook that is called internally by the `pickle` module to unpickle
|
||||
a persistent object.
|
||||
"""
|
||||
meta = dnnlib.EasyDict(meta)
|
||||
meta.state = dnnlib.EasyDict(meta.state)
|
||||
for hook in _import_hooks:
|
||||
meta = hook(meta)
|
||||
assert meta is not None
|
||||
|
||||
assert meta.version == _version
|
||||
module = _src_to_module(meta.module_src)
|
||||
|
||||
assert meta.type == 'class'
|
||||
orig_class = module.__dict__[meta.class_name]
|
||||
decorator_class = persistent_class(orig_class)
|
||||
obj = decorator_class.__new__(decorator_class)
|
||||
|
||||
setstate = getattr(obj, '__setstate__', None)
|
||||
if callable(setstate):
|
||||
setstate(meta.state) # pylint: disable=not-callable
|
||||
else:
|
||||
obj.__dict__.update(meta.state)
|
||||
return obj
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _module_to_src(module):
|
||||
r"""Query the source code of a given Python module.
|
||||
"""
|
||||
src = _module_to_src_dict.get(module, None)
|
||||
if src is None:
|
||||
src = inspect.getsource(module)
|
||||
_module_to_src_dict[module] = src
|
||||
_src_to_module_dict[src] = module
|
||||
return src
|
||||
|
||||
def _src_to_module(src):
|
||||
r"""Get or create a Python module for the given source code.
|
||||
"""
|
||||
module = _src_to_module_dict.get(src, None)
|
||||
if module is None:
|
||||
module_name = "_imported_module_" + uuid.uuid4().hex
|
||||
module = types.ModuleType(module_name)
|
||||
sys.modules[module_name] = module
|
||||
_module_to_src_dict[module] = src
|
||||
_src_to_module_dict[src] = module
|
||||
exec(src, module.__dict__) # pylint: disable=exec-used
|
||||
return module
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _check_pickleable(obj):
|
||||
r"""Check that the given object is pickleable, raising an exception if
|
||||
it is not. This function is expected to be considerably more efficient
|
||||
than actually pickling the object.
|
||||
"""
|
||||
def recurse(obj):
|
||||
if isinstance(obj, (list, tuple, set)):
|
||||
return [recurse(x) for x in obj]
|
||||
if isinstance(obj, dict):
|
||||
return [[recurse(x), recurse(y)] for x, y in obj.items()]
|
||||
if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
|
||||
return None # Python primitive types are pickleable.
|
||||
if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
|
||||
return None # NumPy arrays and PyTorch tensors are pickleable.
|
||||
if is_persistent(obj):
|
||||
return None # Persistent objects are pickleable, by virtue of the constructor check.
|
||||
return obj
|
||||
with io.BytesIO() as f:
|
||||
pickle.dump(recurse(obj), f)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,268 @@
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
||||
# and proprietary rights in and to this software, related documentation
|
||||
# and any modifications thereto. Any use, reproduction, disclosure or
|
||||
# distribution of this software and related documentation without an express
|
||||
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
||||
|
||||
"""Facilities for reporting and collecting training statistics across
|
||||
multiple processes and devices. The interface is designed to minimize
|
||||
synchronization overhead as well as the amount of boilerplate in user
|
||||
code."""
|
||||
|
||||
import re
|
||||
import numpy as np
|
||||
import torch
|
||||
import dnnlib
|
||||
|
||||
from . import misc
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
|
||||
_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
|
||||
_counter_dtype = torch.float64 # Data type to use for the internal counters.
|
||||
_rank = 0 # Rank of the current process.
|
||||
_sync_device = None # Device to use for multiprocess communication. None = single-process.
|
||||
_sync_called = False # Has _sync() been called yet?
|
||||
_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
|
||||
_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def init_multiprocessing(rank, sync_device):
|
||||
r"""Initializes `torch_utils.training_stats` for collecting statistics
|
||||
across multiple processes.
|
||||
|
||||
This function must be called after
|
||||
`torch.distributed.init_process_group()` and before `Collector.update()`.
|
||||
The call is not necessary if multi-process collection is not needed.
|
||||
|
||||
Args:
|
||||
rank: Rank of the current process.
|
||||
sync_device: PyTorch device to use for inter-process
|
||||
communication, or None to disable multi-process
|
||||
collection. Typically `torch.device('cuda', rank)`.
|
||||
"""
|
||||
global _rank, _sync_device
|
||||
assert not _sync_called
|
||||
_rank = rank
|
||||
_sync_device = sync_device
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def report(name, value):
|
||||
r"""Broadcasts the given set of scalars to all interested instances of
|
||||
`Collector`, across device and process boundaries.
|
||||
|
||||
This function is expected to be extremely cheap and can be safely
|
||||
called from anywhere in the training loop, loss function, or inside a
|
||||
`torch.nn.Module`.
|
||||
|
||||
Warning: The current implementation expects the set of unique names to
|
||||
be consistent across processes. Please make sure that `report()` is
|
||||
called at least once for each unique name by each process, and in the
|
||||
same order. If a given process has no scalars to broadcast, it can do
|
||||
`report(name, [])` (empty list).
|
||||
|
||||
Args:
|
||||
name: Arbitrary string specifying the name of the statistic.
|
||||
Averages are accumulated separately for each unique name.
|
||||
value: Arbitrary set of scalars. Can be a list, tuple,
|
||||
NumPy array, PyTorch tensor, or Python scalar.
|
||||
|
||||
Returns:
|
||||
The same `value` that was passed in.
|
||||
"""
|
||||
if name not in _counters:
|
||||
_counters[name] = dict()
|
||||
|
||||
elems = torch.as_tensor(value)
|
||||
if elems.numel() == 0:
|
||||
return value
|
||||
|
||||
elems = elems.detach().flatten().to(_reduce_dtype)
|
||||
moments = torch.stack([
|
||||
torch.ones_like(elems).sum(),
|
||||
elems.sum(),
|
||||
elems.square().sum(),
|
||||
])
|
||||
assert moments.ndim == 1 and moments.shape[0] == _num_moments
|
||||
moments = moments.to(_counter_dtype)
|
||||
|
||||
device = moments.device
|
||||
if device not in _counters[name]:
|
||||
_counters[name][device] = torch.zeros_like(moments)
|
||||
_counters[name][device].add_(moments)
|
||||
return value
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def report0(name, value):
|
||||
r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
|
||||
but ignores any scalars provided by the other processes.
|
||||
See `report()` for further details.
|
||||
"""
|
||||
report(name, value if _rank == 0 else [])
|
||||
return value
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class Collector:
|
||||
r"""Collects the scalars broadcasted by `report()` and `report0()` and
|
||||
computes their long-term averages (mean and standard deviation) over
|
||||
user-defined periods of time.
|
||||
|
||||
The averages are first collected into internal counters that are not
|
||||
directly visible to the user. They are then copied to the user-visible
|
||||
state as a result of calling `update()` and can then be queried using
|
||||
`mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
|
||||
internal counters for the next round, so that the user-visible state
|
||||
effectively reflects averages collected between the last two calls to
|
||||
`update()`.
|
||||
|
||||
Args:
|
||||
regex: Regular expression defining which statistics to
|
||||
collect. The default is to collect everything.
|
||||
keep_previous: Whether to retain the previous averages if no
|
||||
scalars were collected on a given round
|
||||
(default: True).
|
||||
"""
|
||||
def __init__(self, regex='.*', keep_previous=True):
|
||||
self._regex = re.compile(regex)
|
||||
self._keep_previous = keep_previous
|
||||
self._cumulative = dict()
|
||||
self._moments = dict()
|
||||
self.update()
|
||||
self._moments.clear()
|
||||
|
||||
def names(self):
|
||||
r"""Returns the names of all statistics broadcasted so far that
|
||||
match the regular expression specified at construction time.
|
||||
"""
|
||||
return [name for name in _counters if self._regex.fullmatch(name)]
|
||||
|
||||
def update(self):
|
||||
r"""Copies current values of the internal counters to the
|
||||
user-visible state and resets them for the next round.
|
||||
|
||||
If `keep_previous=True` was specified at construction time, the
|
||||
operation is skipped for statistics that have received no scalars
|
||||
since the last update, retaining their previous averages.
|
||||
|
||||
This method performs a number of GPU-to-CPU transfers and one
|
||||
`torch.distributed.all_reduce()`. It is intended to be called
|
||||
periodically in the main training loop, typically once every
|
||||
N training steps.
|
||||
"""
|
||||
if not self._keep_previous:
|
||||
self._moments.clear()
|
||||
for name, cumulative in _sync(self.names()):
|
||||
if name not in self._cumulative:
|
||||
self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
||||
delta = cumulative - self._cumulative[name]
|
||||
self._cumulative[name].copy_(cumulative)
|
||||
if float(delta[0]) != 0:
|
||||
self._moments[name] = delta
|
||||
|
||||
def _get_delta(self, name):
|
||||
r"""Returns the raw moments that were accumulated for the given
|
||||
statistic between the last two calls to `update()`, or zero if
|
||||
no scalars were collected.
|
||||
"""
|
||||
assert self._regex.fullmatch(name)
|
||||
if name not in self._moments:
|
||||
self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
||||
return self._moments[name]
|
||||
|
||||
def num(self, name):
|
||||
r"""Returns the number of scalars that were accumulated for the given
|
||||
statistic between the last two calls to `update()`, or zero if
|
||||
no scalars were collected.
|
||||
"""
|
||||
delta = self._get_delta(name)
|
||||
return int(delta[0])
|
||||
|
||||
def mean(self, name):
|
||||
r"""Returns the mean of the scalars that were accumulated for the
|
||||
given statistic between the last two calls to `update()`, or NaN if
|
||||
no scalars were collected.
|
||||
"""
|
||||
delta = self._get_delta(name)
|
||||
if int(delta[0]) == 0:
|
||||
return float('nan')
|
||||
return float(delta[1] / delta[0])
|
||||
|
||||
def std(self, name):
|
||||
r"""Returns the standard deviation of the scalars that were
|
||||
accumulated for the given statistic between the last two calls to
|
||||
`update()`, or NaN if no scalars were collected.
|
||||
"""
|
||||
delta = self._get_delta(name)
|
||||
if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
|
||||
return float('nan')
|
||||
if int(delta[0]) == 1:
|
||||
return float(0)
|
||||
mean = float(delta[1] / delta[0])
|
||||
raw_var = float(delta[2] / delta[0])
|
||||
return np.sqrt(max(raw_var - np.square(mean), 0))
|
||||
|
||||
def as_dict(self):
|
||||
r"""Returns the averages accumulated between the last two calls to
|
||||
`update()` as an `dnnlib.EasyDict`. The contents are as follows:
|
||||
|
||||
dnnlib.EasyDict(
|
||||
NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
|
||||
...
|
||||
)
|
||||
"""
|
||||
stats = dnnlib.EasyDict()
|
||||
for name in self.names():
|
||||
stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
|
||||
return stats
|
||||
|
||||
def __getitem__(self, name):
|
||||
r"""Convenience getter.
|
||||
`collector[name]` is a synonym for `collector.mean(name)`.
|
||||
"""
|
||||
return self.mean(name)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _sync(names):
|
||||
r"""Synchronize the global cumulative counters across devices and
|
||||
processes. Called internally by `Collector.update()`.
|
||||
"""
|
||||
if len(names) == 0:
|
||||
return []
|
||||
global _sync_called
|
||||
_sync_called = True
|
||||
|
||||
# Collect deltas within current rank.
|
||||
deltas = []
|
||||
device = _sync_device if _sync_device is not None else torch.device('cpu')
|
||||
for name in names:
|
||||
delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
|
||||
for counter in _counters[name].values():
|
||||
delta.add_(counter.to(device))
|
||||
counter.copy_(torch.zeros_like(counter))
|
||||
deltas.append(delta)
|
||||
deltas = torch.stack(deltas)
|
||||
|
||||
# Sum deltas across ranks.
|
||||
if _sync_device is not None:
|
||||
torch.distributed.all_reduce(deltas)
|
||||
|
||||
# Update cumulative values.
|
||||
deltas = deltas.cpu()
|
||||
for idx, name in enumerate(names):
|
||||
if name not in _cumulative:
|
||||
_cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
||||
_cumulative[name].add_(deltas[idx])
|
||||
|
||||
# Return name-value pairs.
|
||||
return [(name, _cumulative[name]) for name in names]
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
@@ -0,0 +1,155 @@
|
||||
import torch
|
||||
from torch.fft import fftn
|
||||
|
||||
|
||||
def roll_quadrants(data, backwards=False):
|
||||
"""
|
||||
Shift low frequencies to the center of fourier transform, i.e. [-N/2, ..., +N/2] -> [0, ..., N-1]
|
||||
Args:
|
||||
data: fourier transform, (NxHxW)
|
||||
backwards: bool, if True shift high frequencies back to center
|
||||
|
||||
Returns:
|
||||
Shifted fourier transform.
|
||||
"""
|
||||
dim = data.ndim - 1
|
||||
|
||||
if dim != 2:
|
||||
raise AttributeError(f'Data must be 2d but it is {dim}d.')
|
||||
if any(s % 2 == 0 for s in data.shape[1:]):
|
||||
raise RuntimeWarning('Roll quadrants for 2d input should only be used with uneven spatial sizes.')
|
||||
|
||||
# for each dimension swap left and right half
|
||||
dims = tuple(range(1, dim+1)) # add one for batch dimension
|
||||
shifts = torch.tensor(data.shape[1:]) // 2 #.div(2, rounding_mode='floor') # N/2 if N even, (N-1)/2 if N odd
|
||||
if backwards:
|
||||
shifts *= -1
|
||||
return data.roll(shifts.tolist(), dims=dims)
|
||||
|
||||
|
||||
def batch_fft(data, normalize=False):
|
||||
"""
|
||||
Compute fourier transform of batch.
|
||||
Args:
|
||||
data: input tensor, (NxHxW)
|
||||
|
||||
Returns:
|
||||
Batch fourier transform of input data.
|
||||
"""
|
||||
|
||||
dim = data.ndim - 1 # subtract one for batch dimension
|
||||
if dim != 2:
|
||||
raise AttributeError(f'Data must be 2d but it is {dim}d.')
|
||||
|
||||
dims = tuple(range(1, dim + 1)) # add one for batch dimension
|
||||
if normalize:
|
||||
norm = 'ortho'
|
||||
else:
|
||||
norm = 'backward'
|
||||
|
||||
if not torch.is_complex(data):
|
||||
data = torch.complex(data, torch.zeros_like(data))
|
||||
freq = fftn(data, dim=dims, norm=norm)
|
||||
|
||||
return freq
|
||||
|
||||
|
||||
def azimuthal_average(image, center=None):
|
||||
# modified to tensor inputs from https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/
|
||||
"""
|
||||
Calculate the azimuthally averaged radial profile.
|
||||
Requires low frequencies to be at the center of the image.
|
||||
Args:
|
||||
image: Batch of 2D images, NxHxW
|
||||
center: The [x,y] pixel coordinates used as the center. The default is
|
||||
None, which then uses the center of the image (including
|
||||
fracitonal pixels).
|
||||
|
||||
Returns:
|
||||
Azimuthal average over the image around the center
|
||||
"""
|
||||
# Check input shapes
|
||||
assert center is None or (len(center) == 2), f'Center has to be None or len(center)=2 ' \
|
||||
f'(but it is len(center)={len(center)}.'
|
||||
# Calculate the indices from the image
|
||||
H, W = image.shape[-2:]
|
||||
h, w = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
|
||||
|
||||
if center is None:
|
||||
center = torch.tensor([(w.max() - w.min()) / 2.0, (h.max() - h.min()) / 2.0])
|
||||
|
||||
# Compute radius for each pixel wrt center
|
||||
r = torch.stack([w-center[0], h-center[1]]).norm(2, 0)
|
||||
|
||||
# Get sorted radii
|
||||
r_sorted, ind = r.flatten().sort()
|
||||
i_sorted = image.flatten(-2, -1)[..., ind]
|
||||
|
||||
# Get the integer part of the radii (bin size = 1)
|
||||
r_int = r_sorted.long() # attribute to the smaller integer
|
||||
|
||||
# Find all pixels that fall within each radial bin.
|
||||
deltar = r_int[1:] - r_int[:-1] # Assumes all radii represented, computes bin change between subsequent radii
|
||||
rind = torch.where(deltar)[0] # location of changed radius
|
||||
|
||||
# compute number of elements in each bin
|
||||
nind = rind + 1 # number of elements = idx + 1
|
||||
nind = torch.cat([torch.tensor([0]), nind, torch.tensor([H*W])]) # add borders
|
||||
nr = nind[1:] - nind[:-1] # number of radius bin, i.e. counter for bins belonging to each radius
|
||||
|
||||
# Cumulative sum to figure out sums for each radius bin
|
||||
if H % 2 == 0:
|
||||
raise NotImplementedError('Not sure if implementation correct, please check')
|
||||
rind = torch.cat([torch.tensor([0]), rind, torch.tensor([H * W - 1])]) # add borders
|
||||
else:
|
||||
rind = torch.cat([rind, torch.tensor([H * W - 1])]) # add borders
|
||||
csim = i_sorted.cumsum(-1, dtype=torch.float64) # integrate over all values with smaller radius
|
||||
tbin = csim[..., rind[1:]] - csim[..., rind[:-1]]
|
||||
# add mean
|
||||
tbin = torch.cat([csim[:, 0:1], tbin], 1)
|
||||
|
||||
radial_prof = tbin / nr.to(tbin.device) # normalize by counted bins
|
||||
|
||||
return radial_prof
|
||||
|
||||
|
||||
def get_spectrum(data, normalize=False):
|
||||
dim = data.ndim - 1 # subtract one for batch dimension
|
||||
if dim != 2:
|
||||
raise AttributeError(f'Data must be 2d but it is {dim}d.')
|
||||
|
||||
freq = batch_fft(data, normalize=normalize)
|
||||
power_spec = freq.real ** 2 + freq.imag ** 2
|
||||
N = data.shape[1]
|
||||
if N % 2 == 0: # duplicate value for N/2 so it is put at the end of the spectrum
|
||||
# and is not averaged with the mean value
|
||||
N_2 = N//2
|
||||
power_spec = torch.cat([power_spec[:, :N_2+1], power_spec[:, N_2:N_2+1], power_spec[:, N_2+1:]], dim=1)
|
||||
power_spec = torch.cat([power_spec[:, :, :N_2+1], power_spec[:, :, N_2:N_2+1], power_spec[:, :, N_2+1:]], dim=2)
|
||||
|
||||
power_spec = roll_quadrants(power_spec)
|
||||
power_spec = azimuthal_average(power_spec)
|
||||
return power_spec
|
||||
|
||||
|
||||
def plot_std(mean, std, x=None, ax=None, **kwargs):
|
||||
import matplotlib.pyplot as plt
|
||||
if ax is None:
|
||||
fig, ax = plt.subplots(1)
|
||||
|
||||
# plot error margins in same color as line
|
||||
err_kwargs = {
|
||||
'alpha': 0.3
|
||||
}
|
||||
|
||||
if 'c' in kwargs.keys():
|
||||
err_kwargs['color'] = kwargs['c']
|
||||
elif 'color' in kwargs.keys():
|
||||
err_kwargs['color'] = kwargs['color']
|
||||
|
||||
if x is None:
|
||||
x = torch.linspace(0, 1, len(mean)) # use normalized x axis
|
||||
ax.plot(x, mean, **kwargs)
|
||||
ax.fill_between(x, mean-std, mean+std, **err_kwargs)
|
||||
|
||||
return ax
|
||||
@@ -5,7 +5,7 @@
|
||||
# Created Date: Tuesday April 28th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 24th January 2022 3:32:47 pm
|
||||
# Last Modified: Saturday, 29th January 2022 3:25:24 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -31,24 +31,24 @@ def getParameters():
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# general settings
|
||||
parser.add_argument('-v', '--version', type=str, default='GramFM',
|
||||
parser.add_argument('-v', '--version', type=str, default='arcface_rec',
|
||||
help="version name for train, test, finetune")
|
||||
parser.add_argument('-t', '--tag', type=str, default='Gram_Feature_match',
|
||||
parser.add_argument('-t', '--tag', type=str, default='arcface_rec',
|
||||
help="tag for current experiment")
|
||||
|
||||
parser.add_argument('-p', '--phase', type=str, default="train",
|
||||
choices=['train', 'finetune','debug'],
|
||||
help="The phase of current project")
|
||||
|
||||
parser.add_argument('-c', '--cuda', type=int, default=1) # <0 if it is set as -1, program will use CPU
|
||||
parser.add_argument('-c', '--cuda', type=int, default=0) # <0 if it is set as -1, program will use CPU
|
||||
parser.add_argument('-e', '--ckpt', type=int, default=74,
|
||||
help="checkpoint epoch for test phase or finetune phase")
|
||||
|
||||
# training
|
||||
parser.add_argument('--experiment_description', type=str,
|
||||
default="使用3作为feature, 尝试使用gram矩阵来计算feature matching")
|
||||
default="用arcface作编码器,进行图像重构")
|
||||
|
||||
parser.add_argument('--train_yaml', type=str, default="train_GramFM.yaml")
|
||||
parser.add_argument('--train_yaml', type=str, default="train_arcface_rec.yaml")
|
||||
|
||||
# system logger
|
||||
parser.add_argument('--logger', type=str,
|
||||
|
||||
@@ -0,0 +1,247 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: train.py
|
||||
# Created Date: Tuesday April 28th 2020
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 8th February 2022 1:05:05 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
from torch.backends import cudnn
|
||||
from utilities.json_config import readConfig, writeConfig
|
||||
from utilities.reporter import Reporter
|
||||
from utilities.yaml_config import getConfigYaml
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
return v.lower() in ('true')
|
||||
|
||||
####################################################################################
|
||||
# To configure the seting of training\finetune\test
|
||||
#
|
||||
####################################################################################
|
||||
def getParameters():
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
# general settings
|
||||
parser.add_argument('-v', '--version', type=str, default='multigpu2',
|
||||
help="version name for train, test, finetune")
|
||||
parser.add_argument('-t', '--tag', type=str, default='multigpu',
|
||||
help="tag for current experiment")
|
||||
|
||||
parser.add_argument('-p', '--phase', type=str, default="train",
|
||||
choices=['train', 'finetune','debug'],
|
||||
help="The phase of current project")
|
||||
|
||||
parser.add_argument('-c', '--gpus', type=int, nargs='+', default=[0,1]) # <0 if it is set as -1, program will use CPU
|
||||
parser.add_argument('-e', '--ckpt', type=int, default=74,
|
||||
help="checkpoint epoch for test phase or finetune phase")
|
||||
|
||||
# training
|
||||
parser.add_argument('--experiment_description', type=str,
|
||||
default="测试多GPU训练")
|
||||
|
||||
parser.add_argument('--train_yaml', type=str, default="train_multigpu.yaml")
|
||||
|
||||
# system logger
|
||||
parser.add_argument('--logger', type=str,
|
||||
default="wandb", choices=['tensorboard', 'wandb','none'], help='system logger')
|
||||
|
||||
# # logs (does not to be changed in most time)
|
||||
# parser.add_argument('--dataloader_workers', type=int, default=6)
|
||||
# parser.add_argument('--use_tensorboard', type=str2bool, default='True',
|
||||
# choices=['True', 'False'], help='enable the tensorboard')
|
||||
# parser.add_argument('--log_step', type=int, default=100)
|
||||
# parser.add_argument('--sample_step', type=int, default=100)
|
||||
|
||||
# # template (onece editing finished, it should be deleted)
|
||||
# parser.add_argument('--str_parameter', type=str, default="default", help='str parameter')
|
||||
# parser.add_argument('--str_parameter_choices', type=str,
|
||||
# default="default", choices=['choice1', 'choice2','choice3'], help='str parameter with choices list')
|
||||
# parser.add_argument('--int_parameter', type=int, default=0, help='int parameter')
|
||||
# parser.add_argument('--float_parameter', type=float, default=0.0, help='float parameter')
|
||||
# parser.add_argument('--bool_parameter', type=str2bool, default='True', choices=['True', 'False'], help='bool parameter')
|
||||
# parser.add_argument('--list_str_parameter', type=str, nargs='+', default=["element1","element2"], help='str list parameter')
|
||||
# parser.add_argument('--list_int_parameter', type=int, nargs='+', default=[0,1], help='int list parameter')
|
||||
return parser.parse_args()
|
||||
|
||||
ignoreKey = [
|
||||
"dataloader_workers",
|
||||
"log_root_path",
|
||||
"project_root",
|
||||
"project_summary",
|
||||
"project_checkpoints",
|
||||
"project_samples",
|
||||
"project_scripts",
|
||||
"reporter_path",
|
||||
"use_specified_data",
|
||||
"specified_data_paths",
|
||||
"dataset_path","cuda",
|
||||
"test_script_name",
|
||||
"test_dataloader",
|
||||
"test_dataset_path",
|
||||
"save_test_result",
|
||||
"test_batch_size",
|
||||
"node_name",
|
||||
"checkpoint_epoch",
|
||||
"test_dataset_path",
|
||||
"test_dataset_name",
|
||||
"use_my_test_date"]
|
||||
|
||||
####################################################################################
|
||||
# This function will create the related directories before the
|
||||
# training\fintune\test starts
|
||||
# Your_log_root (version name)
|
||||
# |---summary/...
|
||||
# |---samples/... (save evaluated images)
|
||||
# |---checkpoints/...
|
||||
# |---scripts/...
|
||||
#
|
||||
####################################################################################
|
||||
def createDirs(sys_state):
|
||||
# the base dir
|
||||
if not os.path.exists(sys_state["log_root_path"]):
|
||||
os.makedirs(sys_state["log_root_path"])
|
||||
|
||||
# create dirs
|
||||
sys_state["project_root"] = os.path.join(sys_state["log_root_path"],
|
||||
sys_state["version"])
|
||||
|
||||
project_root = sys_state["project_root"]
|
||||
if not os.path.exists(project_root):
|
||||
os.makedirs(project_root)
|
||||
|
||||
sys_state["project_summary"] = os.path.join(project_root, "summary")
|
||||
if not os.path.exists(sys_state["project_summary"]):
|
||||
os.makedirs(sys_state["project_summary"])
|
||||
|
||||
sys_state["project_checkpoints"] = os.path.join(project_root, "checkpoints")
|
||||
if not os.path.exists(sys_state["project_checkpoints"]):
|
||||
os.makedirs(sys_state["project_checkpoints"])
|
||||
|
||||
sys_state["project_samples"] = os.path.join(project_root, "samples")
|
||||
if not os.path.exists(sys_state["project_samples"]):
|
||||
os.makedirs(sys_state["project_samples"])
|
||||
|
||||
sys_state["project_scripts"] = os.path.join(project_root, "scripts")
|
||||
if not os.path.exists(sys_state["project_scripts"]):
|
||||
os.makedirs(sys_state["project_scripts"])
|
||||
|
||||
sys_state["reporter_path"] = os.path.join(project_root,sys_state["version"]+"_report")
|
||||
|
||||
def main():
|
||||
|
||||
config = getParameters()
|
||||
# speed up the program
|
||||
cudnn.benchmark = True
|
||||
|
||||
from utilities.logo_class import logo_class
|
||||
logo_class.print_group_logo()
|
||||
|
||||
sys_state = {}
|
||||
|
||||
# set the GPU number
|
||||
gpus = [str(i) for i in config.gpus]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpus)
|
||||
|
||||
# read system environment paths
|
||||
env_config = readConfig('env/env.json')
|
||||
env_config = env_config["path"]
|
||||
|
||||
# obtain all configurations in argparse
|
||||
config_dic = vars(config)
|
||||
for config_key in config_dic.keys():
|
||||
sys_state[config_key] = config_dic[config_key]
|
||||
|
||||
#=======================Train Phase=========================#
|
||||
if config.phase == "train":
|
||||
# read training configurations from yaml file
|
||||
ymal_config = getConfigYaml(os.path.join(env_config["train_config_path"], config.train_yaml))
|
||||
for item in ymal_config.items():
|
||||
sys_state[item[0]] = item[1]
|
||||
|
||||
# create related dirs
|
||||
sys_state["log_root_path"] = env_config["train_log_root"]
|
||||
createDirs(sys_state)
|
||||
|
||||
# create reporter file
|
||||
reporter = Reporter(sys_state["reporter_path"])
|
||||
|
||||
# save the config json
|
||||
config_json = os.path.join(sys_state["project_root"], env_config["config_json_name"])
|
||||
writeConfig(config_json, sys_state)
|
||||
|
||||
# save the dependent scripts
|
||||
# TODO and copy the scripts to the project dir
|
||||
|
||||
# save the trainer script into [train_logs_root]\[version name]\scripts\
|
||||
file1 = os.path.join(env_config["train_scripts_path"],
|
||||
"trainer_%s.py"%sys_state["train_script_name"])
|
||||
tgtfile1 = os.path.join(sys_state["project_scripts"],
|
||||
"trainer_%s.py"%sys_state["train_script_name"])
|
||||
shutil.copyfile(file1,tgtfile1)
|
||||
|
||||
# save the yaml file
|
||||
file1 = os.path.join(env_config["train_config_path"], config.train_yaml)
|
||||
tgtfile1 = os.path.join(sys_state["project_scripts"], config.train_yaml)
|
||||
shutil.copyfile(file1,tgtfile1)
|
||||
|
||||
# TODO replace below lines, here to save the critical scripts
|
||||
|
||||
#=====================Finetune Phase=====================#
|
||||
elif config.phase == "finetune":
|
||||
sys_state["log_root_path"] = env_config["train_log_root"]
|
||||
sys_state["project_root"] = os.path.join(sys_state["log_root_path"], sys_state["version"])
|
||||
|
||||
config_json = os.path.join(sys_state["project_root"], env_config["config_json_name"])
|
||||
train_config = readConfig(config_json)
|
||||
for item in train_config.items():
|
||||
if item[0] in ignoreKey:
|
||||
pass
|
||||
else:
|
||||
sys_state[item[0]] = item[1]
|
||||
|
||||
createDirs(sys_state)
|
||||
reporter = Reporter(sys_state["reporter_path"])
|
||||
sys_state["com_base"] = "train_logs.%s.scripts."%sys_state["version"]
|
||||
|
||||
|
||||
# get the dataset path
|
||||
sys_state["dataset_paths"] = {}
|
||||
for data_key in env_config["dataset_paths"].keys():
|
||||
sys_state["dataset_paths"][data_key] = env_config["dataset_paths"][data_key]
|
||||
|
||||
# display the training information
|
||||
moduleName = "train_scripts.trainer_" + sys_state["train_script_name"]
|
||||
if config.phase == "finetune":
|
||||
moduleName = sys_state["com_base"] + "trainer_" + sys_state["train_script_name"]
|
||||
|
||||
# print some important information
|
||||
# TODO
|
||||
print("Start to run training script: {}".format(moduleName))
|
||||
print("Traning version: %s"%sys_state["version"])
|
||||
print("Dataloader Name: %s"%sys_state["dataloader"])
|
||||
# print("Image Size: %d"%sys_state["imsize"])
|
||||
print("Batch size: %d"%(sys_state["batch_size"]))
|
||||
print("GPUs:", gpus)
|
||||
|
||||
|
||||
|
||||
# Load the training script and start to train
|
||||
reporter.writeConfig(sys_state)
|
||||
|
||||
package = __import__(moduleName, fromlist=True)
|
||||
trainerClass= getattr(package, 'Trainer')
|
||||
trainer = trainerClass(sys_state, reporter)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -5,7 +5,7 @@
|
||||
# Created Date: Sunday January 9th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Monday, 24th January 2022 6:23:16 pm
|
||||
# Last Modified: Tuesday, 25th January 2022 3:25:56 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
@@ -13,17 +13,16 @@
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from utilities.plot import plot_batch
|
||||
from utilities.utilities import Gram
|
||||
|
||||
from train_scripts.trainer_base import TrainerBase
|
||||
|
||||
from utilities.utilities import Gram
|
||||
|
||||
class Trainer(TrainerBase):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: trainer_naiv512.py
|
||||
# Created Date: Sunday January 9th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Saturday, 29th January 2022 3:54:06 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import shutil
|
||||
from cv2 import sqrt
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision.utils import save_image
|
||||
|
||||
from train_scripts.trainer_base import TrainerBase
|
||||
|
||||
class Trainer(TrainerBase):
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
reporter):
|
||||
super(Trainer, self).__init__(config, reporter)
|
||||
|
||||
import inspect
|
||||
print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe()))
|
||||
|
||||
self.img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1)
|
||||
self.img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1)
|
||||
|
||||
# TODO modify this function to build your models
|
||||
def init_framework(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
|
||||
model_config = self.config["model_configs"]
|
||||
|
||||
if self.config["phase"] == "train":
|
||||
gscript_name = "components." + model_config["g_model"]["script"]
|
||||
|
||||
file1 = os.path.join("components", model_config["g_model"]["script"]+".py")
|
||||
tgtfile1 = os.path.join(self.config["project_scripts"], model_config["g_model"]["script"]+".py")
|
||||
shutil.copyfile(file1,tgtfile1)
|
||||
|
||||
elif self.config["phase"] == "finetune":
|
||||
gscript_name = self.config["com_base"] + model_config["g_model"]["script"]
|
||||
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
self.gen = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# print and recorde model structure
|
||||
self.reporter.writeInfo("Generator structure:")
|
||||
self.reporter.writeModel(self.gen.__str__())
|
||||
|
||||
|
||||
# print and recorde model structure
|
||||
arcface1 = torch.load(self.arcface_ckpt, map_location=torch.device("cpu"))
|
||||
self.arcface = arcface1['model'].module
|
||||
|
||||
# train in GPU
|
||||
if self.config["cuda"] >=0:
|
||||
self.gen = self.gen.cuda()
|
||||
self.arcface= self.arcface.cuda()
|
||||
|
||||
self.arcface.eval()
|
||||
self.arcface.requires_grad_(False)
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if self.config["phase"] == "finetune":
|
||||
model_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["checkpoint_names"]["generator_name"]))
|
||||
self.gen.load_state_dict(torch.load(model_path))
|
||||
|
||||
|
||||
print('loaded trained backbone model step {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
def __setup_optimizers__(self):
|
||||
g_train_opt = self.config['g_optim_config']
|
||||
|
||||
g_optim_params = []
|
||||
for k, v in self.gen.named_parameters():
|
||||
if v.requires_grad:
|
||||
g_optim_params.append(v)
|
||||
else:
|
||||
self.reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
optim_type = self.config['optim_type']
|
||||
|
||||
if optim_type == 'Adam':
|
||||
self.g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'optimizer {optim_type} is not supperted yet.')
|
||||
# self.optimizers.append(self.optimizer_g)
|
||||
if self.config["phase"] == "finetune":
|
||||
opt_path = os.path.join(self.config["project_checkpoints"],
|
||||
"step%d_optim_%s.pth"%(self.config["checkpoint_step"],
|
||||
self.config["optimizer_names"]["generator_name"]))
|
||||
self.g_optimizer.load_state_dict(torch.load(opt_path))
|
||||
|
||||
|
||||
print('loaded trained optimizer step {}...!'.format(self.config["project_checkpoints"]))
|
||||
|
||||
|
||||
# TODO modify this function to evaluate your model
|
||||
# Evaluate the checkpoint
|
||||
def __evaluation__(self,
|
||||
step = 0,
|
||||
**kwargs
|
||||
):
|
||||
src_image1 = kwargs["src1"]
|
||||
self.gen.eval()
|
||||
with torch.no_grad():
|
||||
id_vector_src1 = self.arcface(src_image1)
|
||||
img_fake = self.gen(id_vector_src1).cpu()
|
||||
img_fake = img_fake * self.img_std
|
||||
img_fake = img_fake + self.img_mean
|
||||
img_fake = img_fake.clamp_(0, 1)
|
||||
print("Save test data")
|
||||
save_image(img_fake,
|
||||
os.path.join(self.sample_dir, 'step_'+str(step+1)+'.jpg'),
|
||||
nrow=8)
|
||||
|
||||
|
||||
|
||||
|
||||
def train(self):
|
||||
|
||||
ckpt_dir = self.config["project_checkpoints"]
|
||||
log_freq = self.config["log_step"]
|
||||
model_freq = self.config["model_save_step"]
|
||||
sample_freq = self.config["sample_step"]
|
||||
total_step = self.config["total_step"]
|
||||
random_seed = self.config["dataset_params"]["random_seed"]
|
||||
|
||||
self.batch_size = self.config["batch_size"]
|
||||
self.sample_dir = self.config["project_samples"]
|
||||
self.arcface_ckpt= self.config["arcface_ckpt"]
|
||||
|
||||
|
||||
super().train()
|
||||
|
||||
#===============build losses===================#
|
||||
# TODO replace below lines to build your losses
|
||||
# MSE_loss = torch.nn.MSELoss()
|
||||
l1_loss = torch.nn.L1Loss()
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Caculate the epoch number
|
||||
print("Total step = %d"%total_step)
|
||||
random.seed(random_seed)
|
||||
randindex = [i for i in range(self.batch_size)]
|
||||
random.shuffle(randindex)
|
||||
import datetime
|
||||
for step in range(self.start, total_step):
|
||||
self.gen.train()
|
||||
src_image1 = self.train_loader.next()
|
||||
|
||||
latent_id = self.arcface(src_image1)
|
||||
img_fake = self.gen(latent_id.detach())
|
||||
loss = l1_loss(img_fake, src_image1)
|
||||
|
||||
self.g_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.g_optimizer.step()
|
||||
|
||||
# Print out log info
|
||||
if (step + 1) % log_freq == 0:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Step [{}/{}], Reconstruction: {:.4f}". \
|
||||
format(self.config["version"], elapsed, step, total_step, loss.item())
|
||||
print(epochinformation)
|
||||
self.reporter.writeInfo(epochinformation)
|
||||
|
||||
if self.config["logger"] == "tensorboard":
|
||||
self.logger.add_scalar('Rec_loss', loss.item(), step)
|
||||
elif self.config["logger"] == "wandb":
|
||||
self.logger.log({"Rec_loss": loss.item()}, step = step)
|
||||
|
||||
if (step + 1) % sample_freq == 0:
|
||||
self.__evaluation__(
|
||||
step = step,
|
||||
**{
|
||||
"src1": src_image1
|
||||
})
|
||||
|
||||
|
||||
|
||||
#===============adjust learning rate============#
|
||||
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
|
||||
# print("Learning rate decay")
|
||||
# for p in self.optimizer.param_groups:
|
||||
# p['lr'] *= self.config["lr_decay"]
|
||||
# print("Current learning rate is %f"%p['lr'])
|
||||
|
||||
#===============save checkpoints================#
|
||||
if (step+1) % model_freq==0:
|
||||
|
||||
torch.save(self.gen.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1,
|
||||
self.config["checkpoint_names"]["generator_name"])))
|
||||
|
||||
torch.save(self.g_optimizer.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1,
|
||||
self.config["checkpoint_names"]["generator_name"])))
|
||||
|
||||
print("Save step %d model checkpoint!"%(step+1))
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.__evaluation__(
|
||||
step = step,
|
||||
**{
|
||||
"src1": src_image1
|
||||
})
|
||||
@@ -0,0 +1,508 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: trainer_naiv512.py
|
||||
# Created Date: Sunday January 9th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Tuesday, 8th February 2022 2:29:34 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch_utils import misc
|
||||
from torch_utils import training_stats
|
||||
from torch_utils.ops import conv2d_gradfix
|
||||
from torch_utils.ops import grid_sample_gradfix
|
||||
|
||||
from utilities.plot import plot_batch
|
||||
from losses.cos import cosin_metric
|
||||
from train_scripts.trainer_multigpu_base import TrainerBase
|
||||
|
||||
|
||||
class Trainer(TrainerBase):
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
reporter):
|
||||
super(Trainer, self).__init__(config, reporter)
|
||||
|
||||
import inspect
|
||||
print("Current training script -----------> %s"%inspect.getfile(inspect.currentframe()))
|
||||
|
||||
def train(self):
|
||||
# Launch processes.
|
||||
num_gpus = len(self.config["gpus"])
|
||||
print('Launching processes...')
|
||||
torch.multiprocessing.set_start_method('spawn')
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
torch.multiprocessing.spawn(fn=train_loop, args=(self.config, self.reporter, temp_dir), nprocs=num_gpus)
|
||||
|
||||
# TODO modify this function to build your models
|
||||
def init_framework(config, reporter, device, rank):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
print("build models...")
|
||||
# TODO [import models here]
|
||||
torch.cuda.set_device(rank)
|
||||
torch.cuda.empty_cache()
|
||||
model_config = config["model_configs"]
|
||||
|
||||
if config["phase"] == "train":
|
||||
gscript_name = "components." + model_config["g_model"]["script"]
|
||||
dscript_name = "components." + model_config["d_model"]["script"]
|
||||
|
||||
elif config["phase"] == "finetune":
|
||||
gscript_name = config["com_base"] + model_config["g_model"]["script"]
|
||||
dscript_name = config["com_base"] + model_config["d_model"]["script"]
|
||||
|
||||
class_name = model_config["g_model"]["class_name"]
|
||||
package = __import__(gscript_name, fromlist=True)
|
||||
gen_class = getattr(package, class_name)
|
||||
gen = gen_class(**model_config["g_model"]["module_params"])
|
||||
|
||||
# print and recorde model structure
|
||||
reporter.writeInfo("Generator structure:")
|
||||
reporter.writeModel(gen.__str__())
|
||||
|
||||
class_name = model_config["d_model"]["class_name"]
|
||||
package = __import__(dscript_name, fromlist=True)
|
||||
dis_class = getattr(package, class_name)
|
||||
dis = dis_class(**model_config["d_model"]["module_params"])
|
||||
|
||||
|
||||
# print and recorde model structure
|
||||
reporter.writeInfo("Discriminator structure:")
|
||||
reporter.writeModel(dis.__str__())
|
||||
arcface1 = torch.load(config["arcface_ckpt"], map_location=torch.device("cpu"))
|
||||
arcface = arcface1['model'].module
|
||||
|
||||
# train in GPU
|
||||
|
||||
gen = gen.to(device)
|
||||
dis = dis.to(device)
|
||||
arcface= arcface.to(device)
|
||||
arcface.requires_grad_(False)
|
||||
arcface.eval()
|
||||
|
||||
|
||||
# if in finetune phase, load the pretrained checkpoint
|
||||
if config["phase"] == "finetune":
|
||||
model_path = os.path.join(config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(config["checkpoint_step"],
|
||||
config["checkpoint_names"]["generator_name"]))
|
||||
gen.load_state_dict(torch.load(model_path))
|
||||
|
||||
model_path = os.path.join(config["project_checkpoints"],
|
||||
"step%d_%s.pth"%(config["checkpoint_step"],
|
||||
config["checkpoint_names"]["discriminator_name"]))
|
||||
dis.load_state_dict(torch.load(model_path))
|
||||
|
||||
print('loaded trained backbone model step {}...!'.format(config["project_checkpoints"]))
|
||||
return gen, dis, arcface
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
def setup_optimizers(config, reporter, gen, dis, rank):
|
||||
|
||||
torch.cuda.set_device(rank)
|
||||
torch.cuda.empty_cache()
|
||||
g_train_opt = config['g_optim_config']
|
||||
d_train_opt = config['d_optim_config']
|
||||
|
||||
g_optim_params = []
|
||||
d_optim_params = []
|
||||
for k, v in gen.named_parameters():
|
||||
if v.requires_grad:
|
||||
g_optim_params.append(v)
|
||||
else:
|
||||
reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
for k, v in dis.named_parameters():
|
||||
if v.requires_grad:
|
||||
d_optim_params.append(v)
|
||||
else:
|
||||
reporter.writeInfo(f'Params {k} will not be optimized.')
|
||||
print(f'Params {k} will not be optimized.')
|
||||
|
||||
optim_type = config['optim_type']
|
||||
|
||||
if optim_type == 'Adam':
|
||||
g_optimizer = torch.optim.Adam(g_optim_params,**g_train_opt)
|
||||
d_optimizer = torch.optim.Adam(d_optim_params,**d_train_opt)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'optimizer {optim_type} is not supperted yet.')
|
||||
# self.optimizers.append(self.optimizer_g)
|
||||
if config["phase"] == "finetune":
|
||||
opt_path = os.path.join(config["project_checkpoints"],
|
||||
"step%d_optim_%s.pth"%(config["checkpoint_step"],
|
||||
config["optimizer_names"]["generator_name"]))
|
||||
g_optimizer.load_state_dict(torch.load(opt_path))
|
||||
|
||||
opt_path = os.path.join(config["project_checkpoints"],
|
||||
"step%d_optim_%s.pth"%(config["checkpoint_step"],
|
||||
config["optimizer_names"]["discriminator_name"]))
|
||||
d_optimizer.load_state_dict(torch.load(opt_path))
|
||||
|
||||
print('loaded trained optimizer step {}...!'.format(config["project_checkpoints"]))
|
||||
return g_optimizer, d_optimizer
|
||||
|
||||
|
||||
def train_loop(
|
||||
rank,
|
||||
config,
|
||||
reporter,
|
||||
temp_dir
|
||||
):
|
||||
|
||||
version = config["version"]
|
||||
|
||||
ckpt_dir = config["project_checkpoints"]
|
||||
sample_dir = config["project_samples"]
|
||||
|
||||
log_freq = config["log_step"]
|
||||
model_freq = config["model_save_step"]
|
||||
sample_freq = config["sample_step"]
|
||||
total_step = config["total_step"]
|
||||
random_seed = config["dataset_params"]["random_seed"]
|
||||
|
||||
|
||||
id_w = config["id_weight"]
|
||||
rec_w = config["reconstruct_weight"]
|
||||
feat_w = config["feature_match_weight"]
|
||||
num_gpus = len(config["gpus"])
|
||||
batch_gpu = config["batch_size"] // num_gpus
|
||||
|
||||
init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
|
||||
if os.name == 'nt':
|
||||
init_method = 'file:///' + init_file.replace('\\', '/')
|
||||
torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=num_gpus)
|
||||
else:
|
||||
init_method = f'file://{init_file}'
|
||||
torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=num_gpus)
|
||||
|
||||
# Init torch_utils.
|
||||
sync_device = torch.device('cuda', rank)
|
||||
training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
|
||||
|
||||
|
||||
|
||||
if rank == 0:
|
||||
img_std = torch.Tensor([0.229, 0.224, 0.225]).view(3,1,1)
|
||||
img_mean = torch.Tensor([0.485, 0.456, 0.406]).view(3,1,1)
|
||||
|
||||
cudnn_benchmark = True
|
||||
|
||||
# Initialize.
|
||||
device = torch.device('cuda', rank)
|
||||
np.random.seed(random_seed * num_gpus + rank)
|
||||
torch.manual_seed(random_seed * num_gpus + rank)
|
||||
torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed.
|
||||
torch.backends.cuda.matmul.allow_tf32 = False # Improves numerical accuracy.
|
||||
torch.backends.cudnn.allow_tf32 = False # Improves numerical accuracy.
|
||||
conv2d_gradfix.enabled = True # Improves training speed.
|
||||
grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe.
|
||||
|
||||
# Create dataloader.
|
||||
if rank == 0:
|
||||
print('Loading training set...')
|
||||
|
||||
dataset = config["dataset_paths"][config["dataset_name"]]
|
||||
#================================================#
|
||||
print("Prepare the train dataloader...")
|
||||
dlModulename = config["dataloader"]
|
||||
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'GetLoader')
|
||||
dataloader_class= dataloaderClass
|
||||
dataloader = dataloader_class(dataset,
|
||||
rank,
|
||||
num_gpus,
|
||||
batch_gpu,
|
||||
**config["dataset_params"])
|
||||
|
||||
# Construct networks.
|
||||
if rank == 0:
|
||||
print('Constructing networks...')
|
||||
gen, dis, arcface = init_framework(config, reporter, device, rank)
|
||||
|
||||
# Check for existing checkpoint
|
||||
|
||||
# Print network summary tables.
|
||||
# if rank == 0:
|
||||
# attr = torch.empty([batch_gpu, 3, 512, 512], device=device)
|
||||
# id = torch.empty([batch_gpu, 3, 112, 112], device=device)
|
||||
# latent = misc.print_module_summary(arcface, [id])
|
||||
# img = misc.print_module_summary(gen, [attr, latent])
|
||||
# misc.print_module_summary(dis, [img, None])
|
||||
# del attr
|
||||
# del id
|
||||
# del latent
|
||||
# del img
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
|
||||
# Distribute across GPUs.
|
||||
if rank == 0:
|
||||
print(f'Distributing across {num_gpus} GPUs...')
|
||||
for module in [gen, dis, arcface]:
|
||||
if module is not None and num_gpus > 1:
|
||||
for param in misc.params_and_buffers(module):
|
||||
torch.distributed.broadcast(param, src=0)
|
||||
|
||||
# Setup training phases.
|
||||
if rank == 0:
|
||||
print('Setting up training phases...')
|
||||
#===============build losses===================#
|
||||
# TODO replace below lines to build your losses
|
||||
# MSE_loss = torch.nn.MSELoss()
|
||||
l1_loss = torch.nn.L1Loss()
|
||||
cos_loss = cosin_metric
|
||||
|
||||
g_optimizer, d_optimizer = setup_optimizers(config, reporter, gen, dis, rank)
|
||||
|
||||
# Initialize logs.
|
||||
if rank == 0:
|
||||
print('Initializing logs...')
|
||||
if rank == 0:
|
||||
#==============build tensorboard=================#
|
||||
if config["logger"] == "tensorboard":
|
||||
import torch.utils.tensorboard as tensorboard
|
||||
tensorboard_writer = tensorboard.SummaryWriter(config["project_summary"])
|
||||
logger = tensorboard_writer
|
||||
elif config["logger"] == "wandb":
|
||||
import wandb
|
||||
wandb.init(project="Simswap_HQ", entity="xhchen", notes="512",
|
||||
tags=[config["tag"]], name=version)
|
||||
|
||||
wandb.config = {
|
||||
"total_step": config["total_step"],
|
||||
"batch_size": config["batch_size"]
|
||||
}
|
||||
logger = wandb
|
||||
|
||||
|
||||
random.seed(random_seed)
|
||||
randindex = [i for i in range(batch_gpu)]
|
||||
random.shuffle(randindex)
|
||||
|
||||
# set the start point for training loop
|
||||
if config["phase"] == "finetune":
|
||||
start = config["checkpoint_step"]
|
||||
else:
|
||||
start = 0
|
||||
if rank == 0:
|
||||
import datetime
|
||||
start_time = time.time()
|
||||
|
||||
# Caculate the epoch number
|
||||
print("Total step = %d"%total_step)
|
||||
|
||||
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
|
||||
from utilities.logo_class import logo_class
|
||||
logo_class.print_start_training()
|
||||
dis.feature_network.requires_grad_(False)
|
||||
for step in range(start, total_step):
|
||||
gen.train()
|
||||
dis.train()
|
||||
dis.feature_network.eval()
|
||||
for interval in range(2):
|
||||
random.shuffle(randindex)
|
||||
src_image1, src_image2 = dataloader.next()
|
||||
# if rank ==0:
|
||||
|
||||
# elapsed = time.time() - start_time
|
||||
# elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
# print("dataloader:",elapsed)
|
||||
|
||||
if step%2 == 0:
|
||||
img_id = src_image2
|
||||
else:
|
||||
img_id = src_image2[randindex]
|
||||
|
||||
img_id_112 = F.interpolate(img_id,size=(112,112), mode='bicubic')
|
||||
latent_id = arcface(img_id_112)
|
||||
latent_id = F.normalize(latent_id, p=2, dim=1)
|
||||
|
||||
if interval == 0:
|
||||
|
||||
img_fake = gen(src_image1, latent_id)
|
||||
gen_logits,_ = dis(img_fake.detach(), None)
|
||||
loss_Dgen = (F.relu(torch.ones_like(gen_logits) + gen_logits)).mean()
|
||||
|
||||
real_logits,_ = dis(src_image2,None)
|
||||
loss_Dreal = (F.relu(torch.ones_like(real_logits) - real_logits)).mean()
|
||||
|
||||
loss_D = loss_Dgen + loss_Dreal
|
||||
d_optimizer.zero_grad(set_to_none=True)
|
||||
loss_D.backward()
|
||||
with torch.autograd.profiler.record_function('discriminator_opt'):
|
||||
# params = [param for param in dis.parameters() if param.grad is not None]
|
||||
# if len(params) > 0:
|
||||
# flat = torch.cat([param.grad.flatten() for param in params])
|
||||
# if num_gpus > 1:
|
||||
# torch.distributed.all_reduce(flat)
|
||||
# flat /= num_gpus
|
||||
# misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat)
|
||||
# grads = flat.split([param.numel() for param in params])
|
||||
# for param, grad in zip(params, grads):
|
||||
# param.grad = grad.reshape(param.shape)
|
||||
params = [param for param in dis.parameters() if param.grad is not None]
|
||||
flat = torch.cat([param.grad.flatten() for param in params])
|
||||
torch.distributed.all_reduce(flat)
|
||||
flat /= num_gpus
|
||||
misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat)
|
||||
grads = flat.split([param.numel() for param in params])
|
||||
for param, grad in zip(params, grads):
|
||||
param.grad = grad.reshape(param.shape)
|
||||
d_optimizer.step()
|
||||
# if rank ==0:
|
||||
|
||||
# elapsed = time.time() - start_time
|
||||
# elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
# print("Discriminator training:",elapsed)
|
||||
else:
|
||||
|
||||
# model.netD.requires_grad_(True)
|
||||
img_fake = gen(src_image1, latent_id)
|
||||
# G loss
|
||||
gen_logits,feat = dis(img_fake, None)
|
||||
|
||||
loss_Gmain = (-gen_logits).mean()
|
||||
img_fake_down = F.interpolate(img_fake, size=(112,112), mode='bicubic')
|
||||
latent_fake = arcface(img_fake_down)
|
||||
latent_fake = F.normalize(latent_fake, p=2, dim=1)
|
||||
loss_G_ID = (1 - cos_loss(latent_fake, latent_id)).mean()
|
||||
real_feat = dis.get_feature(src_image1)
|
||||
feat_match_loss = l1_loss(feat["3"],real_feat["3"])
|
||||
loss_G = loss_Gmain + loss_G_ID * id_w + \
|
||||
feat_match_loss * feat_w
|
||||
if step%2 == 0:
|
||||
#G_Rec
|
||||
loss_G_Rec = l1_loss(img_fake, src_image1)
|
||||
loss_G += loss_G_Rec * rec_w
|
||||
|
||||
g_optimizer.zero_grad(set_to_none=True)
|
||||
loss_G.backward()
|
||||
with torch.autograd.profiler.record_function('generator_opt'):
|
||||
params = [param for param in gen.parameters() if param.grad is not None]
|
||||
flat = torch.cat([param.grad.flatten() for param in params])
|
||||
torch.distributed.all_reduce(flat)
|
||||
flat /= num_gpus
|
||||
misc.nan_to_num(flat, nan=0, posinf=1e5, neginf=-1e5, out=flat)
|
||||
grads = flat.split([param.numel() for param in params])
|
||||
for param, grad in zip(params, grads):
|
||||
param.grad = grad.reshape(param.shape)
|
||||
g_optimizer.step()
|
||||
# if rank ==0:
|
||||
|
||||
# elapsed = time.time() - start_time
|
||||
# elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
# print("Generator training:",elapsed)
|
||||
|
||||
|
||||
# Print out log info
|
||||
if rank == 0 and (step + 1) % log_freq == 0:
|
||||
elapsed = time.time() - start_time
|
||||
elapsed = str(datetime.timedelta(seconds=elapsed))
|
||||
|
||||
epochinformation="[{}], Elapsed [{}], Step [{}/{}], \
|
||||
G_ID: {:.4f}, G_loss: {:.4f}, Rec_loss: {:.4f}, Fm_loss: {:.4f}, \
|
||||
D_loss: {:.4f}, D_fake: {:.4f}, D_real: {:.4f}". \
|
||||
format(version, elapsed, step, total_step, \
|
||||
loss_G_ID.item(), loss_G.item(), loss_G_Rec.item(), feat_match_loss.item(), \
|
||||
loss_D.item(), loss_Dgen.item(), loss_Dreal.item())
|
||||
print(epochinformation)
|
||||
reporter.writeInfo(epochinformation)
|
||||
|
||||
if config["logger"] == "tensorboard":
|
||||
logger.add_scalar('G/G_loss', loss_G.item(), step)
|
||||
logger.add_scalar('G/G_Rec', loss_G_Rec.item(), step)
|
||||
logger.add_scalar('G/G_feat_match', feat_match_loss.item(), step)
|
||||
logger.add_scalar('G/G_ID', loss_G_ID.item(), step)
|
||||
logger.add_scalar('D/D_loss', loss_D.item(), step)
|
||||
logger.add_scalar('D/D_fake', loss_Dgen.item(), step)
|
||||
logger.add_scalar('D/D_real', loss_Dreal.item(), step)
|
||||
elif config["logger"] == "wandb":
|
||||
logger.log({"G_Loss": loss_G.item()}, step = step)
|
||||
logger.log({"G_Rec": loss_G_Rec.item()}, step = step)
|
||||
logger.log({"G_feat_match": feat_match_loss.item()}, step = step)
|
||||
logger.log({"G_ID": loss_G_ID.item()}, step = step)
|
||||
logger.log({"D_loss": loss_D.item()}, step = step)
|
||||
logger.log({"D_fake": loss_Dgen.item()}, step = step)
|
||||
logger.log({"D_real": loss_Dreal.item()}, step = step)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if rank == 0 and ((step + 1) % sample_freq == 0 or (step+1) % model_freq==0):
|
||||
gen.eval()
|
||||
with torch.no_grad():
|
||||
imgs = []
|
||||
zero_img = (torch.zeros_like(src_image1[0,...]))
|
||||
imgs.append(zero_img.cpu().numpy())
|
||||
save_img = ((src_image1.cpu())* img_std + img_mean).numpy()
|
||||
for r in range(batch_gpu):
|
||||
imgs.append(save_img[r,...])
|
||||
arcface_112 = F.interpolate(src_image2,size=(112,112), mode='bicubic')
|
||||
id_vector_src1 = arcface(arcface_112)
|
||||
id_vector_src1 = F.normalize(id_vector_src1, p=2, dim=1)
|
||||
|
||||
for i in range(batch_gpu):
|
||||
|
||||
imgs.append(save_img[i,...])
|
||||
image_infer = src_image1[i, ...].repeat(batch_gpu, 1, 1, 1)
|
||||
img_fake = gen(image_infer, id_vector_src1).cpu()
|
||||
|
||||
img_fake = img_fake * img_std
|
||||
img_fake = img_fake + img_mean
|
||||
img_fake = img_fake.numpy()
|
||||
for j in range(batch_gpu):
|
||||
imgs.append(img_fake[j,...])
|
||||
print("Save test data")
|
||||
imgs = np.stack(imgs, axis = 0).transpose(0,2,3,1)
|
||||
plot_batch(imgs, os.path.join(sample_dir, 'step_'+str(step+1)+'.jpg'))
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
||||
#===============adjust learning rate============#
|
||||
# if (epoch + 1) in self.config["lr_decay_step"] and self.config["lr_decay_enable"]:
|
||||
# print("Learning rate decay")
|
||||
# for p in self.optimizer.param_groups:
|
||||
# p['lr'] *= self.config["lr_decay"]
|
||||
# print("Current learning rate is %f"%p['lr'])
|
||||
|
||||
#===============save checkpoints================#
|
||||
if rank == 0 and (step+1) % model_freq==0:
|
||||
|
||||
torch.save(gen.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1,
|
||||
config["checkpoint_names"]["generator_name"])))
|
||||
torch.save(dis.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_{}.pth'.format(step + 1,
|
||||
config["checkpoint_names"]["discriminator_name"])))
|
||||
|
||||
torch.save(g_optimizer.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1,
|
||||
config["checkpoint_names"]["generator_name"])))
|
||||
|
||||
torch.save(d_optimizer.state_dict(),
|
||||
os.path.join(ckpt_dir, 'step{}_optim_{}'.format(step + 1,
|
||||
config["checkpoint_names"]["discriminator_name"])))
|
||||
print("Save step %d model checkpoint!"%(step+1))
|
||||
torch.cuda.empty_cache()
|
||||
print("Rank %d process done!"%rank)
|
||||
torch.distributed.barrier()
|
||||
@@ -0,0 +1,122 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: trainer_base.py
|
||||
# Created Date: Sunday January 16th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 6th February 2022 3:06:45 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
class TrainerBase(object):
|
||||
|
||||
def __init__(self, config, reporter):
|
||||
|
||||
self.config = config
|
||||
# logger
|
||||
self.reporter = reporter
|
||||
|
||||
|
||||
|
||||
#========build evaluation dataloader=============#
|
||||
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
|
||||
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
|
||||
|
||||
# #================================================#
|
||||
# print("Prepare the evaluation dataloader...")
|
||||
# dlModulename = config["eval_dataloader"]
|
||||
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
|
||||
# dataloaderClass = getattr(package, 'EvalDataset')
|
||||
# dataloader = dataloaderClass(eval_dataset,
|
||||
# config["eval_batch_size"])
|
||||
# self.eval_loader= dataloader
|
||||
|
||||
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
|
||||
# if len(dataloader)%config["eval_batch_size"]>0:
|
||||
# self.eval_iter+=1
|
||||
|
||||
# #==============build tensorboard=================#
|
||||
# if self.config["logger"] == "tensorboard":
|
||||
# from utilities.utilities import build_tensorboard
|
||||
# tensorboard_writer = build_tensorboard(self.config["project_summary"])
|
||||
# self.logger = tensorboard_writer
|
||||
# elif self.config["logger"] == "wandb":
|
||||
# import wandb
|
||||
# wandb.init(project="Simswap_HQ", entity="xhchen", notes="512",
|
||||
# tags=[self.config["tag"]], name=self.config["version"])
|
||||
|
||||
# wandb.config = {
|
||||
# "total_step": self.config["total_step"],
|
||||
# "batch_size": self.config["batch_size"]
|
||||
# }
|
||||
# self.logger = wandb
|
||||
|
||||
# TODO modify this function to build your models
|
||||
def __init_framework__(self):
|
||||
'''
|
||||
This function is designed to define the framework,
|
||||
and print the framework information into the log file
|
||||
'''
|
||||
#===============build models================#
|
||||
pass
|
||||
|
||||
# TODO modify this function to configurate the optimizer of your pipeline
|
||||
def __setup_optimizers__(self):
|
||||
pass
|
||||
|
||||
|
||||
# TODO modify this function to evaluate your model
|
||||
# Evaluate the checkpoint
|
||||
def __evaluation__(self,
|
||||
step = 0,
|
||||
**kwargs
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def __create_dataloader__(self,
|
||||
config,
|
||||
cur_gpu
|
||||
):
|
||||
# Data loader
|
||||
#============build train dataloader==============#
|
||||
# TODO to modify the key: "your_train_dataset" to get your train dataset path
|
||||
dataset = config["dataset_paths"][config["dataset_name"]]
|
||||
#================================================#
|
||||
print("Prepare the train dataloader...")
|
||||
dlModulename = config["dataloader"]
|
||||
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
|
||||
dataloaderClass = getattr(package, 'GetLoader')
|
||||
dataloader_class= dataloaderClass
|
||||
dataloader = dataloader_class(dataset,
|
||||
cur_gpu,
|
||||
config["batch_size"],
|
||||
**config["dataset_params"])
|
||||
|
||||
return dataloader
|
||||
|
||||
|
||||
def train(self):
|
||||
#===============build framework================#
|
||||
self.init_framework()
|
||||
|
||||
#===============build optimizer================#
|
||||
# Optimizer
|
||||
# TODO replace below lines to build your optimizer
|
||||
print("build the optimizer...")
|
||||
self.__setup_optimizers__()
|
||||
|
||||
# set the start point for training loop
|
||||
if self.config["phase"] == "finetune":
|
||||
self.start = self.config["checkpoint_step"]
|
||||
else:
|
||||
self.start = 0
|
||||
|
||||
# Start time
|
||||
import datetime
|
||||
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
||||
|
||||
from utilities.logo_class import logo_class
|
||||
logo_class.print_start_training()
|
||||
@@ -0,0 +1,42 @@
|
||||
# Related scripts
|
||||
train_script_name: arcface_rec
|
||||
|
||||
# models' scripts
|
||||
model_configs:
|
||||
g_model:
|
||||
script: arcface_decoder
|
||||
class_name: Decoder
|
||||
module_params:
|
||||
diffaug: False
|
||||
interp224: False
|
||||
backbone_kwargs: {}
|
||||
|
||||
|
||||
arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar
|
||||
|
||||
# Training information
|
||||
batch_size: 64
|
||||
|
||||
# Dataset
|
||||
dataloader: VGGFace2HQ_Rec
|
||||
dataset_name: vggface2_hq
|
||||
dataset_params:
|
||||
random_seed: 1234
|
||||
dataloader_workers: 8
|
||||
|
||||
# Dataset
|
||||
|
||||
# Optimizer
|
||||
optim_type: Adam
|
||||
g_optim_config:
|
||||
lr: 0.0008
|
||||
betas: [ 0, 0.99]
|
||||
eps: !!float 1e-8
|
||||
|
||||
# Log
|
||||
log_step: 200
|
||||
model_save_step: 2000
|
||||
total_step: 100000
|
||||
sample_step: 500
|
||||
checkpoint_names:
|
||||
generator_name: Decoder
|
||||
@@ -0,0 +1,63 @@
|
||||
# Related scripts
|
||||
train_script_name: multi_gpu
|
||||
|
||||
# models' scripts
|
||||
model_configs:
|
||||
g_model:
|
||||
script: Generator_ori
|
||||
class_name: Generator
|
||||
module_params:
|
||||
g_conv_dim: 512
|
||||
g_kernel_size: 3
|
||||
res_num: 9
|
||||
|
||||
d_model:
|
||||
script: projected_discriminator
|
||||
class_name: ProjectedDiscriminator
|
||||
module_params:
|
||||
diffaug: False
|
||||
interp224: False
|
||||
backbone_kwargs: {}
|
||||
|
||||
arcface_ckpt: arcface_ckpt/arcface_checkpoint.tar
|
||||
|
||||
# Training information
|
||||
batch_size: 24
|
||||
|
||||
# Dataset
|
||||
dataloader: VGGFace2HQ_multigpu
|
||||
dataset_name: vggface2_hq
|
||||
dataset_params:
|
||||
random_seed: 1234
|
||||
dataloader_workers: 8
|
||||
|
||||
eval_dataloader: DIV2K_hdf5
|
||||
eval_dataset_name: DF2K_H5_Eval
|
||||
eval_batch_size: 2
|
||||
|
||||
# Dataset
|
||||
|
||||
# Optimizer
|
||||
optim_type: Adam
|
||||
g_optim_config:
|
||||
lr: 0.0004
|
||||
betas: [ 0, 0.99]
|
||||
eps: !!float 1e-8
|
||||
|
||||
d_optim_config:
|
||||
lr: 0.0004
|
||||
betas: [ 0, 0.99]
|
||||
eps: !!float 1e-8
|
||||
|
||||
id_weight: 20.0
|
||||
reconstruct_weight: 10.0
|
||||
feature_match_weight: 10.0
|
||||
|
||||
# Log
|
||||
log_step: 300
|
||||
model_save_step: 10000
|
||||
sample_step: 1000
|
||||
total_step: 1000000
|
||||
checkpoint_names:
|
||||
generator_name: Generator
|
||||
discriminator_name: Discriminator
|
||||
Reference in New Issue
Block a user