diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 5ed2821..4f5a2f5 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -24,6 +24,7 @@ class Generator(nn.Module): output_tensor = self.generator(source_embedding, target_features) target_feature = target_features[-1] output_mask = self.masker(target_tensor, target_feature) + output_tensor = output_tensor * output_mask + target_tensor * (1 - output_mask) return output_tensor, output_mask def encode_features(self, input_tensor : Tensor) -> Tuple[Feature, ...]: