diff --git a/face_swapper/src/networks/unet.py b/face_swapper/src/networks/unet.py index 5d6060c..e5d0976 100644 --- a/face_swapper/src/networks/unet.py +++ b/face_swapper/src/networks/unet.py @@ -1,8 +1,8 @@ +from typing import Tuple + import torch from torch import Tensor, nn -from face_swapper.src.types import TargetAttributes - class UNet(nn.Module): def __init__(self) -> None: @@ -35,7 +35,7 @@ class UNet(nn.Module): UpSample(128, 32) ]) - def forward(self, target_tensor : Tensor) -> TargetAttributes: + def forward(self, target_tensor : Tensor) -> Tuple[Tensor, ...]: down_features = [] up_features = [] temp_tensor = target_tensor @@ -49,8 +49,8 @@ class UNet(nn.Module): for index, up_sample in enumerate(self.up_samples): down_index = -(index + 2) - up_feature = up_sample(temp_tensor, down_features[down_index]) - up_features.append(up_feature) + temp_tensor = up_sample(temp_tensor, down_features[down_index]) + up_features.append(temp_tensor) output_tensor = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False) return bottleneck_tensor, *up_features, output_tensor