mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
Fix UNet
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from face_swapper.src.types import TargetAttributes
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@@ -35,7 +35,7 @@ class UNet(nn.Module):
|
||||
UpSample(128, 32)
|
||||
])
|
||||
|
||||
def forward(self, target_tensor : Tensor) -> TargetAttributes:
|
||||
def forward(self, target_tensor : Tensor) -> Tuple[Tensor, ...]:
|
||||
down_features = []
|
||||
up_features = []
|
||||
temp_tensor = target_tensor
|
||||
@@ -49,8 +49,8 @@ class UNet(nn.Module):
|
||||
|
||||
for index, up_sample in enumerate(self.up_samples):
|
||||
down_index = -(index + 2)
|
||||
up_feature = up_sample(temp_tensor, down_features[down_index])
|
||||
up_features.append(up_feature)
|
||||
temp_tensor = up_sample(temp_tensor, down_features[down_index])
|
||||
up_features.append(temp_tensor)
|
||||
|
||||
output_tensor = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
return bottleneck_tensor, *up_features, output_tensor
|
||||
|
||||
Reference in New Issue
Block a user