Fix weight identity

This commit is contained in:
henryruhs
2025-02-12 13:30:16 +01:00
parent b42d2b06e7
commit 860771e482
3 changed files with 4 additions and 4 deletions
+1 -1
View File
@@ -64,7 +64,7 @@ kernel_size = 4
```
[training.losses]
weight_adversarial = 1
weight_id = 20
weight_identity = 20
weight_attribute = 10
weight_reconstruction = 10
weight_pose = 100
+1 -1
View File
@@ -26,7 +26,7 @@ kernel_size =
[training.losses]
weight_adversarial =
weight_id =
weight_identity =
weight_attribute =
weight_reconstruction =
weight_pose =
+2 -2
View File
@@ -29,7 +29,7 @@ class FaceSwapperLoss:
def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet:
source_tensor, target_tensor, is_same_person = batch
weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial')
weight_id = CONFIG.getfloat('training.losses', 'weight_id')
weight_identity = CONFIG.getfloat('training.losses', 'weight_identity')
weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute')
weight_reconstruction = CONFIG.getfloat('training.losses', 'weight_reconstruction')
weight_pose = CONFIG.getfloat('training.losses', 'weight_pose')
@@ -53,7 +53,7 @@ class FaceSwapperLoss:
generator_loss_set['loss_gaze'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_id') * weight_id
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_identity') * weight_identity
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_attribute') * weight_attribute
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_reconstruction') * weight_reconstruction
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_pose') * weight_pose