diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 29fa196..e016b80 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -40,7 +40,7 @@ class FaceSwapperTrainer(lightning.LightningModule): self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call] self.automatic_optimization = False - def forward(self, target_tensor : Tensor, source_embedding : Embedding) -> Tensor: + def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor: output_tensor = self.generator(source_embedding, target_tensor) return output_tensor