diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 8397652..7e810e5 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -36,12 +36,12 @@ class FaceSwapperLoss: self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call] def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet: - weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial') - weight_identity = CONFIG.getfloat('training.losses', 'weight_identity') - weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute') - weight_reconstruction = CONFIG.getfloat('training.losses', 'weight_reconstruction') - weight_pose = CONFIG.getfloat('training.losses', 'weight_pose') - weight_gaze = CONFIG.getfloat('training.losses', 'weight_gaze') + weight_adversarial = CONFIG.getfloat('training.losses', 'adversarial_weight') + weight_identity = CONFIG.getfloat('training.losses', 'identity_weight') + weight_attribute = CONFIG.getfloat('training.losses', 'attribute_weight') + weight_reconstruction = CONFIG.getfloat('training.losses', 'reconstruction_weight') + weight_pose = CONFIG.getfloat('training.losses', 'pose_weight') + weight_gaze = CONFIG.getfloat('training.losses', 'gaze_weight') source_tensor, target_tensor = batch is_same_person = torch.tensor(0) if torch.equal(source_tensor, target_tensor) else torch.tensor(1) generator_loss_set =\