This commit is contained in:
harisreedhar
2025-03-23 18:38:48 +05:30
parent c85c755e00
commit 9ede8a2a7d
+2 -2
View File
@@ -19,12 +19,12 @@ class Generator(nn.Module):
self.generator.apply(init_weight)
self.masker.apply(init_weight)
def forward(self, source_embedding : Embedding, target_tensor : Tensor, target_features : Tuple[Feature, ...]) -> Tuple[Tensor, Mask, ]:
def forward(self, source_embedding : Embedding, target_tensor : Tensor, target_features : Tuple[Feature, ...]) -> Tuple[Tensor, Mask]:
output_tensor = self.generator(source_embedding, target_features)
target_feature = target_features[-1]
output_mask = self.masker(target_tensor, target_feature)
output_tensor = output_tensor * output_mask + target_tensor * (1 - output_mask)
return output_tensor, output_mask,
return output_tensor, output_mask
def encode_features(self, input_tensor : Tensor) -> Tuple[Feature, ...]:
return self.encoder(input_tensor)