Improve generator namings, Flip args to source then target

This commit is contained in:
henryruhs
2025-02-18 11:00:30 +01:00
parent f63bc788ac
commit bf696be097
5 changed files with 24 additions and 22 deletions
+14 -14
View File
@@ -1,34 +1,34 @@
import configparser
from typing import Tuple
from torch import nn
from torch import Tensor, nn
from ..networks.attribute_modulator import AADGenerator
from ..networks.unet import UNet
from ..types import Embedding, TargetAttributes, VisionTensor
from ..types import Embedding, Attributes
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
class AdaptiveEmbeddingIntegrationNetwork(nn.Module):
class Generator(nn.Module):
def __init__(self) -> None:
super(AdaptiveEmbeddingIntegrationNetwork, self).__init__()
super(Generator, self).__init__()
id_channels = CONFIG.getint('training.model.generator', 'id_channels')
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
self.encoder = UNet()
self.generator = AADGenerator(id_channels, num_blocks)
self.encoder.apply(init_weight)
self.generator.apply(init_weight)
self.attribute_encoder = UNet()
self.attribute_generator = AADGenerator(id_channels, num_blocks)
self.attribute_encoder.apply(init_weight)
self.attribute_generator.apply(init_weight)
def forward(self, target : VisionTensor, source_embedding : Embedding) -> Tuple[VisionTensor, TargetAttributes]:
target_attributes = self.get_attributes(target)
swap_tensor = self.generator(target_attributes, source_embedding)
return swap_tensor, target_attributes
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor:
target_attributes = self.get_attributes(target_tensor)
output_tensor = self.attribute_generator(target_attributes, source_embedding)
return output_tensor
def get_attributes(self, target : VisionTensor) -> TargetAttributes:
return self.encoder(target)
def get_attributes(self, input_tensor : Tensor) -> Attributes:
return self.attribute_encoder(input_tensor)
def init_weight(module : nn.Module) -> None: