This commit is contained in:
harisreedhar
2025-02-20 18:09:37 +05:30
committed by henryruhs
parent dcf19634d1
commit b47c6b72ee
10 changed files with 112 additions and 170 deletions
+9 -32
View File
@@ -27,31 +27,23 @@ This `config.ini` utilizes the MegaFace dataset to train the Embedding Converter
```
[preparing.dataset]
dataset_path = .datasets/megaface/train.rec
crop_size = 112
process_limit = 650000
```
```
[preparing.model]
source_path = .models/arcface_w600k_r50.onnx
target_path = .models/arcface_simswap.onnx
```
```
[preparing.input]
directory_path = .inputs
source_path = .inputs/arcface_w600k_r50.npy
target_path = .inputs/arcface_simswap.npy
dataset_path = .datasets/images
image_pattern = {}/*.*g
```
```
[training.loader]
split_ratio = 0.8
batch_size = 51200
batch_size = 256
num_workers = 8
```
```
[training.model]
source_path = .models/arcface_w600k_r50.pt
target_path = .models/arcface_simswap.pt
```
```
[training.trainer]
learning_rate = 0.001
@@ -74,21 +66,6 @@ ir_version = 10
opset_version = 15
```
```
[execution]
providers = CUDAExecutionProvider
```
Preparing
---------
Prepare the embedding dataset.
```
python prepare.py
```
Training
--------
+5 -14
View File
@@ -1,22 +1,16 @@
[preparing.dataset]
dataset_path =
crop_size =
process_limit =
[preparing.model]
source_path =
target_path =
[preparing.input]
directory_path =
source_path =
target_path =
image_pattern =
[training.loader]
split_ratio =
batch_size =
num_workers =
[training.model]
source_path =
target_path =
[training.trainer]
learning_rate =
max_epochs =
@@ -32,6 +26,3 @@ source_path =
target_path =
ir_version =
opset_version =
[execution]
providers =
-6
View File
@@ -1,6 +0,0 @@
#!/usr/bin/env python3
from src.preparing import prepare
if __name__ == '__main__':
prepare()
+41
View File
@@ -0,0 +1,41 @@
import glob
import random
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from .helper import read_image
from .types import Batch, ImagePathList
class DataLoaderRecognition(Dataset[torch.Tensor]):
def __init__(self, dataset_path : str, dataset_image_pattern : str) -> None:
self.image_paths = self.prepare_image_paths(dataset_path, dataset_image_pattern)
self.dataset_total = len(self.image_paths)
self.transforms = self.compose_transforms()
def __getitem__(self, index : int) -> Batch:
target_image_path = random.choice(self.image_paths)
target_vision_frame = read_image(target_image_path)
target_tensor = self.transforms(target_vision_frame)
return target_tensor
def __len__(self) -> int:
return self.dataset_total
def prepare_image_paths(self, dataset_path : str, dataset_image_pattern : str) -> ImagePathList:
image_paths = glob.glob(dataset_image_pattern.format(dataset_path))
return image_paths
def compose_transforms(self) -> transforms:
transform = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize((112, 112), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
transforms.ToTensor(),
transforms.Lambda(lambda temp_tensor : temp_tensor[[2, 1, 0], :, :]),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
return transform
+2 -2
View File
@@ -19,6 +19,6 @@ def export() -> None:
makedirs(directory_path, exist_ok = True)
model = EmbeddingConverterTrainer.load_from_checkpoint(source_path, map_location = 'cpu')
model.eval()
model.ir_version = ir_version
input_tensor = torch.randn(1, 512)
model.ir_version = torch.tensor(ir_version)
input_tensor = (torch.randn(1, 512), )
torch.onnx.export(model, input_tensor, target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = opset_version)
+17
View File
@@ -0,0 +1,17 @@
import platform
import cv2
import numpy
from .types import VisionFrame
def is_windows() -> bool:
return platform.system().lower() == 'windows'
def read_image(image_path : str) -> VisionFrame:
if is_windows():
image_buffer = numpy.fromfile(image_path, dtype = numpy.uint8)
return cv2.imdecode(image_buffer, cv2.IMREAD_COLOR)
return cv2.imread(image_path)
-79
View File
@@ -1,79 +0,0 @@
import configparser
from os import makedirs
from os.path import isfile
from typing import List
import numpy
numpy.bool = numpy.bool_
from mxnet.io import ImageRecordIter
from onnxruntime import InferenceSession
from tqdm import tqdm
from .types import Embedding, EmbeddingDataset, VisionFrame
CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
def prepare_crop_vision_frame(crop_vision_frame : VisionFrame) -> VisionFrame:
crop_vision_frame = crop_vision_frame.astype(numpy.float32) / 255.0
crop_vision_frame = (crop_vision_frame - 0.5) * 2
return crop_vision_frame
def create_inference_session(model_path : str, execution_providers : List[str]) -> InferenceSession:
inference_session = InferenceSession(model_path, providers = execution_providers)
return inference_session
def forward(inference_session : InferenceSession, crop_vision_frame : VisionFrame) -> Embedding:
embedding = inference_session.run(None,
{
'input': crop_vision_frame
})[0]
return embedding
def create_embedding_dataset(dataset_reader : ImageRecordIter, source_inference_session : InferenceSession, target_inference_session : InferenceSession) -> EmbeddingDataset:
dataset_process_limit = CONFIG.getint('preparing.dataset', 'process_limit')
embedding_pairs = []
with tqdm(total = dataset_process_limit) as progress:
for batch in dataset_reader:
crop_vision_frame = batch.data[0].asnumpy()
crop_vision_frame = prepare_crop_vision_frame(crop_vision_frame)
source_embedding = forward(source_inference_session, crop_vision_frame)
target_embedding = forward(target_inference_session, crop_vision_frame)
embedding_pairs.append([ source_embedding, target_embedding ])
progress.update()
if progress.n == dataset_process_limit:
return numpy.concatenate(embedding_pairs, axis = 1).T
return numpy.concatenate(embedding_pairs, axis = 1).T
def prepare() -> None:
dataset_path = CONFIG.get('preparing.dataset', 'dataset_path')
dataset_crop_size = CONFIG.getint('preparing.dataset', 'crop_size')
model_source_path = CONFIG.get('preparing.model', 'source_path')
model_target_path = CONFIG.get('preparing.model', 'target_path')
input_directory_path = CONFIG.get('preparing.input', 'directory_path')
input_source_path = CONFIG.get('preparing.input', 'source_path')
input_target_path = CONFIG.get('preparing.input', 'target_path')
execution_providers = CONFIG.get('execution', 'providers').split(' ')
makedirs(input_directory_path, exist_ok = True)
if isfile(dataset_path) and isfile(model_source_path) and isfile(model_target_path):
dataset_reader = ImageRecordIter(
path_imgrec = dataset_path,
data_shape = (3, dataset_crop_size, dataset_crop_size),
batch_size = 1,
shuffle = False
)
source_inference_session = create_inference_session(model_source_path, execution_providers)
target_inference_session = create_inference_session(model_target_path, execution_providers)
embedding_dataset = create_embedding_dataset(dataset_reader, source_inference_session, target_inference_session)
numpy.save(input_source_path, embedding_dataset[..., 0].T)
numpy.save(input_target_path, embedding_dataset[..., 1].T)
+34 -28
View File
@@ -3,15 +3,15 @@ import os
from typing import Any, Tuple
import lightning
import numpy
import torch
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.tuner import Tuner
from torch import Tensor, nn
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split
from torch.utils.data import DataLoader, Dataset, random_split
from .data_loader import DataLoaderRecognition
from .models.embedding_converter import EmbeddingConverter
from .types import Batch, Embedding
@@ -22,23 +22,34 @@ CONFIG.read('config.ini')
class EmbeddingConverterTrainer(lightning.LightningModule):
def __init__(self) -> None:
super(EmbeddingConverterTrainer, self).__init__()
source_path = CONFIG.get('training.model', 'source_path')
target_path = CONFIG.get('training.model', 'target_path')
self.lr = CONFIG.getfloat('training.trainer', 'learning_rate')
self.embedding_converter = EmbeddingConverter()
self.source_embedder = torch.jit.load(source_path) # type:ignore[no-untyped-call]
self.target_embedder = torch.jit.load(target_path) # type:ignore[no-untyped-call]
self.source_embedder.eval()
self.target_embedder.eval()
self.mse_loss = nn.MSELoss()
def forward(self, source_embedding : Embedding) -> Embedding:
return self.embedding_converter(source_embedding)
def training_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target = batch
output_tensor = self(source_tensor)
loss_training = self.mse_loss(output_tensor, target)
with torch.no_grad():
source_embedding = self.source_embedder(batch)
target_embedding = self.target_embedder(batch)
output_embedding = self(source_embedding)
loss_training = self.mse_loss(output_embedding, target_embedding)
self.log('loss_training', loss_training, prog_bar = True)
return loss_training
def validation_step(self, batch : Batch, batch_index : int) -> Tensor:
source_tensor, target_tensor = batch
output_tensor = self(source_tensor)
validation = self.mse_loss(output_tensor, target_tensor)
with torch.no_grad():
source_embedding = self.source_embedder(batch)
target_embedding = self.target_embedder(batch)
output_embedding = self(source_embedding)
validation = self.mse_loss(output_embedding, target_embedding)
self.log('validation', validation, prog_bar = True)
return validation
@@ -53,36 +64,28 @@ class EmbeddingConverterTrainer(lightning.LightningModule):
'lr_scheduler':
{
'scheduler': scheduler,
'monitor': 'train_loss',
'monitor': 'loss_training',
'interval': 'epoch',
'frequency': 1
}
}
def create_loaders() -> Tuple[DataLoader, DataLoader]:
loader_batch_size = CONFIG.getint('training.loader', 'batch_size')
loader_num_workers = CONFIG.getint('training.loader', 'num_workers')
def create_loaders(dataset : Dataset[Any]) -> Tuple[DataLoader[Any], DataLoader[Any]]:
batch_size = CONFIG.getint('training.loader', 'batch_size')
num_workers = CONFIG.getint('training.loader', 'num_workers')
training_dataset, validate_dataset = split_dataset()
training_loader = DataLoader(training_dataset, batch_size = loader_batch_size, num_workers = loader_num_workers, shuffle = True, pin_memory = True)
validation_loader = DataLoader(validate_dataset, batch_size = loader_batch_size, num_workers = loader_num_workers, shuffle = False, pin_memory = True)
training_dataset, validate_dataset = split_dataset(dataset)
training_loader = DataLoader(training_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = DataLoader(validate_dataset, batch_size = batch_size, shuffle = False, num_workers = num_workers, drop_last = False, pin_memory = True, persistent_workers = True)
return training_loader, validation_loader
def split_dataset() -> Tuple[Dataset[Any], Dataset[Any]]:
input_source_path = CONFIG.get('preparing.input', 'source_path')
input_target_path = CONFIG.get('preparing.input', 'target_path')
def split_dataset(dataset : Dataset[Any]) -> Tuple[Dataset[Any], Dataset[Any]]:
loader_split_ratio = CONFIG.getfloat('training.loader', 'split_ratio')
source_tensor = torch.from_numpy(numpy.load(input_source_path)).float()
target_tensor = torch.from_numpy(numpy.load(input_target_path)).float()
dataset = TensorDataset(source_tensor, target_tensor)
dataset_size = len(dataset)
training_size = int(loader_split_ratio * len(dataset))
validation_size = int(dataset_size - training_size)
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ])
training_size = int(loader_split_ratio * len(dataset)) # type:ignore[operator, arg-type]
validation_size = len(dataset) - training_size # type:ignore[arg-type]
training_dataset, validate_dataset = random_split(dataset, [training_size, validation_size])
return training_dataset, validate_dataset
@@ -112,9 +115,12 @@ def create_trainer() -> Trainer:
def train() -> None:
dataset_path = CONFIG.get('preparing.dataset', 'dataset_path')
dataset_image_pattern = CONFIG.get('preparing.dataset', 'image_pattern')
resume_file_path = CONFIG.get('training.output', 'resume_file_path')
training_loader, validation_loader = create_loaders()
dataset = DataLoaderRecognition(dataset_path, dataset_image_pattern)
training_loader, validation_loader = create_loaders(dataset)
embedding_converter_trainer = EmbeddingConverterTrainer()
trainer = create_trainer()
tuner = Tuner(trainer)
+4 -7
View File
@@ -1,12 +1,9 @@
from typing import Any, Tuple, TypeAlias
from typing import Any, List, TypeAlias
from numpy.typing import NDArray
from torch import Tensor
Batch : TypeAlias = Tuple[Tensor, Tensor]
Embedding : TypeAlias = NDArray[Any]
EmbeddingDataset : TypeAlias = NDArray[Embedding]
FaceLandmark5 : TypeAlias = NDArray[Any]
ImagePathList : TypeAlias = List[str]
Batch : TypeAlias = Tensor
Embedding : TypeAlias = Tensor
VisionFrame : TypeAlias = NDArray[Any]
-2
View File
@@ -2,9 +2,7 @@
lightning==2.5.0
numpy==1.26.4
onnx==1.17.0
onnxruntime==1.20.1
opencv-python==4.11.0.86
mxnet==1.9.1
pytorch-msssim==1.0.0
torch==2.6.0
torchvision==0.21.0