From ed0f6ae897b8a3ac790a17c9c9556bbca466b6b3 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Sun, 23 Feb 2025 01:05:01 +0100 Subject: [PATCH] Use new loss code, Remove unused code, Remove old types, Ban VisionTensor naming --- face_swapper/src/helper.py | 18 +-- face_swapper/src/inferencing.py | 12 +- face_swapper/src/models/loss.py | 149 +----------------- .../src/networks/attribute_modulator.py | 8 +- face_swapper/src/training.py | 74 ++++----- face_swapper/src/types.py | 10 -- 6 files changed, 54 insertions(+), 217 deletions(-) diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index 855a317..2e3cf60 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -2,19 +2,19 @@ import numpy import torch from torch import Tensor, nn -from .types import EmbedderModule, Embedding, Padding, VisionFrame, VisionTensor +from .types import EmbedderModule, Embedding, Padding, VisionFrame -def convert_to_vision_tensor(vision_frame : VisionFrame) -> VisionTensor: - vision_tensor = torch.from_numpy(vision_frame[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32)) - vision_tensor = vision_tensor / 255.0 - vision_tensor = (vision_tensor - 0.5) * 2 - vision_tensor = vision_tensor.unsqueeze(0) - return vision_tensor +def convert_to_tensor(vision_frame : VisionFrame) -> Tensor: + output_tensor = torch.from_numpy(vision_frame[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32)) + output_tensor = output_tensor / 255.0 + output_tensor = (output_tensor - 0.5) * 2 + output_tensor = output_tensor.unsqueeze(0) + return output_tensor -def convert_to_vision_frame(vision_tensor : VisionTensor) -> VisionFrame: - vision_frame = vision_tensor.detach().cpu().numpy()[0] +def convert_to_vision_frame(input_tensor : Tensor) -> VisionFrame: + vision_frame = input_tensor.detach().cpu().numpy()[0] vision_frame = vision_frame.transpose(1, 2, 0) vision_frame = (vision_frame + 1) * 127.5 vision_frame = vision_frame.clip(0, 255).astype(numpy.uint8) diff --git a/face_swapper/src/inferencing.py b/face_swapper/src/inferencing.py index d01510d..30d3da5 100644 --- a/face_swapper/src/inferencing.py +++ b/face_swapper/src/inferencing.py @@ -3,7 +3,7 @@ import configparser import cv2 import torch -from .helper import calc_embedding, convert_to_vision_frame, convert_to_vision_tensor +from .helper import calc_embedding, convert_to_vision_frame, convert_to_tensor from .models.generator import Generator from .types import EmbedderModule, GeneratorModule, VisionFrame @@ -12,11 +12,11 @@ CONFIG.read('config.ini') def run_swap(generator : GeneratorModule, embedder : EmbedderModule, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame: - source_vision_tensor = convert_to_vision_tensor(source_vision_frame) - target_vision_tensor = convert_to_vision_tensor(target_vision_frame) - source_embedding = calc_embedding(embedder, source_vision_tensor, (0, 0, 0, 0)) - output_vision_tensor = generator(source_embedding, target_vision_tensor)[0] - output_vision_frame = convert_to_vision_frame(output_vision_tensor) + source_tensor = convert_to_tensor(source_vision_frame) + target_tensor = convert_to_tensor(target_vision_frame) + source_embedding = calc_embedding(embedder, source_tensor, (0, 0, 0, 0)) + output_tensor = generator(source_embedding, target_tensor)[0] + output_vision_frame = convert_to_vision_frame(output_tensor) return output_vision_frame diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 7e810e5..6df2c7b 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -6,153 +6,12 @@ from pytorch_msssim import ssim from torch import Tensor, nn from ..helper import calc_embedding -from ..types import Attributes, Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor +from ..types import Attributes, FaceLandmark203 CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') -def hinge_real_loss(input_tensor : Tensor) -> Tensor: - real_loss = torch.relu(1 - input_tensor) - real_loss = real_loss.mean(dim = [ 1, 2, 3 ]) - return real_loss - - -def hinge_fake_loss(input_tensor : Tensor) -> Tensor: - fake_loss = torch.relu(input_tensor + 1) - fake_loss = fake_loss.mean(dim = [ 1, 2, 3 ]) - return fake_loss - - -class FaceSwapperLoss: - def __init__(self) -> None: - embedder_path = CONFIG.get('training.model', 'embedder_path') - landmarker_path = CONFIG.get('training.model', 'landmarker_path') - motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path') - self.batch_size = CONFIG.getint('training.loader', 'batch_size') - self.mse_loss = nn.MSELoss() - self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] - self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call] - self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call] - - def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet: - weight_adversarial = CONFIG.getfloat('training.losses', 'adversarial_weight') - weight_identity = CONFIG.getfloat('training.losses', 'identity_weight') - weight_attribute = CONFIG.getfloat('training.losses', 'attribute_weight') - weight_reconstruction = CONFIG.getfloat('training.losses', 'reconstruction_weight') - weight_pose = CONFIG.getfloat('training.losses', 'pose_weight') - weight_gaze = CONFIG.getfloat('training.losses', 'gaze_weight') - source_tensor, target_tensor = batch - is_same_person = torch.tensor(0) if torch.equal(source_tensor, target_tensor) else torch.tensor(1) - generator_loss_set =\ - { - 'loss_adversarial': self.calc_adversarial_loss(discriminator_outputs), - 'loss_identity': self.calc_identity_loss(source_tensor, swap_tensor), - 'loss_attribute': self.calc_attribute_loss(target_attributes, swap_attributes), - 'loss_reconstruction': self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person) - } - - generator_loss_set['loss_pose'] = self.calc_pose_loss(swap_tensor, target_tensor) - generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor) - - generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial - generator_loss_set['loss_generator'] += generator_loss_set.get('loss_identity') * weight_identity - generator_loss_set['loss_generator'] += generator_loss_set.get('loss_attribute') * weight_attribute - generator_loss_set['loss_generator'] += generator_loss_set.get('loss_reconstruction') * weight_reconstruction - generator_loss_set['loss_generator'] += generator_loss_set.get('loss_pose') * weight_pose - generator_loss_set['loss_generator'] += generator_loss_set.get('loss_gaze') * weight_gaze - return generator_loss_set - - def hinge_real_loss(input_tensor: Tensor) -> Tensor: - real_loss = torch.relu(1 - input_tensor) - real_loss = real_loss.mean(dim = [1, 2, 3]) - return real_loss - - def hinge_fake_loss(input_tensor: Tensor) -> Tensor: - fake_loss = torch.relu(input_tensor + 1) - fake_loss = fake_loss.mean(dim = [1, 2, 3]) - return fake_loss - - def calc_discriminator_loss(self, real_discriminator_outputs : DiscriminatorOutputs, fake_discriminator_outputs : DiscriminatorOutputs) -> DiscriminatorLossSet: - discriminator_loss_set = {} - loss_fakes = [] - - for fake_discriminator_output in fake_discriminator_outputs: - loss_fakes.append(hinge_fake_loss(fake_discriminator_output[0])) - - loss_trues = [] - - for true_discriminator_output in real_discriminator_outputs: - loss_trues.append(hinge_real_loss(true_discriminator_output[0])) - - loss_fake = torch.stack(loss_fakes).mean() - loss_true = torch.stack(loss_trues).mean() - discriminator_loss_set['loss_discriminator'] = (loss_true + loss_fake) * 0.5 - return discriminator_loss_set - - def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor: - loss_adversarials = [] - - for discriminator_output in discriminator_outputs: - loss_adversarials.append(hinge_real_loss(discriminator_output[0]).mean()) - - loss_adversarial = torch.stack(loss_adversarials).mean() - return loss_adversarial - - def calc_attribute_loss(self, target_attributes : TargetAttributes, swap_attributes : SwapAttributes) -> LossTensor: - loss_attributes = [] - - for swap_attribute, target_attribute in zip(swap_attributes, target_attributes): - loss_attributes.append(torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean()) - - loss_attribute = torch.stack(loss_attributes).mean() * 0.5 - return loss_attribute - - def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor: - loss_reconstruction = torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1) - loss_reconstruction = torch.mean(loss_reconstruction, dim = 1) * 0.5 - loss_reconstruction = torch.sum(loss_reconstruction * is_same_person) / (is_same_person.sum() + 1e-4) - loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean() - loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5 - return loss_reconstruction - - def calc_identity_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor: - swap_embedding = calc_embedding(self.embedder, swap_tensor, (30, 0, 10, 10)) - source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10)) - loss_identity = (1 - torch.cosine_similarity(source_embedding, swap_embedding)).mean() - return loss_identity - - def calc_pose_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor: - swap_motion_features = self.get_pose_features(swap_tensor) - target_motion_features = self.get_pose_features(target_tensor) - loss_pose = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) - - for swap_motion_feature, target_motion_feature in zip(swap_motion_features, target_motion_features): - loss_pose += self.mse_loss(swap_motion_feature, target_motion_feature) - - return loss_pose - - def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor: - swap_landmark = self.get_face_landmarks(swap_tensor) - target_landmark = self.get_face_landmarks(target_tensor) - left_gaze_loss = self.mse_loss(swap_landmark[:, 198], target_landmark[:, 198]) - right_gaze_loss = self.mse_loss(swap_landmark[:, 197], target_landmark[:, 197]) - gaze_loss = left_gaze_loss + right_gaze_loss - return gaze_loss - - def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203: - vision_tensor_norm = (vision_tensor + 1) * 0.5 - vision_tensor_norm = nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear') - landmarks = self.landmarker(vision_tensor_norm)[2].view(-1, 203, 2) - return landmarks - - def get_pose_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]: - vision_tensor_norm = (vision_tensor + 1) * 0.5 - pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm) - rotation = torch.cat([ pitch, yaw, roll ], dim = 1) - return translation, scale, rotation - - class DiscriminatorLoss(nn.Module): def __init__(self) -> None: super().__init__() @@ -270,8 +129,8 @@ class PoseLoss(nn.Module): return pose_loss, weighted_pose_loss def get_motion_features(self, input_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]: - vision_tensor_norm = (input_tensor + 1) * 0.5 - pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm) + input_tensor = (input_tensor + 1) * 0.5 + pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(input_tensor) rotation = torch.cat([ pitch, yaw, roll ], dim = 1) return translation, scale, rotation @@ -283,7 +142,7 @@ class GazeLoss(nn.Module): self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call] self.mse_loss = nn.MSELoss() - def calc(self, target_tensor : VisionTensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]: + def calc(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]: gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight') output_face_landmark = self.detect_face_landmark(output_tensor) target_face_landmark = self.detect_face_landmark(target_tensor) diff --git a/face_swapper/src/networks/attribute_modulator.py b/face_swapper/src/networks/attribute_modulator.py index cdf44e1..f342d6c 100644 --- a/face_swapper/src/networks/attribute_modulator.py +++ b/face_swapper/src/networks/attribute_modulator.py @@ -1,7 +1,7 @@ import torch from torch import Tensor, nn -from ..types import Embedding, TargetAttributes +from ..types import Attributes, Embedding class AADGenerator(nn.Module): @@ -17,7 +17,7 @@ class AADGenerator(nn.Module): self.res_block_7 = AADResBlock(128, 64, 64, id_channels, num_blocks) self.res_block_8 = AADResBlock(64, 3, 64, id_channels, num_blocks) - def forward(self, target_attributes : TargetAttributes, source_embedding : Embedding) -> Tensor: + def forward(self, target_attributes : Attributes, source_embedding : Embedding) -> Tensor: feature_map = self.upsample(source_embedding) feature_map_1 = nn.functional.interpolate(self.res_block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) feature_map_2 = nn.functional.interpolate(self.res_block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) @@ -59,10 +59,10 @@ class AADSequential(nn.Module): super().__init__() self.layers = nn.ModuleList(args) - def forward(self, feature_map : Tensor, attribute_embedding : Embedding, id_embedding : Embedding) -> Tensor: + def forward(self, feature_map : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor: for layer in self.layers: if isinstance(layer, AADLayer): - feature_map = layer(feature_map, attribute_embedding, id_embedding) + feature_map = layer(feature_map, attribute_embedding, identity_embedding) else: feature_map = layer(feature_map) return feature_map diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index e5760cb..fb52b12 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -16,18 +16,18 @@ from .dataset import DynamicDataset from .helper import calc_embedding from .models.discriminator import Discriminator from .models.generator import Generator -from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, FaceSwapperLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss -from .types import Batch, Embedding, VisionTensor +from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss +from .types import Batch, Embedding CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') -class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): +class FaceSwapperTrainer(lightning.LightningModule): def __init__(self) -> None: super().__init__() - FaceSwapperLoss.__init__(self) automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization') + embedder_path = CONFIG.get('training.model', 'embedder_path') self.generator = Generator() self.discriminator = Discriminator() @@ -38,9 +38,10 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): self.identity_loss = IdentityLoss() self.pose_loss = PoseLoss() self.gaze_loss = GazeLoss() + self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] self.automatic_optimization = automatic_optimization - def forward(self, target_tensor : VisionTensor, source_embedding : Embedding) -> Tensor: + def forward(self, target_tensor : Tensor, source_embedding : Embedding) -> Tensor: output_tensor = self.generator(source_embedding, target_tensor) return output_tensor @@ -61,34 +62,6 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): generator_output_attributes = self.generator.get_attributes(generator_output_tensor) discriminator_output_tensors = self.discriminator(generator_output_tensor) - generator_loss_set = self.calc_generator_loss(generator_output_tensor, target_attributes, generator_output_attributes, discriminator_output_tensors, batch) - generator_optimizer.zero_grad() - self.manual_backward(generator_loss_set.get('loss_generator')) - generator_optimizer.step() - - discriminator_source_tensors = self.discriminator(source_tensor) - discriminator_output_tensors = self.discriminator(generator_output_tensor.detach()) - - discriminator_loss_set = self.calc_discriminator_loss(discriminator_source_tensors, discriminator_output_tensors) - discriminator_optimizer.zero_grad() - self.manual_backward(discriminator_loss_set.get('loss_discriminator')) - discriminator_optimizer.step() - - if self.global_step % preview_frequency == 0: - self.generate_preview(source_tensor, target_tensor, generator_output_tensor) - - self.log('loss_generator', generator_loss_set.get('loss_generator'), prog_bar = True) - self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator'), prog_bar = True) - self.log('loss_adversarial', generator_loss_set.get('loss_adversarial')) - self.log('loss_attribute', generator_loss_set.get('loss_attribute')) - self.log('loss_identity', generator_loss_set.get('loss_identity')) - self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction')) - self.log('loss_pose', generator_loss_set.get('loss_pose')) - self.log('loss_gaze', generator_loss_set.get('loss_gaze')) - - ############################################### - - discriminator_loss = self.discriminator_loss.calc(discriminator_source_tensors, discriminator_output_tensors) adversarial_loss, weighted_adversarial_loss = self.adversarial_loss.calc(discriminator_output_tensors) attribute_loss, weighted_attribute_loss = self.attribute_loss.calc(target_attributes, generator_output_attributes) reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor) @@ -97,15 +70,30 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): gaze_loss, weighted_gaze_loss = self.gaze_loss.calc(target_tensor, generator_output_tensor) generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss - self.log('generator_loss_new', generator_loss, prog_bar = True) - self.log('discriminator_loss_new', discriminator_loss, prog_bar = True) - self.log('adversarial_loss_new', adversarial_loss) - self.log('attribute_loss_new', attribute_loss) - self.log('reconstruction_loss_new', reconstruction_loss) - self.log('identity_loss_new', identity_loss) - self.log('pose_loss_new', pose_loss) - self.log('gaze_loss_new', gaze_loss) - return generator_loss_set.get('loss_generator') + generator_optimizer.zero_grad() + self.manual_backward(generator_loss) + generator_optimizer.step() + + discriminator_source_tensors = self.discriminator(source_tensor) + discriminator_output_tensors = self.discriminator(generator_output_tensor.detach()) + discriminator_loss = self.discriminator_loss.calc(discriminator_source_tensors, discriminator_output_tensors) + + discriminator_optimizer.zero_grad() + self.manual_backward(discriminator_loss) + discriminator_optimizer.step() + + if self.global_step % preview_frequency == 0: + self.generate_preview(source_tensor, target_tensor, generator_output_tensor) + + self.log('generator_loss', generator_loss, prog_bar = True) + self.log('discriminator_loss', discriminator_loss, prog_bar = True) + self.log('adversarial_loss', adversarial_loss) + self.log('attribute_loss', attribute_loss) + self.log('reconstruction_loss', reconstruction_loss) + self.log('identity_loss', identity_loss) + self.log('pose_loss', pose_loss) + self.log('gaze_loss', gaze_loss) + return generator_loss def validation_step(self, batch : Batch, batch_index : int) -> Tensor: source_tensor, target_tensor = batch @@ -116,7 +104,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss): self.log('validation', validation) return validation - def generate_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, output_tensor : VisionTensor) -> None: + def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> None: preview_limit = 8 preview_cells = [] diff --git a/face_swapper/src/types.py b/face_swapper/src/types.py index 83d6e6e..5dd0628 100644 --- a/face_swapper/src/types.py +++ b/face_swapper/src/types.py @@ -7,23 +7,13 @@ from torch.nn import Module Batch : TypeAlias = Tuple[Tensor, Tensor] -SwapAttributes : TypeAlias = Tuple[Tensor, ...] -TargetAttributes : TypeAlias = Tuple[Tensor, ...] -DiscriminatorOutputs : TypeAlias = List[List[Tensor]] - Attributes : TypeAlias = Tuple[Tensor, ...] Embedding : TypeAlias = Tensor FaceLandmark203 : TypeAlias = Tensor -StateSet : TypeAlias = OrderedDict[str, Any] Padding : TypeAlias = Tuple[int, int, int, int] VisionFrame : TypeAlias = NDArray[Any] -LossTensor : TypeAlias = Tensor -VisionTensor : TypeAlias = Tensor - -GeneratorLossSet : TypeAlias = Dict[str, Tensor] -DiscriminatorLossSet : TypeAlias = Dict[str, Tensor] GeneratorModule : TypeAlias = Module EmbedderModule : TypeAlias = Module