diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 73db988..bb0b9be 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -34,7 +34,7 @@ class FaceSwapperLoss: weight_pose = CONFIG.getfloat('training.losses', 'weight_pose') weight_gaze = CONFIG.getfloat('training.losses', 'weight_gaze') source_tensor, target_tensor = batch - is_same_person = torch.tensor(0) if source_tensor == target_tensor else torch.tensor(1) + is_same_person = torch.tensor(0) if torch.equal(source_tensor, target_tensor) else torch.tensor(1) generator_loss_set =\ { 'loss_adversarial': self.calc_adversarial_loss(discriminator_outputs),