mirror of
https://github.com/facefusion/facefusion-labs.git
synced 2026-04-19 15:56:37 +02:00
changes
This commit is contained in:
@@ -31,6 +31,7 @@ file_pattern = .datasets/vggface2/**/*.jpg
|
||||
warp_template = vgg_face_hq_to_arcface_128_v2
|
||||
batch_mode = equal
|
||||
batch_ratio = 0.2
|
||||
resolution = 256
|
||||
```
|
||||
|
||||
```
|
||||
|
||||
@@ -26,6 +26,7 @@ num_filters =
|
||||
num_layers =
|
||||
num_discriminators =
|
||||
kernel_size =
|
||||
resolution =
|
||||
|
||||
[training.losses]
|
||||
adversarial_weight =
|
||||
|
||||
@@ -12,11 +12,12 @@ from .types import Batch, BatchMode, WarpTemplate
|
||||
|
||||
|
||||
class DynamicDataset(Dataset[Tensor]):
|
||||
def __init__(self, file_pattern : str, warp_template : WarpTemplate, batch_mode : BatchMode, batch_ratio : float) -> None:
|
||||
def __init__(self, file_pattern : str, warp_template : WarpTemplate, batch_mode : BatchMode, batch_ratio : float, resolution : int) -> None:
|
||||
self.file_paths = glob.glob(file_pattern)
|
||||
self.warp_template = warp_template
|
||||
self.batch_mode = batch_mode
|
||||
self.batch_ratio = batch_ratio
|
||||
self.resolution = resolution
|
||||
self.transforms = self.compose_transforms()
|
||||
|
||||
def __getitem__(self, index : int) -> Batch:
|
||||
@@ -38,7 +39,7 @@ class DynamicDataset(Dataset[Tensor]):
|
||||
[
|
||||
AugmentTransform(),
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize((256, 256), interpolation = transforms.InterpolationMode.BICUBIC),
|
||||
transforms.Resize((self.resolution, self.resolution), interpolation = transforms.InterpolationMode.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
WarpTransform(self.warp_template),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
|
||||
@@ -169,7 +169,13 @@ class GazeLoss(nn.Module):
|
||||
return gaze_loss, weighted_gaze_loss
|
||||
|
||||
def detect_gaze(self, input_tensor : Tensor) -> Gaze:
|
||||
crop_tensor = input_tensor[:, :, 60: 224, 16: 205]
|
||||
resolution = CONFIG.getint('training.dataset', 'resolution')
|
||||
scale_factor = resolution / 256
|
||||
y_min = int(60 * scale_factor)
|
||||
y_max = int(224 * scale_factor)
|
||||
x_min = int(16 * scale_factor)
|
||||
x_max = int(205 * scale_factor)
|
||||
crop_tensor = input_tensor[:, :, y_min: y_max, x_min: x_max]
|
||||
crop_tensor = (crop_tensor + 1) * 0.5
|
||||
crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor)
|
||||
crop_tensor = nn.functional.interpolate(crop_tensor, size = (448, 448), mode = 'bicubic')
|
||||
|
||||
@@ -28,8 +28,9 @@ class AAD(nn.Module):
|
||||
temp_tensors = self.pixel_shuffle_up_sample(source_embedding)
|
||||
|
||||
for index, layer in enumerate(self.layers[:-1]):
|
||||
temp_shape = target_attributes[index + 1].shape[2:]
|
||||
temp_tensor = layer(temp_tensors, target_attributes[index], source_embedding)
|
||||
temp_tensors = nn.functional.interpolate(temp_tensor, scale_factor = 2, mode = 'bilinear', align_corners = False)
|
||||
temp_tensors = nn.functional.interpolate(temp_tensor, temp_shape, mode = 'bilinear', align_corners = False)
|
||||
|
||||
temp_tensors = self.layers[-1](temp_tensors, target_attributes[-1], source_embedding)
|
||||
output_tensor = torch.tanh(temp_tensors)
|
||||
@@ -113,6 +114,9 @@ class FeatureModulation(nn.Module):
|
||||
def forward(self, input_tensor : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
|
||||
temp_tensor = self.instance_norm(input_tensor)
|
||||
|
||||
if attribute_embedding.shape[2:] != temp_tensor.shape[2:]:
|
||||
attribute_embedding = nn.functional.interpolate(attribute_embedding, size = temp_tensor.shape[2:], mode = 'bilinear')
|
||||
|
||||
attribute_scale = self.conv1(attribute_embedding)
|
||||
attribute_shift = self.conv2(attribute_embedding)
|
||||
attribute_modulation = attribute_scale * temp_tensor + attribute_shift
|
||||
|
||||
@@ -200,12 +200,13 @@ def train() -> None:
|
||||
dataset_warp_template = cast(WarpTemplate, CONFIG.get('training.dataset', 'warp_template'))
|
||||
dataset_batch_mode = cast(BatchMode, CONFIG.get('training.dataset', 'batch_mode'))
|
||||
dataset_batch_ratio = CONFIG.getfloat('training.dataset', 'batch_ratio')
|
||||
dataset_resolution = CONFIG.getint('training.dataset', 'resolution')
|
||||
output_resume_path = CONFIG.get('training.output', 'resume_path')
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.set_float32_matmul_precision('high')
|
||||
|
||||
dataset = DynamicDataset(dataset_file_pattern, dataset_warp_template, dataset_batch_mode, dataset_batch_ratio)
|
||||
dataset = DynamicDataset(dataset_file_pattern, dataset_warp_template, dataset_batch_mode, dataset_batch_ratio, dataset_resolution)
|
||||
training_loader, validation_loader = create_loaders(dataset)
|
||||
face_swapper_trainer = FaceSwapperTrainer()
|
||||
trainer = create_trainer()
|
||||
|
||||
Reference in New Issue
Block a user