diff --git a/face_swapper/src/models/generator.py b/face_swapper/src/models/generator.py index 3fc26a4..5a378d0 100644 --- a/face_swapper/src/models/generator.py +++ b/face_swapper/src/models/generator.py @@ -4,7 +4,7 @@ from typing import Tuple from torch import nn from face_swapper.src.networks.attribute_modulator import AADGenerator -from face_swapper.src.networks.encoder import UNet +from face_swapper.src.networks.unet import UNet from face_swapper.src.types import Embedding, TargetAttributes, VisionTensor CONFIG = configparser.ConfigParser() @@ -17,9 +17,9 @@ class AdaptiveEmbeddingIntegrationNetwork(nn.Module): id_channels = CONFIG.getint('training.model.generator', 'id_channels') num_blocks = CONFIG.getint('training.model.generator', 'num_blocks') - self.encoder = UNet() + self.unet = UNet() self.generator = AADGenerator(id_channels, num_blocks) - self.encoder.apply(init_weight) + self.unet.apply(init_weight) self.generator.apply(init_weight) def forward(self, target : VisionTensor, source_embedding : Embedding) -> Tuple[VisionTensor, TargetAttributes]: @@ -28,7 +28,7 @@ class AdaptiveEmbeddingIntegrationNetwork(nn.Module): return swap_tensor, target_attributes def get_attributes(self, target : VisionTensor) -> TargetAttributes: - return self.encoder(target) + return self.unet(target) def init_weight(module : nn.Module) -> None: diff --git a/face_swapper/src/networks/encoder.py b/face_swapper/src/networks/encoder.py deleted file mode 100644 index d4e270e..0000000 --- a/face_swapper/src/networks/encoder.py +++ /dev/null @@ -1,67 +0,0 @@ -import torch -from torch import Tensor, nn - -from face_swapper.src.types import TargetAttributes, VisionTensor - - -class Upsample(nn.Module): - def __init__(self, input_channels : int, output_channels : int) -> None: - super(Upsample, self).__init__() - self.deconv = nn.ConvTranspose2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) - self.batch_norm = nn.BatchNorm2d(output_channels) - self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) - - def forward(self, input_tensor : Tensor, skip_tensor : Tensor) -> Tensor: - temp_tensor = self.deconv(input_tensor) - temp_tensor = self.batch_norm(temp_tensor) - temp_tensor = self.leaky_relu(temp_tensor) - return torch.cat((temp_tensor, skip_tensor), dim = 1) - - -class DownSample(nn.Module): - def __init__(self, input_channels : int, output_channels : int) -> None: - super(DownSample, self).__init__() - self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) - self.batch_norm = nn.BatchNorm2d(output_channels) - self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) - - def forward(self, input_tensor : Tensor) -> Tensor: - temp_tensor = self.conv(input_tensor) - temp_tensor = self.batch_norm(temp_tensor) - temp_tensor = self.leaky_relu(temp_tensor) - return temp_tensor - - -class UNet(nn.Module): - def __init__(self) -> None: - super(UNet, self).__init__() - self.downsampler_1 = DownSample(3, 32) - self.downsampler_2 = DownSample(32, 64) - self.downsampler_3 = DownSample(64, 128) - self.downsampler_4 = DownSample(128, 256) - self.downsampler_5 = DownSample(256, 512) - self.downsampler_6 = DownSample(512, 1024) - self.bottleneck = DownSample(1024, 1024) - self.upsampler_1 = Upsample(1024, 1024) - self.upsampler_2 = Upsample(2048, 512) - self.upsampler_3 = Upsample(1024, 256) - self.upsampler_4 = Upsample(512, 128) - self.upsampler_5 = Upsample(256, 64) - self.upsampler_6 = Upsample(128, 32) - - def forward(self, target : VisionTensor) -> TargetAttributes: - downsample_feature_1 = self.downsampler_1(target) - downsample_feature_2 = self.downsampler_2(downsample_feature_1) - downsample_feature_3 = self.downsampler_3(downsample_feature_2) - downsample_feature_4 = self.downsampler_4(downsample_feature_3) - downsample_feature_5 = self.downsampler_5(downsample_feature_4) - downsample_feature_6 = self.downsampler_6(downsample_feature_5) - bottleneck_output = self.bottleneck(downsample_feature_6) - upsample_feature_1 = self.upsampler_1(bottleneck_output, downsample_feature_6) - upsample_feature_2 = self.upsampler_2(upsample_feature_1, downsample_feature_5) - upsample_feature_3 = self.upsampler_3(upsample_feature_2, downsample_feature_4) - upsample_feature_4 = self.upsampler_4(upsample_feature_3, downsample_feature_3) - upsample_feature_5 = self.upsampler_5(upsample_feature_4, downsample_feature_2) - upsample_feature_6 = self.upsampler_6(upsample_feature_5, downsample_feature_1) - output = nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False) - return bottleneck_output, upsample_feature_1, upsample_feature_2, upsample_feature_3, upsample_feature_4, upsample_feature_5, upsample_feature_6, output diff --git a/face_swapper/src/networks/unet.py b/face_swapper/src/networks/unet.py new file mode 100644 index 0000000..6454455 --- /dev/null +++ b/face_swapper/src/networks/unet.py @@ -0,0 +1,85 @@ +import torch +from torch import Tensor, nn + +from face_swapper.src.types import TargetAttributes + + +class UNet(nn.Module): + def __init__(self) -> None: + super(UNet, self).__init__() + self.down = self.create_down() + self.up = self.create_up() + + @staticmethod + def create_down(): + return nn.ModuleList( + [ + Down(3, 32), + Down(32, 64), + Down(64, 128), + Down(128, 256), + Down(256, 512), + Down(512, 1024), + Down(1024, 1024) + ]) + + @staticmethod + def create_up(): + return nn.ModuleList( + [ + Up(1024, 1024), + Up(2048, 512), + Up(1024, 256), + Up(512, 128), + Up(256, 64), + Up(128, 32) + ]) + + def forward(self, target_tensor : Tensor) -> TargetAttributes: + down_features = [] + up_features = [] + temp_tensor = target_tensor + + for down in self.down: + temp_tensor = down(temp_tensor) + down_features.append(temp_tensor) + + bottleneck_tensor = down_features[-1] + temp_tensor = bottleneck_tensor + + for index, up in enumerate(self.up): + down_index = -(index + 2) + up_feature = up(temp_tensor, down_features[down_index]) + up_features.append(up_feature) + + output_tensor = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False) + return bottleneck_tensor, *up_features, output_tensor + + +class Up(nn.Module): + def __init__(self, input_channels : int, output_channels : int) -> None: + super(Up, self).__init__() + self.conv_transpose = nn.ConvTranspose2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) + self.batch_norm = nn.BatchNorm2d(output_channels) + self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) + + def forward(self, input_tensor : Tensor, skip_tensor : Tensor) -> Tensor: + temp_tensor = self.conv_transpose(input_tensor) + temp_tensor = self.batch_norm(temp_tensor) + temp_tensor = self.leaky_relu(temp_tensor) + temp_tensor = torch.cat((temp_tensor, skip_tensor), dim = 1) + return temp_tensor + + +class Down(nn.Module): + def __init__(self, input_channels : int, output_channels : int) -> None: + super(Down, self).__init__() + self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) + self.batch_norm = nn.BatchNorm2d(output_channels) + self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) + + def forward(self, input_tensor : Tensor) -> Tensor: + temp_tensor = self.conv(input_tensor) + temp_tensor = self.batch_norm(temp_tensor) + temp_tensor = self.leaky_relu(temp_tensor) + return temp_tensor