mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Fix some naming and types
This commit is contained in:
+28
-12
@@ -1,32 +1,48 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .types import AlignmentMatrices, EmbedderModule, Embedding, Padding
|
||||
from .types import WarpMatrixSet, EmbedderModule, Embedding, Padding
|
||||
|
||||
ALIGNMENT_MATRICES: AlignmentMatrices =\
|
||||
WARP_MATRIX_SET : WarpMatrixSet =\
|
||||
{
|
||||
'__vgg_face_hq__to__arcface_128_v2__': torch.tensor(
|
||||
'vgg_face_hq_to_arcface_128_v2': torch.tensor(
|
||||
[
|
||||
[
|
||||
[ 1.01305414, -0.00140513, -0.00585911 ],
|
||||
[ 0.00140513, 1.01305414, 0.11169602 ]
|
||||
], dtype = torch.float32),
|
||||
'__arcface_128_v2__to__arcface_112_v2__': torch.tensor(
|
||||
1.01305414,
|
||||
-0.00140513,
|
||||
-0.00585911
|
||||
],
|
||||
[
|
||||
[ 8.75000016e-01, -1.07193451e-08, 3.80446920e-10 ],
|
||||
[ 1.07193451e-08, 8.75000016e-01, -1.25000007e-01 ]
|
||||
], dtype = torch.float32)
|
||||
0.00140513,
|
||||
1.01305414,
|
||||
0.11169602
|
||||
]
|
||||
], dtype = torch.float32),
|
||||
'arcface_128_v2_to_arcface_112_v2': torch.tensor(
|
||||
[
|
||||
[
|
||||
8.75000016e-01,
|
||||
-1.07193451e-08,
|
||||
3.80446920e-10
|
||||
],
|
||||
[
|
||||
1.07193451e-08,
|
||||
8.75000016e-01,
|
||||
-1.25000007e-01
|
||||
]
|
||||
], dtype = torch.float32)
|
||||
}
|
||||
|
||||
|
||||
def warp_tensor(input_tensor : Tensor, alignment_matrix : str) -> Tensor:
|
||||
matrix = ALIGNMENT_MATRICES.get(alignment_matrix).repeat(input_tensor.shape[0], 1, 1)
|
||||
matrix = WARP_MATRIX_SET.get(alignment_matrix).repeat(input_tensor.shape[0], 1, 1)
|
||||
grid = nn.functional.affine_grid(matrix.to(input_tensor.device), list(input_tensor.shape))
|
||||
output_tensor = nn.functional.grid_sample(input_tensor, grid, align_corners = False, padding_mode = 'reflection')
|
||||
return output_tensor
|
||||
|
||||
|
||||
def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
|
||||
crop_tensor = warp_tensor(input_tensor, '__arcface_128_v2__to__arcface_112_v2__')
|
||||
crop_tensor = warp_tensor(input_tensor, 'arcface_128_v2_to_arcface_112_v2')
|
||||
crop_tensor = nn.functional.interpolate(crop_tensor, size = (112, 112), mode = 'area')
|
||||
crop_tensor[:, :, :padding[0], :] = 0
|
||||
crop_tensor[:, :, 112 - padding[1]:, :] = 0
|
||||
|
||||
@@ -18,4 +18,4 @@ MotionExtractorModule : TypeAlias = Module
|
||||
|
||||
OptimizerConfig : TypeAlias = Any
|
||||
|
||||
AlignmentMatrices : TypeAlias = Dict[str, Tensor]
|
||||
WarpMatrixSet : TypeAlias = Dict[str, Tensor]
|
||||
|
||||
Reference in New Issue
Block a user