From a1fd382659fde3a33cfaceee6ac356b3082ac7e1 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Wed, 11 Dec 2024 14:28:13 +0530 Subject: [PATCH] ugly training code --- face_swapper/src/training.py | 87 ++++++++++++++++++++++++------------ 1 file changed, 59 insertions(+), 28 deletions(-) diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index f7d1c37..342db36 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -170,17 +170,19 @@ class FaceSwapper(pytorch_lightning.LightningModule): def configure_optimizers(self) -> OptimizerLRScheduler: generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = CONFIG.getfloat('training.generator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4) discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = CONFIG.getfloat('training.discriminator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4) - generator_scheduler = torch.optim.lr_scheduler.StepLR(generator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma')) - discriminator_scheduler = torch.optim.lr_scheduler.StepLR(discriminator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma')) - return ( - { - "optimizer": generator_optimizer, - "lr_scheduler": generator_scheduler - }, - { - "optimizer": discriminator_optimizer, - "lr_scheduler": discriminator_scheduler - }) + if CONFIG.getboolean('training.schedulers', 'enable'): + generator_scheduler = torch.optim.lr_scheduler.StepLR(generator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma')) + discriminator_scheduler = torch.optim.lr_scheduler.StepLR(discriminator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma')) + return ( + { + "optimizer": generator_optimizer, + "lr_scheduler": generator_scheduler + }, + { + "optimizer": discriminator_optimizer, + "lr_scheduler": discriminator_scheduler + }) + return generator_optimizer, discriminator_optimizer def training_step(self, batch : Batch, batch_index : int) -> Tensor: @@ -256,7 +258,7 @@ class FaceSwapper(pytorch_lightning.LightningModule): loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean() loss_reconstruction = loss_reconstruction * 0.3 + loss_ssim * 0.7 generator_losses['loss_reconstruction'] = loss_reconstruction - generator_losses['loss_generator'] += CONFIG.getfloat('training.losses', 'weight_reconstruction') + generator_losses['loss_generator'] += loss_reconstruction * CONFIG.getfloat('training.losses', 'weight_reconstruction') if CONFIG.getfloat('training.losses', 'weight_tsr') > 0: # tsr loss @@ -285,13 +287,13 @@ class FaceSwapper(pytorch_lightning.LightningModule): loss_left_eye_open = L2_loss(swap_landmark_features[0], target_landmark_features[0]) loss_right_eye_open = L2_loss(swap_landmark_features[1], target_landmark_features[1]) loss_eye_open = loss_left_eye_open + loss_right_eye_open - generator_losses['loss_eye_open'] = loss_eye_open * CONFIG.getfloat('training.losses', 'weight_eye_open') - generator_losses['loss_generator'] += loss_eye_open + generator_losses['loss_eye_open'] = loss_eye_open + generator_losses['loss_generator'] += loss_eye_open * CONFIG.getfloat('training.losses', 'weight_eye_open') # lip open loss loss_lip_open = L2_loss(swap_landmark_features[2], target_landmark_features[2]) - generator_losses['loss_lip_open'] = loss_lip_open * CONFIG.getfloat('training.losses', 'weight_lip_open') - generator_losses['loss_generator'] += loss_lip_open + generator_losses['loss_lip_open'] = loss_lip_open + generator_losses['loss_generator'] += loss_lip_open * CONFIG.getfloat('training.losses', 'weight_lip_open') return generator_losses @@ -359,30 +361,59 @@ class FaceSwapper(pytorch_lightning.LightningModule): torchvision.utils.save_image(grid, f"previews/step_{self.global_step}.jpg") self.logger.experiment.add_image("Generator Preview", grid, self.global_step) + def log_validation_preview(self): - validation_source_path = CONFIG.get('training.validation', 'sources') - validation_target_path = CONFIG.get('training.validation', 'targets') - sources = [read_image(os.path.join(validation_source_path, f)) for f in os.listdir(validation_source_path) if f.lower().endswith('.jpg') or f.lower().endswith('.png')] - targets = [read_image(os.path.join(validation_target_path, f)) for f in os.listdir(validation_target_path) if f.lower().endswith('.jpg') or f.lower().endswith('.png')] + read_images = lambda path: [read_image(os.path.join(path, f)) for f in sorted(os.listdir(path)) if f.lower().endswith('.jpg') or f.lower().endswith('.png')] + to_numpy = lambda x: (x.cpu().detach().numpy()[0].transpose(1, 2, 0).clip(-1, 1)[:, :, ::-1] + 1) * 127.5 transforms = torchvision.transforms.Compose( [ torchvision.transforms.Resize((256, 256), interpolation = torchvision.transforms.InterpolationMode.BICUBIC), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) - to_numpy = lambda x: (x.cpu().detach().numpy()[0].transpose(1, 2, 0).clip(-1, 1)[:,:,::-1] + 1) * 127.5 - self.generator.eval() - results = [] + sources = read_images(CONFIG.get('training.validation', 'sources')) + targets_front = read_images(CONFIG.get('training.validation', 'targets_front')) + targets_side = read_images(CONFIG.get('training.validation', 'targets_side')) + targets_makeup = read_images(CONFIG.get('training.validation', 'targets_makeup')) + targets_occlusion = read_images(CONFIG.get('training.validation', 'targets_occlusion')) - for source, target in zip(sources, targets): + + self.generator.eval() + + results_source = [] + results_front = [] + results_side = [] + results_makeup = [] + results_occlusion = [] + + for source, target_front, target_side, target_makeup, target_occlusion in zip(sources, targets_front, targets_side, targets_makeup, targets_occlusion): source_tensor = transforms(source).unsqueeze(0).to(self.device).half() - target_tensor = transforms(target).unsqueeze(0).to(self.device).half() source_embedding = self.get_arcface_embedding(source_tensor, (0, 0, 0, 0)) + target_front_tensor = transforms(target_front).unsqueeze(0).to(self.device).half() + target_side_tensor = transforms(target_side).unsqueeze(0).to(self.device).half() + target_makeup_tensor = transforms(target_makeup).unsqueeze(0).to(self.device).half() + target_occlusion_tensor = transforms(target_occlusion).unsqueeze(0).to(self.device).half() with torch.no_grad(): - output, _ = self.generator(target_tensor, source_embedding) - results.append(numpy.hstack([to_numpy(source_tensor), to_numpy(target_tensor), to_numpy(output)])) - preview = numpy.vstack(results) + output_front, _ = self.generator(target_front_tensor, source_embedding) + output_side, _ = self.generator(target_side_tensor, source_embedding) + output_makeup, _ = self.generator(target_makeup_tensor, source_embedding) + output_occlusion, _ = self.generator(target_occlusion_tensor, source_embedding) + + results_source.append(to_numpy(source_tensor)) + results_front.append(numpy.hstack([to_numpy(target_front_tensor), to_numpy(output_front)])) + results_side.append(numpy.hstack([to_numpy(target_side_tensor), to_numpy(output_side)])) + results_makeup.append(numpy.hstack([to_numpy(target_makeup_tensor), to_numpy(output_makeup)])) + results_occlusion.append(numpy.hstack([to_numpy(target_occlusion_tensor), to_numpy(output_occlusion)])) + + sources_vertical = numpy.vstack(results_source) + results_front_vertical = numpy.vstack(results_front) + results_side_vertical = numpy.vstack(results_side) + results_makeup_vertical = numpy.vstack(results_makeup) + results_occlusion_vertical = numpy.vstack(results_occlusion) + pad = numpy.zeros((sources_vertical.shape[0], 10, 3), dtype = sources_vertical.dtype) + preview = numpy.hstack([sources_vertical, pad, results_front_vertical, pad, results_side_vertical, pad, results_makeup_vertical, pad, results_occlusion_vertical]) + os.makedirs("validation_previews", exist_ok=True) cv2.imwrite(f"validation_previews/step_{self.global_step}.jpg", preview) self.generator.train()