change face-parser

This commit is contained in:
harisreedhar
2025-03-23 18:01:38 +05:30
parent 9153b4ce8f
commit 583d09e666
+4 -6
View File
@@ -196,14 +196,12 @@ class MaskLoss(nn.Module):
return mask_loss, weighted_mask_loss
def calc_mask(self, target_tensor : Tensor) -> Tensor:
target_tensor = torch.nn.functional.interpolate(target_tensor, (512, 512), mode = 'bilinear')
face_mask_regions = torch.tensor([ 1, 2, 3, 4, 5, 10, 11, 12, 13 ]).to(target_tensor.device)
target_tensor = torch.nn.functional.interpolate(target_tensor, (256, 256), mode = 'bilinear')
target_tensor = (target_tensor.clip(-1, 1) + 1) * 0.5
with torch.no_grad():
output_tensor = self.face_parser(target_tensor)[0]
output_tensor = output_tensor.argmax(1)
output_tensor = torch.isin(output_tensor, face_mask_regions).to(target_tensor.dtype)
output_tensor = output_tensor.view(-1, 1, 512, 512)
output_tensor = self.face_parser(target_tensor)
output_tensor = output_tensor.clamp(0, 1)
output_tensor = torch.nn.functional.interpolate(output_tensor, (self.config_output_size, self.config_output_size), mode = 'bilinear')
return output_tensor