This commit is contained in:
harisreedhar
2025-03-10 13:24:39 +05:30
committed by henryruhs
parent 8f1f002c64
commit 1659805b08
5 changed files with 47 additions and 28 deletions
+7
View File
@@ -65,6 +65,13 @@ num_discriminators = 3
kernel_size = 4
```
```
[training.model.masker]
input_channels = 67
output_channels = 1
base_channels = 16
```
```
[training.losses]
adversarial_weight = 1.0
+5
View File
@@ -28,6 +28,11 @@ num_layers =
num_discriminators =
kernel_size =
[training.model.masker]
input_channels =
output_channels =
base_channels =
[training.losses]
adversarial_weight =
attribute_weight =
+1 -1
View File
@@ -14,7 +14,7 @@ class Generator(nn.Module):
super().__init__()
self.encoder = UNet(config_parser)
self.generator = AAD(config_parser)
self.masker = MaskNet(67, 1, 16)
self.masker = MaskNet(config_parser)
self.encoder.apply(init_weight)
self.generator.apply(init_weight)
self.masker.apply(init_weight)
+2 -2
View File
@@ -195,12 +195,12 @@ class MaskLoss(nn.Module):
def calc_mask(self, target_tensor : Tensor) -> Tensor:
target_tensor = torch.nn.functional.interpolate(target_tensor, (512, 512), mode = 'bilinear')
face_indices = torch.tensor([ 1, 2, 3, 4, 5, 10, 11, 12, 13 ]).to(target_tensor.device)
face_mask_regions = torch.tensor([ 1, 2, 3, 4, 5, 10, 11, 12, 13 ]).to(target_tensor.device)
with torch.no_grad():
output_tensor = self.parser(target_tensor)[0]
output_tensor = output_tensor.argmax(1)
output_tensor = torch.isin(output_tensor, face_indices).to(target_tensor.dtype)
output_tensor = torch.isin(output_tensor, face_mask_regions).to(target_tensor.dtype)
output_tensor = output_tensor.view(-1, 1, 512, 512)
output_tensor = torch.nn.functional.interpolate(output_tensor, (self.config_output_size, self.config_output_size), mode = 'bilinear')
return output_tensor
+32 -25
View File
@@ -1,62 +1,69 @@
from configparser import ConfigParser
import torch
from torch import Tensor, nn
class MaskNet(nn.Module):
def __init__(self, input_channels : int, output_channels : int, base_channels : int):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.down_samples = self.create_down_samples(input_channels, base_channels)
self.up_samples = self.create_up_samples(base_channels)
self.bottleneck = ResBlock(base_channels * 4)
self.conv = nn.Conv2d(base_channels, output_channels, kernel_size = 1)
self.config_input_channels = config_parser.getint('training.model.masker', 'input_channels')
self.config_output_channels = config_parser.getint('training.model.masker', 'output_channels')
self.config_base_channels = config_parser.getint('training.model.masker', 'base_channels')
self.down_samples = self.create_down_samples(self.config_input_channels, self.config_base_channels)
self.up_samples = self.create_up_samples(self.config_base_channels)
self.bottleneck = BottleNeck(self.config_base_channels * 4)
self.conv = nn.Conv2d(self.config_base_channels, self.config_output_channels, kernel_size = 1)
self.sigmoid = nn.Sigmoid()
def create_down_samples(self, input_channels : int, base_channels: int) -> nn.ModuleList:
down_samples = nn.ModuleList(
[
DownSample(input_channels, base_channels),
DownSample(base_channels, base_channels * 2),
DownSample(base_channels * 2, base_channels * 4)
])
[
DownSample(input_channels, base_channels),
DownSample(base_channels, base_channels * 2),
DownSample(base_channels * 2, base_channels * 4)
])
return down_samples
def create_up_samples(self, base_channels : int) -> nn.ModuleList:
down_samples = nn.ModuleList(
[
UpSample(base_channels * 4, base_channels * 2),
UpSample(base_channels * 2, base_channels),
UpSample(base_channels, base_channels)
])
[
UpSample(base_channels * 4, base_channels * 2),
UpSample(base_channels * 2, base_channels),
UpSample(base_channels, base_channels)
])
return down_samples
def forward(self, target_tensor : Tensor, target_attribute : Tensor) -> Tensor:
output_tensor = torch.cat([ target_tensor, target_attribute ], dim=1)
output_tensor = torch.cat([ target_tensor, target_attribute ], dim = 1)
for down_sample in self.down_samples:
output_tensor = down_sample(output_tensor)
output_tensor = self.bottleneck(output_tensor)
for up_sample in self.up_samples:
output_tensor = up_sample(output_tensor)
output_tensor = self.conv(output_tensor)
output_tensor = self.activation(output_tensor)
return output_tensor
class ResBlock(nn.Module):
def __init__(self, channels: int):
class BottleNeck(nn.Module):
def __init__(self, channels : int):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
nn.Conv2d(channels, channels, kernel_size = 3, padding = 1, bias = False),
nn.BatchNorm2d(channels),
nn.ReLU(inplace=True),
nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
nn.ReLU(inplace = True),
nn.Conv2d(channels, channels, kernel_size = 3, padding = 1, bias = False),
nn.BatchNorm2d(channels),
nn.ReLU(inplace=True)
nn.ReLU(inplace = True)
)
self.relu = nn.ReLU(inplace=True)
self.relu = nn.ReLU(inplace = True)
def forward(self, input_tensor: Tensor) -> Tensor:
def forward(self, input_tensor : Tensor) -> Tensor:
output_tensor = self.conv(input_tensor) + input_tensor
output_tensor = self.relu(output_tensor)
return output_tensor
@@ -66,7 +73,7 @@ class UpSample(nn.Module):
def __init__(self, input_channels : int, output_channels : int) -> None:
super().__init__()
self.conv_transpose = nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 2, stride = 2)
self.relu = nn.ReLU(inplace=True)
self.relu = nn.ReLU(inplace = True)
def forward(self, input_tensor : Tensor) -> Tensor:
output_tensor = self.conv_transpose(input_tensor)