diff --git a/face_swapper/config.ini b/face_swapper/config.ini index 2f17749..ad69109 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -16,7 +16,6 @@ gazer_path = motion_extractor_path = [training.model.generator] -encoder_type = identity_channels = output_channels = output_size = diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 369edc2..7e1ec08 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -3,7 +3,7 @@ import configparser from torch import Tensor, nn from ..networks.aad import AAD -from ..networks.unet import UNet, UNetPro +from ..networks.unet import UNet from ..types import Attributes, Embedding CONFIG = configparser.ConfigParser() @@ -13,16 +13,12 @@ CONFIG.read('config.ini') class Generator(nn.Module): def __init__(self) -> None: super().__init__() - encoder_type = CONFIG.get('training.model.generator', 'encoder_type') identity_channels = CONFIG.getint('training.model.generator', 'identity_channels') output_channels = CONFIG.getint('training.model.generator', 'output_channels') output_size = CONFIG.getint('training.model.generator', 'output_size') num_blocks = CONFIG.getint('training.model.generator', 'num_blocks') - if encoder_type == 'unet': - self.encoder = UNet(output_size) - if encoder_type == 'unet-pro': - self.encoder = UNetPro(output_size) + self.encoder = UNet(output_size) self.generator = AAD(identity_channels, output_channels, output_size, num_blocks) self.encoder.apply(init_weight) self.generator.apply(init_weight) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index d2b6f40..6272ba9 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -169,7 +169,7 @@ class GazeLoss(nn.Module): return gaze_loss, weighted_gaze_loss def detect_gaze(self, input_tensor : Tensor) -> Gaze: - scale_factor = CONFIG.getint('training.losses', 'gaze_scale_factor') + scale_factor = CONFIG.getfloat('training.losses', 'gaze_scale_factor') y_min = int(60 * scale_factor) y_max = int(224 * scale_factor) x_min = int(16 * scale_factor) diff --git a/face_swapper/src/networks/unet.py b/face_swapper/src/networks/unet.py index 8ab7707..187cbe6 100644 --- a/face_swapper/src/networks/unet.py +++ b/face_swapper/src/networks/unet.py @@ -2,9 +2,6 @@ from typing import Tuple import torch from torch import Tensor, nn -from torchvision import models -from torchvision.models import ResNet34_Weights - class UNet(nn.Module): def __init__(self, output_size : int) -> None: @@ -88,38 +85,6 @@ class UNet(nn.Module): return bottleneck_tensor, *up_features, output_tensor -class UNetPro(UNet): - def __init__(self, output_size : int) -> None: - super().__init__(output_size) - self.resnet = models.resnet34(weights = ResNet34_Weights.DEFAULT) - self.down_samples = self.create_down_samples() - self.up_samples = self.create_up_samples() - - def create_down_samples(self) -> nn.ModuleList: - down_samples = nn.ModuleList( - [ - nn.Sequential( - self.resnet.conv1, - self.resnet.bn1, - self.resnet.relu, - nn.Conv2d(64, 32, kernel_size = 1, bias = False), - nn.BatchNorm2d(32), - nn.LeakyReLU(0.1, inplace = True) - ), - DownSample(32, 64), - self.resnet.layer2, - self.resnet.layer3, - self.resnet.layer4, - DownSample(512, 1024), - DownSample(1024, 1024) - ]) - - if self.output_size == 512: - down_samples.append(DownSample(2048, 2048)) - - return down_samples - - class UpSample(nn.Module): def __init__(self, input_channels : int, output_channels : int) -> None: super().__init__() diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index d5e0245..fd35e34 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -80,7 +80,7 @@ class FaceSwapperTrainer(lightning.LightningModule): return generator_config, discriminator_config def training_step(self, batch : Batch, batch_index : int) -> Tensor: - preview_frequency = CONFIG.getfloat('training.trainer', 'preview_frequency') + preview_frequency = CONFIG.getint('training.trainer', 'preview_frequency') source_tensor, target_tensor = batch generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]