mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Use new loss code, Remove unused code, Remove old types, Ban VisionTensor naming
This commit is contained in:
@@ -2,19 +2,19 @@ import numpy
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .types import EmbedderModule, Embedding, Padding, VisionFrame, VisionTensor
|
||||
from .types import EmbedderModule, Embedding, Padding, VisionFrame
|
||||
|
||||
|
||||
def convert_to_vision_tensor(vision_frame : VisionFrame) -> VisionTensor:
|
||||
vision_tensor = torch.from_numpy(vision_frame[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32))
|
||||
vision_tensor = vision_tensor / 255.0
|
||||
vision_tensor = (vision_tensor - 0.5) * 2
|
||||
vision_tensor = vision_tensor.unsqueeze(0)
|
||||
return vision_tensor
|
||||
def convert_to_tensor(vision_frame : VisionFrame) -> Tensor:
|
||||
output_tensor = torch.from_numpy(vision_frame[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32))
|
||||
output_tensor = output_tensor / 255.0
|
||||
output_tensor = (output_tensor - 0.5) * 2
|
||||
output_tensor = output_tensor.unsqueeze(0)
|
||||
return output_tensor
|
||||
|
||||
|
||||
def convert_to_vision_frame(vision_tensor : VisionTensor) -> VisionFrame:
|
||||
vision_frame = vision_tensor.detach().cpu().numpy()[0]
|
||||
def convert_to_vision_frame(input_tensor : Tensor) -> VisionFrame:
|
||||
vision_frame = input_tensor.detach().cpu().numpy()[0]
|
||||
vision_frame = vision_frame.transpose(1, 2, 0)
|
||||
vision_frame = (vision_frame + 1) * 127.5
|
||||
vision_frame = vision_frame.clip(0, 255).astype(numpy.uint8)
|
||||
|
||||
@@ -3,7 +3,7 @@ import configparser
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from .helper import calc_embedding, convert_to_vision_frame, convert_to_vision_tensor
|
||||
from .helper import calc_embedding, convert_to_vision_frame, convert_to_tensor
|
||||
from .models.generator import Generator
|
||||
from .types import EmbedderModule, GeneratorModule, VisionFrame
|
||||
|
||||
@@ -12,11 +12,11 @@ CONFIG.read('config.ini')
|
||||
|
||||
|
||||
def run_swap(generator : GeneratorModule, embedder : EmbedderModule, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
|
||||
source_vision_tensor = convert_to_vision_tensor(source_vision_frame)
|
||||
target_vision_tensor = convert_to_vision_tensor(target_vision_frame)
|
||||
source_embedding = calc_embedding(embedder, source_vision_tensor, (0, 0, 0, 0))
|
||||
output_vision_tensor = generator(source_embedding, target_vision_tensor)[0]
|
||||
output_vision_frame = convert_to_vision_frame(output_vision_tensor)
|
||||
source_tensor = convert_to_tensor(source_vision_frame)
|
||||
target_tensor = convert_to_tensor(target_vision_frame)
|
||||
source_embedding = calc_embedding(embedder, source_tensor, (0, 0, 0, 0))
|
||||
output_tensor = generator(source_embedding, target_tensor)[0]
|
||||
output_vision_frame = convert_to_vision_frame(output_tensor)
|
||||
return output_vision_frame
|
||||
|
||||
|
||||
|
||||
@@ -6,153 +6,12 @@ from pytorch_msssim import ssim
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..helper import calc_embedding
|
||||
from ..types import Attributes, Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
|
||||
from ..types import Attributes, FaceLandmark203
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
def hinge_real_loss(input_tensor : Tensor) -> Tensor:
|
||||
real_loss = torch.relu(1 - input_tensor)
|
||||
real_loss = real_loss.mean(dim = [ 1, 2, 3 ])
|
||||
return real_loss
|
||||
|
||||
|
||||
def hinge_fake_loss(input_tensor : Tensor) -> Tensor:
|
||||
fake_loss = torch.relu(input_tensor + 1)
|
||||
fake_loss = fake_loss.mean(dim = [ 1, 2, 3 ])
|
||||
return fake_loss
|
||||
|
||||
|
||||
class FaceSwapperLoss:
|
||||
def __init__(self) -> None:
|
||||
embedder_path = CONFIG.get('training.model', 'embedder_path')
|
||||
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
|
||||
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
|
||||
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
self.mse_loss = nn.MSELoss()
|
||||
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
|
||||
def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet:
|
||||
weight_adversarial = CONFIG.getfloat('training.losses', 'adversarial_weight')
|
||||
weight_identity = CONFIG.getfloat('training.losses', 'identity_weight')
|
||||
weight_attribute = CONFIG.getfloat('training.losses', 'attribute_weight')
|
||||
weight_reconstruction = CONFIG.getfloat('training.losses', 'reconstruction_weight')
|
||||
weight_pose = CONFIG.getfloat('training.losses', 'pose_weight')
|
||||
weight_gaze = CONFIG.getfloat('training.losses', 'gaze_weight')
|
||||
source_tensor, target_tensor = batch
|
||||
is_same_person = torch.tensor(0) if torch.equal(source_tensor, target_tensor) else torch.tensor(1)
|
||||
generator_loss_set =\
|
||||
{
|
||||
'loss_adversarial': self.calc_adversarial_loss(discriminator_outputs),
|
||||
'loss_identity': self.calc_identity_loss(source_tensor, swap_tensor),
|
||||
'loss_attribute': self.calc_attribute_loss(target_attributes, swap_attributes),
|
||||
'loss_reconstruction': self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person)
|
||||
}
|
||||
|
||||
generator_loss_set['loss_pose'] = self.calc_pose_loss(swap_tensor, target_tensor)
|
||||
generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor)
|
||||
|
||||
generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_identity') * weight_identity
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_attribute') * weight_attribute
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_reconstruction') * weight_reconstruction
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_pose') * weight_pose
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_gaze') * weight_gaze
|
||||
return generator_loss_set
|
||||
|
||||
def hinge_real_loss(input_tensor: Tensor) -> Tensor:
|
||||
real_loss = torch.relu(1 - input_tensor)
|
||||
real_loss = real_loss.mean(dim = [1, 2, 3])
|
||||
return real_loss
|
||||
|
||||
def hinge_fake_loss(input_tensor: Tensor) -> Tensor:
|
||||
fake_loss = torch.relu(input_tensor + 1)
|
||||
fake_loss = fake_loss.mean(dim = [1, 2, 3])
|
||||
return fake_loss
|
||||
|
||||
def calc_discriminator_loss(self, real_discriminator_outputs : DiscriminatorOutputs, fake_discriminator_outputs : DiscriminatorOutputs) -> DiscriminatorLossSet:
|
||||
discriminator_loss_set = {}
|
||||
loss_fakes = []
|
||||
|
||||
for fake_discriminator_output in fake_discriminator_outputs:
|
||||
loss_fakes.append(hinge_fake_loss(fake_discriminator_output[0]))
|
||||
|
||||
loss_trues = []
|
||||
|
||||
for true_discriminator_output in real_discriminator_outputs:
|
||||
loss_trues.append(hinge_real_loss(true_discriminator_output[0]))
|
||||
|
||||
loss_fake = torch.stack(loss_fakes).mean()
|
||||
loss_true = torch.stack(loss_trues).mean()
|
||||
discriminator_loss_set['loss_discriminator'] = (loss_true + loss_fake) * 0.5
|
||||
return discriminator_loss_set
|
||||
|
||||
def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor:
|
||||
loss_adversarials = []
|
||||
|
||||
for discriminator_output in discriminator_outputs:
|
||||
loss_adversarials.append(hinge_real_loss(discriminator_output[0]).mean())
|
||||
|
||||
loss_adversarial = torch.stack(loss_adversarials).mean()
|
||||
return loss_adversarial
|
||||
|
||||
def calc_attribute_loss(self, target_attributes : TargetAttributes, swap_attributes : SwapAttributes) -> LossTensor:
|
||||
loss_attributes = []
|
||||
|
||||
for swap_attribute, target_attribute in zip(swap_attributes, target_attributes):
|
||||
loss_attributes.append(torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean())
|
||||
|
||||
loss_attribute = torch.stack(loss_attributes).mean() * 0.5
|
||||
return loss_attribute
|
||||
|
||||
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor:
|
||||
loss_reconstruction = torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1)
|
||||
loss_reconstruction = torch.mean(loss_reconstruction, dim = 1) * 0.5
|
||||
loss_reconstruction = torch.sum(loss_reconstruction * is_same_person) / (is_same_person.sum() + 1e-4)
|
||||
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
|
||||
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
|
||||
return loss_reconstruction
|
||||
|
||||
def calc_identity_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
|
||||
swap_embedding = calc_embedding(self.embedder, swap_tensor, (30, 0, 10, 10))
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
|
||||
loss_identity = (1 - torch.cosine_similarity(source_embedding, swap_embedding)).mean()
|
||||
return loss_identity
|
||||
|
||||
def calc_pose_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
|
||||
swap_motion_features = self.get_pose_features(swap_tensor)
|
||||
target_motion_features = self.get_pose_features(target_tensor)
|
||||
loss_pose = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
|
||||
|
||||
for swap_motion_feature, target_motion_feature in zip(swap_motion_features, target_motion_features):
|
||||
loss_pose += self.mse_loss(swap_motion_feature, target_motion_feature)
|
||||
|
||||
return loss_pose
|
||||
|
||||
def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
|
||||
swap_landmark = self.get_face_landmarks(swap_tensor)
|
||||
target_landmark = self.get_face_landmarks(target_tensor)
|
||||
left_gaze_loss = self.mse_loss(swap_landmark[:, 198], target_landmark[:, 198])
|
||||
right_gaze_loss = self.mse_loss(swap_landmark[:, 197], target_landmark[:, 197])
|
||||
gaze_loss = left_gaze_loss + right_gaze_loss
|
||||
return gaze_loss
|
||||
|
||||
def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203:
|
||||
vision_tensor_norm = (vision_tensor + 1) * 0.5
|
||||
vision_tensor_norm = nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear')
|
||||
landmarks = self.landmarker(vision_tensor_norm)[2].view(-1, 203, 2)
|
||||
return landmarks
|
||||
|
||||
def get_pose_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
vision_tensor_norm = (vision_tensor + 1) * 0.5
|
||||
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm)
|
||||
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
|
||||
return translation, scale, rotation
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@@ -270,8 +129,8 @@ class PoseLoss(nn.Module):
|
||||
return pose_loss, weighted_pose_loss
|
||||
|
||||
def get_motion_features(self, input_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
vision_tensor_norm = (input_tensor + 1) * 0.5
|
||||
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm)
|
||||
input_tensor = (input_tensor + 1) * 0.5
|
||||
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(input_tensor)
|
||||
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
|
||||
return translation, scale, rotation
|
||||
|
||||
@@ -283,7 +142,7 @@ class GazeLoss(nn.Module):
|
||||
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def calc(self, target_tensor : VisionTensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
|
||||
def calc(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
|
||||
gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight')
|
||||
output_face_landmark = self.detect_face_landmark(output_tensor)
|
||||
target_face_landmark = self.detect_face_landmark(target_tensor)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..types import Embedding, TargetAttributes
|
||||
from ..types import Attributes, Embedding
|
||||
|
||||
|
||||
class AADGenerator(nn.Module):
|
||||
@@ -17,7 +17,7 @@ class AADGenerator(nn.Module):
|
||||
self.res_block_7 = AADResBlock(128, 64, 64, id_channels, num_blocks)
|
||||
self.res_block_8 = AADResBlock(64, 3, 64, id_channels, num_blocks)
|
||||
|
||||
def forward(self, target_attributes : TargetAttributes, source_embedding : Embedding) -> Tensor:
|
||||
def forward(self, target_attributes : Attributes, source_embedding : Embedding) -> Tensor:
|
||||
feature_map = self.upsample(source_embedding)
|
||||
feature_map_1 = nn.functional.interpolate(self.res_block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_2 = nn.functional.interpolate(self.res_block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
@@ -59,10 +59,10 @@ class AADSequential(nn.Module):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(args)
|
||||
|
||||
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, id_embedding : Embedding) -> Tensor:
|
||||
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
|
||||
for layer in self.layers:
|
||||
if isinstance(layer, AADLayer):
|
||||
feature_map = layer(feature_map, attribute_embedding, id_embedding)
|
||||
feature_map = layer(feature_map, attribute_embedding, identity_embedding)
|
||||
else:
|
||||
feature_map = layer(feature_map)
|
||||
return feature_map
|
||||
|
||||
@@ -16,18 +16,18 @@ from .dataset import DynamicDataset
|
||||
from .helper import calc_embedding
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, FaceSwapperLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss
|
||||
from .types import Batch, Embedding, VisionTensor
|
||||
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss
|
||||
from .types import Batch, Embedding
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
class FaceSwapperTrainer(lightning.LightningModule):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
FaceSwapperLoss.__init__(self)
|
||||
automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization')
|
||||
embedder_path = CONFIG.get('training.model', 'embedder_path')
|
||||
|
||||
self.generator = Generator()
|
||||
self.discriminator = Discriminator()
|
||||
@@ -38,9 +38,10 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
self.identity_loss = IdentityLoss()
|
||||
self.pose_loss = PoseLoss()
|
||||
self.gaze_loss = GazeLoss()
|
||||
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.automatic_optimization = automatic_optimization
|
||||
|
||||
def forward(self, target_tensor : VisionTensor, source_embedding : Embedding) -> Tensor:
|
||||
def forward(self, target_tensor : Tensor, source_embedding : Embedding) -> Tensor:
|
||||
output_tensor = self.generator(source_embedding, target_tensor)
|
||||
return output_tensor
|
||||
|
||||
@@ -61,34 +62,6 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
generator_output_attributes = self.generator.get_attributes(generator_output_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor)
|
||||
|
||||
generator_loss_set = self.calc_generator_loss(generator_output_tensor, target_attributes, generator_output_attributes, discriminator_output_tensors, batch)
|
||||
generator_optimizer.zero_grad()
|
||||
self.manual_backward(generator_loss_set.get('loss_generator'))
|
||||
generator_optimizer.step()
|
||||
|
||||
discriminator_source_tensors = self.discriminator(source_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
|
||||
|
||||
discriminator_loss_set = self.calc_discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
|
||||
discriminator_optimizer.zero_grad()
|
||||
self.manual_backward(discriminator_loss_set.get('loss_discriminator'))
|
||||
discriminator_optimizer.step()
|
||||
|
||||
if self.global_step % preview_frequency == 0:
|
||||
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)
|
||||
|
||||
self.log('loss_generator', generator_loss_set.get('loss_generator'), prog_bar = True)
|
||||
self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator'), prog_bar = True)
|
||||
self.log('loss_adversarial', generator_loss_set.get('loss_adversarial'))
|
||||
self.log('loss_attribute', generator_loss_set.get('loss_attribute'))
|
||||
self.log('loss_identity', generator_loss_set.get('loss_identity'))
|
||||
self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction'))
|
||||
self.log('loss_pose', generator_loss_set.get('loss_pose'))
|
||||
self.log('loss_gaze', generator_loss_set.get('loss_gaze'))
|
||||
|
||||
###############################################
|
||||
|
||||
discriminator_loss = self.discriminator_loss.calc(discriminator_source_tensors, discriminator_output_tensors)
|
||||
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss.calc(discriminator_output_tensors)
|
||||
attribute_loss, weighted_attribute_loss = self.attribute_loss.calc(target_attributes, generator_output_attributes)
|
||||
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor)
|
||||
@@ -97,15 +70,30 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
gaze_loss, weighted_gaze_loss = self.gaze_loss.calc(target_tensor, generator_output_tensor)
|
||||
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss
|
||||
|
||||
self.log('generator_loss_new', generator_loss, prog_bar = True)
|
||||
self.log('discriminator_loss_new', discriminator_loss, prog_bar = True)
|
||||
self.log('adversarial_loss_new', adversarial_loss)
|
||||
self.log('attribute_loss_new', attribute_loss)
|
||||
self.log('reconstruction_loss_new', reconstruction_loss)
|
||||
self.log('identity_loss_new', identity_loss)
|
||||
self.log('pose_loss_new', pose_loss)
|
||||
self.log('gaze_loss_new', gaze_loss)
|
||||
return generator_loss_set.get('loss_generator')
|
||||
generator_optimizer.zero_grad()
|
||||
self.manual_backward(generator_loss)
|
||||
generator_optimizer.step()
|
||||
|
||||
discriminator_source_tensors = self.discriminator(source_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
|
||||
discriminator_loss = self.discriminator_loss.calc(discriminator_source_tensors, discriminator_output_tensors)
|
||||
|
||||
discriminator_optimizer.zero_grad()
|
||||
self.manual_backward(discriminator_loss)
|
||||
discriminator_optimizer.step()
|
||||
|
||||
if self.global_step % preview_frequency == 0:
|
||||
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)
|
||||
|
||||
self.log('generator_loss', generator_loss, prog_bar = True)
|
||||
self.log('discriminator_loss', discriminator_loss, prog_bar = True)
|
||||
self.log('adversarial_loss', adversarial_loss)
|
||||
self.log('attribute_loss', attribute_loss)
|
||||
self.log('reconstruction_loss', reconstruction_loss)
|
||||
self.log('identity_loss', identity_loss)
|
||||
self.log('pose_loss', pose_loss)
|
||||
self.log('gaze_loss', gaze_loss)
|
||||
return generator_loss
|
||||
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_tensor, target_tensor = batch
|
||||
@@ -116,7 +104,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
self.log('validation', validation)
|
||||
return validation
|
||||
|
||||
def generate_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, output_tensor : VisionTensor) -> None:
|
||||
def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> None:
|
||||
preview_limit = 8
|
||||
preview_cells = []
|
||||
|
||||
|
||||
@@ -7,23 +7,13 @@ from torch.nn import Module
|
||||
|
||||
Batch : TypeAlias = Tuple[Tensor, Tensor]
|
||||
|
||||
SwapAttributes : TypeAlias = Tuple[Tensor, ...]
|
||||
TargetAttributes : TypeAlias = Tuple[Tensor, ...]
|
||||
DiscriminatorOutputs : TypeAlias = List[List[Tensor]]
|
||||
|
||||
Attributes : TypeAlias = Tuple[Tensor, ...]
|
||||
Embedding : TypeAlias = Tensor
|
||||
FaceLandmark203 : TypeAlias = Tensor
|
||||
|
||||
StateSet : TypeAlias = OrderedDict[str, Any]
|
||||
Padding : TypeAlias = Tuple[int, int, int, int]
|
||||
|
||||
VisionFrame : TypeAlias = NDArray[Any]
|
||||
LossTensor : TypeAlias = Tensor
|
||||
VisionTensor : TypeAlias = Tensor
|
||||
|
||||
GeneratorLossSet : TypeAlias = Dict[str, Tensor]
|
||||
DiscriminatorLossSet : TypeAlias = Dict[str, Tensor]
|
||||
|
||||
GeneratorModule : TypeAlias = Module
|
||||
EmbedderModule : TypeAlias = Module
|
||||
|
||||
Reference in New Issue
Block a user