122 lines
4.5 KiB
Python
122 lines
4.5 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
#############################################################
|
|
# File: trainer_base.py
|
|
# Created Date: Sunday January 16th 2022
|
|
# Author: Chen Xuanhong
|
|
# Email: chenxuanhongzju@outlook.com
|
|
# Last Modified: Sunday, 6th February 2022 3:06:45 pm
|
|
# Modified By: Chen Xuanhong
|
|
# Copyright (c) 2022 Shanghai Jiao Tong University
|
|
#############################################################
|
|
|
|
class TrainerBase(object):
|
|
|
|
def __init__(self, config, reporter):
|
|
|
|
self.config = config
|
|
# logger
|
|
self.reporter = reporter
|
|
|
|
|
|
|
|
#========build evaluation dataloader=============#
|
|
# TODO to modify the key: "your_eval_dataset" to get your evaluation dataset path
|
|
# eval_dataset = config["dataset_paths"][config["eval_dataset_name"]]
|
|
|
|
# #================================================#
|
|
# print("Prepare the evaluation dataloader...")
|
|
# dlModulename = config["eval_dataloader"]
|
|
# package = __import__("data_tools.eval_dataloader_%s"%dlModulename, fromlist=True)
|
|
# dataloaderClass = getattr(package, 'EvalDataset')
|
|
# dataloader = dataloaderClass(eval_dataset,
|
|
# config["eval_batch_size"])
|
|
# self.eval_loader= dataloader
|
|
|
|
# self.eval_iter = len(dataloader)//config["eval_batch_size"]
|
|
# if len(dataloader)%config["eval_batch_size"]>0:
|
|
# self.eval_iter+=1
|
|
|
|
# #==============build tensorboard=================#
|
|
# if self.config["logger"] == "tensorboard":
|
|
# from utilities.utilities import build_tensorboard
|
|
# tensorboard_writer = build_tensorboard(self.config["project_summary"])
|
|
# self.logger = tensorboard_writer
|
|
# elif self.config["logger"] == "wandb":
|
|
# import wandb
|
|
# wandb.init(project="Simswap_HQ", entity="xhchen", notes="512",
|
|
# tags=[self.config["tag"]], name=self.config["version"])
|
|
|
|
# wandb.config = {
|
|
# "total_step": self.config["total_step"],
|
|
# "batch_size": self.config["batch_size"]
|
|
# }
|
|
# self.logger = wandb
|
|
|
|
# TODO modify this function to build your models
|
|
def __init_framework__(self):
|
|
'''
|
|
This function is designed to define the framework,
|
|
and print the framework information into the log file
|
|
'''
|
|
#===============build models================#
|
|
pass
|
|
|
|
# TODO modify this function to configurate the optimizer of your pipeline
|
|
def __setup_optimizers__(self):
|
|
pass
|
|
|
|
|
|
# TODO modify this function to evaluate your model
|
|
# Evaluate the checkpoint
|
|
def __evaluation__(self,
|
|
step = 0,
|
|
**kwargs
|
|
):
|
|
pass
|
|
|
|
|
|
def __create_dataloader__(self,
|
|
config,
|
|
cur_gpu
|
|
):
|
|
# Data loader
|
|
#============build train dataloader==============#
|
|
# TODO to modify the key: "your_train_dataset" to get your train dataset path
|
|
dataset = config["dataset_paths"][config["dataset_name"]]
|
|
#================================================#
|
|
print("Prepare the train dataloader...")
|
|
dlModulename = config["dataloader"]
|
|
package = __import__("data_tools.data_loader_%s"%dlModulename, fromlist=True)
|
|
dataloaderClass = getattr(package, 'GetLoader')
|
|
dataloader_class= dataloaderClass
|
|
dataloader = dataloader_class(dataset,
|
|
cur_gpu,
|
|
config["batch_size"],
|
|
**config["dataset_params"])
|
|
|
|
return dataloader
|
|
|
|
|
|
def train(self):
|
|
#===============build framework================#
|
|
self.init_framework()
|
|
|
|
#===============build optimizer================#
|
|
# Optimizer
|
|
# TODO replace below lines to build your optimizer
|
|
print("build the optimizer...")
|
|
self.__setup_optimizers__()
|
|
|
|
# set the start point for training loop
|
|
if self.config["phase"] == "finetune":
|
|
self.start = self.config["checkpoint_step"]
|
|
else:
|
|
self.start = 0
|
|
|
|
# Start time
|
|
import datetime
|
|
print("Start to train at %s"%(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))
|
|
|
|
from utilities.logo_class import logo_class
|
|
logo_class.print_start_training() |