mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
changes
This commit is contained in:
@@ -26,10 +26,8 @@ Setup
|
||||
This `config.ini` utilizes the MegaFace dataset to train the Face Swapper model.
|
||||
|
||||
```
|
||||
[preparing.dataset]
|
||||
dataset_path = .datasets/vggface2
|
||||
directory_pattern = {}/*
|
||||
image_pattern = {}/*.*g
|
||||
[training.dataset]
|
||||
file_pattern = .datasets/vggface2/**/*.jpg
|
||||
same_person_probability = 0.2
|
||||
```
|
||||
|
||||
@@ -37,6 +35,7 @@ same_person_probability = 0.2
|
||||
[training.loader]
|
||||
batch_size = 8
|
||||
num_workers = 8
|
||||
split_ratio = 0.95
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
[preparing.dataset]
|
||||
dataset_path =
|
||||
directory_pattern =
|
||||
image_pattern =
|
||||
[training.dataset]
|
||||
file_pattern =
|
||||
same_person_probability =
|
||||
|
||||
[training.loader]
|
||||
batch_size =
|
||||
num_workers =
|
||||
split_ratio =
|
||||
|
||||
[training.model]
|
||||
id_embedder_path =
|
||||
|
||||
@@ -1,44 +1,30 @@
|
||||
import glob
|
||||
import os.path
|
||||
import random
|
||||
from typing import Tuple
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
from .types import Batch, ImagePathList, ImagePathSet
|
||||
from .types import Batch
|
||||
|
||||
|
||||
class DataLoader(Dataset[Tensor]):
|
||||
def __init__(self, dataset_path : str, dataset_image_pattern : str, dataset_directory_pattern : str, same_person_probability : float) -> None:
|
||||
def __init__(self, file_pattern : str, same_person_probability : float) -> None:
|
||||
self.same_person_probability = same_person_probability
|
||||
self.directory_paths = glob.glob(dataset_directory_pattern.format(dataset_path))
|
||||
self.image_paths, self.image_path_set = self.prepare_image_paths(dataset_image_pattern)
|
||||
self.file_paths = glob.glob(file_pattern)
|
||||
self.transforms = self.compose_transforms()
|
||||
|
||||
def __getitem__(self, index : int) -> Batch:
|
||||
source_image_path = self.image_paths[index]
|
||||
def __getitem__(self, index : int) -> Batch: # type:ignore[override]
|
||||
source_image_path = self.file_paths[index]
|
||||
|
||||
if random.random() > self.same_person_probability:
|
||||
if random.random() < self.same_person_probability:
|
||||
return self.prepare_same_person(source_image_path)
|
||||
|
||||
return self.prepare_different_person(source_image_path)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.image_paths)
|
||||
|
||||
#todo: remove this method - only use glob.glob in init()
|
||||
def prepare_image_paths(self, dataset_image_pattern : str) -> Tuple[ImagePathList, ImagePathSet]:
|
||||
image_paths = []
|
||||
image_path_set = {}
|
||||
|
||||
for directory_path in self.directory_paths:
|
||||
image_paths.extend(glob.glob(dataset_image_pattern.format(directory_path)))
|
||||
image_path_set[directory_path] = image_paths
|
||||
return image_paths, image_path_set
|
||||
return len(self.file_paths)
|
||||
|
||||
@staticmethod
|
||||
def compose_transforms() -> transforms:
|
||||
@@ -54,20 +40,15 @@ class DataLoader(Dataset[Tensor]):
|
||||
])
|
||||
|
||||
def prepare_different_person(self, source_image_path : str) -> Batch:
|
||||
is_same_person = torch.tensor(0)
|
||||
target_image_path = random.choice(self.image_paths)
|
||||
target_image_path = random.choice(self.file_paths)
|
||||
source_vision_frame = cv2.imread(source_image_path)
|
||||
target_vision_frame = cv2.imread(target_image_path)
|
||||
source_tensor = self.transforms(source_vision_frame)
|
||||
target_tensor = self.transforms(target_vision_frame)
|
||||
return source_tensor, target_tensor, is_same_person
|
||||
return source_tensor, target_tensor
|
||||
|
||||
def prepare_same_person(self, source_image_path : str) -> Batch:
|
||||
is_same_person = torch.tensor(1)
|
||||
#todo: why not like in prepare_different_person
|
||||
target_image_path = random.choice(self.image_path_set.get(os.path.dirname(source_image_path)))
|
||||
source_vision_frame = cv2.imread(source_image_path)
|
||||
target_vision_frame = cv2.imread(target_image_path)
|
||||
source_tensor = self.transforms(source_vision_frame)
|
||||
target_tensor = self.transforms(target_vision_frame)
|
||||
return source_tensor, target_tensor, is_same_person
|
||||
target_tensor = source_tensor.clone()
|
||||
return source_tensor, target_tensor
|
||||
|
||||
@@ -21,7 +21,7 @@ def export() -> None:
|
||||
model = Generator()
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
model.ir_version = ir_version
|
||||
model.ir_version = torch.tensor(ir_version)
|
||||
source_tensor = torch.randn(1, 512)
|
||||
target_tensor = torch.randn(1, 3, 256, 256)
|
||||
torch.onnx.export(model, (source_tensor, target_tensor), target_path, input_names = [ 'source', 'target' ], output_names = [ 'output' ], opset_version = opset_version)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import numpy
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from pytorch_msssim import ssim
|
||||
|
||||
from .types import EmbedderModule, Embedding, Padding, VisionFrame, VisionTensor
|
||||
|
||||
@@ -41,6 +42,13 @@ def calc_id_embedding(id_embedder : EmbedderModule, vision_tensor : VisionTensor
|
||||
crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0
|
||||
crop_vision_tensor[:, :, :, :padding[2]] = 0
|
||||
crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0
|
||||
source_embedding = id_embedder(crop_vision_tensor)
|
||||
with torch.no_grad():
|
||||
source_embedding = id_embedder(crop_vision_tensor)
|
||||
source_embedding = nn.functional.normalize(source_embedding, p = 2)
|
||||
return source_embedding
|
||||
|
||||
|
||||
def calc_structural_similarity(swap_tensor : VisionTensor, target_tensor : VisionTensor) -> Tensor:
|
||||
swap_data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))
|
||||
structural_similarity = 1 - ssim(swap_tensor, target_tensor, data_range = swap_data_range).mean()
|
||||
return structural_similarity
|
||||
|
||||
@@ -2,10 +2,9 @@ import configparser
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from pytorch_msssim import ssim
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
|
||||
from ..helper import calc_id_embedding, calc_structural_similarity, hinge_fake_loss, hinge_real_loss
|
||||
from ..types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
@@ -27,7 +26,7 @@ class FaceSwapperLoss:
|
||||
self.motion_extractor.eval()
|
||||
|
||||
def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet:
|
||||
source_tensor, target_tensor, is_same_person = batch
|
||||
source_tensor, target_tensor = batch
|
||||
weight_adversarial = CONFIG.getfloat('training.losses', 'weight_adversarial')
|
||||
weight_identity = CONFIG.getfloat('training.losses', 'weight_identity')
|
||||
weight_attribute = CONFIG.getfloat('training.losses', 'weight_attribute')
|
||||
@@ -39,7 +38,7 @@ class FaceSwapperLoss:
|
||||
'loss_adversarial': self.calc_adversarial_loss(discriminator_outputs),
|
||||
'loss_identity': self.calc_identity_loss(source_tensor, swap_tensor),
|
||||
'loss_attribute': self.calc_attribute_loss(target_attributes, swap_attributes),
|
||||
'loss_reconstruction': self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person)
|
||||
'loss_reconstruction': self.calc_reconstruction_loss(source_tensor, target_tensor, swap_tensor)
|
||||
}
|
||||
|
||||
if weight_pose > 0:
|
||||
@@ -95,12 +94,22 @@ class FaceSwapperLoss:
|
||||
loss_attribute = torch.stack(loss_attributes).mean() * 0.5
|
||||
return loss_attribute
|
||||
|
||||
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor:
|
||||
loss_reconstruction = torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1)
|
||||
loss_reconstruction = torch.mean(loss_reconstruction, dim = 1) * 0.5
|
||||
loss_reconstruction = torch.sum(loss_reconstruction * is_same_person) / (is_same_person.sum() + 1e-4)
|
||||
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
|
||||
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
|
||||
def calc_reconstruction_loss(self, source_tensor : VisionTensor, target_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
|
||||
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calc_id_embedding(self.id_embedder, target_tensor, (0, 0, 0, 0))
|
||||
face_similarities = (torch.cosine_similarity(source_embedding, target_embedding) + 1) * 0.5
|
||||
loss_reconstructions = []
|
||||
|
||||
for index, face_similarity in enumerate(face_similarities):
|
||||
if face_similarity.item() > 0.9:
|
||||
loss_mse = self.mse_loss(swap_tensor[index], target_tensor[index])
|
||||
loss_ssim = calc_structural_similarity(swap_tensor[index].unsqueeze(0), target_tensor[index].unsqueeze(0))
|
||||
loss_reconstruction = (loss_mse + loss_ssim) * 0.5
|
||||
loss_reconstructions.append(loss_reconstruction)
|
||||
else:
|
||||
loss_reconstructions.append(torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype))
|
||||
|
||||
loss_reconstruction = torch.stack(loss_reconstructions).mean()
|
||||
return loss_reconstruction
|
||||
|
||||
def calc_identity_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import configparser
|
||||
import os
|
||||
from typing import Tuple
|
||||
from typing import Any, Tuple
|
||||
|
||||
import lightning
|
||||
import torch
|
||||
@@ -10,7 +10,7 @@ from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from torch import Tensor, nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader as TorchDataLoader, Subset
|
||||
from torch.utils.data import DataLoader as TorchDataLoader, Dataset, random_split
|
||||
|
||||
from .data_loader import DataLoader
|
||||
from .helper import calc_id_embedding
|
||||
@@ -42,7 +42,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
return generator_optimizer, discriminator_optimizer
|
||||
|
||||
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_tensor, target_tensor, is_same_person = batch
|
||||
source_tensor, target_tensor = batch
|
||||
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
|
||||
swap_tensor = self.generator(source_embedding, target_tensor)
|
||||
@@ -75,7 +75,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
return generator_losses.get('loss_generator')
|
||||
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_tensor, target_tensor, _ = batch
|
||||
source_tensor, target_tensor = batch
|
||||
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
|
||||
output_tensor = self.generator(source_embedding, target_tensor)
|
||||
output_embedding = calc_id_embedding(self.id_embedder, output_tensor, (0, 0, 0, 0))
|
||||
@@ -91,7 +91,25 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
preview_items.append(torch.cat([ source_tensor, target_tensor, output_tensor] , dim = 2))
|
||||
|
||||
preview_grid = torchvision.utils.make_grid(torch.cat(preview_items, dim = 1).unsqueeze(0), normalize = True, scale_each = True)
|
||||
self.logger.experiment.add_image('preview', preview_grid, self.global_step)
|
||||
self.logger.experiment.add_image('preview', preview_grid, self.global_step) # type:ignore[attr-defined]
|
||||
|
||||
|
||||
def create_loaders(dataset : Dataset[Any]): # type:ignore[type-arg]
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
num_workers = CONFIG.getint('training.loader', 'num_workers')
|
||||
|
||||
training_dataset, validate_dataset = split_dataset(dataset)
|
||||
training_loader = TorchDataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = TorchDataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, drop_last = False, pin_memory = True, persistent_workers = True)
|
||||
return training_loader, validation_loader # type:ignore[return-value]
|
||||
|
||||
|
||||
def split_dataset(dataset : Dataset[Any]) -> Tuple[Dataset[Any], Dataset[Any]]:
|
||||
loader_split_ratio = CONFIG.getfloat('training.loader', 'split_ratio')
|
||||
training_size = int(loader_split_ratio * len(dataset)) # type:ignore[operator, arg-type]
|
||||
validation_size = len(dataset) - training_size # type:ignore[arg-type]
|
||||
training_dataset, validate_dataset = random_split(dataset, [training_size, validation_size])
|
||||
return training_dataset, validate_dataset
|
||||
|
||||
|
||||
def create_trainer() -> Trainer:
|
||||
@@ -106,7 +124,7 @@ def create_trainer() -> Trainer:
|
||||
logger = logger,
|
||||
log_every_n_steps = 10,
|
||||
max_epochs = trainer_max_epochs,
|
||||
precision = trainer_precision,
|
||||
precision = trainer_precision, # type:ignore[arg-type]
|
||||
callbacks =
|
||||
[
|
||||
ModelCheckpoint(
|
||||
@@ -123,17 +141,12 @@ def create_trainer() -> Trainer:
|
||||
|
||||
|
||||
def train() -> None:
|
||||
dataset_path = CONFIG.get('preparing.dataset', 'dataset_path')
|
||||
dataset_image_pattern = CONFIG.get('preparing.dataset', 'image_pattern')
|
||||
dataset_directory_pattern = CONFIG.get('preparing.dataset', 'directory_pattern')
|
||||
same_person_probability = CONFIG.getfloat('preparing.dataset', 'same_person_probability')
|
||||
batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
num_workers = CONFIG.getint('training.loader', 'num_workers')
|
||||
dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern')
|
||||
same_person_probability = CONFIG.getfloat('training.dataset', 'same_person_probability')
|
||||
output_resume_path = CONFIG.get('training.output', 'resume_path')
|
||||
|
||||
dataset = DataLoader(dataset_path, dataset_image_pattern, dataset_directory_pattern, same_person_probability)
|
||||
training_loader = TorchDataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
validation_loader = TorchDataLoader(Subset(dataset, range(1000)), batch_size = batch_size, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
|
||||
dataset = DataLoader(dataset_file_pattern, same_person_probability)
|
||||
training_loader, validation_loader = create_loaders(dataset)
|
||||
face_swapper_trainer = FaceSwapperTrainer()
|
||||
trainer = create_trainer()
|
||||
|
||||
|
||||
@@ -5,10 +5,7 @@ from numpy.typing import NDArray
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
Batch : TypeAlias = Tuple[Tensor, Tensor, Tensor]
|
||||
|
||||
ImagePathList : TypeAlias = List[str]
|
||||
ImagePathSet : TypeAlias = Dict[str, ImagePathList]
|
||||
Batch : TypeAlias = Tuple[Tensor, Tensor]
|
||||
|
||||
SwapAttributes : TypeAlias = Tuple[Tensor, ...]
|
||||
TargetAttributes : TypeAlias = Tuple[Tensor, ...]
|
||||
|
||||
Reference in New Issue
Block a user