Restore RGB normalization

This commit is contained in:
henryruhs
2025-02-24 16:21:37 +01:00
parent 6eff69a41a
commit a951d700fc
2 changed files with 2 additions and 2 deletions
+1 -1
View File
@@ -28,5 +28,5 @@ class StaticDataset(Dataset[Tensor]):
transforms.Resize((112, 112), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5)
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
+1 -1
View File
@@ -34,7 +34,7 @@ class DynamicDataset(Dataset[Tensor]):
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1)),
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5)
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def prepare_different_batch(self, source_image_path : str) -> Batch: