ugly training code

This commit is contained in:
harisreedhar
2024-12-11 14:28:13 +05:30
committed by henryruhs
parent 650268c06b
commit a1fd382659
+59 -28
View File
@@ -170,17 +170,19 @@ class FaceSwapper(pytorch_lightning.LightningModule):
def configure_optimizers(self) -> OptimizerLRScheduler:
generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = CONFIG.getfloat('training.generator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4)
discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = CONFIG.getfloat('training.discriminator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4)
generator_scheduler = torch.optim.lr_scheduler.StepLR(generator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma'))
discriminator_scheduler = torch.optim.lr_scheduler.StepLR(discriminator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma'))
return (
{
"optimizer": generator_optimizer,
"lr_scheduler": generator_scheduler
},
{
"optimizer": discriminator_optimizer,
"lr_scheduler": discriminator_scheduler
})
if CONFIG.getboolean('training.schedulers', 'enable'):
generator_scheduler = torch.optim.lr_scheduler.StepLR(generator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma'))
discriminator_scheduler = torch.optim.lr_scheduler.StepLR(discriminator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma'))
return (
{
"optimizer": generator_optimizer,
"lr_scheduler": generator_scheduler
},
{
"optimizer": discriminator_optimizer,
"lr_scheduler": discriminator_scheduler
})
return generator_optimizer, discriminator_optimizer
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
@@ -256,7 +258,7 @@ class FaceSwapper(pytorch_lightning.LightningModule):
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
loss_reconstruction = loss_reconstruction * 0.3 + loss_ssim * 0.7
generator_losses['loss_reconstruction'] = loss_reconstruction
generator_losses['loss_generator'] += CONFIG.getfloat('training.losses', 'weight_reconstruction')
generator_losses['loss_generator'] += loss_reconstruction * CONFIG.getfloat('training.losses', 'weight_reconstruction')
if CONFIG.getfloat('training.losses', 'weight_tsr') > 0:
# tsr loss
@@ -285,13 +287,13 @@ class FaceSwapper(pytorch_lightning.LightningModule):
loss_left_eye_open = L2_loss(swap_landmark_features[0], target_landmark_features[0])
loss_right_eye_open = L2_loss(swap_landmark_features[1], target_landmark_features[1])
loss_eye_open = loss_left_eye_open + loss_right_eye_open
generator_losses['loss_eye_open'] = loss_eye_open * CONFIG.getfloat('training.losses', 'weight_eye_open')
generator_losses['loss_generator'] += loss_eye_open
generator_losses['loss_eye_open'] = loss_eye_open
generator_losses['loss_generator'] += loss_eye_open * CONFIG.getfloat('training.losses', 'weight_eye_open')
# lip open loss
loss_lip_open = L2_loss(swap_landmark_features[2], target_landmark_features[2])
generator_losses['loss_lip_open'] = loss_lip_open * CONFIG.getfloat('training.losses', 'weight_lip_open')
generator_losses['loss_generator'] += loss_lip_open
generator_losses['loss_lip_open'] = loss_lip_open
generator_losses['loss_generator'] += loss_lip_open * CONFIG.getfloat('training.losses', 'weight_lip_open')
return generator_losses
@@ -359,30 +361,59 @@ class FaceSwapper(pytorch_lightning.LightningModule):
torchvision.utils.save_image(grid, f"previews/step_{self.global_step}.jpg")
self.logger.experiment.add_image("Generator Preview", grid, self.global_step)
def log_validation_preview(self):
validation_source_path = CONFIG.get('training.validation', 'sources')
validation_target_path = CONFIG.get('training.validation', 'targets')
sources = [read_image(os.path.join(validation_source_path, f)) for f in os.listdir(validation_source_path) if f.lower().endswith('.jpg') or f.lower().endswith('.png')]
targets = [read_image(os.path.join(validation_target_path, f)) for f in os.listdir(validation_target_path) if f.lower().endswith('.jpg') or f.lower().endswith('.png')]
read_images = lambda path: [read_image(os.path.join(path, f)) for f in sorted(os.listdir(path)) if f.lower().endswith('.jpg') or f.lower().endswith('.png')]
to_numpy = lambda x: (x.cpu().detach().numpy()[0].transpose(1, 2, 0).clip(-1, 1)[:, :, ::-1] + 1) * 127.5
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.Resize((256, 256), interpolation = torchvision.transforms.InterpolationMode.BICUBIC),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
to_numpy = lambda x: (x.cpu().detach().numpy()[0].transpose(1, 2, 0).clip(-1, 1)[:,:,::-1] + 1) * 127.5
self.generator.eval()
results = []
sources = read_images(CONFIG.get('training.validation', 'sources'))
targets_front = read_images(CONFIG.get('training.validation', 'targets_front'))
targets_side = read_images(CONFIG.get('training.validation', 'targets_side'))
targets_makeup = read_images(CONFIG.get('training.validation', 'targets_makeup'))
targets_occlusion = read_images(CONFIG.get('training.validation', 'targets_occlusion'))
for source, target in zip(sources, targets):
self.generator.eval()
results_source = []
results_front = []
results_side = []
results_makeup = []
results_occlusion = []
for source, target_front, target_side, target_makeup, target_occlusion in zip(sources, targets_front, targets_side, targets_makeup, targets_occlusion):
source_tensor = transforms(source).unsqueeze(0).to(self.device).half()
target_tensor = transforms(target).unsqueeze(0).to(self.device).half()
source_embedding = self.get_arcface_embedding(source_tensor, (0, 0, 0, 0))
target_front_tensor = transforms(target_front).unsqueeze(0).to(self.device).half()
target_side_tensor = transforms(target_side).unsqueeze(0).to(self.device).half()
target_makeup_tensor = transforms(target_makeup).unsqueeze(0).to(self.device).half()
target_occlusion_tensor = transforms(target_occlusion).unsqueeze(0).to(self.device).half()
with torch.no_grad():
output, _ = self.generator(target_tensor, source_embedding)
results.append(numpy.hstack([to_numpy(source_tensor), to_numpy(target_tensor), to_numpy(output)]))
preview = numpy.vstack(results)
output_front, _ = self.generator(target_front_tensor, source_embedding)
output_side, _ = self.generator(target_side_tensor, source_embedding)
output_makeup, _ = self.generator(target_makeup_tensor, source_embedding)
output_occlusion, _ = self.generator(target_occlusion_tensor, source_embedding)
results_source.append(to_numpy(source_tensor))
results_front.append(numpy.hstack([to_numpy(target_front_tensor), to_numpy(output_front)]))
results_side.append(numpy.hstack([to_numpy(target_side_tensor), to_numpy(output_side)]))
results_makeup.append(numpy.hstack([to_numpy(target_makeup_tensor), to_numpy(output_makeup)]))
results_occlusion.append(numpy.hstack([to_numpy(target_occlusion_tensor), to_numpy(output_occlusion)]))
sources_vertical = numpy.vstack(results_source)
results_front_vertical = numpy.vstack(results_front)
results_side_vertical = numpy.vstack(results_side)
results_makeup_vertical = numpy.vstack(results_makeup)
results_occlusion_vertical = numpy.vstack(results_occlusion)
pad = numpy.zeros((sources_vertical.shape[0], 10, 3), dtype = sources_vertical.dtype)
preview = numpy.hstack([sources_vertical, pad, results_front_vertical, pad, results_side_vertical, pad, results_makeup_vertical, pad, results_occlusion_vertical])
os.makedirs("validation_previews", exist_ok=True)
cv2.imwrite(f"validation_previews/step_{self.global_step}.jpg", preview)
self.generator.train()