mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Rename id_embedder to embedder, Tons of naming in training step, Introduce new IdentityLoss class
This commit is contained in:
@@ -40,7 +40,7 @@ split_ratio = 0.9995
|
||||
|
||||
```
|
||||
[training.model]
|
||||
id_embedder_path = .models/id_embedder.pt
|
||||
embedder_path = .models/arcface.pt
|
||||
landmarker_path = .models/landmarker.pt
|
||||
motion_extractor_path = .models/motion_extractor.pt
|
||||
```
|
||||
@@ -99,7 +99,7 @@ opset_version = 15
|
||||
```
|
||||
[inferencing]
|
||||
generator_path = .outputs/last.ckpt
|
||||
id_embedder_path = .models/id_embedder.pt
|
||||
embedder_path = .models/arcface.pt
|
||||
source_path = .assets/source.jpg
|
||||
target_path = .assets/target.jpg
|
||||
output_path = .outputs/output.jpg
|
||||
|
||||
+10
-10
@@ -34,13 +34,13 @@ def hinge_fake_loss(input_tensor : Tensor) -> Tensor:
|
||||
return fake_loss
|
||||
|
||||
|
||||
def calc_id_embedding(id_embedder : EmbedderModule, vision_tensor : VisionTensor, padding : Padding) -> Embedding:
|
||||
crop_vision_tensor = vision_tensor[:, :, 15 : 241, 15 : 241]
|
||||
crop_vision_tensor = nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'area')
|
||||
crop_vision_tensor[:, :, :padding[0], :] = 0
|
||||
crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0
|
||||
crop_vision_tensor[:, :, :, :padding[2]] = 0
|
||||
crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0
|
||||
source_embedding = id_embedder(crop_vision_tensor)
|
||||
source_embedding = nn.functional.normalize(source_embedding, p = 2)
|
||||
return source_embedding
|
||||
def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
|
||||
crop_tensor = input_tensor[:, :, 15: 241, 15: 241]
|
||||
crop_tensor = nn.functional.interpolate(crop_tensor, size = (112, 112), mode = 'area')
|
||||
crop_tensor[:, :, :padding[0], :] = 0
|
||||
crop_tensor[:, :, 112 - padding[1]:, :] = 0
|
||||
crop_tensor[:, :, :, :padding[2]] = 0
|
||||
crop_tensor[:, :, :, 112 - padding[3]:] = 0
|
||||
embedding = embedder(crop_tensor)
|
||||
embedding = nn.functional.normalize(embedding, p = 2)
|
||||
return embedding
|
||||
|
||||
@@ -3,7 +3,7 @@ import configparser
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from .helper import calc_id_embedding, convert_to_vision_frame, convert_to_vision_tensor
|
||||
from .helper import calc_embedding, convert_to_vision_frame, convert_to_vision_tensor
|
||||
from .models.generator import Generator
|
||||
from .types import EmbedderModule, GeneratorModule, VisionFrame
|
||||
|
||||
@@ -11,10 +11,10 @@ CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
def run_swap(generator : GeneratorModule, id_embedder : EmbedderModule, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
|
||||
def run_swap(generator : GeneratorModule, embedder : EmbedderModule, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
|
||||
source_vision_tensor = convert_to_vision_tensor(source_vision_frame)
|
||||
target_vision_tensor = convert_to_vision_tensor(target_vision_frame)
|
||||
source_embedding = calc_id_embedding(id_embedder, source_vision_tensor, (0, 0, 0, 0))
|
||||
source_embedding = calc_embedding(embedder, source_vision_tensor, (0, 0, 0, 0))
|
||||
output_vision_tensor = generator(source_embedding, target_vision_tensor)[0]
|
||||
output_vision_frame = convert_to_vision_frame(output_vision_tensor)
|
||||
return output_vision_frame
|
||||
@@ -22,7 +22,7 @@ def run_swap(generator : GeneratorModule, id_embedder : EmbedderModule, source_v
|
||||
|
||||
def infer() -> None:
|
||||
generator_path = CONFIG.get('inferencing', 'generator_path')
|
||||
id_embedder_path = CONFIG.get('inferencing', 'id_embedder_path')
|
||||
embedder_path = CONFIG.get('inferencing', 'embedder_path')
|
||||
source_path = CONFIG.get('inferencing', 'source_path')
|
||||
target_path = CONFIG.get('inferencing', 'target_path')
|
||||
output_path = CONFIG.get('inferencing', 'output_path')
|
||||
@@ -31,10 +31,10 @@ def infer() -> None:
|
||||
generator = Generator()
|
||||
generator.load_state_dict(state_dict)
|
||||
generator.eval()
|
||||
id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
id_embedder.eval()
|
||||
embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
embedder.eval()
|
||||
|
||||
source_vision_frame = cv2.imread(source_path)
|
||||
target_vision_frame = cv2.imread(target_path)
|
||||
output_vision_frame = run_swap(generator, id_embedder, source_vision_frame, target_vision_frame)
|
||||
output_vision_frame = run_swap(generator, embedder, source_vision_frame, target_vision_frame)
|
||||
cv2.imwrite(output_path, output_vision_frame)
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
from pytorch_msssim import ssim
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
|
||||
from ..helper import calc_embedding, hinge_fake_loss, hinge_real_loss
|
||||
from ..types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
@@ -14,15 +14,15 @@ CONFIG.read('config.ini')
|
||||
|
||||
class FaceSwapperLoss:
|
||||
def __init__(self) -> None:
|
||||
id_embedder_path = CONFIG.get('training.model', 'id_embedder_path')
|
||||
embedder_path = CONFIG.get('training.model', 'embedder_path')
|
||||
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
|
||||
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
|
||||
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
|
||||
self.mse_loss = nn.MSELoss()
|
||||
self.id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.id_embedder.eval()
|
||||
self.embedder.eval()
|
||||
self.landmarker.eval()
|
||||
self.motion_extractor.eval()
|
||||
|
||||
@@ -105,8 +105,8 @@ class FaceSwapperLoss:
|
||||
return loss_reconstruction
|
||||
|
||||
def calc_identity_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
|
||||
swap_embedding = calc_id_embedding(self.id_embedder, swap_tensor, (30, 0, 10, 10))
|
||||
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (30, 0, 10, 10))
|
||||
swap_embedding = calc_embedding(self.embedder, swap_tensor, (30, 0, 10, 10))
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
|
||||
loss_identity = (1 - torch.cosine_similarity(source_embedding, swap_embedding)).mean()
|
||||
return loss_identity
|
||||
|
||||
@@ -139,3 +139,17 @@ class FaceSwapperLoss:
|
||||
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm)
|
||||
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
|
||||
return translation, scale, rotation
|
||||
|
||||
|
||||
class IdentityLoss(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super(IdentityLoss, self).__init__()
|
||||
embedder_path = CONFIG.get('training.model', 'embedder_path')
|
||||
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
self.embedder.eval()
|
||||
|
||||
def calc_loss(self, source_tensor : Tensor, output_tensor : Tensor) -> Tensor:
|
||||
output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
|
||||
loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
|
||||
return loss
|
||||
|
||||
@@ -13,10 +13,10 @@ from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, Dataset, random_split
|
||||
|
||||
from .dataset import DynamicDataset
|
||||
from .helper import calc_id_embedding
|
||||
from .helper import calc_embedding
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import FaceSwapperLoss
|
||||
from .models.loss import FaceSwapperLoss, IdentityLoss
|
||||
from .types import Batch, Embedding, VisionTensor
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
@@ -27,9 +27,12 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
FaceSwapperLoss.__init__(self)
|
||||
automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization')
|
||||
|
||||
self.generator = Generator()
|
||||
self.discriminator = Discriminator()
|
||||
self.automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization')
|
||||
self.identity_loss = IdentityLoss()
|
||||
self.automatic_optimization = automatic_optimization
|
||||
|
||||
def forward(self, target_tensor : VisionTensor, source_embedding : Embedding) -> Tensor:
|
||||
output_tensor = self.generator(source_embedding, target_tensor)
|
||||
@@ -42,43 +45,57 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
|
||||
return generator_optimizer, discriminator_optimizer
|
||||
|
||||
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
preview_frequency = CONFIG.getfloat('training.trainer', 'preview_frequency')
|
||||
|
||||
source_tensor, target_tensor = batch
|
||||
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
|
||||
swap_tensor = self.generator(source_embedding, target_tensor)
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_attributes = self.generator.get_attributes(target_tensor)
|
||||
swap_attributes = self.generator.get_attributes(swap_tensor)
|
||||
fake_discriminator_outputs = self.discriminator(swap_tensor)
|
||||
generator_output_tensor = self.generator(source_embedding, target_tensor)
|
||||
generator_output_attributes = self.generator.get_attributes(generator_output_tensor)
|
||||
discriminator_output_tensor = self.discriminator(generator_output_tensor)
|
||||
|
||||
generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, swap_attributes, fake_discriminator_outputs, batch)
|
||||
generator_loss_set = self.calc_generator_loss(generator_output_tensor, target_attributes, generator_output_attributes, discriminator_output_tensor, batch)
|
||||
generator_optimizer.zero_grad()
|
||||
self.manual_backward(generator_losses.get('loss_generator'))
|
||||
self.manual_backward(generator_loss_set.get('loss_generator'))
|
||||
generator_optimizer.step()
|
||||
|
||||
real_discriminator_outputs = self.discriminator(source_tensor)
|
||||
fake_discriminator_outputs = self.discriminator(swap_tensor.detach())
|
||||
discriminator_source_tensor = self.discriminator(source_tensor)
|
||||
discriminator_output_tensor = self.discriminator(generator_output_tensor.detach())
|
||||
|
||||
discriminator_losses = self.calc_discriminator_loss(real_discriminator_outputs, fake_discriminator_outputs)
|
||||
discriminator_loss_set = self.calc_discriminator_loss(discriminator_source_tensor, discriminator_output_tensor)
|
||||
discriminator_optimizer.zero_grad()
|
||||
self.manual_backward(discriminator_losses.get('loss_discriminator'))
|
||||
self.manual_backward(discriminator_loss_set.get('loss_discriminator'))
|
||||
discriminator_optimizer.step()
|
||||
|
||||
if self.global_step % CONFIG.getint('training.trainer', 'preview_frequency') == 0:
|
||||
self.generate_preview(source_tensor, target_tensor, swap_tensor)
|
||||
if self.global_step % preview_frequency == 0:
|
||||
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)
|
||||
|
||||
self.log('loss_generator', generator_losses.get('loss_generator'), prog_bar = True)
|
||||
self.log('loss_discriminator', discriminator_losses.get('loss_discriminator'), prog_bar = True)
|
||||
self.log('loss_adversarial', generator_losses.get('loss_adversarial'))
|
||||
self.log('loss_attribute', generator_losses.get('loss_attribute'))
|
||||
self.log('loss_identity', generator_losses.get('loss_identity'))
|
||||
self.log('loss_reconstruction', generator_losses.get('loss_reconstruction'))
|
||||
return generator_losses.get('loss_generator')
|
||||
self.log('loss_generator', generator_loss_set.get('loss_generator'), prog_bar = True)
|
||||
self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator'), prog_bar = True)
|
||||
self.log('loss_adversarial', generator_loss_set.get('loss_adversarial'))
|
||||
self.log('loss_attribute', generator_loss_set.get('loss_attribute'))
|
||||
self.log('loss_identity', generator_loss_set.get('loss_identity'), prog_bar = True)
|
||||
self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction'))
|
||||
|
||||
identity_loss = self.identity_loss.calc_loss(generator_output_tensor, source_tensor)
|
||||
generator_loss = self.calc_generator_loss_new(identity_loss)
|
||||
|
||||
self.log('loss_generator_new', generator_loss, prog_bar = True)
|
||||
self.log('loss_identity_new', identity_loss, prog_bar = True)
|
||||
return generator_loss_set.get('loss_generator')
|
||||
|
||||
def calc_generator_loss_new(self, identity_loss : Tensor) -> Tensor:
|
||||
weight_identity = CONFIG.getfloat('training.losses', 'weight_identity')
|
||||
generator_loss = identity_loss * weight_identity
|
||||
|
||||
return generator_loss
|
||||
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_tensor, target_tensor = batch
|
||||
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
output_tensor = self.generator(source_embedding, target_tensor)
|
||||
output_embedding = calc_id_embedding(self.id_embedder, output_tensor, (0, 0, 0, 0))
|
||||
output_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0))
|
||||
validation = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
|
||||
self.log('validation', validation)
|
||||
return validation
|
||||
|
||||
Reference in New Issue
Block a user