mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Fix CI
This commit is contained in:
@@ -5,7 +5,7 @@ import numpy
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .types import Embedder, Embedding, Padding, VisionFrame, VisionTensor
|
||||
from .types import EmbedderModule, Embedding, Padding, VisionFrame, VisionTensor
|
||||
|
||||
|
||||
def is_windows() -> bool:
|
||||
@@ -48,7 +48,7 @@ def hinge_fake_loss(input_tensor : Tensor) -> Tensor:
|
||||
return fake_loss
|
||||
|
||||
|
||||
def calc_id_embedding(id_embedder : Embedder, vision_tensor : VisionTensor, padding : Padding) -> Embedding:
|
||||
def calc_id_embedding(id_embedder : EmbedderModule, vision_tensor : VisionTensor, padding : Padding) -> Embedding:
|
||||
crop_vision_tensor = vision_tensor[:, :, 15 : 241, 15 : 241]
|
||||
crop_vision_tensor = nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'area')
|
||||
crop_vision_tensor[:, :, :padding[0], :] = 0
|
||||
|
||||
@@ -5,13 +5,13 @@ import torch
|
||||
|
||||
from .helper import calc_id_embedding, convert_to_vision_frame, convert_to_vision_tensor, read_image
|
||||
from .models.generator import Generator
|
||||
from .types import Embedder, Generator, VisionFrame
|
||||
from .types import EmbedderModule, GeneratorModule, VisionFrame
|
||||
|
||||
CONFIG = configparser.ConfigParser()
|
||||
CONFIG.read('config.ini')
|
||||
|
||||
|
||||
def run_swap(generator : Generator, id_embedder : Embedder, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
|
||||
def run_swap(generator : GeneratorModule, id_embedder : EmbedderModule, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
|
||||
source_vision_tensor = convert_to_vision_tensor(source_vision_frame)
|
||||
target_vision_tensor = convert_to_vision_tensor(target_vision_frame)
|
||||
source_embedding = calc_id_embedding(id_embedder, source_vision_tensor, (0, 0, 0, 0))
|
||||
|
||||
@@ -28,5 +28,5 @@ VisionTensor : TypeAlias = Tensor
|
||||
GeneratorLossSet : TypeAlias = Dict[str, Tensor]
|
||||
DiscriminatorLossSet : TypeAlias = Dict[str, Tensor]
|
||||
|
||||
Generator : TypeAlias = Module
|
||||
Embedder : TypeAlias = Module
|
||||
GeneratorModule : TypeAlias = Module
|
||||
EmbedderModule : TypeAlias = Module
|
||||
|
||||
Reference in New Issue
Block a user