mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Fix attribute loss
This commit is contained in:
@@ -16,10 +16,10 @@ class Generator(nn.Module):
|
||||
self.encoder.apply(init_weight)
|
||||
self.generator.apply(init_weight)
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor:
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tuple[Attribute, ...]]:
|
||||
target_attributes = self.encode_attributes(target_tensor)
|
||||
output_tensor = self.generator(source_embedding, target_attributes)
|
||||
return output_tensor
|
||||
return output_tensor, target_attributes
|
||||
|
||||
def encode_attributes(self, input_tensor : Tensor) -> Tuple[Attribute, ...]:
|
||||
return self.encoder(input_tensor)
|
||||
|
||||
Reference in New Issue
Block a user