Adjust naming and typing

This commit is contained in:
henryruhs
2025-02-19 08:14:18 +01:00
parent 39818a16df
commit c161da2f25
2 changed files with 10 additions and 10 deletions
+6 -6
View File
@@ -16,18 +16,18 @@ class Generator(nn.Module):
id_channels = CONFIG.getint('training.model.generator', 'id_channels')
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
self.attribute_encoder = UNet()
self.attribute_generator = AADGenerator(id_channels, num_blocks)
self.attribute_encoder.apply(init_weight)
self.attribute_generator.apply(init_weight)
self.unet = UNet()
self.aad_generator = AADGenerator(id_channels, num_blocks)
self.unet.apply(init_weight)
self.aad_generator.apply(init_weight)
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor:
target_attributes = self.get_attributes(target_tensor)
output_tensor = self.attribute_generator(target_attributes, source_embedding)
output_tensor = self.aad_generator(target_attributes, source_embedding)
return output_tensor
def get_attributes(self, input_tensor : Tensor) -> Attributes:
return self.attribute_encoder(input_tensor)
return self.unet(input_tensor)
def init_weight(module : nn.Module) -> None:
+4 -4
View File
@@ -17,7 +17,7 @@ from .helper import calc_id_embedding
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import FaceSwapperLoss
from .types import Batch, Embedding, TargetAttributes, VisionTensor
from .types import Batch, Embedding, VisionTensor
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
@@ -31,9 +31,9 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
self.discriminator = Discriminator()
self.automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization')
def forward(self, target_tensor : VisionTensor, source_embedding : Embedding) -> Tuple[VisionTensor, TargetAttributes]:
output = self.generator(source_embedding, target_tensor)
return output
def forward(self, target_tensor : VisionTensor, source_embedding : Embedding) -> Tensor:
output_tensor = self.generator(source_embedding, target_tensor)
return output_tensor
def configure_optimizers(self) -> Tuple[Optimizer, Optimizer]:
learning_rate = CONFIG.getfloat('training.trainer', 'learning_rate')