mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Merge pull request #56 from facefusion/name-polishing
Add Attributes, polish names
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
from configparser import ConfigParser
|
||||
from typing import Tuple
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..networks.aad import AAD
|
||||
from ..networks.unet import UNet
|
||||
from ..types import Attributes, Embedding
|
||||
from ..types import Attribute, Embedding
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
@@ -20,7 +21,7 @@ class Generator(nn.Module):
|
||||
output_tensor = self.generator(source_embedding, target_attributes)
|
||||
return output_tensor
|
||||
|
||||
def get_attributes(self, input_tensor : Tensor) -> Attributes:
|
||||
def get_attributes(self, input_tensor : Tensor) -> Tuple[Attribute, ...]:
|
||||
return self.encoder(input_tensor)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from torch import Tensor, nn
|
||||
from torchvision import transforms
|
||||
|
||||
from ..helper import calc_embedding
|
||||
from ..types import Attributes, EmbedderModule, FaceParserModule, Gaze, GazerModule, MotionExtractorModule
|
||||
from ..types import Attribute, EmbedderModule, FaceParserModule, Gaze, GazerModule, MotionExtractorModule
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Module):
|
||||
@@ -55,7 +55,7 @@ class AttributeLoss(nn.Module):
|
||||
self.config_batch_size = config_parser.getint('training.loader', 'batch_size')
|
||||
self.config_attribute_weight = config_parser.getfloat('training.losses', 'attribute_weight')
|
||||
|
||||
def forward(self, target_attributes : Attributes, output_attributes : Attributes) -> Tuple[Tensor, Tensor]:
|
||||
def forward(self, target_attributes : Tuple[Attribute, ...], output_attributes : Tuple[Attribute, ...]) -> Tuple[Tensor, Tensor]:
|
||||
temp_tensors = []
|
||||
|
||||
for target_attribute, output_attribute in zip(target_attributes, output_attributes):
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from configparser import ConfigParser
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ..types import Attributes, Embedding
|
||||
from ..types import Attribute, Embedding
|
||||
|
||||
|
||||
class AAD(nn.Module):
|
||||
@@ -56,7 +57,7 @@ class AAD(nn.Module):
|
||||
|
||||
return layers
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_attributes : Attributes) -> Tensor:
|
||||
def forward(self, source_embedding : Embedding, target_attributes : Tuple[Attribute, ...]) -> Tensor:
|
||||
temp_tensors = self.pixel_shuffle_up_sample(source_embedding)
|
||||
|
||||
for index, layer in enumerate(self.layers[:-1]):
|
||||
|
||||
@@ -3,6 +3,8 @@ from configparser import ConfigParser
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from face_swapper.src.types import Attribute
|
||||
|
||||
|
||||
class MaskNet(nn.Module):
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
@@ -32,7 +34,7 @@ class MaskNet(nn.Module):
|
||||
UpSample(num_filters, num_filters)
|
||||
])
|
||||
|
||||
def forward(self, target_tensor : Tensor, target_attribute : Tensor) -> Tensor:
|
||||
def forward(self, target_tensor : Tensor, target_attribute : Attribute) -> Tensor:
|
||||
output_tensor = torch.cat([ target_tensor, target_attribute ], dim = 1)
|
||||
|
||||
for down_sample in self.down_samples:
|
||||
|
||||
@@ -54,11 +54,10 @@ class FaceSwapperTrainer(LightningModule):
|
||||
self.automatic_optimization = False
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
with torch.no_grad():
|
||||
output_tensor = self.generator(source_embedding, target_tensor)
|
||||
target_attributes = self.generator.get_attributes(target_tensor)
|
||||
mask_tensor = self.masker(target_tensor, target_attributes[-1])
|
||||
target_attribute = self.generator.get_attributes(target_tensor)[-1]
|
||||
mask_tensor = self.masker(target_tensor, target_attribute)
|
||||
|
||||
return output_tensor, mask_tensor
|
||||
|
||||
@@ -127,7 +126,8 @@ class FaceSwapperTrainer(LightningModule):
|
||||
self.untoggle_optimizer(generator_optimizer)
|
||||
|
||||
self.toggle_optimizer(masker_optimizer)
|
||||
mask_tensor = self.masker(target_tensor, target_attributes[-1].detach())
|
||||
target_attribute = target_attributes[-1].detach()
|
||||
mask_tensor = self.masker(target_tensor, target_attribute)
|
||||
mask_loss = self.mask_loss(target_tensor, mask_tensor)
|
||||
|
||||
self.manual_backward(mask_loss)
|
||||
|
||||
@@ -6,7 +6,7 @@ from torch.nn import Module
|
||||
Batch : TypeAlias = Tuple[Tensor, Tensor]
|
||||
BatchMode = Literal['equal', 'same']
|
||||
|
||||
Attributes : TypeAlias = Tuple[Tensor, ...]
|
||||
Attribute : TypeAlias = Tensor
|
||||
Embedding : TypeAlias = Tensor
|
||||
Gaze : TypeAlias = Tuple[Tensor, Tensor]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user