From 860771e482eff97bf1e592aced49c2eb6b9fdc3d Mon Sep 17 00:00:00 2001 From: henryruhs Date: Wed, 12 Feb 2025 13:30:16 +0100 Subject: [PATCH] Fix weight identity --- face_swapper/README.md | 2 +- face_swapper/config.ini | 2 +- face_swapper/src/models/loss.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/face_swapper/README.md b/face_swapper/README.md index c261d02..0643c9e 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -64,7 +64,7 @@ kernel_size = 4 ``` [training.losses] weight_adversarial = 1 -weight_id = 20 +weight_identity = 20 weight_attribute = 10 weight_reconstruction = 10 weight_pose = 100 diff --git a/face_swapper/config.ini b/face_swapper/config.ini index effb748..faa6254 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -26,7 +26,7 @@ kernel_size = [training.losses] weight_adversarial = -weight_id = +weight_identity = weight_attribute = weight_reconstruction = weight_pose = diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 01a9d12..a209d49 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -29,7 +29,7 @@ class FaceSwapperLoss: def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet: source_tensor, target_tensor, is_same_person = batch weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial') - weight_id = CONFIG.getfloat('training.losses', 'weight_id') + weight_identity = CONFIG.getfloat('training.losses', 'weight_identity') weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute') weight_reconstruction = CONFIG.getfloat('training.losses', 'weight_reconstruction') weight_pose = CONFIG.getfloat('training.losses', 'weight_pose') @@ -53,7 +53,7 @@ class FaceSwapperLoss: generator_loss_set['loss_gaze'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial - generator_loss_set['loss_generator'] += generator_loss_set.get('loss_id') * weight_id + generator_loss_set['loss_generator'] += generator_loss_set.get('loss_identity') * weight_identity generator_loss_set['loss_generator'] += generator_loss_set.get('loss_attribute') * weight_attribute generator_loss_set['loss_generator'] += generator_loss_set.get('loss_reconstruction') * weight_reconstruction generator_loss_set['loss_generator'] += generator_loss_set.get('loss_pose') * weight_pose