mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-05-22 23:59:40 +02:00
Fix CI
This commit is contained in:
@@ -10,7 +10,7 @@ from torch.utils.data import Dataset
|
||||
from torchvision import io, transforms
|
||||
|
||||
from .helper import warp_tensor
|
||||
from .types import Batch, WarpTemplate, BatchMode
|
||||
from .types import Batch, BatchMode, WarpTemplate
|
||||
|
||||
|
||||
class DynamicDataset(Dataset[Tensor]):
|
||||
|
||||
@@ -24,7 +24,7 @@ def infer() -> None:
|
||||
generator = Generator(CONFIG_PARSER)
|
||||
generator.load_state_dict(state_dict)
|
||||
generator.eval()
|
||||
embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
embedder = torch.jit.load(config.get('embedder_path'), map_location = 'cpu') # type:ignore[no-untyped-call]
|
||||
embedder.eval()
|
||||
|
||||
source_tensor = io.read_image(config.get('source_path'))
|
||||
|
||||
@@ -17,7 +17,6 @@ class Discriminator(nn.Module):
|
||||
self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
|
||||
self.discriminators = self.create_discriminators()
|
||||
|
||||
|
||||
def create_discriminators(self) -> nn.ModuleList:
|
||||
discriminators = nn.ModuleList()
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from configparser import ConfigParser
|
||||
import warnings
|
||||
from typing import Tuple, cast
|
||||
from configparser import ConfigParser
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
@@ -17,7 +17,7 @@ from .helper import calc_embedding
|
||||
from .models.discriminator import Discriminator
|
||||
from .models.generator import Generator
|
||||
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, MotionLoss, ReconstructionLoss
|
||||
from .types import Batch, BatchMode, Embedding, OptimizerSet, WarpTemplate
|
||||
from .types import Batch, Embedding, OptimizerSet
|
||||
|
||||
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user