diff --git a/face_swapper/src/models/discriminator.py b/face_swapper/src/models/discriminator.py index c8fc1c4..bd78ae0 100644 --- a/face_swapper/src/models/discriminator.py +++ b/face_swapper/src/models/discriminator.py @@ -1,10 +1,9 @@ import configparser from typing import List -from torch import nn +from torch import Tensor, nn from ..networks.nld import NLD -from ..types import VisionTensor CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') @@ -31,7 +30,7 @@ class Discriminator(nn.Module): return discriminators - def forward(self, input_tensor : VisionTensor) -> List[List[VisionTensor]]: + def forward(self, input_tensor : Tensor) -> List[List[Tensor]]: temp_tensor = input_tensor output_tensors = []