mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
changes
This commit is contained in:
@@ -6,6 +6,7 @@ from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import io, transforms
|
||||
|
||||
from .helper import warp_tensor
|
||||
from .types import Batch
|
||||
|
||||
|
||||
@@ -35,6 +36,7 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
|
||||
transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1)),
|
||||
transforms.ToTensor(),
|
||||
lambda temp: warp_tensor(temp.unsqueeze(0), '__vgg_face_hq__to__arcface_128_v2__').squeeze(0),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
|
||||
@@ -1,10 +1,32 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .types import EmbedderModule, Embedding, Padding
|
||||
from .types import AlignmentMatrices, EmbedderModule, Embedding, Padding
|
||||
|
||||
ALIGNMENT_MATRICES: AlignmentMatrices =\
|
||||
{
|
||||
'__vgg_face_hq__to__arcface_128_v2__': torch.tensor(
|
||||
[
|
||||
[ 1.01305414, -0.00140513, -0.00585911 ],
|
||||
[ 0.00140513, 1.01305414, 0.11169602 ]
|
||||
], dtype = torch.float32),
|
||||
'__arcface_128_v2__to__arcface_112_v2__': torch.tensor(
|
||||
[
|
||||
[ 8.75000016e-01, -1.07193451e-08, 3.80446920e-10 ],
|
||||
[ 1.07193451e-08, 8.75000016e-01, -1.25000007e-01 ]
|
||||
], dtype = torch.float32)
|
||||
}
|
||||
|
||||
|
||||
def warp_tensor(input_tensor : Tensor, alignment_matrix : str) -> Tensor:
|
||||
matrix = ALIGNMENT_MATRICES.get(alignment_matrix).repeat(input_tensor.shape[0], 1, 1)
|
||||
grid = nn.functional.affine_grid(matrix.to(input_tensor.device), list(input_tensor.shape))
|
||||
output_tensor = nn.functional.grid_sample(input_tensor, grid, align_corners = False, padding_mode = 'reflection')
|
||||
return output_tensor
|
||||
|
||||
|
||||
def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
|
||||
crop_tensor = input_tensor[:, :, 15: 241, 15: 241]
|
||||
crop_tensor = warp_tensor(input_tensor, '__arcface_128_v2__to__arcface_112_v2__')
|
||||
crop_tensor = nn.functional.interpolate(crop_tensor, size = (112, 112), mode = 'area')
|
||||
crop_tensor[:, :, :padding[0], :] = 0
|
||||
crop_tensor[:, :, 112 - padding[1]:, :] = 0
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Tuple, TypeAlias
|
||||
from typing import Any, Dict, Tuple, TypeAlias
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
@@ -17,3 +17,5 @@ LandmarkerModule : TypeAlias = Module
|
||||
MotionExtractorModule : TypeAlias = Module
|
||||
|
||||
OptimizerConfig : TypeAlias = Any
|
||||
|
||||
AlignmentMatrices : TypeAlias = Dict[str, Tensor]
|
||||
|
||||
Reference in New Issue
Block a user