Use skip_tensor variable

This commit is contained in:
henryruhs
2025-02-14 16:38:25 +01:00
parent 650551c19b
commit 88c4e53192
+2 -2
View File
@@ -48,8 +48,8 @@ class UNet(nn.Module):
temp_tensor = bottleneck_tensor
for index, up_sample in enumerate(self.up_samples):
down_index = -(index + 2)
temp_tensor = up_sample(temp_tensor, down_features[down_index])
skip_tensor = down_features[-(index + 2)]
temp_tensor = up_sample(temp_tensor, skip_tensor)
up_features.append(temp_tensor)
output_tensor = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False)