diff --git a/face_swapper/src/networks/unet.py b/face_swapper/src/networks/unet.py index e5d0976..8a97df3 100644 --- a/face_swapper/src/networks/unet.py +++ b/face_swapper/src/networks/unet.py @@ -48,8 +48,8 @@ class UNet(nn.Module): temp_tensor = bottleneck_tensor for index, up_sample in enumerate(self.up_samples): - down_index = -(index + 2) - temp_tensor = up_sample(temp_tensor, down_features[down_index]) + skip_tensor = down_features[-(index + 2)] + temp_tensor = up_sample(temp_tensor, skip_tensor) up_features.append(temp_tensor) output_tensor = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False)