Revert the config dicts

This commit is contained in:
henryruhs
2025-03-06 23:28:11 +01:00
parent e5f983b2bf
commit 1dfd230fc5
7 changed files with 60 additions and 96 deletions
+2 -5
View File
@@ -10,11 +10,8 @@ from .types import Batch
class StaticDataset(Dataset[Tensor]):
def __init__(self, config_parser : ConfigParser) -> None:
self.config =\
{
'file_pattern': config_parser.get('training.dataset', 'file_pattern')
}
self.file_paths = glob.glob(self.config.get('file_pattern'))
self.config_file_pattern = config_parser.get('training.dataset', 'file_pattern')
self.file_paths = glob.glob(self.config_file_pattern)
self.transforms = self.compose_transforms()
def __getitem__(self, index : int) -> Batch:
+9 -12
View File
@@ -10,18 +10,15 @@ CONFIG_PARSER.read('config.ini')
def export() -> None:
config =\
{
'directory_path': CONFIG_PARSER.get('exporting', 'directory_path'),
'source_path': CONFIG_PARSER.get('exporting', 'source_path'),
'target_path': CONFIG_PARSER.get('exporting', 'target_path'),
'ir_version': CONFIG_PARSER.getint('exporting', 'ir_version'),
'opset_version': CONFIG_PARSER.getint('exporting', 'opset_version')
}
config_directory_path = CONFIG_PARSER.get('exporting', 'directory_path')
config_source_path = CONFIG_PARSER.get('exporting', 'source_path')
config_target_path = CONFIG_PARSER.get('exporting', 'target_path')
config_ir_version = CONFIG_PARSER.getint('exporting', 'ir_version')
config_opset_version = CONFIG_PARSER.getint('exporting', 'opset_version')
os.makedirs(config.get('directory_path'), exist_ok = True) # type:ignore[arg-type]
model = EmbeddingConverterTrainer.load_from_checkpoint(config.get('source_path'), map_location = 'cpu')
os.makedirs(config_directory_path, exist_ok = True)
model = EmbeddingConverterTrainer.load_from_checkpoint(config_source_path, map_location = 'cpu')
model.eval()
model.ir_version = torch.tensor(config.get('ir_version'))
model.ir_version = torch.tensor(config_ir_version)
input_tensor = torch.randn(1, 512)
torch.onnx.export(model, input_tensor, config.get('target_path'), input_names = [ 'input' ], output_names = [ 'output' ], opset_version = config.get('opset_version'))
torch.onnx.export(model, input_tensor, config_target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = config_opset_version)
+23 -38
View File
@@ -21,15 +21,12 @@ CONFIG_PARSER.read('config.ini')
class EmbeddingConverterTrainer(LightningModule):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config =\
{
'source_path': config_parser.get('training.model', 'source_path'),
'target_path': config_parser.get('training.model', 'target_path'),
'learning_rate': config_parser.getfloat('training.trainer', 'learning_rate')
}
self.config_source_path = config_parser.get('training.model', 'source_path')
self.config_target_path = config_parser.get('training.model', 'target_path')
self.config_learning_rate = config_parser.getfloat('training.trainer', 'learning_rate')
self.embedding_converter = EmbeddingConverter()
self.source_embedder = torch.jit.load(self.config.get('source_path'), map_location = 'cpu') # type:ignore[no-untyped-call]
self.target_embedder = torch.jit.load(self.config.get('target_path'), map_location = 'cpu') # type:ignore[no-untyped-call]
self.source_embedder = torch.jit.load(self.config_source_path, map_location = 'cpu')
self.target_embedder = torch.jit.load(self.config_target_path, map_location = 'cpu')
self.mse_loss = nn.MSELoss()
def forward(self, source_embedding : Embedding) -> Embedding:
@@ -53,7 +50,7 @@ class EmbeddingConverterTrainer(LightningModule):
return validation_score
def configure_optimizers(self) -> OptimizerSet:
optimizer = torch.optim.Adam(self.parameters(), lr = self.config.get('learning_rate'))
optimizer = torch.optim.Adam(self.parameters(), lr = self.config_learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
optimizer_set =\
{
@@ -71,52 +68,43 @@ class EmbeddingConverterTrainer(LightningModule):
def create_loaders(dataset : Dataset[Tensor]) -> Tuple[StatefulDataLoader[Tensor], StatefulDataLoader[Tensor]]:
config =\
{
'batch_size': CONFIG_PARSER.getint('training.loader', 'batch_size'),
'num_workers': CONFIG_PARSER.getint('training.loader', 'num_workers')
}
config_batch_size = CONFIG_PARSER.getint('training.loader', 'batch_size')
config_num_workers = CONFIG_PARSER.getint('training.loader', 'num_workers')
training_dataset, validate_dataset = split_dataset(dataset)
training_loader = StatefulDataLoader(training_dataset, batch_size = config.get('batch_size'), shuffle = True, num_workers = config.get('num_workers'), drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = StatefulDataLoader(validate_dataset, batch_size = config.get('batch_size'), shuffle = False, num_workers = config.get('num_workers'), pin_memory = True, persistent_workers = True)
training_loader = StatefulDataLoader(training_dataset, batch_size = config_batch_size, shuffle = True, num_workers = config_num_workers, drop_last = True, pin_memory = True, persistent_workers = True)
validation_loader = StatefulDataLoader(validate_dataset, batch_size = config_batch_size, shuffle = False, num_workers = config_num_workers, pin_memory = True, persistent_workers = True)
return training_loader, validation_loader
def split_dataset(dataset : Dataset[Tensor]) -> Tuple[Dataset[Tensor], Dataset[Tensor]]:
config =\
{
'split_ratio': CONFIG_PARSER.getfloat('training.loader', 'split_ratio')
}
config_split_ratio = CONFIG_PARSER.getfloat('training.loader', 'split_ratio')
dataset_size = len(dataset) # type:ignore[arg-type]
training_size = int(dataset_size * config.get('split_ratio'))
training_size = int(dataset_size * config_split_ratio)
validation_size = int(dataset_size - training_size)
training_dataset, validate_dataset = random_split(dataset, [ training_size, validation_size ])
return training_dataset, validate_dataset
def create_trainer() -> Trainer:
config =\
{
'max_epochs': CONFIG_PARSER.getint('training.trainer', 'max_epochs'),
'precision': CONFIG_PARSER.get('training.trainer', 'precision'),
'directory_path': CONFIG_PARSER.get('training.output', 'directory_path'),
'file_pattern': CONFIG_PARSER.get('training.output', 'file_pattern')
}
config_max_epochs = CONFIG_PARSER.getint('training.trainer', 'max_epochs')
config_precision = CONFIG_PARSER.get('training.trainer', 'precision')
config_directory_path = CONFIG_PARSER.get('training.output', 'directory_path')
config_file_pattern = CONFIG_PARSER.get('training.output', 'file_pattern')
logger = TensorBoardLogger('.logs', name = 'embedding_converter')
return Trainer(
logger = logger,
log_every_n_steps = 10,
max_epochs = config.get('max_epochs'),
precision = config.get('precision'),
max_epochs = config_max_epochs,
precision = config_precision,
callbacks =
[
ModelCheckpoint(
monitor = 'training_loss',
dirpath = config.get('directory_path'),
filename = config.get('file_pattern'),
dirpath = config_directory_path,
filename = config_file_pattern,
every_n_epochs = 1,
save_top_k = 3,
save_last = True
@@ -126,10 +114,7 @@ def create_trainer() -> Trainer:
def train() -> None:
config =\
{
'resume_path': CONFIG_PARSER.get('training.output', 'resume_path')
}
config_resume_path = CONFIG_PARSER.get('training.output', 'resume_path')
if torch.cuda.is_available():
torch.set_float32_matmul_precision('high')
@@ -139,7 +124,7 @@ def train() -> None:
embedding_converter_trainer = EmbeddingConverterTrainer(CONFIG_PARSER)
trainer = create_trainer()
if os.path.exists(config.get('resume_path')):
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config.get('resume_path'))
if os.path.exists(config_resume_path):
trainer.fit(embedding_converter_trainer, training_loader, validation_loader, ckpt_path = config_resume_path)
else:
trainer.fit(embedding_converter_trainer, training_loader, validation_loader)
+2 -5
View File
@@ -9,10 +9,7 @@ from ..networks.nld import NLD
class Discriminator(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config =\
{
'num_discriminators': config_parser.getint('training.model.discriminator', 'num_discriminators')
}
self.config_num_discriminators = config_parser.getint('training.model.discriminator', 'num_discriminators')
self.config_parser = config_parser
self.avg_pool = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = (1, 1), count_include_pad = False)
self.discriminators = self.create_discriminators()
@@ -20,7 +17,7 @@ class Discriminator(nn.Module):
def create_discriminators(self) -> nn.ModuleList:
discriminators = nn.ModuleList()
for _ in range(self.config.get('num_discriminators')):
for _ in range(self.config_num_discriminators):
discriminator = NLD(self.config_parser).sequences
discriminators.append(discriminator)
+6 -12
View File
@@ -101,28 +101,22 @@ class ReconstructionLoss(nn.Module):
class IdentityLoss(nn.Module):
def __init__(self, config_parser : ConfigParser, embedder : EmbedderModule) -> None:
super().__init__()
self.config =\
{
'identity_weight': config_parser.getfloat('training.losses', 'identity_weight')
}
self.config_identity_weight = config_parser.getfloat('training.losses', 'identity_weight')
self.embedder = embedder
def forward(self, source_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tensor, Tensor]:
output_embedding = calc_embedding(self.embedder, output_tensor, (30, 0, 10, 10))
source_embedding = calc_embedding(self.embedder, source_tensor, (30, 0, 10, 10))
identity_loss = (1 - torch.cosine_similarity(source_embedding, output_embedding)).mean()
weighted_identity_loss = identity_loss * self.config.get('identity_weight')
weighted_identity_loss = identity_loss * self.config_identity_weight
return identity_loss, weighted_identity_loss
class MotionLoss(nn.Module):
def __init__(self, config_parser : ConfigParser, motion_extractor : MotionExtractorModule):
super().__init__()
self.config =\
{
'pose_weight': config_parser.getfloat('training.losses', 'pose_weight'),
'expression_weight': config_parser.getfloat('training.losses', 'expression_weight')
}
self.config_pose_weight = config_parser.getfloat('training.losses', 'pose_weight')
self.expression_weight = config_parser.getfloat('training.losses', 'expression_weight')
self.motion_extractor = motion_extractor
self.mse_loss = nn.MSELoss()
@@ -141,12 +135,12 @@ class MotionLoss(nn.Module):
temp_tensors.append(temp_tensor)
pose_loss = torch.stack(temp_tensors).mean()
weighted_pose_loss = pose_loss * self.config.get('pose_weight')
weighted_pose_loss = pose_loss * self.config_pose_weight
return pose_loss, weighted_pose_loss
def calc_expression_loss(self, target_expression : Tensor, output_expression : Tensor) -> Tuple[Tensor, Tensor]:
expression_loss = (1 - torch.cosine_similarity(target_expression, output_expression)).mean()
weighted_expression_loss = expression_loss * self.config.get('expression_weight')
weighted_expression_loss = expression_loss * self.config_expression_weight
return expression_loss, weighted_expression_loss
def get_motions(self, input_tensor : Tensor) -> Tuple[Tuple[Tensor, ...], Tensor]:
+11 -14
View File
@@ -7,31 +7,28 @@ from torch import Tensor, nn
class NLD(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config =\
{
'input_channels': config_parser.getint('training.model.discriminator', 'input_channels'),
'num_filters': config_parser.getint('training.model.discriminator', 'num_filters'),
'kernel_size': config_parser.getint('training.model.discriminator', 'kernel_size'),
'num_layers': config_parser.getint('training.model.discriminator', 'num_layers')
}
self.config_input_channels = config_parser.getint('training.model.discriminator', 'input_channels')
self.config_num_filters = config_parser.getint('training.model.discriminator', 'num_filters')
self.config_kernel_size = config_parser.getint('training.model.discriminator', 'kernel_size')
self.config_num_layers = config_parser.getint('training.model.discriminator', 'num_layers')
self.layers = self.create_layers()
self.sequences = nn.Sequential(*self.layers)
def create_layers(self) -> nn.ModuleList:
padding = math.ceil((self.config.get('kernel_size') - 1) / 2)
current_filters = self.config.get('num_filters')
padding = math.ceil((self.config_kernel_size - 1) / 2)
current_filters = self.config_num_filters
layers = nn.ModuleList(
[
nn.Conv2d(self.config.get('input_channels'), current_filters, kernel_size = self.config.get('kernel_size'), stride = 2, padding = padding),
nn.Conv2d(self.config_input_channels, current_filters, kernel_size = self.config_kernel_size, stride = 2, padding = padding),
nn.LeakyReLU(0.2, True)
])
for _ in range(1, self.config.get('num_layers')):
for _ in range(1, self.config_num_layers):
previous_filters = current_filters
current_filters = min(current_filters * 2, 512)
layers +=\
[
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config.get('kernel_size'), stride = 2, padding = padding),
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config_kernel_size, stride = 2, padding = padding),
nn.InstanceNorm2d(current_filters),
nn.LeakyReLU(0.2, True)
]
@@ -40,10 +37,10 @@ class NLD(nn.Module):
current_filters = min(current_filters * 2, 512)
layers +=\
[
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config.get('kernel_size'), padding = padding),
nn.Conv2d(previous_filters, current_filters, kernel_size = self.config_kernel_size, padding = padding),
nn.InstanceNorm2d(current_filters),
nn.LeakyReLU(0.2, True),
nn.Conv2d(current_filters, 1, kernel_size = self.config.get('kernel_size'), padding = padding)
nn.Conv2d(current_filters, 1, kernel_size = self.config_kernel_size, padding = padding)
]
return layers
+7 -10
View File
@@ -8,10 +8,7 @@ from torch import Tensor, nn
class UNet(nn.Module):
def __init__(self, config_parser : ConfigParser) -> None:
super().__init__()
self.config =\
{
'output_size': config_parser.getint('training.model.generator', 'output_size')
}
self.config_output_size = config_parser.getint('training.model.generator', 'output_size')
self.down_samples = self.create_down_samples()
self.up_samples = self.create_up_samples()
@@ -25,20 +22,20 @@ class UNet(nn.Module):
DownSample(256, 512)
])
if self.config.get('output_size') == 128:
if self.config_output_size == 128:
down_samples.extend(
[
DownSample(512, 512)
])
if self.config.get('output_size') == 256:
if self.config_output_size == 256:
down_samples.extend(
[
DownSample(512, 1024),
DownSample(1024, 1024)
])
if self.config.get('output_size') == 512:
if self.config_output_size == 512:
down_samples.extend(
[
DownSample(512, 1024),
@@ -51,20 +48,20 @@ class UNet(nn.Module):
def create_up_samples(self) -> nn.ModuleList:
up_samples = nn.ModuleList()
if self.config.get('output_size') == 128:
if self.config_output_size == 128:
up_samples.extend(
[
UpSample(512, 512)
])
if self.config.get('output_size') == 256:
if self.config_output_size == 256:
up_samples.extend(
[
UpSample(1024, 1024),
UpSample(2048, 512)
])
if self.config.get('output_size') == 512:
if self.config_output_size == 512:
up_samples.extend(
[
UpSample(2048, 2048),