Improve generator namings, Flip args to source then target

This commit is contained in:
henryruhs
2025-02-18 11:00:30 +01:00
parent f63bc788ac
commit bf696be097
5 changed files with 24 additions and 22 deletions
+2 -2
View File
@@ -3,7 +3,7 @@ from os import makedirs
import torch
from .models.generator import AdaptiveEmbeddingIntegrationNetwork
from .models.generator import Generator
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
@@ -18,7 +18,7 @@ def export() -> None:
makedirs(directory_path, exist_ok = True)
state_dict = torch.load(source_path, map_location = 'cpu').get('state_dict').get('generator')
model = AdaptiveEmbeddingIntegrationNetwork()
model = Generator()
model.load_state_dict(state_dict)
model.eval()
model.ir_version = ir_version
+2 -2
View File
@@ -4,7 +4,7 @@ import cv2
import torch
from .helper import calc_id_embedding, convert_to_vision_frame, convert_to_vision_tensor, read_image
from .models.generator import AdaptiveEmbeddingIntegrationNetwork
from .models.generator import Generator
from .types import Embedder, Generator, VisionFrame
CONFIG = configparser.ConfigParser()
@@ -28,7 +28,7 @@ def infer() -> None:
output_path = CONFIG.get('inferencing', 'output_path')
state_dict = torch.load(generator_path, map_location = 'cpu').get('state_dict').get('generator')
generator = AdaptiveEmbeddingIntegrationNetwork()
generator = Generator()
generator.load_state_dict(state_dict)
generator.eval()
id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
+14 -14
View File
@@ -1,34 +1,34 @@
import configparser
from typing import Tuple
from torch import nn
from torch import Tensor, nn
from ..networks.attribute_modulator import AADGenerator
from ..networks.unet import UNet
from ..types import Embedding, TargetAttributes, VisionTensor
from ..types import Embedding, Attributes
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
class AdaptiveEmbeddingIntegrationNetwork(nn.Module):
class Generator(nn.Module):
def __init__(self) -> None:
super(AdaptiveEmbeddingIntegrationNetwork, self).__init__()
super(Generator, self).__init__()
id_channels = CONFIG.getint('training.model.generator', 'id_channels')
num_blocks = CONFIG.getint('training.model.generator', 'num_blocks')
self.encoder = UNet()
self.generator = AADGenerator(id_channels, num_blocks)
self.encoder.apply(init_weight)
self.generator.apply(init_weight)
self.attribute_encoder = UNet()
self.attribute_generator = AADGenerator(id_channels, num_blocks)
self.attribute_encoder.apply(init_weight)
self.attribute_generator.apply(init_weight)
def forward(self, target : VisionTensor, source_embedding : Embedding) -> Tuple[VisionTensor, TargetAttributes]:
target_attributes = self.get_attributes(target)
swap_tensor = self.generator(target_attributes, source_embedding)
return swap_tensor, target_attributes
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tensor:
target_attributes = self.get_attributes(target_tensor)
output_tensor = self.attribute_generator(target_attributes, source_embedding)
return output_tensor
def get_attributes(self, target : VisionTensor) -> TargetAttributes:
return self.encoder(target)
def get_attributes(self, input_tensor : Tensor) -> Attributes:
return self.attribute_encoder(input_tensor)
def init_weight(module : nn.Module) -> None:
+5 -4
View File
@@ -15,7 +15,7 @@ from torch.utils.data import DataLoader
from .data_loader import DataLoaderVGG
from .helper import calc_id_embedding
from .models.discriminator import Discriminator
from .models.generator import AdaptiveEmbeddingIntegrationNetwork
from .models.generator import Generator
from .models.loss import FaceSwapperLoss
from .types import Batch, Embedding, TargetAttributes, VisionTensor
@@ -27,12 +27,12 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
def __init__(self) -> None:
super().__init__()
FaceSwapperLoss.__init__(self)
self.generator = AdaptiveEmbeddingIntegrationNetwork()
self.generator = Generator()
self.discriminator = Discriminator()
self.automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization')
def forward(self, target_tensor : VisionTensor, source_embedding : Embedding) -> Tuple[VisionTensor, TargetAttributes]:
output = self.generator(target_tensor, source_embedding)
output = self.generator(source_embedding, target_tensor)
return output
def configure_optimizers(self) -> Tuple[Optimizer, Optimizer]:
@@ -45,7 +45,8 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
source_tensor, target_tensor, is_same_person = batch
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
swap_tensor, target_attributes = self.generator(target_tensor, source_embedding)
swap_tensor = self.generator(source_embedding, target_tensor)
target_attributes = self.generator.get_attributes(target_tensor)
swap_attributes = self.generator.get_attributes(swap_tensor)
fake_discriminator_outputs = self.discriminator(swap_tensor)
+1
View File
@@ -14,6 +14,7 @@ SwapAttributes : TypeAlias = Tuple[Tensor, ...]
TargetAttributes : TypeAlias = Tuple[Tensor, ...]
DiscriminatorOutputs : TypeAlias = List[List[Tensor]]
Attributes : TypeAlias = Tuple[Tensor, ...]
Embedding : TypeAlias = Tensor
FaceLandmark203 : TypeAlias = Tensor