mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
refactor/convert tensor (#76)
* change to convert template * changes * changes * Conditional convert input tensor * Conditional convert input tensor --------- Co-authored-by: harisreedhar <h4harisreedhar.s.s@gmail.com>
This commit is contained in:
+1
-1
@@ -28,7 +28,7 @@ This `config.ini` utilizes the VGGFace2 dataset to train the HyperSwap model.
|
||||
```
|
||||
[training.dataset]
|
||||
file_pattern = .datasets/vggface2/**/*.jpg
|
||||
warp_template = vggfacehq_512_to_arcface_128
|
||||
convert_template = vggfacehq_512_to_arcface_128
|
||||
transform_size = 256
|
||||
batch_mode = equal
|
||||
batch_ratio = 0.2
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[training.dataset]
|
||||
file_pattern =
|
||||
warp_template =
|
||||
convert_template =
|
||||
transform_size =
|
||||
batch_mode =
|
||||
batch_ratio =
|
||||
|
||||
+15
-10
@@ -9,8 +9,8 @@ from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import io, transforms
|
||||
|
||||
from .helper import warp_tensor
|
||||
from .types import Batch, BatchMode, WarpTemplate
|
||||
from .helper import convert_tensor
|
||||
from .types import Batch, BatchMode, ConvertTemplate
|
||||
|
||||
|
||||
class DynamicDataset(Dataset[Tensor]):
|
||||
@@ -38,15 +38,20 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
return len(self.file_paths)
|
||||
|
||||
def compose_transforms(self) -> transforms:
|
||||
return transforms.Compose(
|
||||
__transforms__ =\
|
||||
[
|
||||
AugmentTransform(),
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize((self.config_transform_size, self.config_transform_size), interpolation = transforms.InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
WarpTransform(self.config_parser),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
transforms.ToTensor()
|
||||
]
|
||||
|
||||
if self.config_parser.get('training.dataset', 'convert_template'):
|
||||
__transforms__.append(ConvertTensorTransform(self.config_parser))
|
||||
|
||||
__transforms__.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
|
||||
|
||||
return transforms.Compose(__transforms__)
|
||||
|
||||
def prepare_different_batch(self, source_path : str) -> Batch:
|
||||
target_path = random.choice(self.file_paths)
|
||||
@@ -98,10 +103,10 @@ class AugmentTransform:
|
||||
])
|
||||
|
||||
|
||||
class WarpTransform:
|
||||
class ConvertTensorTransform:
|
||||
def __init__(self, config_parser : ConfigParser) -> None:
|
||||
self.config_warp_template = cast(WarpTemplate, config_parser.get('training.dataset', 'warp_template'))
|
||||
self.config_convert_template = cast(ConvertTemplate, config_parser.get('training.dataset', 'convert_template'))
|
||||
|
||||
def __call__(self, input_tensor : Tensor) -> Tensor:
|
||||
temp_tensor = input_tensor.unsqueeze(0)
|
||||
return warp_tensor(temp_tensor, self.config_warp_template).squeeze(0)
|
||||
return convert_tensor(temp_tensor, self.config_convert_template).squeeze(0)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .types import EmbedderModule, Embedding, Mask, Padding, WarpTemplate, WarpTemplateSet
|
||||
from .types import ConvertTemplate, ConvertTemplateSet, EmbedderModule, Embedding, Mask, Padding
|
||||
|
||||
WARP_TEMPLATE_SET : WarpTemplateSet =\
|
||||
CONVERT_TEMPLATE_SET : ConvertTemplateSet =\
|
||||
{
|
||||
'arcface_128_to_arcface_112_v2': torch.tensor(
|
||||
[
|
||||
@@ -23,15 +23,15 @@ WARP_TEMPLATE_SET : WarpTemplateSet =\
|
||||
}
|
||||
|
||||
|
||||
def warp_tensor(input_tensor : Tensor, warp_template : WarpTemplate) -> Tensor:
|
||||
normed_warp_template = WARP_TEMPLATE_SET.get(warp_template).repeat(input_tensor.shape[0], 1, 1)
|
||||
affine_grid = nn.functional.affine_grid(normed_warp_template.to(input_tensor.device), list(input_tensor.shape))
|
||||
def convert_tensor(input_tensor : Tensor, convert_template : ConvertTemplate) -> Tensor:
|
||||
convert_matrix = CONVERT_TEMPLATE_SET.get(convert_template).repeat(input_tensor.shape[0], 1, 1)
|
||||
affine_grid = nn.functional.affine_grid(convert_matrix.to(input_tensor.device), list(input_tensor.shape))
|
||||
output_tensor = nn.functional.grid_sample(input_tensor, affine_grid, padding_mode = 'reflection')
|
||||
return output_tensor
|
||||
|
||||
|
||||
def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
|
||||
crop_tensor = warp_tensor(input_tensor, 'arcface_128_to_arcface_112_v2')
|
||||
crop_tensor = convert_tensor(input_tensor, 'arcface_128_to_arcface_112_v2')
|
||||
crop_tensor = nn.functional.interpolate(crop_tensor, size = 112, mode = 'area')
|
||||
crop_tensor[:, :, :padding[0], :] = 0
|
||||
crop_tensor[:, :, 112 - padding[1]:, :] = 0
|
||||
|
||||
@@ -89,7 +89,6 @@ class HyperSwapTrainer(LightningModule):
|
||||
source_tensor, target_tensor = batch
|
||||
do_update = (batch_index + 1) % self.config_accumulate_size == 0
|
||||
generator_optimizer, discriminator_optimizer = self.optimizers() #type:ignore[attr-defined]
|
||||
|
||||
source_embedding = calc_embedding(self.generator_embedder, source_tensor, (0, 0, 0, 0))
|
||||
target_embedding = calc_embedding(self.generator_embedder, target_tensor, (0, 0, 0, 0))
|
||||
generator_target_features = self.generator.encode_features(target_tensor)
|
||||
|
||||
@@ -20,5 +20,5 @@ FaceMaskerModule : TypeAlias = Module
|
||||
|
||||
OptimizerSet : TypeAlias = Any
|
||||
|
||||
WarpTemplate = Literal['arcface_128_to_arcface_112_v2', 'ffhq_512_to_arcface_128', 'vggfacehq_512_to_arcface_128']
|
||||
WarpTemplateSet : TypeAlias = Dict[WarpTemplate, Tensor]
|
||||
ConvertTemplate = Literal['arcface_128_to_arcface_112_v2', 'ffhq_512_to_arcface_128', 'vggfacehq_512_to_arcface_128']
|
||||
ConvertTemplateSet : TypeAlias = Dict[ConvertTemplate, Tensor]
|
||||
|
||||
Reference in New Issue
Block a user