mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Modernize to use ModuleList, Fix some types
This commit is contained in:
@@ -1,21 +1,31 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from embedding_converter.src.types import VisionTensor
|
||||
|
||||
|
||||
class EmbeddingConverter(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(EmbeddingConverter, self).__init__()
|
||||
self.fc1 = nn.Linear(512, 1024)
|
||||
self.fc2 = nn.Linear(1024, 2048)
|
||||
self.fc3 = nn.Linear(2048, 1024)
|
||||
self.fc4 = nn.Linear(1024, 512)
|
||||
self.layers = self.create_layers()
|
||||
self.activation = nn.LeakyReLU()
|
||||
|
||||
def forward(self, inputs : Tensor) -> Tensor:
|
||||
norm_inputs = inputs / torch.norm(inputs)
|
||||
outputs = self.activation(self.fc1(norm_inputs))
|
||||
outputs = self.activation(self.fc2(outputs))
|
||||
outputs = self.activation(self.fc3(outputs))
|
||||
outputs = self.fc4(outputs)
|
||||
return outputs
|
||||
@staticmethod
|
||||
def create_layers() -> nn.ModuleList:
|
||||
layers = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(512, 1024),
|
||||
nn.Linear(1024, 2048),
|
||||
nn.Linear(2048, 1024),
|
||||
nn.Linear(1024, 512)
|
||||
])
|
||||
return layers
|
||||
|
||||
def forward(self, input_tensor: VisionTensor) -> VisionTensor:
|
||||
output_tensor = input_tensor / torch.norm(input_tensor)
|
||||
|
||||
for layer in self.layers[:-1]:
|
||||
output_tensor = self.activation(layer(output_tensor))
|
||||
|
||||
output_tensor = self.layers[-1](output_tensor)
|
||||
return output_tensor
|
||||
|
||||
@@ -11,7 +11,7 @@ from torch import Tensor
|
||||
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split
|
||||
|
||||
from .models.embedding_converter import EmbeddingConverter
|
||||
from .types import Batch, Loader
|
||||
from .types import Batch, Embedding
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
@@ -23,20 +23,20 @@ class EmbeddingConverterTrainer(pytorch_lightning.LightningModule):
|
||||
self.embedding_converter = EmbeddingConverter()
|
||||
self.mse_loss = torch.nn.MSELoss()
|
||||
|
||||
def forward(self, source_embedding : Tensor) -> Tensor:
|
||||
def forward(self, source_embedding : Embedding) -> Embedding:
|
||||
return self.embedding_converter(source_embedding)
|
||||
|
||||
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source, target = batch
|
||||
output = self(source)
|
||||
loss_training = self.mse_loss(output, target)
|
||||
source_tensor, target = batch
|
||||
output_tensor = self(source_tensor)
|
||||
loss_training = self.mse_loss(output_tensor, target)
|
||||
self.log('loss_training', loss_training, prog_bar = True)
|
||||
return loss_training
|
||||
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source, target = batch
|
||||
output = self(source)
|
||||
loss_validation = self.mse_loss(output, target)
|
||||
source_tensor, target_tensor = batch
|
||||
output_tensor = self(source_tensor)
|
||||
loss_validation = self.mse_loss(output_tensor, target_tensor)
|
||||
self.log('loss_validation', loss_validation, prog_bar = True)
|
||||
return loss_validation
|
||||
|
||||
@@ -58,7 +58,7 @@ class EmbeddingConverterTrainer(pytorch_lightning.LightningModule):
|
||||
}
|
||||
|
||||
|
||||
def create_loaders() -> Tuple[Loader, Loader]:
|
||||
def create_loaders() -> Tuple[DataLoader, DataLoader]:
|
||||
loader_batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
loader_num_workers = CONFIG.getint('training.loader', 'num_workers')
|
||||
|
||||
@@ -73,9 +73,9 @@ def split_dataset() -> Tuple[Dataset[Any], Dataset[Any]]:
|
||||
input_target_path = CONFIG.get('preparing.input', 'target_path')
|
||||
loader_split_ratio = CONFIG.getfloat('training.loader', 'split_ratio')
|
||||
|
||||
source_input = torch.from_numpy(numpy.load(input_source_path)).float()
|
||||
target_input = torch.from_numpy(numpy.load(input_target_path)).float()
|
||||
dataset = TensorDataset(source_input, target_input)
|
||||
source_tensor = torch.from_numpy(numpy.load(input_source_path)).float()
|
||||
target_tensor = torch.from_numpy(numpy.load(input_target_path)).float()
|
||||
dataset = TensorDataset(source_tensor, target_tensor)
|
||||
|
||||
dataset_size = len(dataset)
|
||||
training_size = int(loader_split_ratio * len(dataset))
|
||||
|
||||
@@ -4,11 +4,11 @@ from numpy.typing import NDArray
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
Batch = Tuple[Tensor, Tensor]
|
||||
Loader = DataLoader[Tuple[Tensor, ...]]
|
||||
|
||||
Embedding = NDArray[Any]
|
||||
EmbeddingDataset = NDArray[Embedding]
|
||||
FaceLandmark5 = NDArray[Any]
|
||||
|
||||
VisionFrame = NDArray[Any]
|
||||
VisionTensor = Tensor
|
||||
|
||||
Batch = Tuple[VisionTensor, VisionTensor]
|
||||
|
||||
Reference in New Issue
Block a user