mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Merge pull request #66 from facefusion/remove-redundant-encoder-calc
Remove redundant encoder calculation
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
from configparser import ConfigParser
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from .training import FaceSwapperTrainer
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
from torchvision import io
|
||||
|
||||
from .helper import calc_embedding
|
||||
from .models.generator import Generator
|
||||
from .training import FaceSwapperTrainer
|
||||
|
||||
CONFIG_PARSER = configparser.ConfigParser()
|
||||
CONFIG_PARSER.read('config.ini')
|
||||
@@ -17,10 +17,7 @@ def infer() -> None:
|
||||
config_target_path = CONFIG_PARSER.get('inferencing', 'target_path')
|
||||
config_output_path = CONFIG_PARSER.get('inferencing', 'output_path')
|
||||
|
||||
state_dict = torch.load(config_generator_path).get('state_dict').get('generator')
|
||||
generator = Generator(CONFIG_PARSER)
|
||||
generator.load_state_dict(state_dict)
|
||||
generator.eval()
|
||||
generator = FaceSwapperTrainer.load_from_checkpoint(config_generator_path, config_parser = CONFIG_PARSER, map_location = 'cpu').eval()
|
||||
embedder = torch.jit.load(config_embedder_path, map_location = 'cpu').eval()
|
||||
|
||||
source_tensor = io.read_image(config_source_path)
|
||||
|
||||
@@ -19,8 +19,7 @@ class Generator(nn.Module):
|
||||
self.generator.apply(init_weight)
|
||||
self.masker.apply(init_weight)
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]:
|
||||
target_features = self.encode_features(target_tensor)
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor, target_features : Tuple[Feature, ...]) -> Tuple[Tensor, Mask]:
|
||||
output_tensor = self.generator(source_embedding, target_features)
|
||||
target_feature = target_features[-1]
|
||||
output_mask = self.masker(target_tensor, target_feature)
|
||||
|
||||
@@ -53,7 +53,8 @@ class FaceSwapperTrainer(LightningModule):
|
||||
|
||||
def forward(self, source_embedding : Embedding, target_tensor : Tensor) -> Tuple[Tensor, Mask]:
|
||||
with torch.no_grad():
|
||||
output_tensor, output_mask = self.generator(source_embedding, target_tensor)
|
||||
generator_target_features = self.generator.encode_features(target_tensor)
|
||||
output_tensor, output_mask = self.generator(source_embedding, target_tensor, generator_target_features)
|
||||
|
||||
return output_tensor, output_mask
|
||||
|
||||
@@ -89,8 +90,8 @@ class FaceSwapperTrainer(LightningModule):
|
||||
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor)
|
||||
generator_target_features = self.generator.encode_features(target_tensor)
|
||||
generator_output_tensor, generator_output_mask = self.generator(source_embedding, target_tensor, generator_target_features)
|
||||
generator_output_features = self.generator.encode_features(generator_output_tensor)
|
||||
discriminator_output_tensors = self.discriminator(generator_output_tensor)
|
||||
adversarial_loss, weighted_adversarial_loss = self.adversarial_loss(discriminator_output_tensors)
|
||||
@@ -138,7 +139,7 @@ class FaceSwapperTrainer(LightningModule):
|
||||
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
|
||||
source_tensor, target_tensor = batch
|
||||
source_embedding = calc_embedding(self.embedder, source_tensor, (0, 0, 0, 0))
|
||||
output_tensor, _ = self.generator(source_embedding, target_tensor)
|
||||
output_tensor, _ = self.forward(source_embedding, target_tensor)
|
||||
output_embedding = calc_embedding(self.embedder, output_tensor, (0, 0, 0, 0))
|
||||
validation_score = (nn.functional.cosine_similarity(source_embedding, output_embedding).mean() + 1) * 0.5
|
||||
self.log('validation_score', validation_score, sync_dist = True, prog_bar = True)
|
||||
|
||||
Reference in New Issue
Block a user