Use new loss code, Remove unused code, Remove old types, Ban VisionTensor naming

This commit is contained in:
henryruhs
2025-02-23 01:05:01 +01:00
parent 63e4bea3cd
commit ed0f6ae897
6 changed files with 54 additions and 217 deletions
+9 -9
View File
@@ -2,19 +2,19 @@ import numpy
import torch
from torch import Tensor, nn
from .types import EmbedderModule, Embedding, Padding, VisionFrame, VisionTensor
from .types import EmbedderModule, Embedding, Padding, VisionFrame
def convert_to_vision_tensor(vision_frame : VisionFrame) -> VisionTensor:
vision_tensor = torch.from_numpy(vision_frame[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32))
vision_tensor = vision_tensor / 255.0
vision_tensor = (vision_tensor - 0.5) * 2
vision_tensor = vision_tensor.unsqueeze(0)
return vision_tensor
def convert_to_tensor(vision_frame : VisionFrame) -> Tensor:
output_tensor = torch.from_numpy(vision_frame[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32))
output_tensor = output_tensor / 255.0
output_tensor = (output_tensor - 0.5) * 2
output_tensor = output_tensor.unsqueeze(0)
return output_tensor
def convert_to_vision_frame(vision_tensor : VisionTensor) -> VisionFrame:
vision_frame = vision_tensor.detach().cpu().numpy()[0]
def convert_to_vision_frame(input_tensor : Tensor) -> VisionFrame:
vision_frame = input_tensor.detach().cpu().numpy()[0]
vision_frame = vision_frame.transpose(1, 2, 0)
vision_frame = (vision_frame + 1) * 127.5
vision_frame = vision_frame.clip(0, 255).astype(numpy.uint8)
+6 -6
View File
@@ -3,7 +3,7 @@ import configparser
import cv2
import torch
from .helper import calc_embedding, convert_to_vision_frame, convert_to_vision_tensor
from .helper import calc_embedding, convert_to_vision_frame, convert_to_tensor
from .models.generator import Generator
from .types import EmbedderModule, GeneratorModule, VisionFrame
@@ -12,11 +12,11 @@ CONFIG.read('config.ini')
def run_swap(generator : GeneratorModule, embedder : EmbedderModule, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
source_vision_tensor = convert_to_vision_tensor(source_vision_frame)
target_vision_tensor = convert_to_vision_tensor(target_vision_frame)
source_embedding = calc_embedding(embedder, source_vision_tensor, (0, 0, 0, 0))
output_vision_tensor = generator(source_embedding, target_vision_tensor)[0]
output_vision_frame = convert_to_vision_frame(output_vision_tensor)
source_tensor = convert_to_tensor(source_vision_frame)
target_tensor = convert_to_tensor(target_vision_frame)
source_embedding = calc_embedding(embedder, source_tensor, (0, 0, 0, 0))
output_tensor = generator(source_embedding, target_tensor)[0]
output_vision_frame = convert_to_vision_frame(output_tensor)
return output_vision_frame
+4 -145
View File
@@ -6,153 +6,12 @@ from pytorch_msssim import ssim
from torch import Tensor, nn
from ..helper import calc_embedding
from ..types import Attributes, Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, LossTensor, SwapAttributes, TargetAttributes, VisionTensor
from ..types import Attributes, FaceLandmark203
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
def hinge_real_loss(input_tensor : Tensor) -> Tensor:
real_loss = torch.relu(1 - input_tensor)
real_loss = real_loss.mean(dim = [ 1, 2, 3 ])
return real_loss
def hinge_fake_loss(input_tensor : Tensor) -> Tensor:
fake_loss = torch.relu(input_tensor + 1)
fake_loss = fake_loss.mean(dim = [ 1, 2, 3 ])
return fake_loss
class FaceSwapperLoss:
def __init__(self) -> None:
embedder_path = CONFIG.get('training.model', 'embedder_path')
landmarker_path = CONFIG.get('training.model', 'landmarker_path')
motion_extractor_path = CONFIG.get('training.model', 'motion_extractor_path')
self.batch_size = CONFIG.getint('training.loader', 'batch_size')
self.mse_loss = nn.MSELoss()
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.motion_extractor = torch.jit.load(motion_extractor_path, map_location = 'cpu') # type:ignore[no-untyped-call]
def calc_generator_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes, swap_attributes : SwapAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> GeneratorLossSet:
weight_adversarial = CONFIG.getfloat('training.losses', 'adversarial_weight')
weight_identity = CONFIG.getfloat('training.losses', 'identity_weight')
weight_attribute = CONFIG.getfloat('training.losses', 'attribute_weight')
weight_reconstruction = CONFIG.getfloat('training.losses', 'reconstruction_weight')
weight_pose = CONFIG.getfloat('training.losses', 'pose_weight')
weight_gaze = CONFIG.getfloat('training.losses', 'gaze_weight')
source_tensor, target_tensor = batch
is_same_person = torch.tensor(0) if torch.equal(source_tensor, target_tensor) else torch.tensor(1)
generator_loss_set =\
{
'loss_adversarial': self.calc_adversarial_loss(discriminator_outputs),
'loss_identity': self.calc_identity_loss(source_tensor, swap_tensor),
'loss_attribute': self.calc_attribute_loss(target_attributes, swap_attributes),
'loss_reconstruction': self.calc_reconstruction_loss(swap_tensor, target_tensor, is_same_person)
}
generator_loss_set['loss_pose'] = self.calc_pose_loss(swap_tensor, target_tensor)
generator_loss_set['loss_gaze'] = self.calc_gaze_loss(swap_tensor, target_tensor)
generator_loss_set['loss_generator'] = generator_loss_set.get('loss_adversarial') * weight_adversarial
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_identity') * weight_identity
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_attribute') * weight_attribute
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_reconstruction') * weight_reconstruction
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_pose') * weight_pose
generator_loss_set['loss_generator'] += generator_loss_set.get('loss_gaze') * weight_gaze
return generator_loss_set
def hinge_real_loss(input_tensor: Tensor) -> Tensor:
real_loss = torch.relu(1 - input_tensor)
real_loss = real_loss.mean(dim = [1, 2, 3])
return real_loss
def hinge_fake_loss(input_tensor: Tensor) -> Tensor:
fake_loss = torch.relu(input_tensor + 1)
fake_loss = fake_loss.mean(dim = [1, 2, 3])
return fake_loss
def calc_discriminator_loss(self, real_discriminator_outputs : DiscriminatorOutputs, fake_discriminator_outputs : DiscriminatorOutputs) -> DiscriminatorLossSet:
discriminator_loss_set = {}
loss_fakes = []
for fake_discriminator_output in fake_discriminator_outputs:
loss_fakes.append(hinge_fake_loss(fake_discriminator_output[0]))
loss_trues = []
for true_discriminator_output in real_discriminator_outputs:
loss_trues.append(hinge_real_loss(true_discriminator_output[0]))
loss_fake = torch.stack(loss_fakes).mean()
loss_true = torch.stack(loss_trues).mean()
discriminator_loss_set['loss_discriminator'] = (loss_true + loss_fake) * 0.5
return discriminator_loss_set
def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor:
loss_adversarials = []
for discriminator_output in discriminator_outputs:
loss_adversarials.append(hinge_real_loss(discriminator_output[0]).mean())
loss_adversarial = torch.stack(loss_adversarials).mean()
return loss_adversarial
def calc_attribute_loss(self, target_attributes : TargetAttributes, swap_attributes : SwapAttributes) -> LossTensor:
loss_attributes = []
for swap_attribute, target_attribute in zip(swap_attributes, target_attributes):
loss_attributes.append(torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean())
loss_attribute = torch.stack(loss_attributes).mean() * 0.5
return loss_attribute
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor:
loss_reconstruction = torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1)
loss_reconstruction = torch.mean(loss_reconstruction, dim = 1) * 0.5
loss_reconstruction = torch.sum(loss_reconstruction * is_same_person) / (is_same_person.sum() + 1e-4)
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
return loss_reconstruction
def calc_identity_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
swap_embedding = calc_embedding(self.embedder, swap_tensor, (30, 0, 10, 10))
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
loss_identity = (1 - torch.cosine_similarity(source_embedding, swap_embedding)).mean()
return loss_identity
def calc_pose_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
swap_motion_features = self.get_pose_features(swap_tensor)
target_motion_features = self.get_pose_features(target_tensor)
loss_pose = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
for swap_motion_feature, target_motion_feature in zip(swap_motion_features, target_motion_features):
loss_pose += self.mse_loss(swap_motion_feature, target_motion_feature)
return loss_pose
def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
swap_landmark = self.get_face_landmarks(swap_tensor)
target_landmark = self.get_face_landmarks(target_tensor)
left_gaze_loss = self.mse_loss(swap_landmark[:, 198], target_landmark[:, 198])
right_gaze_loss = self.mse_loss(swap_landmark[:, 197], target_landmark[:, 197])
gaze_loss = left_gaze_loss + right_gaze_loss
return gaze_loss
def get_face_landmarks(self, vision_tensor : VisionTensor) -> FaceLandmark203:
vision_tensor_norm = (vision_tensor + 1) * 0.5
vision_tensor_norm = nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear')
landmarks = self.landmarker(vision_tensor_norm)[2].view(-1, 203, 2)
return landmarks
def get_pose_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
vision_tensor_norm = (vision_tensor + 1) * 0.5
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm)
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
return translation, scale, rotation
class DiscriminatorLoss(nn.Module):
def __init__(self) -> None:
super().__init__()
@@ -270,8 +129,8 @@ class PoseLoss(nn.Module):
return pose_loss, weighted_pose_loss
def get_motion_features(self, input_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]:
vision_tensor_norm = (input_tensor + 1) * 0.5
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(vision_tensor_norm)
input_tensor = (input_tensor + 1) * 0.5
pitch, yaw, roll, translation, expression, scale, _ = self.motion_extractor(input_tensor)
rotation = torch.cat([ pitch, yaw, roll ], dim = 1)
return translation, scale, rotation
@@ -283,7 +142,7 @@ class GazeLoss(nn.Module):
self.landmarker = torch.jit.load(landmarker_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.mse_loss = nn.MSELoss()
def calc(self, target_tensor : VisionTensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
def calc(self, target_tensor : Tensor, output_tensor : Tensor, ) -> Tuple[Tensor, Tensor]:
gaze_weight = CONFIG.getfloat('training.losses', 'gaze_weight')
output_face_landmark = self.detect_face_landmark(output_tensor)
target_face_landmark = self.detect_face_landmark(target_tensor)
@@ -1,7 +1,7 @@
import torch
from torch import Tensor, nn
from ..types import Embedding, TargetAttributes
from ..types import Attributes, Embedding
class AADGenerator(nn.Module):
@@ -17,7 +17,7 @@ class AADGenerator(nn.Module):
self.res_block_7 = AADResBlock(128, 64, 64, id_channels, num_blocks)
self.res_block_8 = AADResBlock(64, 3, 64, id_channels, num_blocks)
def forward(self, target_attributes : TargetAttributes, source_embedding : Embedding) -> Tensor:
def forward(self, target_attributes : Attributes, source_embedding : Embedding) -> Tensor:
feature_map = self.upsample(source_embedding)
feature_map_1 = nn.functional.interpolate(self.res_block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
feature_map_2 = nn.functional.interpolate(self.res_block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False)
@@ -59,10 +59,10 @@ class AADSequential(nn.Module):
super().__init__()
self.layers = nn.ModuleList(args)
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, id_embedding : Embedding) -> Tensor:
def forward(self, feature_map : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
for layer in self.layers:
if isinstance(layer, AADLayer):
feature_map = layer(feature_map, attribute_embedding, id_embedding)
feature_map = layer(feature_map, attribute_embedding, identity_embedding)
else:
feature_map = layer(feature_map)
return feature_map
+31 -43
View File
@@ -16,18 +16,18 @@ from .dataset import DynamicDataset
from .helper import calc_embedding
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, FaceSwapperLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss
from .types import Batch, Embedding, VisionTensor
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, PoseLoss, ReconstructionLoss
from .types import Batch, Embedding
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
class FaceSwapperTrainer(lightning.LightningModule):
def __init__(self) -> None:
super().__init__()
FaceSwapperLoss.__init__(self)
automatic_optimization = CONFIG.getboolean('training.trainer', 'automatic_optimization')
embedder_path = CONFIG.get('training.model', 'embedder_path')
self.generator = Generator()
self.discriminator = Discriminator()
@@ -38,9 +38,10 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
self.identity_loss = IdentityLoss()
self.pose_loss = PoseLoss()
self.gaze_loss = GazeLoss()
self.embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
self.automatic_optimization = automatic_optimization
def forward(self, target_tensor : VisionTensor, source_embedding : Embedding) -> Tensor:
def forward(self, target_tensor : Tensor, source_embedding : Embedding) -> Tensor:
output_tensor = self.generator(source_embedding, target_tensor)
return output_tensor
@@ -61,34 +62,6 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
generator_output_attributes = self.generator.get_attributes(generator_output_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor)
generator_loss_set = self.calc_generator_loss(generator_output_tensor, target_attributes, generator_output_attributes, discriminator_output_tensors, batch)
generator_optimizer.zero_grad()
self.manual_backward(generator_loss_set.get('loss_generator'))
generator_optimizer.step()
discriminator_source_tensors = self.discriminator(source_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
discriminator_loss_set = self.calc_discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
discriminator_optimizer.zero_grad()
self.manual_backward(discriminator_loss_set.get('loss_discriminator'))
discriminator_optimizer.step()
if self.global_step % preview_frequency == 0:
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)
self.log('loss_generator', generator_loss_set.get('loss_generator'), prog_bar = True)
self.log('loss_discriminator', discriminator_loss_set.get('loss_discriminator'), prog_bar = True)
self.log('loss_adversarial', generator_loss_set.get('loss_adversarial'))
self.log('loss_attribute', generator_loss_set.get('loss_attribute'))
self.log('loss_identity', generator_loss_set.get('loss_identity'))
self.log('loss_reconstruction', generator_loss_set.get('loss_reconstruction'))
self.log('loss_pose', generator_loss_set.get('loss_pose'))
self.log('loss_gaze', generator_loss_set.get('loss_gaze'))
###############################################
discriminator_loss = self.discriminator_loss.calc(discriminator_source_tensors, discriminator_output_tensors)
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss.calc(discriminator_output_tensors)
attribute_loss, weighted_attribute_loss = self.attribute_loss.calc(target_attributes, generator_output_attributes)
reconstruction_loss, weighted_reconstruction_loss = self.reconstruction_loss.calc(source_tensor, target_tensor, generator_output_tensor)
@@ -97,15 +70,30 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
gaze_loss, weighted_gaze_loss = self.gaze_loss.calc(target_tensor, generator_output_tensor)
generator_loss = weighted_adversarial_loss + weighted_attribute_loss + weighted_reconstruction_loss + weighted_identity_loss + weighted_pose_loss
self.log('generator_loss_new', generator_loss, prog_bar = True)
self.log('discriminator_loss_new', discriminator_loss, prog_bar = True)
self.log('adversarial_loss_new', adversarial_loss)
self.log('attribute_loss_new', attribute_loss)
self.log('reconstruction_loss_new', reconstruction_loss)
self.log('identity_loss_new', identity_loss)
self.log('pose_loss_new', pose_loss)
self.log('gaze_loss_new', gaze_loss)
return generator_loss_set.get('loss_generator')
generator_optimizer.zero_grad()
self.manual_backward(generator_loss)
generator_optimizer.step()
discriminator_source_tensors = self.discriminator(source_tensor)
discriminator_output_tensors = self.discriminator(generator_output_tensor.detach())
discriminator_loss = self.discriminator_loss.calc(discriminator_source_tensors, discriminator_output_tensors)
discriminator_optimizer.zero_grad()
self.manual_backward(discriminator_loss)
discriminator_optimizer.step()
if self.global_step % preview_frequency == 0:
self.generate_preview(source_tensor, target_tensor, generator_output_tensor)
self.log('generator_loss', generator_loss, prog_bar = True)
self.log('discriminator_loss', discriminator_loss, prog_bar = True)
self.log('adversarial_loss', adversarial_loss)
self.log('attribute_loss', attribute_loss)
self.log('reconstruction_loss', reconstruction_loss)
self.log('identity_loss', identity_loss)
self.log('pose_loss', pose_loss)
self.log('gaze_loss', gaze_loss)
return generator_loss
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor = batch
@@ -116,7 +104,7 @@ class FaceSwapperTrainer(lightning.LightningModule, FaceSwapperLoss):
self.log('validation', validation)
return validation
def generate_preview(self, source_tensor : VisionTensor, target_tensor : VisionTensor, output_tensor : VisionTensor) -> None:
def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor) -> None:
preview_limit = 8
preview_cells = []
-10
View File
@@ -7,23 +7,13 @@ from torch.nn import Module
Batch : TypeAlias = Tuple[Tensor, Tensor]
SwapAttributes : TypeAlias = Tuple[Tensor, ...]
TargetAttributes : TypeAlias = Tuple[Tensor, ...]
DiscriminatorOutputs : TypeAlias = List[List[Tensor]]
Attributes : TypeAlias = Tuple[Tensor, ...]
Embedding : TypeAlias = Tensor
FaceLandmark203 : TypeAlias = Tensor
StateSet : TypeAlias = OrderedDict[str, Any]
Padding : TypeAlias = Tuple[int, int, int, int]
VisionFrame : TypeAlias = NDArray[Any]
LossTensor : TypeAlias = Tensor
VisionTensor : TypeAlias = Tensor
GeneratorLossSet : TypeAlias = Dict[str, Tensor]
DiscriminatorLossSet : TypeAlias = Dict[str, Tensor]
GeneratorModule : TypeAlias = Module
EmbedderModule : TypeAlias = Module