Adjust config and namings

This commit is contained in:
henryruhs
2025-03-04 16:24:24 +01:00
parent 176dced1f6
commit 430c71d031
7 changed files with 18 additions and 17 deletions
+2 -1
View File
@@ -29,9 +29,9 @@ This `config.ini` utilizes the MegaFace dataset to train the Face Swapper model.
[training.dataset]
file_pattern = .datasets/vggface2/**/*.jpg
warp_template = vgg_face_hq_to_arcface_128_v2
transform_size = 256
batch_mode = equal
batch_ratio = 0.2
resolution = 256
```
```
@@ -72,6 +72,7 @@ attribute_weight = 10
reconstruction_weight = 20
identity_weight = 20
gaze_weight = 0
gaze_scale_factor = 1
pose_weight = 0
expression_weight = 0
```
+2 -1
View File
@@ -1,6 +1,7 @@
[training.dataset]
file_pattern =
warp_template =
transform_size =
batch_mode =
batch_ratio =
@@ -26,7 +27,6 @@ num_filters =
num_layers =
num_discriminators =
kernel_size =
resolution =
[training.losses]
adversarial_weight =
@@ -34,6 +34,7 @@ attribute_weight =
reconstruction_weight =
identity_weight =
gaze_weight =
gaze_scale_factor =
pose_weight =
expression_weight =
+3 -3
View File
@@ -12,12 +12,12 @@ from .types import Batch, BatchMode, WarpTemplate
class DynamicDataset(Dataset[Tensor]):
def __init__(self, file_pattern : str, warp_template : WarpTemplate, batch_mode : BatchMode, batch_ratio : float, resolution : int) -> None:
def __init__(self, file_pattern : str, warp_template : WarpTemplate, transform_size : int, batch_mode : BatchMode, batch_ratio : float) -> None:
self.file_paths = glob.glob(file_pattern)
self.warp_template = warp_template
self.transform_size = transform_size
self.batch_mode = batch_mode
self.batch_ratio = batch_ratio
self.resolution = resolution
self.transforms = self.compose_transforms()
def __getitem__(self, index : int) -> Batch:
@@ -39,7 +39,7 @@ class DynamicDataset(Dataset[Tensor]):
[
AugmentTransform(),
transforms.ToPILImage(),
transforms.Resize((self.resolution, self.resolution), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.Resize((self.transform_size, self.transform_size), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
WarpTransform(self.warp_template),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+1 -1
View File
@@ -27,7 +27,7 @@ def warp_tensor(input_tensor : Tensor, warp_template : WarpTemplate) -> Tensor:
def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
crop_tensor = warp_tensor(input_tensor, 'arcface_128_v2_to_arcface_112_v2')
crop_tensor = nn.functional.interpolate(crop_tensor, size = (112, 112), mode = 'area')
crop_tensor = nn.functional.interpolate(crop_tensor, size = 112, mode = 'area')
crop_tensor[:, :, :padding[0], :] = 0
crop_tensor[:, :, 112 - padding[1]:, :] = 0
crop_tensor[:, :, :, :padding[2]] = 0
+4 -4
View File
@@ -169,15 +169,15 @@ class GazeLoss(nn.Module):
return gaze_loss, weighted_gaze_loss
def detect_gaze(self, input_tensor : Tensor) -> Gaze:
resolution = CONFIG.getint('training.dataset', 'resolution')
scale_factor = resolution / 256
scale_factor = CONFIG.getint('training.losses', 'gaze_scale_factor')
y_min = int(60 * scale_factor)
y_max = int(224 * scale_factor)
x_min = int(16 * scale_factor)
x_max = int(205 * scale_factor)
crop_tensor = input_tensor[:, :, y_min: y_max, x_min: x_max]
crop_tensor = input_tensor[:, :, y_min:y_max, x_min:x_max]
crop_tensor = (crop_tensor + 1) * 0.5
crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor)
crop_tensor = nn.functional.interpolate(crop_tensor, size = (448, 448), mode = 'bicubic')
crop_tensor = nn.functional.interpolate(crop_tensor, size = 448, mode = 'bicubic')
pitch_tensor, yaw_tensor = self.gazer(crop_tensor)
return pitch_tensor, yaw_tensor
+4 -5
View File
@@ -28,9 +28,9 @@ class AAD(nn.Module):
temp_tensors = self.pixel_shuffle_up_sample(source_embedding)
for index, layer in enumerate(self.layers[:-1]):
temp_shape = target_attributes[index + 1].shape[2:]
temp_tensor = layer(temp_tensors, target_attributes[index], source_embedding)
temp_tensors = nn.functional.interpolate(temp_tensor, temp_shape, mode = 'bilinear', align_corners = False)
temp_size = target_attributes[index + 1].shape[2:]
temp_tensors = nn.functional.interpolate(temp_tensor, temp_size, mode = 'bilinear', align_corners = False)
temp_tensors = self.layers[-1](temp_tensors, target_attributes[-1], source_embedding)
output_tensor = torch.tanh(temp_tensors)
@@ -113,10 +113,9 @@ class FeatureModulation(nn.Module):
def forward(self, input_tensor : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
temp_tensor = self.instance_norm(input_tensor)
temp_size = temp_tensor.shape[2:]
if attribute_embedding.shape[2:] != temp_tensor.shape[2:]:
attribute_embedding = nn.functional.interpolate(attribute_embedding, size = temp_tensor.shape[2:], mode = 'bilinear')
attribute_embedding = nn.functional.interpolate(attribute_embedding, size = temp_size, mode = 'bilinear')
attribute_scale = self.conv1(attribute_embedding)
attribute_shift = self.conv2(attribute_embedding)
attribute_modulation = attribute_scale * temp_tensor + attribute_shift
+2 -2
View File
@@ -198,15 +198,15 @@ def create_trainer() -> Trainer:
def train() -> None:
dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern')
dataset_warp_template = cast(WarpTemplate, CONFIG.get('training.dataset', 'warp_template'))
dataset_transform_size = CONFIG.getint('training.dataset', 'transform_size')
dataset_batch_mode = cast(BatchMode, CONFIG.get('training.dataset', 'batch_mode'))
dataset_batch_ratio = CONFIG.getfloat('training.dataset', 'batch_ratio')
dataset_resolution = CONFIG.getint('training.dataset', 'resolution')
output_resume_path = CONFIG.get('training.output', 'resume_path')
if torch.cuda.is_available():
torch.set_float32_matmul_precision('high')
dataset = DynamicDataset(dataset_file_pattern, dataset_warp_template, dataset_batch_mode, dataset_batch_ratio, dataset_resolution)
dataset = DynamicDataset(dataset_file_pattern, dataset_warp_template, dataset_transform_size, dataset_batch_mode, dataset_batch_ratio)
training_loader, validation_loader = create_loaders(dataset)
face_swapper_trainer = FaceSwapperTrainer()
trainer = create_trainer()