mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Revert changes in generator
This commit is contained in:
@@ -17,9 +17,9 @@ class AdaptiveEmbeddingIntegrationNetwork(nn.Module):
|
||||
id_channels = CONFIG.getint('training.model.generator', 'id_channels')
|
||||
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
|
||||
|
||||
self.unet = UNet()
|
||||
self.encoder = UNet()
|
||||
self.generator = AADGenerator(id_channels, num_blocks)
|
||||
self.unet.apply(init_weight)
|
||||
self.encoder.apply(init_weight)
|
||||
self.generator.apply(init_weight)
|
||||
|
||||
def forward(self, target : VisionTensor, source_embedding : Embedding) -> Tuple[VisionTensor, TargetAttributes]:
|
||||
@@ -28,7 +28,7 @@ class AdaptiveEmbeddingIntegrationNetwork(nn.Module):
|
||||
return swap_tensor, target_attributes
|
||||
|
||||
def get_attributes(self, target : VisionTensor) -> TargetAttributes:
|
||||
return self.unet(target)
|
||||
return self.encoder(target)
|
||||
|
||||
|
||||
def init_weight(module : nn.Module) -> None:
|
||||
|
||||
Reference in New Issue
Block a user