This commit is contained in:
henryruhs
2025-02-14 16:10:39 +01:00
parent 953525e6b0
commit b69f69d015
+5 -5
View File
@@ -1,8 +1,8 @@
from typing import Tuple
import torch
from torch import Tensor, nn
from face_swapper.src.types import TargetAttributes
class UNet(nn.Module):
def __init__(self) -> None:
@@ -35,7 +35,7 @@ class UNet(nn.Module):
UpSample(128, 32)
])
def forward(self, target_tensor : Tensor) -> TargetAttributes:
def forward(self, target_tensor : Tensor) -> Tuple[Tensor, ...]:
down_features = []
up_features = []
temp_tensor = target_tensor
@@ -49,8 +49,8 @@ class UNet(nn.Module):
for index, up_sample in enumerate(self.up_samples):
down_index = -(index + 2)
up_feature = up_sample(temp_tensor, down_features[down_index])
up_features.append(up_feature)
temp_tensor = up_sample(temp_tensor, down_features[down_index])
up_features.append(temp_tensor)
output_tensor = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False)
return bottleneck_tensor, *up_features, output_tensor