Use new config for old code

This commit is contained in:
henryruhs
2025-02-23 00:39:48 +01:00
parent 14b9bccafe
commit 63e4bea3cd
+6 -6
View File
@@ -36,12 +36,12 @@ class FaceSwapperLoss:
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet:
weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial')
weight_identity = CONFIG.getfloat('training.losses', 'weight_identity')
weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute')
weight_reconstruction = CONFIG.getfloat('training.losses', 'weight_reconstruction')
weight_pose = CONFIG.getfloat('training.losses', 'weight_pose')
weight_gaze = CONFIG.getfloat('training.losses', 'weight_gaze')
weight_adversarial = CONFIG.getfloat('training.losses', 'adversarial_weight')
weight_identity = CONFIG.getfloat('training.losses', 'identity_weight')
weight_attribute = CONFIG.getfloat('training.losses', 'attribute_weight')
weight_reconstruction = CONFIG.getfloat('training.losses', 'reconstruction_weight')
weight_pose = CONFIG.getfloat('training.losses', 'pose_weight')
weight_gaze = CONFIG.getfloat('training.losses', 'gaze_weight')
source_tensor, target_tensor = batch
is_same_person = torch.tensor(0) if torch.equal(source_tensor, target_tensor) else torch.tensor(1)
generator_loss_set =\