mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Fix weight identity
This commit is contained in:
@@ -64,7 +64,7 @@ kernel_size = 4
|
||||
```
|
||||
[training.losses]
|
||||
weight_adversarial = 1
|
||||
weight_id = 20
|
||||
weight_identity = 20
|
||||
weight_attribute = 10
|
||||
weight_reconstruction = 10
|
||||
weight_pose = 100
|
||||
|
||||
@@ -26,7 +26,7 @@ kernel_size =
|
||||
|
||||
[training.losses]
|
||||
weight_adversarial =
|
||||
weight_id =
|
||||
weight_identity =
|
||||
weight_attribute =
|
||||
weight_reconstruction =
|
||||
weight_pose =
|
||||
|
||||
@@ -29,7 +29,7 @@ class FaceSwapperLoss:
|
||||
def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet:
|
||||
source_tensor, target_tensor, is_same_person = batch
|
||||
weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial')
|
||||
weight_id = CONFIG.getfloat('training.losses', 'weight_id')
|
||||
weight_identity = CONFIG.getfloat('training.losses', 'weight_identity')
|
||||
weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute')
|
||||
weight_reconstruction = CONFIG.getfloat('training.losses', 'weight_reconstruction')
|
||||
weight_pose = CONFIG.getfloat('training.losses', 'weight_pose')
|
||||
@@ -53,7 +53,7 @@ class FaceSwapperLoss:
|
||||
generator_loss_set['loss_gaze'] = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
|
||||
|
||||
generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_id') * weight_id
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_identity') * weight_identity
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_attribute') * weight_attribute
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_reconstruction') * weight_reconstruction
|
||||
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_pose') * weight_pose
|
||||
|
||||
Reference in New Issue
Block a user