Fix attribute loss

This commit is contained in:
henryruhs
2025-03-12 22:41:28 +01:00
parent 0d73bcf918
commit 7c75b0d898
+2 -2
View File
@@ -16,10 +16,10 @@ class Generator(nn.Module):
self.encoder.apply(init_weight)
self.generator.apply(init_weight)
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor:
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tuple[Attribute, ...]]:
target_attributes = self.encode_attributes(target_tensor)
output_tensor = self.generator(source_embedding, target_attributes)
return output_tensor
return output_tensor, target_attributes
def encode_attributes(self, input_tensor : Tensor) -> Tuple[Attribute, ...]:
return self.encoder(input_tensor)