mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
ugly training code
This commit is contained in:
@@ -170,17 +170,19 @@ class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
def configure_optimizers(self) -> OptimizerLRScheduler:
|
||||
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = CONFIG.getfloat('training.generator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = CONFIG.getfloat('training.discriminator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4)
|
||||
generator_scheduler = torch.optim.lr_scheduler.StepLR(generator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma'))
|
||||
discriminator_scheduler = torch.optim.lr_scheduler.StepLR(discriminator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma'))
|
||||
return (
|
||||
{
|
||||
"optimizer": generator_optimizer,
|
||||
"lr_scheduler": generator_scheduler
|
||||
},
|
||||
{
|
||||
"optimizer": discriminator_optimizer,
|
||||
"lr_scheduler": discriminator_scheduler
|
||||
})
|
||||
if CONFIG.getboolean('training.schedulers', 'enable'):
|
||||
generator_scheduler = torch.optim.lr_scheduler.StepLR(generator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma'))
|
||||
discriminator_scheduler = torch.optim.lr_scheduler.StepLR(discriminator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma'))
|
||||
return (
|
||||
{
|
||||
"optimizer": generator_optimizer,
|
||||
"lr_scheduler": generator_scheduler
|
||||
},
|
||||
{
|
||||
"optimizer": discriminator_optimizer,
|
||||
"lr_scheduler": discriminator_scheduler
|
||||
})
|
||||
return generator_optimizer, discriminator_optimizer
|
||||
|
||||
|
||||
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
@@ -256,7 +258,7 @@ class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
|
||||
loss_reconstruction = loss_reconstruction * 0.3 + loss_ssim * 0.7
|
||||
generator_losses['loss_reconstruction'] = loss_reconstruction
|
||||
generator_losses['loss_generator'] += CONFIG.getfloat('training.losses', 'weight_reconstruction')
|
||||
generator_losses['loss_generator'] += loss_reconstruction * CONFIG.getfloat('training.losses', 'weight_reconstruction')
|
||||
|
||||
if CONFIG.getfloat('training.losses', 'weight_tsr') > 0:
|
||||
# tsr loss
|
||||
@@ -285,13 +287,13 @@ class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
loss_left_eye_open = L2_loss(swap_landmark_features[0], target_landmark_features[0])
|
||||
loss_right_eye_open = L2_loss(swap_landmark_features[1], target_landmark_features[1])
|
||||
loss_eye_open = loss_left_eye_open + loss_right_eye_open
|
||||
generator_losses['loss_eye_open'] = loss_eye_open * CONFIG.getfloat('training.losses', 'weight_eye_open')
|
||||
generator_losses['loss_generator'] += loss_eye_open
|
||||
generator_losses['loss_eye_open'] = loss_eye_open
|
||||
generator_losses['loss_generator'] += loss_eye_open * CONFIG.getfloat('training.losses', 'weight_eye_open')
|
||||
|
||||
# lip open loss
|
||||
loss_lip_open = L2_loss(swap_landmark_features[2], target_landmark_features[2])
|
||||
generator_losses['loss_lip_open'] = loss_lip_open * CONFIG.getfloat('training.losses', 'weight_lip_open')
|
||||
generator_losses['loss_generator'] += loss_lip_open
|
||||
generator_losses['loss_lip_open'] = loss_lip_open
|
||||
generator_losses['loss_generator'] += loss_lip_open * CONFIG.getfloat('training.losses', 'weight_lip_open')
|
||||
return generator_losses
|
||||
|
||||
|
||||
@@ -359,30 +361,59 @@ class FaceSwapper(pytorch_lightning.LightningModule):
|
||||
torchvision.utils.save_image(grid, f"previews/step_{self.global_step}.jpg")
|
||||
self.logger.experiment.add_image("Generator Preview", grid, self.global_step)
|
||||
|
||||
|
||||
def log_validation_preview(self):
|
||||
validation_source_path = CONFIG.get('training.validation', 'sources')
|
||||
validation_target_path = CONFIG.get('training.validation', 'targets')
|
||||
sources = [read_image(os.path.join(validation_source_path, f)) for f in os.listdir(validation_source_path) if f.lower().endswith('.jpg') or f.lower().endswith('.png')]
|
||||
targets = [read_image(os.path.join(validation_target_path, f)) for f in os.listdir(validation_target_path) if f.lower().endswith('.jpg') or f.lower().endswith('.png')]
|
||||
read_images = lambda path: [read_image(os.path.join(path, f)) for f in sorted(os.listdir(path)) if f.lower().endswith('.jpg') or f.lower().endswith('.png')]
|
||||
to_numpy = lambda x: (x.cpu().detach().numpy()[0].transpose(1, 2, 0).clip(-1, 1)[:, :, ::-1] + 1) * 127.5
|
||||
transforms = torchvision.transforms.Compose(
|
||||
[
|
||||
torchvision.transforms.Resize((256, 256), interpolation = torchvision.transforms.InterpolationMode.BICUBIC),
|
||||
torchvision.transforms.ToTensor(),
|
||||
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
])
|
||||
to_numpy = lambda x: (x.cpu().detach().numpy()[0].transpose(1, 2, 0).clip(-1, 1)[:,:,::-1] + 1) * 127.5
|
||||
self.generator.eval()
|
||||
results = []
|
||||
sources = read_images(CONFIG.get('training.validation', 'sources'))
|
||||
targets_front = read_images(CONFIG.get('training.validation', 'targets_front'))
|
||||
targets_side = read_images(CONFIG.get('training.validation', 'targets_side'))
|
||||
targets_makeup = read_images(CONFIG.get('training.validation', 'targets_makeup'))
|
||||
targets_occlusion = read_images(CONFIG.get('training.validation', 'targets_occlusion'))
|
||||
|
||||
for source, target in zip(sources, targets):
|
||||
|
||||
self.generator.eval()
|
||||
|
||||
results_source = []
|
||||
results_front = []
|
||||
results_side = []
|
||||
results_makeup = []
|
||||
results_occlusion = []
|
||||
|
||||
for source, target_front, target_side, target_makeup, target_occlusion in zip(sources, targets_front, targets_side, targets_makeup, targets_occlusion):
|
||||
source_tensor = transforms(source).unsqueeze(0).to(self.device).half()
|
||||
target_tensor = transforms(target).unsqueeze(0).to(self.device).half()
|
||||
source_embedding = self.get_arcface_embedding(source_tensor, (0, 0, 0, 0))
|
||||
target_front_tensor = transforms(target_front).unsqueeze(0).to(self.device).half()
|
||||
target_side_tensor = transforms(target_side).unsqueeze(0).to(self.device).half()
|
||||
target_makeup_tensor = transforms(target_makeup).unsqueeze(0).to(self.device).half()
|
||||
target_occlusion_tensor = transforms(target_occlusion).unsqueeze(0).to(self.device).half()
|
||||
|
||||
with torch.no_grad():
|
||||
output, _ = self.generator(target_tensor, source_embedding)
|
||||
results.append(numpy.hstack([to_numpy(source_tensor), to_numpy(target_tensor), to_numpy(output)]))
|
||||
preview = numpy.vstack(results)
|
||||
output_front, _ = self.generator(target_front_tensor, source_embedding)
|
||||
output_side, _ = self.generator(target_side_tensor, source_embedding)
|
||||
output_makeup, _ = self.generator(target_makeup_tensor, source_embedding)
|
||||
output_occlusion, _ = self.generator(target_occlusion_tensor, source_embedding)
|
||||
|
||||
results_source.append(to_numpy(source_tensor))
|
||||
results_front.append(numpy.hstack([to_numpy(target_front_tensor), to_numpy(output_front)]))
|
||||
results_side.append(numpy.hstack([to_numpy(target_side_tensor), to_numpy(output_side)]))
|
||||
results_makeup.append(numpy.hstack([to_numpy(target_makeup_tensor), to_numpy(output_makeup)]))
|
||||
results_occlusion.append(numpy.hstack([to_numpy(target_occlusion_tensor), to_numpy(output_occlusion)]))
|
||||
|
||||
sources_vertical = numpy.vstack(results_source)
|
||||
results_front_vertical = numpy.vstack(results_front)
|
||||
results_side_vertical = numpy.vstack(results_side)
|
||||
results_makeup_vertical = numpy.vstack(results_makeup)
|
||||
results_occlusion_vertical = numpy.vstack(results_occlusion)
|
||||
pad = numpy.zeros((sources_vertical.shape[0], 10, 3), dtype = sources_vertical.dtype)
|
||||
preview = numpy.hstack([sources_vertical, pad, results_front_vertical, pad, results_side_vertical, pad, results_makeup_vertical, pad, results_occlusion_vertical])
|
||||
|
||||
os.makedirs("validation_previews", exist_ok=True)
|
||||
cv2.imwrite(f"validation_previews/step_{self.global_step}.jpg", preview)
|
||||
self.generator.train()
|
||||
|
||||
Reference in New Issue
Block a user