mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Migrate most to self.config and self.context
This commit is contained in:
@@ -53,31 +53,37 @@ class AdversarialLoss(nn.Module):
|
||||
|
||||
|
||||
class AttributeLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'batch_size': config_parser.getint('training.loader', 'batch_size'),
|
||||
'attribute_weight': config_parser.getfloat('training.losses', 'attribute_weight')
|
||||
}
|
||||
|
||||
def forward(self, target_attributes : Attributes, output_attributes : Attributes) -> Tuple[Tensor, Tensor]:
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
attribute_weight = CONFIG.getfloat('training.losses', 'attribute_weight')
|
||||
temp_tensors = []
|
||||
|
||||
for target_attribute, output_attribute in zip(target_attributes, output_attributes):
|
||||
temp_tensor = torch.mean(torch.pow(output_attribute - target_attribute, 2).reshape(batch_size, -1), dim = 1).mean()
|
||||
temp_tensor = torch.mean(torch.pow(output_attribute - target_attribute, 2).reshape(self.config.get('batch_size'), -1), dim = 1).mean()
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
attribute_loss = torch.stack(temp_tensors).mean() * 0.5
|
||||
weighted_attribute_loss = attribute_loss * attribute_weight
|
||||
weighted_attribute_loss = attribute_loss * self.config.get('attribute_weight')
|
||||
return attribute_loss, weighted_attribute_loss
|
||||
|
||||
|
||||
class ReconstructionLoss(nn.Module):
|
||||
def __init__(self, embedder : EmbedderModule) -> None:
|
||||
def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None:
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'reconstruction_weight': config_parser.getfloat('training.losses', 'reconstruction_weight')
|
||||
}
|
||||
self.embedder = embedder
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
reconstruction_weight = CONFIG.getfloat('training.losses', 'reconstruction_weight')
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calc_embedding(self.embedder, target_tensor, (0, 0, 0, 0))
|
||||
has_similar_identity = torch.cosine_similarity(source_embedding, target_embedding) > 0.8
|
||||
@@ -88,27 +94,35 @@ class ReconstructionLoss(nn.Module):
|
||||
data_range = float(torch.max(output_tensor) - torch.min(output_tensor))
|
||||
visual_loss = 1 - ssim(output_tensor, target_tensor, data_range = data_range).mean()
|
||||
reconstruction_loss = (reconstruction_loss + visual_loss) * 0.5
|
||||
weighted_reconstruction_loss = reconstruction_loss * reconstruction_weight
|
||||
weighted_reconstruction_loss = reconstruction_loss * self.config.get('reconstruction_weight')
|
||||
return reconstruction_loss, weighted_reconstruction_loss
|
||||
|
||||
|
||||
class IdentityLoss(nn.Module):
|
||||
def __init__(self, embedder : EmbedderModule) -> None:
|
||||
def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None:
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'identity_weight': config_parser.getfloat('training.losses', 'identity_weight')
|
||||
}
|
||||
self.embedder = embedder
|
||||
|
||||
def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
identity_weight = CONFIG.getfloat('training.losses', 'identity_weight')
|
||||
output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
|
||||
identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
|
||||
weighted_identity_loss = identity_loss * identity_weight
|
||||
weighted_identity_loss = identity_loss * self.config.get('identity_weight')
|
||||
return identity_loss, weighted_identity_loss
|
||||
|
||||
|
||||
class MotionLoss(nn.Module):
|
||||
def __init__(self, motion_extractor : MotionExtractorModule):
|
||||
def __init__(self, config_parser : ConfigParser, motion_extractor : MotionExtractorModule):
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'pose_weight': config_parser.getfloat('training.losses', 'pose_weight'),
|
||||
'expression_weight': config_parser.getfloat('training.losses', 'expression_weight')
|
||||
}
|
||||
self.motion_extractor = motion_extractor
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
@@ -120,7 +134,6 @@ class MotionLoss(nn.Module):
|
||||
return pose_loss, weighted_pose_loss, expression_loss, weighted_expression_loss
|
||||
|
||||
def calc_pose_loss(self, target_poses : Tuple[Tensor, ...], output_poses : Tuple[Tensor, ...]) -> Tuple[Tensor, Tensor]:
|
||||
pose_weight = CONFIG.getfloat('training.losses', 'pose_weight')
|
||||
temp_tensors = []
|
||||
|
||||
for target_pose, output_pose in zip(target_poses, output_poses):
|
||||
@@ -128,13 +141,12 @@ class MotionLoss(nn.Module):
|
||||
temp_tensors.append(temp_tensor)
|
||||
|
||||
pose_loss = torch.stack(temp_tensors).mean()
|
||||
weighted_pose_loss = pose_loss * pose_weight
|
||||
weighted_pose_loss = pose_loss * self.config.get('pose_weight')
|
||||
return pose_loss, weighted_pose_loss
|
||||
|
||||
def calc_expression_loss(self, target_expression : Tensor, output_expression : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
expression_weight = CONFIG.getfloat('training.losses', 'expression_weight')
|
||||
expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean()
|
||||
weighted_expression_loss = expression_loss * expression_weight
|
||||
weighted_expression_loss = expression_loss * self.config.get('expression_weight')
|
||||
return expression_loss, weighted_expression_loss
|
||||
|
||||
def get_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]:
|
||||
@@ -148,13 +160,17 @@ class MotionLoss(nn.Module):
|
||||
|
||||
|
||||
class GazeLoss(nn.Module):
|
||||
def __init__(self, gazer : GazerModule) -> None:
|
||||
def __init__(self, config_parser : ConfigParser, gazer : GazerModule) -> None:
|
||||
super().__init__()
|
||||
self.config =\
|
||||
{
|
||||
'gaze_weight': config_parser.getfloat('training.losses', 'gaze_weight'),
|
||||
'output_size': config_parser.getint('training.model.generator', 'output_size')
|
||||
}
|
||||
self.gazer = gazer
|
||||
self.l1_loss = nn.L1Loss()
|
||||
|
||||
def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight')
|
||||
output_pitch, output_yaw = self.detect_gaze(output_tensor)
|
||||
target_pitch, target_yaw = self.detect_gaze(target_tensor)
|
||||
|
||||
@@ -162,12 +178,11 @@ class GazeLoss(nn.Module):
|
||||
yaw_loss = self.l1_loss(output_yaw, target_yaw)
|
||||
|
||||
gaze_loss = (pitch_loss + yaw_loss) * 0.5
|
||||
weighted_gaze_loss = gaze_loss * gaze_weight
|
||||
weighted_gaze_loss = gaze_loss * self.config.get('gaze_weight')
|
||||
return gaze_loss, weighted_gaze_loss
|
||||
|
||||
def detect_gaze(self, input_tensor : Tensor) -> Gaze:
|
||||
output_size = CONFIG.getint('training.model.generator', 'output_size')
|
||||
crop_sizes = (torch.tensor([ 0.235, 0.875, 0.0625, 0.8 ]) * output_size).int()
|
||||
crop_sizes = (torch.tensor([ 0.235, 0.875, 0.0625, 0.8 ]) * self.config.get('output_size')).int()
|
||||
crop_tensor = input_tensor[:, :, crop_sizes[0]:crop_sizes[1], crop_sizes[2]:crop_sizes[3]]
|
||||
crop_tensor = (crop_tensor + 1) * 0.5
|
||||
crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from configparser import ConfigParser
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
@@ -5,51 +7,55 @@ from ..types import Attributes, Embedding
|
||||
|
||||
|
||||
class AAD(nn.Module):
|
||||
def __init__(self, identity_channels : int, output_channels : int, output_size : int, num_blocks : int) -> None:
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
super().__init__()
|
||||
self.identity_channels = identity_channels
|
||||
self.output_channels = output_channels
|
||||
self.output_size = output_size
|
||||
self.num_blocks = num_blocks
|
||||
self.pixel_shuffle_up_sample = PixelShuffleUpSample(identity_channels, output_channels)
|
||||
self.config =\
|
||||
{
|
||||
'identity_channels': config_parser.getint('training.model.generator', 'identity_channels'),
|
||||
'output_channels': config_parser.getint('training.model.generator', 'output_channels'),
|
||||
'output_size': config_parser.getint('training.model.generator', 'output_size'),
|
||||
'num_blocks': config_parser.getint('training.model.generator', 'num_blocks')
|
||||
}
|
||||
self.config_parser = config_parser
|
||||
self.pixel_shuffle_up_sample = PixelShuffleUpSample(self.config.get('identity_channels'), self.config.get('output_channels'))
|
||||
self.layers = self.create_layers()
|
||||
|
||||
def create_layers(self) -> nn.ModuleList:
|
||||
layers = nn.ModuleList()
|
||||
|
||||
if self.output_size == 128:
|
||||
if self.config.get('output_size') == 128:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(512, 512, 512, self.identity_channels, self.num_blocks),
|
||||
AdaptiveFeatureModulation(512, 512, 1024, self.identity_channels, self.num_blocks),
|
||||
AdaptiveFeatureModulation(512, 512, 512, self.identity_channels, self.num_blocks),
|
||||
AdaptiveFeatureModulation(512, 512, 512, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(512, 512, 1024, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(512, 512, 512, self.config.get('identity_channels'), self.config.get('num_blocks'))
|
||||
])
|
||||
|
||||
if self.output_size == 256:
|
||||
if self.config.get('output_size') == 256:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(1024, 1024, 1024, self.identity_channels, self.num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 2048, self.identity_channels, self.num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 1024, 1024, self.identity_channels, self.num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.identity_channels, self.num_blocks)
|
||||
AdaptiveFeatureModulation(1024, 1024, 1024, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(1024, 1024, 2048, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(1024, 1024, 1024, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config.get('identity_channels'), self.config.get('num_blocks'))
|
||||
])
|
||||
|
||||
if self.output_size == 512:
|
||||
if self.config.get('output_size') == 512:
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(2048, 2048, 2048, self.identity_channels, self.num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 2048, 4096, self.identity_channels, self.num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 2048, 2048, self.identity_channels, self.num_blocks),
|
||||
AdaptiveFeatureModulation(2048, 1024, 1024, self.identity_channels, self.num_blocks),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.identity_channels, self.num_blocks)
|
||||
AdaptiveFeatureModulation(2048, 2048, 2048, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(2048, 2048, 4096, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(2048, 2048, 2048, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(2048, 1024, 1024, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(1024, 512, 512, self.config.get('identity_channels'), self.config.get('num_blocks'))
|
||||
])
|
||||
|
||||
layers.extend(
|
||||
[
|
||||
AdaptiveFeatureModulation(512, 256, 256, self.identity_channels, self.num_blocks),
|
||||
AdaptiveFeatureModulation(256, 128, 128, self.identity_channels, self.num_blocks),
|
||||
AdaptiveFeatureModulation(128, 64, 64, self.identity_channels, self.num_blocks),
|
||||
AdaptiveFeatureModulation(64, 3, 64, self.identity_channels, self.num_blocks)
|
||||
AdaptiveFeatureModulation(512, 256, 256, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(256, 128, 128, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(128, 64, 64, self.config.get('identity_channels'), self.config.get('num_blocks')),
|
||||
AdaptiveFeatureModulation(64, 3, 64, self.config.get('identity_channels'), self.config.get('num_blocks'))
|
||||
])
|
||||
|
||||
return layers
|
||||
@@ -69,40 +75,43 @@ class AAD(nn.Module):
|
||||
class AdaptiveFeatureModulation(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int, attribute_channels : int, identity_channels : int, num_blocks : int) -> None:
|
||||
super().__init__()
|
||||
self.input_channels = input_channels
|
||||
self.output_channels = output_channels
|
||||
self.attribute_channels = attribute_channels
|
||||
self.identity_channels = identity_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.context =\
|
||||
{
|
||||
'input_channels': input_channels,
|
||||
'output_channels': output_channels,
|
||||
'attribute_channels': attribute_channels,
|
||||
'identity_channels': identity_channels,
|
||||
'num_blocks': num_blocks
|
||||
}
|
||||
self.primary_layers = self.create_primary_layers()
|
||||
self.shortcut_layers = self.create_shortcut_layers()
|
||||
|
||||
def create_primary_layers(self) -> nn.ModuleList:
|
||||
primary_layers = nn.ModuleList()
|
||||
|
||||
for index in range(self.num_blocks):
|
||||
for index in range(self.context.get('num_blocks')):
|
||||
primary_layers.extend(
|
||||
[
|
||||
FeatureModulation(self.input_channels, self.attribute_channels, self.identity_channels),
|
||||
FeatureModulation(self.context.get('input_channels'), self.context.get('attribute_channels'), self.context.get('identity_channels')),
|
||||
nn.ReLU(inplace = True)
|
||||
])
|
||||
|
||||
if index < self.num_blocks - 1:
|
||||
primary_layers.append(nn.Conv2d(self.input_channels, self.input_channels, kernel_size = 3, padding = 1, bias = False))
|
||||
if index < self.context.get('num_blocks') - 1:
|
||||
primary_layers.append(nn.Conv2d(self.context.get('input_channels'), self.context.get('input_channels'), kernel_size = 3, padding = 1, bias = False))
|
||||
else:
|
||||
primary_layers.append(nn.Conv2d(self.input_channels, self.output_channels, kernel_size = 3, padding = 1, bias = False))
|
||||
primary_layers.append(nn.Conv2d(self.context.get('input_channels'), self.context.get('output_channels'), kernel_size = 3, padding = 1, bias = False))
|
||||
|
||||
return primary_layers
|
||||
|
||||
def create_shortcut_layers(self) -> nn.ModuleList:
|
||||
shortcut_layers = nn.ModuleList()
|
||||
|
||||
if self.input_channels > self.output_channels:
|
||||
if self.context.get('input_channels') > self.context.get('output_channels'):
|
||||
shortcut_layers.extend(
|
||||
[
|
||||
FeatureModulation(self.input_channels, self.attribute_channels, self.identity_channels),
|
||||
FeatureModulation(self.context.get('input_channels'), self.context.get('attribute_channels'), self.context.get('identity_channels')),
|
||||
nn.ReLU(inplace = True),
|
||||
nn.Conv2d(self.input_channels, self.output_channels, kernel_size = 3, padding = 1, bias = False)
|
||||
nn.Conv2d(self.context.get('input_channels'), self.context.get('output_channels'), kernel_size = 3, padding = 1, bias = False)
|
||||
])
|
||||
|
||||
return shortcut_layers
|
||||
@@ -116,7 +125,7 @@ class AdaptiveFeatureModulation(nn.Module):
|
||||
else:
|
||||
primary_tensor = primary_layer(primary_tensor)
|
||||
|
||||
if self.input_channels > self.output_channels:
|
||||
if self.context.get('input_channels') > self.context.get('output_channels'):
|
||||
shortcut_tensor = input_tensor
|
||||
|
||||
for shortcut_layer in self.shortcut_layers:
|
||||
@@ -133,7 +142,10 @@ class AdaptiveFeatureModulation(nn.Module):
|
||||
class FeatureModulation(nn.Module):
|
||||
def __init__(self, input_channels : int, attribute_channels : int, identity_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.input_channels = input_channels
|
||||
self.context =\
|
||||
{
|
||||
'input_channels': input_channels
|
||||
}
|
||||
self.conv1 = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
|
||||
self.conv2 = nn.Conv2d(attribute_channels, input_channels, kernel_size = 1)
|
||||
self.conv3 = nn.Conv2d(input_channels, 1, kernel_size = 1)
|
||||
@@ -148,8 +160,8 @@ class FeatureModulation(nn.Module):
|
||||
attribute_shift = self.conv2(attribute_embedding)
|
||||
attribute_modulation = attribute_scale * temp_tensor + attribute_shift
|
||||
|
||||
identity_scale = self.linear2(identity_embedding).reshape(temp_tensor.shape[0], self.input_channels, 1, 1).expand_as(temp_tensor)
|
||||
identity_shift = self.linear1(identity_embedding).reshape(temp_tensor.shape[0], self.input_channels, 1, 1).expand_as(temp_tensor)
|
||||
identity_scale = self.linear2(identity_embedding).reshape(temp_tensor.shape[0], self.context.get('input_channels'), 1, 1).expand_as(temp_tensor)
|
||||
identity_shift = self.linear1(identity_embedding).reshape(temp_tensor.shape[0], self.context.get('input_channels'), 1, 1).expand_as(temp_tensor)
|
||||
identity_modulation = identity_scale * temp_tensor + identity_shift
|
||||
|
||||
temp_mask = torch.sigmoid(self.conv3(temp_tensor))
|
||||
|
||||
@@ -106,7 +106,7 @@ class UNet(nn.Module):
|
||||
class UpSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.conv_transpose = nn.ConvTranspose2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.conv_transpose = nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.batch_norm = nn.BatchNorm2d(output_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
|
||||
@@ -121,7 +121,7 @@ class UpSample(nn.Module):
|
||||
class DownSample(nn.Module):
|
||||
def __init__(self, input_channels : int, output_channels : int) -> None:
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.conv = nn.Conv2d(input_channels, output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False)
|
||||
self.batch_norm = nn.BatchNorm2d(output_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
|
||||
|
||||
@@ -39,10 +39,9 @@ class FaceSwapperTrainer(LightningModule):
|
||||
self.embedder = torch.jit.load(self.config.get('embedder_path'), map_location = 'cpu').eval() # type:ignore[no-untyped-call]
|
||||
self.gazer = torch.jit.load(self.config.get('gazer_path'), map_location = 'cpu').eval() # type:ignore[no-untyped-call]
|
||||
self.motion_extractor = torch.jit.load(self.config.get('motion_extractor_path'), map_location = 'cpu').eval() # type:ignore[no-untyped-call]
|
||||
|
||||
self.generator = Generator(config_parser)
|
||||
self.discriminator = Discriminator(config_parser)
|
||||
self.discriminator_loss = DiscriminatorLoss(config_parser)
|
||||
self.discriminator_loss = DiscriminatorLoss()
|
||||
self.adversarial_loss = AdversarialLoss(config_parser)
|
||||
self.attribute_loss = AttributeLoss(config_parser)
|
||||
self.reconstruction_loss = ReconstructionLoss(config_parser, self.embedder)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from configparser import ConfigParser
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -7,18 +9,17 @@ from face_swapper.src.networks.unet import UNet
|
||||
|
||||
@pytest.mark.parametrize('output_size', [ 128, 256, 512 ])
|
||||
def test_aad_with_unet(output_size : int) -> None:
|
||||
identity_channels = 512
|
||||
output_channels = 1024
|
||||
if output_size == 128:
|
||||
output_channels = 2048
|
||||
if output_size == 256:
|
||||
output_channels = 4096
|
||||
if output_size == 512:
|
||||
output_channels = 8192
|
||||
num_blocks = 2
|
||||
config_parser = ConfigParser()
|
||||
config_parser['training.model.generator'] =\
|
||||
{
|
||||
'identity_channels': '512',
|
||||
'output_channels': str(output_size * 16),
|
||||
'output_size': str(output_size),
|
||||
'num_blocks': '2'
|
||||
}
|
||||
|
||||
generator = AAD(identity_channels, output_channels, output_size, num_blocks).eval()
|
||||
encoder = UNet(output_size).eval()
|
||||
generator = AAD(config_parser).eval()
|
||||
encoder = UNet(config_parser).eval()
|
||||
|
||||
source_tensor = torch.randn(1, 512)
|
||||
target_tensor = torch.randn(1, 3, output_size, output_size)
|
||||
|
||||
Reference in New Issue
Block a user