From b69f69d015dac07c6adcc2950d165e1e5cfdc0be Mon Sep 17 00:00:00 2001 From: henryruhs Date: Fri, 14 Feb 2025 16:10:39 +0100 Subject: [PATCH] Fix UNet --- face_swapper/src/networks/unet.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/face_swapper/src/networks/unet.py b/face_swapper/src/networks/unet.py index 5d6060c..e5d0976 100644 --- a/face_swapper/src/networks/unet.py +++ b/face_swapper/src/networks/unet.py @@ -1,8 +1,8 @@ +from typing import Tuple + import torch from torch import Tensor, nn -from face_swapper.src.types import TargetAttributes - class UNet(nn.Module): def __init__(self) -> None: @@ -35,7 +35,7 @@ class UNet(nn.Module): UpSample(128, 32) ]) - def forward(self, target_tensor : Tensor) -> TargetAttributes: + def forward(self, target_tensor : Tensor) -> Tuple[Tensor, ...]: down_features = [] up_features = [] temp_tensor = target_tensor @@ -49,8 +49,8 @@ class UNet(nn.Module): for index, up_sample in enumerate(self.up_samples): down_index = -(index + 2) - up_feature = up_sample(temp_tensor, down_features[down_index]) - up_features.append(up_feature) + temp_tensor = up_sample(temp_tensor, down_features[down_index]) + up_features.append(temp_tensor) output_tensor = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False) return bottleneck_tensor, *up_features, output_tensor