This commit is contained in:
henryruhs
2025-02-18 11:05:58 +01:00
parent e1ba81f220
commit 39818a16df
3 changed files with 6 additions and 6 deletions
+2 -2
View File
@@ -5,7 +5,7 @@ import numpy
import torch
from torch import Tensor, nn
from .types import Embedder, Embedding, Padding, VisionFrame, VisionTensor
from .types import EmbedderModule, Embedding, Padding, VisionFrame, VisionTensor
def is_windows() -> bool:
@@ -48,7 +48,7 @@ def hinge_fake_loss(input_tensor : Tensor) -> Tensor:
return fake_loss
def calc_id_embedding(id_embedder : Embedder, vision_tensor : VisionTensor, padding : Padding) -> Embedding:
def calc_id_embedding(id_embedder : EmbedderModule, vision_tensor : VisionTensor, padding : Padding) -> Embedding:
crop_vision_tensor = vision_tensor[:, :, 15 : 241, 15 : 241]
crop_vision_tensor = nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'area')
crop_vision_tensor[:, :, :padding[0], :] = 0
+2 -2
View File
@@ -5,13 +5,13 @@ import torch
from .helper import calc_id_embedding, convert_to_vision_frame, convert_to_vision_tensor, read_image
from .models.generator import Generator
from .types import Embedder, Generator, VisionFrame
from .types import EmbedderModule, GeneratorModule, VisionFrame
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
def run_swap(generator : Generator, id_embedder : Embedder, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
def run_swap(generator : GeneratorModule, id_embedder : EmbedderModule, source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame:
source_vision_tensor = convert_to_vision_tensor(source_vision_frame)
target_vision_tensor = convert_to_vision_tensor(target_vision_frame)
source_embedding = calc_id_embedding(id_embedder, source_vision_tensor, (0, 0, 0, 0))
+2 -2
View File
@@ -28,5 +28,5 @@ VisionTensor : TypeAlias = Tensor
GeneratorLossSet : TypeAlias = Dict[str, Tensor]
DiscriminatorLossSet : TypeAlias = Dict[str, Tensor]
Generator : TypeAlias = Module
Embedder : TypeAlias = Module
GeneratorModule : TypeAlias = Module
EmbedderModule : TypeAlias = Module