fix the GPU0 problem

This commit is contained in:
chenxuanhong
2022-02-15 01:40:11 +08:00
parent a148db410c
commit 4a6197a685
10 changed files with 839 additions and 23 deletions
+7 -6
View File
@@ -5,7 +5,7 @@
# Created Date: Tuesday April 28th 2020
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Sunday, 13th February 2022 2:16:50 am
# Last Modified: Monday, 14th February 2022 11:54:02 pm
# Modified By: Chen Xuanhong
# Copyright (c) 2020 Shanghai Jiao Tong University
#############################################################
@@ -31,24 +31,24 @@ def getParameters():
parser = argparse.ArgumentParser()
# general settings
parser.add_argument('-v', '--version', type=str, default='invoup2',
parser.add_argument('-v', '--version', type=str, default='depthwise',
help="version name for train, test, finetune")
parser.add_argument('-t', '--tag', type=str, default='invo_upsample',
parser.add_argument('-t', '--tag', type=str, default='depthwise_conv',
help="tag for current experiment")
parser.add_argument('-p', '--phase', type=str, default="train",
choices=['train', 'finetune','debug'],
help="The phase of current project")
parser.add_argument('-c', '--gpus', type=int, nargs='+', default=[0,1]) # <0 if it is set as -1, program will use CPU
parser.add_argument('-c', '--gpus', type=int, nargs='+', default=[0,1,2,3]) # <0 if it is set as -1, program will use CPU
parser.add_argument('-e', '--ckpt', type=int, default=74,
help="checkpoint epoch for test phase or finetune phase")
# training
parser.add_argument('--experiment_description', type=str,
default="generator网络前向部分残差的赋值错误,现纠正,重新训练网络")
default="使用depthwise卷积作为基础算子测试性能")
parser.add_argument('--train_yaml', type=str, default="train_Invoup.yaml")
parser.add_argument('--train_yaml', type=str, default="train_Depthwise.yaml")
# system logger
parser.add_argument('--logger', type=str,
@@ -141,6 +141,7 @@ def main():
config = getParameters()
# speed up the program
cudnn.benchmark = True
cudnn.enabled = True
from utilities.logo_class import logo_class
logo_class.print_group_logo()