Mask typing and naming related updates

This commit is contained in:
henryruhs
2025-03-14 08:30:55 +01:00
parent 33d00ac941
commit 904a447e06
+1 -1
View File
@@ -53,7 +53,7 @@ class FaceSwapperTrainer(LightningModule):
self.mask_loss = MaskLoss(config_parser, self.face_parser)
self.automatic_optimization = False
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tensor]:
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]:
with torch.no_grad():
output_tensor, target_attributes = self.generator(source_embedding, target_tensor)
target_attribute = target_attributes[-1]