diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 2ddf748..d083614 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -16,10 +16,10 @@ class Generator(nn.Module): self.encoder.apply(init_weight) self.generator.apply(init_weight) - def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor: + def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tuple[Attribute, ...]]: target_attributes = self.encode_attributes(target_tensor) output_tensor = self.generator(source_embedding, target_attributes) - return output_tensor + return output_tensor, target_attributes def encode_attributes(self, input_tensor : Tensor) -> Tuple[Attribute, ...]: return self.encoder(input_tensor)