Rename id_embedder to embedder, Tons of naming in training step, Introduce new IdentityLoss class

This commit is contained in:
henryruhs
2025-02-22 12:37:47 +01:00
parent 2220f5ef08
commit 085c493e18
5 changed files with 80 additions and 49 deletions
+2 -2
View File
@@ -40,7 +40,7 @@ split_ratio = 0.9995
```
[training.model]
id_embedder_path = .models/id_embedder.pt
embedder_path = .models/arcface.pt
landmarker_path = .models/landmarker.pt
motion_extractor_path = .models/motion_extractor.pt
```
@@ -99,7 +99,7 @@ opset_version = 15
```
[inferencing]
generator_path = .outputs/last.ckpt
id_embedder_path = .models/id_embedder.pt
embedder_path = .models/arcface.pt
source_path = .assets/source.jpg
target_path = .assets/target.jpg
output_path = .outputs/output.jpg
+10 -10
View File
@@ -34,13 +34,13 @@ def hinge_fake_loss(input_tensor : Tensor) -> Tensor:
return fake_loss
def calc_id_embedding(id_embedder : EmbedderModule, vision_tensor : VisionTensor, padding : Padding) -> Embedding:
crop_vision_tensor = vision_tensor[:, :, 15 : 241, 15 : 241]
crop_vision_tensor = nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'area')
crop_vision_tensor[:, :, :padding[0], :] = 0
crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0
crop_vision_tensor[:, :, :, :padding[2]] = 0
crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0
source_embedding = id_embedder(crop_vision_tensor)
source_embedding = nn.functional.normalize(source_embedding, p = 2)
return source_embedding
def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
crop_tensor = input_tensor[:, :, 15: 241, 15: 241]
crop_tensor = nn.functional.interpolate(crop_tensor, size = (112, 112), mode = 'area')
crop_tensor[:, :, :padding[0], :] = 0
crop_tensor[:, :, 112 - padding[1]:, :] = 0
crop_tensor[:, :, :, :padding[2]] = 0
crop_tensor[:, :, :, 112 - padding[3]:] = 0
embedding = embedder(crop_tensor)
embedding = nn.functional.normalize(embedding, p = 2)
return embedding
+7 -7
View File
@@ -3,7 +3,7 @@ import configparser
import cv2
import torch
from .helper import calc_id_embedding, convert_to_vision_frame, convert_to_vision_tensor
from .helper import calc_embedding, convert_to_vision_frame, convert_to_vision_tensor
from .models.generator import Generator
from .types import EmbedderModule, GeneratorModule, VisionFrame
@@ -11,10 +11,10 @@ CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
def run_swap(generator : GeneratorModule, id_embedder : EmbedderModule, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
def run_swap(generator : GeneratorModule, embedder : EmbedderModule, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
source_vision_tensor = convert_to_vision_tensor(source_vision_frame)
target_vision_tensor = convert_to_vision_tensor(target_vision_frame)
source_embedding = calc_id_embedding(id_embedder, source_vision_tensor, (0, 0, 0, 0))
source_embedding = calc_embedding(embedder, source_vision_tensor, (0, 0, 0, 0))
output_vision_tensor = generator(source_embedding, target_vision_tensor)[0]
output_vision_frame = convert_to_vision_frame(output_vision_tensor)
return output_vision_frame
@@ -22,7 +22,7 @@ def run_swap(generator : GeneratorModule, id_embedder : EmbedderModule, source_v
def infer() -> None:
generator_path = CONFIG.get('inferencing', 'generator_path')
id_embedder_path = CONFIG.get('inferencing', 'id_embedder_path')
embedder_path = CONFIG.get('inferencing', 'embedder_path')
source_path = CONFIG.get('inferencing', 'source_path')
target_path = CONFIG.get('inferencing', 'target_path')
output_path = CONFIG.get('inferencing', 'output_path')
@@ -31,10 +31,10 @@ def infer() -> None:
generator = Generator()
generator.load_state_dict(state_dict)
generator.eval()
id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
id_embedder.eval()
embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
embedder.eval()
source_vision_frame = cv2.imread(source_path)
target_vision_frame = cv2.imread(target_path)
output_vision_frame = run_swap(generator, id_embedder, source_vision_frame, target_vision_frame)
output_vision_frame = run_swap(generator, embedder, source_vision_frame, target_vision_frame)
cv2.imwrite(output_path, output_vision_frame)
+20 -6
View File
@@ -5,7 +5,7 @@ import torch
from pytorch_msssim import ssim
from torch import Tensor, nn
from ..helper import calc_id_embedding, hinge_fake_loss, hinge_real_loss
from ..helper import calc_embedding, hinge_fake_loss, hinge_real_loss
from ..types import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
CONFIG = configparser.ConfigParser()
@@ -14,15 +14,15 @@ CONFIG.read('config.ini')
class FaceSwapperLoss:
def __init__(self) -> None:
id_embedder_path = CONFIG.get('training.model', 'id_embedder_path')
embedder_path = CONFIG.get('training.model', 'embedder_path')
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
self.mse_loss = nn.MSELoss()
self.id_embedder = torch.jit.load(id_embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.id_embedder.eval()
self.embedder.eval()
self.landmarker.eval()
self.motion_extractor.eval()
@@ -105,8 +105,8 @@ class FaceSwapperLoss:
return loss_reconstruction
def calc_identity_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
swap_embedding = calc_id_embedding(self.id_embedder, swap_tensor, (30, 0, 10, 10))
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (30, 0, 10, 10))
swap_embedding = calc_embedding(self.embedder, swap_tensor, (30, 0, 10, 10))
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
loss_identity = (1 - torch.cosine_similarity(source_embedding, swap_embedding)).mean()
return loss_identity
@@ -139,3 +139,17 @@ class FaceSwapperLoss:
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm)
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
return translation, scale, rotation
class IdentityLoss(torch.nn.Module):
def __init__(self) -> None:
super(IdentityLoss, self).__init__()
embedder_path = CONFIG.get('training.model', 'embedder_path')
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.embedder.eval()
def calc_loss(self, source_tensor : Tensor, output_tensor : Tensor) -> Tensor:
output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
return loss
+41 -24
View File
@@ -13,10 +13,10 @@ from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset, random_split
from .dataset import DynamicDataset
from .helper import calc_id_embedding
from .helper import calc_embedding
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import FaceSwapperLoss
from .models.loss import FaceSwapperLoss, IdentityLoss
from .types import Batch, Embedding, VisionTensor
CONFIG = configparser.ConfigParser()
@@ -27,9 +27,12 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
def __init__(self) -> None:
super().__init__()
FaceSwapperLoss.__init__(self)
automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization')
self.generator = Generator()
self.discriminator = Discriminator()
self.automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization')
self.identity_loss = IdentityLoss()
self.automatic_optimization = automatic_optimization
def forward(self, target_tensor : VisionTensor, source_embedding : Embedding) -> Tensor:
output_tensor = self.generator(source_embedding, target_tensor)
@@ -42,43 +45,57 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
return generator_optimizer, discriminator_optimizer
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
preview_frequency = CONFIG.getfloat('training.trainer', 'preview_frequency')
source_tensor, target_tensor = batch
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
swap_tensor = self.generator(source_embedding, target_tensor)
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
target_attributes = self.generator.get_attributes(target_tensor)
swap_attributes = self.generator.get_attributes(swap_tensor)
fake_discriminator_outputs = self.discriminator(swap_tensor)
generator_output_tensor = self.generator(source_embedding, target_tensor)
generator_output_attributes = self.generator.get_attributes(generator_output_tensor)
discriminator_output_tensor = self.discriminator(generator_output_tensor)
generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, swap_attributes, fake_discriminator_outputs, batch)
generator_loss_set = self.calc_generator_loss(generator_output_tensor, target_attributes, generator_output_attributes, discriminator_output_tensor, batch)
generator_optimizer.zero_grad()
self.manual_backward(generator_losses.get('loss_generator'))
self.manual_backward(generator_loss_set.get('loss_generator'))
generator_optimizer.step()
real_discriminator_outputs = self.discriminator(source_tensor)
fake_discriminator_outputs = self.discriminator(swap_tensor.detach())
discriminator_source_tensor = self.discriminator(source_tensor)
discriminator_output_tensor = self.discriminator(generator_output_tensor.detach())
discriminator_losses = self.calc_discriminator_loss(real_discriminator_outputs, fake_discriminator_outputs)
discriminator_loss_set = self.calc_discriminator_loss(discriminator_source_tensor, discriminator_output_tensor)
discriminator_optimizer.zero_grad()
self.manual_backward(discriminator_losses.get('loss_discriminator'))
self.manual_backward(discriminator_loss_set.get('loss_discriminator'))
discriminator_optimizer.step()
if self.global_step % CONFIG.getint('training.trainer', 'preview_frequency') == 0:
self.generate_preview(source_tensor, target_tensor, swap_tensor)
if self.global_step % preview_frequency == 0:
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)
self.log('loss_generator', generator_losses.get('loss_generator'), prog_bar = True)
self.log('loss_discriminator', discriminator_losses.get('loss_discriminator'), prog_bar = True)
self.log('loss_adversarial', generator_losses.get('loss_adversarial'))
self.log('loss_attribute', generator_losses.get('loss_attribute'))
self.log('loss_identity', generator_losses.get('loss_identity'))
self.log('loss_reconstruction', generator_losses.get('loss_reconstruction'))
return generator_losses.get('loss_generator')
self.log('loss_generator', generator_loss_set.get('loss_generator'), prog_bar = True)
self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator'), prog_bar = True)
self.log('loss_adversarial', generator_loss_set.get('loss_adversarial'))
self.log('loss_attribute', generator_loss_set.get('loss_attribute'))
self.log('loss_identity', generator_loss_set.get('loss_identity'), prog_bar = True)
self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction'))
identity_loss = self.identity_loss.calc_loss(generator_output_tensor, source_tensor)
generator_loss = self.calc_generator_loss_new(identity_loss)
self.log('loss_generator_new', generator_loss, prog_bar = True)
self.log('loss_identity_new', identity_loss, prog_bar = True)
return generator_loss_set.get('loss_generator')
def calc_generator_loss_new(self, identity_loss : Tensor) -> Tensor:
weight_identity = CONFIG.getfloat('training.losses', 'weight_identity')
generator_loss = identity_loss * weight_identity
return generator_loss
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor = batch
source_embedding = calc_id_embedding(self.id_embedder, source_tensor, (0, 0, 0, 0))
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
output_tensor = self.generator(source_embedding, target_tensor)
output_embedding = calc_id_embedding(self.id_embedder, output_tensor, (0, 0, 0, 0))
output_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0))
validation = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
self.log('validation', validation)
return validation