Mask typing and naming related updates

This commit is contained in:
henryruhs
2025-03-14 08:09:47 +01:00
parent b5efcbe44a
commit 5234874bc7
5 changed files with 30 additions and 27 deletions
+4 -4
View File
@@ -1,7 +1,7 @@
import torch
from torch import Tensor, nn
from .types import EmbedderModule, Embedding, Padding, WarpTemplate, WarpTemplateSet
from .types import EmbedderModule, Embedding, Mask, Padding, WarpTemplate, WarpTemplateSet
WARP_TEMPLATE_SET : WarpTemplateSet =\
{
@@ -38,9 +38,9 @@ def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : P
return embedding
def overlay_mask(input_tensor : Tensor, mask_tensor : Tensor) -> Tensor:
def overlay_mask(input_tensor : Tensor, input_mask : Mask) -> Tensor:
overlay_tensor = torch.zeros(*input_tensor.shape, dtype = input_tensor.dtype, device = input_tensor.device)
overlay_tensor[:, 2, :, :] = 1
mask_tensor = mask_tensor.repeat(1, 3, 1, 1).clamp(0, 0.8)
output_tensor = input_tensor * (1 - mask_tensor) + overlay_tensor * mask_tensor
input_mask = input_mask.repeat(1, 3, 1, 1).clamp(0, 0.8)
output_tensor = input_tensor * (1 - input_mask) + overlay_tensor * input_mask
return output_tensor
+4 -4
View File
@@ -7,7 +7,7 @@ from torch import Tensor, nn
from torchvision import transforms
from ..helper import calc_embedding
from ..types import Attribute, EmbedderModule, FaceParserModule, GazerModule, Loss, MotionExtractorModule
from ..types import Attribute, EmbedderModule, FaceParserModule, GazerModule, Loss, Mask, MotionExtractorModule
class DiscriminatorLoss(nn.Module):
@@ -186,11 +186,11 @@ class MaskLoss(nn.Module):
self.face_parser = face_parser
self.mse_loss = nn.MSELoss()
def forward(self, target_tensor : Tensor, mask_tensor : Tensor) -> Loss:
def forward(self, target_tensor : Tensor, output_mask : Mask) -> Loss:
target_mask = self.calc_mask(target_tensor)
target_mask = target_mask.view(-1, self.config_output_size, self.config_output_size)
mask_tensor = mask_tensor.view(-1, self.config_output_size, self.config_output_size)
mask_loss = self.mse_loss(target_mask, mask_tensor)
output_mask = output_mask.view(-1, self.config_output_size, self.config_output_size)
mask_loss = self.mse_loss(target_mask, output_mask)
return mask_loss
def calc_mask(self, target_tensor : Tensor) -> Tensor:
+13 -11
View File
@@ -3,7 +3,7 @@ from configparser import ConfigParser
import torch
from torch import Tensor, nn
from ..types import Attribute
from ..types import Attribute, Mask
class MaskNet(nn.Module):
@@ -14,7 +14,7 @@ class MaskNet(nn.Module):
self.config_num_filters = config_parser.getint('training.model.masker', 'num_filters')
self.down_samples = self.create_down_samples(self.config_input_channels, self.config_num_filters)
self.up_samples = self.create_up_samples(self.config_num_filters)
self.bottleneck = BottleNeck(self.config_num_filters * 2)
self.bottleneck = BottleNeck(self.config_num_filters * 4)
self.conv = nn.Conv2d(self.config_num_filters, self.config_output_channels, kernel_size = 1)
self.sigmoid = nn.Sigmoid()
@@ -23,31 +23,33 @@ class MaskNet(nn.Module):
return nn.ModuleList(
[
DownSample(input_channels, num_filters),
DownSample(num_filters, num_filters * 2)
DownSample(num_filters, num_filters * 2),
DownSample(num_filters, num_filters * 4)
])
@staticmethod
def create_up_samples(num_filters : int) -> nn.ModuleList:
return nn.ModuleList(
[
UpSample(num_filters * 4, num_filters),
UpSample(num_filters * 2, num_filters),
UpSample(num_filters, num_filters)
])
def forward(self, input_tensor : Tensor, input_attribute : Attribute) -> Tensor:
output_tensor = torch.cat([ input_tensor, input_attribute ], dim = 1)
def forward(self, input_tensor : Tensor, input_attribute : Attribute) -> Mask:
output_mask = torch.cat([ input_tensor, input_attribute ], dim = 1)
for down_sample in self.down_samples:
output_tensor = down_sample(output_tensor)
output_mask = down_sample(output_mask)
output_tensor = self.bottleneck(output_tensor)
output_mask = self.bottleneck(output_mask)
for up_sample in self.up_samples:
output_tensor = up_sample(output_tensor)
output_mask = up_sample(output_mask)
output_tensor = self.conv(output_tensor)
output_tensor = self.sigmoid(output_tensor)
return output_tensor
output_mask = self.conv(output_mask)
output_mask = self.sigmoid(output_mask)
return output_mask
class BottleNeck(nn.Module):
+8 -8
View File
@@ -18,7 +18,7 @@ from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, MaskLoss, MotionLoss, ReconstructionLoss
from .networks.masknet import MaskNet
from .types import Batch, Embedding, OptimizerSet
from .types import Batch, Embedding, Mask, OptimizerSet
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
@@ -57,9 +57,9 @@ class FaceSwapperTrainer(LightningModule):
with torch.no_grad():
output_tensor, target_attributes = self.generator(source_embedding, target_tensor)
target_attribute = target_attributes[-1]
mask_tensor = self.masker(target_tensor, target_attribute)
output_mask = self.masker(target_tensor, target_attribute)
return output_tensor, mask_tensor
return output_tensor, output_mask
def configure_optimizers(self) -> Tuple[OptimizerSet, OptimizerSet, OptimizerSet]:
generator_optimizer = torch.optim.AdamW(self.generator.parameters(), lr = self.config_learning_rate, betas = (0.0, 0.999), weight_decay = 1e-4)
@@ -120,8 +120,8 @@ class FaceSwapperTrainer(LightningModule):
discriminator_loss = self.discriminator_loss(discriminator_source_tensors, discriminator_output_tensors)
generator_output_attribute = generator_output_attributes[-1]
mask_tensor = self.masker(generator_output_tensor.detach(), generator_output_attribute.detach())
mask_loss = self.mask_loss(target_tensor, mask_tensor)
generator_output_mask = self.masker(generator_output_tensor.detach(), generator_output_attribute.detach())
mask_loss = self.mask_loss(target_tensor, generator_output_mask)
self.toggle_optimizer(generator_optimizer)
self.manual_backward(generator_loss)
@@ -145,7 +145,7 @@ class FaceSwapperTrainer(LightningModule):
self.untoggle_optimizer(masker_optimizer)
if self.global_step % self.config_preview_frequency == 0:
self.generate_preview(source_tensor, target_tensor, generator_output_tensor, mask_tensor)
self.generate_preview(source_tensor, target_tensor, generator_output_tensor, generator_output_mask)
self.log('generator_loss', generator_loss, prog_bar = True)
self.log('discriminator_loss', discriminator_loss, prog_bar = True)
@@ -167,10 +167,10 @@ class FaceSwapperTrainer(LightningModule):
self.log('validation_score', validation_score, prog_bar = True)
return validation_score
def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor, mask_tensor : Tensor) -> None:
def generate_preview(self, source_tensor : Tensor, target_tensor : Tensor, output_tensor : Tensor, output_mask : Mask) -> None:
preview_limit = 8
preview_cells = []
overlay_tensor = overlay_mask(output_tensor, mask_tensor)
overlay_tensor = overlay_mask(output_tensor, output_mask)
for source_tensor, target_tensor, output_tensor, overlay_tensor in zip(source_tensor[:preview_limit], target_tensor[:preview_limit], output_tensor[:preview_limit], overlay_tensor[:preview_limit]):
preview_cell = torch.cat([ source_tensor, target_tensor, output_tensor, overlay_tensor ], dim = 2)
+1
View File
@@ -8,6 +8,7 @@ BatchMode = Literal['equal', 'same', 'different']
Attribute : TypeAlias = Tensor
Embedding : TypeAlias = Tensor
Mask : TypeAlias = Tensor
Loss : TypeAlias = Tensor
Padding : TypeAlias = Tuple[int, int, int, int]