mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
change face-parser
This commit is contained in:
@@ -196,14 +196,12 @@ class MaskLoss(nn.Module):
|
||||
return mask_loss, weighted_mask_loss
|
||||
|
||||
def calc_mask(self, target_tensor : Tensor) -> Tensor:
|
||||
target_tensor = torch.nn.functional.interpolate(target_tensor, (512, 512), mode = 'bilinear')
|
||||
face_mask_regions = torch.tensor([ 1, 2, 3, 4, 5, 10, 11, 12, 13 ]).to(target_tensor.device)
|
||||
target_tensor = torch.nn.functional.interpolate(target_tensor, (256, 256), mode = 'bilinear')
|
||||
target_tensor = (target_tensor.clip(-1, 1) + 1) * 0.5
|
||||
|
||||
with torch.no_grad():
|
||||
output_tensor = self.face_parser(target_tensor)[0]
|
||||
output_tensor = output_tensor.argmax(1)
|
||||
output_tensor = torch.isin(output_tensor, face_mask_regions).to(target_tensor.dtype)
|
||||
output_tensor = output_tensor.view(-1, 1, 512, 512)
|
||||
output_tensor = self.face_parser(target_tensor)
|
||||
output_tensor = output_tensor.clamp(0, 1)
|
||||
output_tensor = torch.nn.functional.interpolate(output_tensor, (self.config_output_size, self.config_output_size), mode = 'bilinear')
|
||||
|
||||
return output_tensor
|
||||
|
||||
Reference in New Issue
Block a user