This commit is contained in:
harisreedhar
2025-03-11 17:42:29 +05:30
committed by henryruhs
parent 52b98b5be5
commit a7a21cd684
+1 -1
View File
@@ -39,7 +39,7 @@ def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : P
def overlay_mask(target_tensor : Tensor, mask_tensor : Tensor) -> Tensor:
color_tensor = torch.zeros(*list(target_tensor.shape), dtype = target_tensor.dtype, device = target_tensor.device)
color_tensor = torch.zeros(*target_tensor.shape, dtype = target_tensor.dtype, device = target_tensor.device)
color_tensor[:, 2, :, :] = 1
mask_tensor = mask_tensor.repeat(1, 3, 1, 1).clamp(0, 0.8)
output_tensor = target_tensor * (1 - mask_tensor) + color_tensor * mask_tensor