From a971506271d4f8aea271a87fa04e9a5f77f8a584 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Fri, 14 Feb 2025 16:12:15 +0100 Subject: [PATCH] Revert changes in generator --- face_swapper/src/models/generator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 5a378d0..e9b25e0 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -17,9 +17,9 @@ class AdaptiveEmbeddingIntegrationNetwork(nn.Module): id_channels = CONFIG.getint('training.model.generator', 'id_channels') num_blocks = CONFIG.getint('training.model.generator', 'num_blocks') - self.unet = UNet() + self.encoder = UNet() self.generator = AADGenerator(id_channels, num_blocks) - self.unet.apply(init_weight) + self.encoder.apply(init_weight) self.generator.apply(init_weight) def forward(self, target : VisionTensor, source_embedding : Embedding) -> Tuple[VisionTensor, TargetAttributes]: @@ -28,7 +28,7 @@ class AdaptiveEmbeddingIntegrationNetwork(nn.Module): return swap_tensor, target_attributes def get_attributes(self, target : VisionTensor) -> TargetAttributes: - return self.unet(target) + return self.encoder(target) def init_weight(module : nn.Module) -> None: