Files
SimSwapPlus/train_scripts/trainer_multigpu_base.py
T
2022-02-08 16:37:30 +08:00

122 lines
4.5 KiB
Python

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################
# File: trainer_base.py
# Created Date: Sunday January 16th 2022
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 6th February 2022 3:06:45 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2022 Shanghai Jiao Tong University
#############################################################
class TrainerBase(object):
def __init__(self, config, reporter):
self.config = config
# logger
self.reporter = reporter
#========build evaluation dataloader=============#
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
# #================================================#
# print("Prepare the evaluation dataloader...")
# dlModulename = config["eval_dataloader"]
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
# dataloaderClass = getattr(package, 'EvalDataset')
# dataloader = dataloaderClass(eval_dataset,
# config["eval_batch_size"])
# self.eval_loader= dataloader
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
# if len(dataloader)%config["eval_batch_size"]>0:
# self.eval_iter+=1
# #==============build tensorboard=================#
# if self.config["logger"] == "tensorboard":
# from utilities.utilities import build_tensorboard
# tensorboard_writer = build_tensorboard(self.config["project_summary"])
# self.logger = tensorboard_writer
# elif self.config["logger"] == "wandb":
# import wandb
# wandb.init(project="Simswap_HQ", entity="xhchen", notes="512",
# tags=[self.config["tag"]], name=self.config["version"])
# wandb.config = {
# "total_step": self.config["total_step"],
# "batch_size": self.config["batch_size"]
# }
# self.logger = wandb
# TODO modify this function to build your models
def __init_framework__(self):
'''
This function is designed to define the framework,
and print the framework information into the log file
'''
#===============build models================#
pass
# TODO modify this function to configurate the optimizer of your pipeline
def __setup_optimizers__(self):
pass
# TODO modify this function to evaluate your model
# Evaluate the checkpoint
def __evaluation__(self,
step = 0,
**kwargs
):
pass
def __create_dataloader__(self,
config,
cur_gpu
):
# Data loader
#============build train dataloader==============#
# TODO to modify the key: "your_train_dataset" to get your train dataset path
dataset = config["dataset_paths"][config["dataset_name"]]
#================================================#
print("Prepare the train dataloader...")
dlModulename = config["dataloader"]
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
dataloaderClass = getattr(package, 'GetLoader')
dataloader_class= dataloaderClass
dataloader = dataloader_class(dataset,
cur_gpu,
config["batch_size"],
**config["dataset_params"])
return dataloader
def train(self):
#===============build framework================#
self.init_framework()
#===============build optimizer================#
# Optimizer
# TODO replace below lines to build your optimizer
print("build the optimizer...")
self.__setup_optimizers__()
# set the start point for training loop
if self.config["phase"] == "finetune":
self.start = self.config["checkpoint_step"]
else:
self.start = 0
# Start time
import datetime
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
from utilities.logo_class import logo_class
logo_class.print_start_training()