Introduce Transform classes, Add albumentations

This commit is contained in:
henryruhs
2025-03-03 18:23:18 +01:00
parent 9dc1031fa5
commit e5a4a54e61
+35 -6
View File
@@ -2,6 +2,7 @@ import glob
import os
import random
import albumentations
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import io, transforms
@@ -37,16 +38,12 @@ class DynamicDataset(Dataset[Tensor]):
[
transforms.ToPILImage(),
transforms.Resize((256, 256), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1)),
AlterTransform(),
transforms.ToTensor(),
transforms.Lambda(self.warp_tensor),
WarpTransform(self.warp_template),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def warp_tensor(self, temp_tensor : Tensor) -> Tensor:
return warp_tensor(temp_tensor.unsqueeze(0), self.warp_template).squeeze(0)
def prepare_different_batch(self, source_path : str) -> Batch:
target_path = random.choice(self.file_paths)
source_tensor = io.read_image(source_path)
@@ -69,3 +66,35 @@ class DynamicDataset(Dataset[Tensor]):
target_tensor = io.read_image(target_path)
target_tensor = self.transforms(target_tensor)
return source_tensor, target_tensor
class AlterTransform:
def __init__(self) -> None:
self.transforms = self.compose_transforms()
def __call__(self, input_tensor : Tensor) -> Tensor:
temp_tensor = input_tensor.numpy().transpose(1, 2, 0)
return self.transforms(temp_tensor).get('image')
@staticmethod
def compose_transforms() -> albumentations.Compose:
return albumentations.Compose(
[
albumentations.RandomBrightnessContrast(p = 0.3),
albumentations.OneOf(
[
albumentations.MotionBlur(p = 0.1),
albumentations.MedianBlur(p = 0.1)
], p = 0.3),
albumentations.ColorJitter(p = 0.1),
])
class WarpTransform:
def __init__(self, warp_template : WarpTemplate) -> None:
self.warp_template = warp_template
def __call__(self, input_tensor : Tensor) -> Tensor:
temp_tensor = input_tensor.unsqueeze(0)
return warp_tensor(temp_tensor, self.warp_template).squeeze(0)