This commit is contained in:
henryruhs
2025-03-06 18:32:20 +01:00
parent b829d5e42c
commit 6bc44ad3d8
4 changed files with 5 additions and 6 deletions
+1 -1
View File
@@ -10,7 +10,7 @@ from torch.utils.data import Dataset
from torchvision import io, transforms
from .helper import warp_tensor
from .types import Batch, WarpTemplate, BatchMode
from .types import Batch, BatchMode, WarpTemplate
class DynamicDataset(Dataset[Tensor]):
+1 -1
View File
@@ -24,7 +24,7 @@ def infer() -> None:
generator = Generator(CONFIG_PARSER)
generator.load_state_dict(state_dict)
generator.eval()
embedder = torch.jit.load(embedder_path, map_location = 'cpu') # type:ignore[no-untyped-call]
embedder = torch.jit.load(config.get('embedder_path'), map_location = 'cpu') # type:ignore[no-untyped-call]
embedder.eval()
source_tensor = io.read_image(config.get('source_path'))
-1
View File
@@ -17,7 +17,6 @@ class Discriminator(nn.Module):
self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
self.discriminators = self.create_discriminators()
def create_discriminators(self) -> nn.ModuleList:
discriminators = nn.ModuleList()
+3 -3
View File
@@ -1,7 +1,7 @@
import os
from configparser import ConfigParser
import warnings
from typing import Tuple, cast
from configparser import ConfigParser
from typing import Tuple
import torch
import torchvision
@@ -17,7 +17,7 @@ from .helper import calc_embedding
from .models.discriminator import Discriminator
from .models.generator import Generator
from .models.loss import AdversarialLoss, AttributeLoss, DiscriminatorLoss, GazeLoss, IdentityLoss, MotionLoss, ReconstructionLoss
from .types import Batch, BatchMode, Embedding, OptimizerSet, WarpTemplate
from .types import Batch, Embedding, OptimizerSet
warnings.filterwarnings('ignore', category = UserWarning, module = 'torch')