mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Improve lot of types, imports and names
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import nn
|
||||
|
||||
from ..types import VisionTensor
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.tuner.tuning import Tuner
|
||||
from torch import Tensor
|
||||
from torch import Tensor, nn
|
||||
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split
|
||||
|
||||
from .models.embedding_converter import EmbeddingConverter
|
||||
@@ -21,7 +21,7 @@ class EmbeddingConverterTrainer(pytorch_lightning.LightningModule):
|
||||
def __init__(self) -> None:
|
||||
super(EmbeddingConverterTrainer, self).__init__()
|
||||
self.embedding_converter = EmbeddingConverter()
|
||||
self.mse_loss = torch.nn.MSELoss()
|
||||
self.mse_loss = nn.MSELoss()
|
||||
|
||||
def forward(self, source_embedding : Embedding) -> Embedding:
|
||||
return self.embedding_converter(source_embedding)
|
||||
|
||||
@@ -12,9 +12,9 @@ from .types import Batch, ImagePathList, ImagePathSet
|
||||
|
||||
|
||||
class DataLoaderVGG(TensorDataset):
|
||||
def __init__(self, dataset_path : str, dataset_image_pattern : str, dataset_folder_pattern : str, same_person_probability : float) -> None:
|
||||
def __init__(self, dataset_path : str, dataset_image_pattern : str, dataset_directory_pattern : str, same_person_probability : float) -> None:
|
||||
self.same_person_probability = same_person_probability
|
||||
self.directory_paths = glob.glob(dataset_folder_pattern.format(dataset_path))
|
||||
self.directory_paths = glob.glob(dataset_directory_pattern.format(dataset_path))
|
||||
self.image_paths, self.image_path_set = self.prepare_image_paths(dataset_image_pattern)
|
||||
self.dataset_total = len(self.image_paths)
|
||||
self.transforms = self.compose_transforms()
|
||||
|
||||
@@ -3,8 +3,9 @@ import platform
|
||||
import cv2
|
||||
import numpy
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .types import Embedder, Embedding, Padding, Tensor, VisionFrame, VisionTensor
|
||||
from .types import Embedder, Embedding, Padding, VisionFrame, VisionTensor
|
||||
|
||||
|
||||
def is_windows() -> bool:
|
||||
@@ -35,25 +36,25 @@ def convert_to_vision_frame(vision_tensor : VisionTensor) -> VisionFrame:
|
||||
return vision_frame
|
||||
|
||||
|
||||
def hinge_real_loss(tensor : Tensor) -> Tensor:
|
||||
real_loss = torch.relu(1 - tensor)
|
||||
def hinge_real_loss(input_tensor : Tensor) -> Tensor:
|
||||
real_loss = torch.relu(1 - input_tensor)
|
||||
real_loss = real_loss.mean(dim = [ 1, 2, 3 ])
|
||||
return real_loss
|
||||
|
||||
|
||||
def hinge_fake_loss(tensor : Tensor) -> Tensor:
|
||||
fake_loss = torch.relu(tensor + 1)
|
||||
def hinge_fake_loss(input_tensor : Tensor) -> Tensor:
|
||||
fake_loss = torch.relu(input_tensor + 1)
|
||||
fake_loss = fake_loss.mean(dim = [ 1, 2, 3 ])
|
||||
return fake_loss
|
||||
|
||||
|
||||
def calc_id_embedding(id_embedder : Embedder, vision_tensor : VisionTensor, padding : Padding) -> Embedding:
|
||||
crop_vision_tensor = vision_tensor[:, :, 15 : 241, 15 : 241]
|
||||
crop_vision_tensor = torch.nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'area')
|
||||
crop_vision_tensor = nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'area')
|
||||
crop_vision_tensor[:, :, :padding[0], :] = 0
|
||||
crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0
|
||||
crop_vision_tensor[:, :, :, :padding[2]] = 0
|
||||
crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0
|
||||
source_embedding = id_embedder(crop_vision_tensor)
|
||||
source_embedding = torch.nn.functional.normalize(source_embedding, p = 2)
|
||||
source_embedding = nn.functional.normalize(source_embedding, p = 2)
|
||||
return source_embedding
|
||||
|
||||
@@ -2,7 +2,7 @@ import configparser
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
import torch.nn as nn
|
||||
from torch import nn
|
||||
|
||||
from face_swapper.src.types import VisionTensor
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import configparser
|
||||
from typing import Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import nn
|
||||
|
||||
from face_swapper.src.networks.attribute_modulator import AADGenerator
|
||||
from face_swapper.src.networks.encoder import UNet
|
||||
|
||||
@@ -18,7 +18,7 @@ class FaceSwapperLoss:
|
||||
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
|
||||
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
|
||||
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
self.mse_loss = torch.nn.MSELoss()
|
||||
self.mse_loss = nn.MSELoss()
|
||||
self.id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
@@ -127,7 +127,7 @@ class FaceSwapperLoss:
|
||||
|
||||
def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203:
|
||||
vision_tensor_norm = (vision_tensor + 1) * 0.5
|
||||
vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear')
|
||||
vision_tensor_norm = nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear')
|
||||
landmarks = self.landmarker(vision_tensor_norm)[2].view(-1, 203, 2)
|
||||
return landmarks
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from torch import Tensor, nn as nn
|
||||
from torch import Tensor, nn
|
||||
|
||||
from face_swapper.src.types import Embedding, TargetAttributes
|
||||
|
||||
@@ -19,13 +19,13 @@ class AADGenerator(nn.Module):
|
||||
|
||||
def forward(self, target_attributes : TargetAttributes, source_embedding : Embedding) -> Tensor:
|
||||
feature_map = self.upsample(source_embedding)
|
||||
feature_map_1 = torch.nn.functional.interpolate(self.res_block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_2 = torch.nn.functional.interpolate(self.res_block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_3 = torch.nn.functional.interpolate(self.res_block_3(feature_map_2, target_attributes[2], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_4 = torch.nn.functional.interpolate(self.res_block_4(feature_map_3, target_attributes[3], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_5 = torch.nn.functional.interpolate(self.res_block_5(feature_map_4, target_attributes[4], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_6 = torch.nn.functional.interpolate(self.res_block_6(feature_map_5, target_attributes[5], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_7 = torch.nn.functional.interpolate(self.res_block_7(feature_map_6, target_attributes[6], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_1 = nn.functional.interpolate(self.res_block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_2 = nn.functional.interpolate(self.res_block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_3 = nn.functional.interpolate(self.res_block_3(feature_map_2, target_attributes[2], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_4 = nn.functional.interpolate(self.res_block_4(feature_map_3, target_attributes[3], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_5 = nn.functional.interpolate(self.res_block_5(feature_map_4, target_attributes[4], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_6 = nn.functional.interpolate(self.res_block_6(feature_map_5, target_attributes[5], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
feature_map_7 = nn.functional.interpolate(self.res_block_7(feature_map_6, target_attributes[6], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
output = self.res_block_8(feature_map_7, target_attributes[7], source_embedding)
|
||||
return torch.tanh(output)
|
||||
|
||||
@@ -41,7 +41,7 @@ class AADLayer(nn.Module):
|
||||
self.instance_norm = nn.InstanceNorm2d(input_channels)
|
||||
self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1)
|
||||
|
||||
def forward(self, feature_map : Tensor, attribute_embedding : Tensor, id_embedding : Embedding) -> Tensor:
|
||||
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, id_embedding : Embedding) -> Tensor:
|
||||
feature_map = self.instance_norm(feature_map)
|
||||
gamma_attribute = self.conv_gamma(attribute_embedding)
|
||||
beta_attribute = self.conv_beta(attribute_embedding)
|
||||
@@ -59,7 +59,7 @@ class AADSequential(nn.Module):
|
||||
super(AADSequential, self).__init__()
|
||||
self.layers = nn.ModuleList(args)
|
||||
|
||||
def forward(self, feature_map: Tensor, attribute_embedding: Tensor, id_embedding: Embedding) -> Tensor:
|
||||
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, id_embedding : Embedding) -> Tensor:
|
||||
for layer in self.layers:
|
||||
if isinstance(layer, AADLayer):
|
||||
feature_map = layer(feature_map, attribute_embedding, id_embedding)
|
||||
@@ -99,7 +99,7 @@ class AADResBlock(nn.Module):
|
||||
)
|
||||
self.auxiliary_add_blocks = auxiliary_add_blocks
|
||||
|
||||
def forward(self, feature_map : Tensor, attribute_embedding : Tensor, id_embedding : Embedding) -> Tensor:
|
||||
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, id_embedding : Embedding) -> Tensor:
|
||||
primary_feature = self.primary_add_blocks(feature_map, attribute_embedding, id_embedding)
|
||||
|
||||
if self.input_channels > self.output_channels:
|
||||
@@ -115,7 +115,7 @@ class PixelShuffleUpsample(nn.Module):
|
||||
self.conv = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 3, padding = 1)
|
||||
self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2)
|
||||
|
||||
def forward(self, temp : Tensor) -> Tensor:
|
||||
temp = self.conv(temp.view(temp.shape[0], -1, 1, 1))
|
||||
temp = self.pixel_shuffle(temp)
|
||||
return temp
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = self.conv(input_tensor.view(input_tensor.shape[0], -1, 1, 1))
|
||||
temp_tensor = self.pixel_shuffle(temp_tensor)
|
||||
return temp_tensor
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from torch import Tensor, nn as nn
|
||||
from torch import Tensor, nn
|
||||
|
||||
from face_swapper.src.types import TargetAttributes, VisionTensor
|
||||
|
||||
@@ -11,11 +11,11 @@ class Upsample(nn.Module):
|
||||
self.batch_norm = nn.BatchNorm2d(output_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
|
||||
def forward(self, temp : Tensor, skip_tensor : Tensor) -> Tensor:
|
||||
temp = self.deconv(temp)
|
||||
temp = self.batch_norm(temp)
|
||||
temp = self.leaky_relu(temp)
|
||||
return torch.cat((temp, skip_tensor), dim = 1)
|
||||
def forward(self, input_tensor : Tensor, skip_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = self.deconv(input_tensor)
|
||||
temp_tensor = self.batch_norm(temp_tensor)
|
||||
temp_tensor = self.leaky_relu(temp_tensor)
|
||||
return torch.cat((temp_tensor, skip_tensor), dim = 1)
|
||||
|
||||
|
||||
class DownSample(nn.Module):
|
||||
@@ -25,11 +25,11 @@ class DownSample(nn.Module):
|
||||
self.batch_norm = nn.BatchNorm2d(output_channels)
|
||||
self.leaky_relu = nn.LeakyReLU(0.1, inplace = True)
|
||||
|
||||
def forward(self, temp : Tensor) -> Tensor:
|
||||
temp = self.conv(temp)
|
||||
temp = self.batch_norm(temp)
|
||||
temp = self.leaky_relu(temp)
|
||||
return temp
|
||||
def forward(self, input_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = self.conv(input_tensor)
|
||||
temp_tensor = self.batch_norm(temp_tensor)
|
||||
temp_tensor = self.leaky_relu(temp_tensor)
|
||||
return temp_tensor
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
@@ -63,5 +63,5 @@ class UNet(nn.Module):
|
||||
upsample_feature_4 = self.upsampler_4(upsample_feature_3, downsample_feature_3)
|
||||
upsample_feature_5 = self.upsampler_5(upsample_feature_4, downsample_feature_2)
|
||||
upsample_feature_6 = self.upsampler_6(upsample_feature_5, downsample_feature_1)
|
||||
output = torch.nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
output = nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
return bottleneck_output, upsample_feature_1, upsample_feature_2, upsample_feature_3, upsample_feature_4, upsample_feature_5, upsample_feature_6, output
|
||||
|
||||
+20
-21
@@ -1,32 +1,31 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple, TypeAlias
|
||||
|
||||
import torch.nn
|
||||
from numpy.typing import NDArray
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.nn import Module
|
||||
|
||||
Batch = Tuple[Any, Any, Any]
|
||||
Loader = DataLoader[Tuple[Tensor, ...]]
|
||||
ImagePathList = List[str]
|
||||
ImagePathSet = Dict[str, ImagePathList]
|
||||
ImagePathList : TypeAlias = List[str]
|
||||
ImagePathSet : TypeAlias = Dict[str, ImagePathList]
|
||||
|
||||
SwapAttributes = Tuple[Tensor, ...]
|
||||
TargetAttributes = Tuple[Tensor, ...]
|
||||
DiscriminatorOutputs = List[List[Tensor]]
|
||||
SwapAttributes : TypeAlias = Tuple[Tensor, ...]
|
||||
TargetAttributes : TypeAlias = Tuple[Tensor, ...]
|
||||
DiscriminatorOutputs : TypeAlias = List[List[Tensor]]
|
||||
|
||||
Embedding = Tensor
|
||||
FaceLandmark203 = Tensor
|
||||
Embedding : TypeAlias = Tensor
|
||||
FaceLandmark203 : TypeAlias = Tensor
|
||||
|
||||
StateSet = OrderedDict[str, Any]
|
||||
Padding = Tuple[int, int, int, int]
|
||||
StateSet : TypeAlias = OrderedDict[str, Any]
|
||||
Padding : TypeAlias = Tuple[int, int, int, int]
|
||||
|
||||
LossTensor = Tensor
|
||||
VisionTensor = Tensor
|
||||
VisionFrame = NDArray[Any]
|
||||
VisionFrame : TypeAlias = NDArray[Any]
|
||||
LossTensor : TypeAlias = Tensor
|
||||
VisionTensor : TypeAlias = Tensor
|
||||
|
||||
GeneratorLossSet = Dict[str, Tensor]
|
||||
DiscriminatorLossSet = Dict[str, Tensor]
|
||||
Batch : TypeAlias = Tuple[VisionTensor, VisionTensor, Tensor]
|
||||
|
||||
Generator = torch.nn.Module
|
||||
Embedder = torch.nn.Module
|
||||
GeneratorLossSet : TypeAlias = Dict[str, Tensor]
|
||||
DiscriminatorLossSet : TypeAlias = Dict[str, Tensor]
|
||||
|
||||
Generator : TypeAlias = Module
|
||||
Embedder : TypeAlias = Module
|
||||
|
||||
Reference in New Issue
Block a user