Revert changes in generator

This commit is contained in:
henryruhs
2025-02-14 16:12:15 +01:00
parent b69f69d015
commit a971506271
+3 -3
View File
@@ -17,9 +17,9 @@ class AdaptiveEmbeddingIntegrationNetwork(nn.Module):
id_channels = CONFIG.getint('training.model.generator', 'id_channels')
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
self.unet = UNet()
self.encoder = UNet()
self.generator = AADGenerator(id_channels, num_blocks)
self.unet.apply(init_weight)
self.encoder.apply(init_weight)
self.generator.apply(init_weight)
def forward(self, target : VisionTensor, source_embedding : Embedding) -> Tuple[VisionTensor, TargetAttributes]:
@@ -28,7 +28,7 @@ class AdaptiveEmbeddingIntegrationNetwork(nn.Module):
return swap_tensor, target_attributes
def get_attributes(self, target : VisionTensor) -> TargetAttributes:
return self.unet(target)
return self.encoder(target)
def init_weight(module : nn.Module) -> None: