From 583d09e666844fcde3bc1273acb83037cb1e3483 Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Sun, 23 Mar 2025 18:01:38 +0530 Subject: [PATCH] change face-parser --- face_swapper/src/models/loss.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 2b1a36e..2549af0 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -196,14 +196,12 @@ class MaskLoss(nn.Module): return mask_loss, weighted_mask_loss def calc_mask(self, target_tensor : Tensor) -> Tensor: - target_tensor = torch.nn.functional.interpolate(target_tensor, (512, 512), mode = 'bilinear') - face_mask_regions = torch.tensor([ 1, 2, 3, 4, 5, 10, 11, 12, 13 ]).to(target_tensor.device) + target_tensor = torch.nn.functional.interpolate(target_tensor, (256, 256), mode = 'bilinear') + target_tensor = (target_tensor.clip(-1, 1) + 1) * 0.5 with torch.no_grad(): - output_tensor = self.face_parser(target_tensor)[0] - output_tensor = output_tensor.argmax(1) - output_tensor = torch.isin(output_tensor, face_mask_regions).to(target_tensor.dtype) - output_tensor = output_tensor.view(-1, 1, 512, 512) + output_tensor = self.face_parser(target_tensor) + output_tensor = output_tensor.clamp(0, 1) output_tensor = torch.nn.functional.interpolate(output_tensor, (self.config_output_size, self.config_output_size), mode = 'bilinear') return output_tensor