update dataloader
This commit is contained in:
@@ -5,7 +5,7 @@
|
|||||||
# Created Date: Sunday January 16th 2022
|
# Created Date: Sunday January 16th 2022
|
||||||
# Author: Chen Xuanhong
|
# Author: Chen Xuanhong
|
||||||
# Email: chenxuanhongzju@outlook.com
|
# Email: chenxuanhongzju@outlook.com
|
||||||
# Last Modified: Monday, 14th February 2022 11:35:32 pm
|
# Last Modified: Tuesday, 15th February 2022 1:54:50 am
|
||||||
# Modified By: Chen Xuanhong
|
# Modified By: Chen Xuanhong
|
||||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||||
#############################################################
|
#############################################################
|
||||||
@@ -119,8 +119,11 @@ class Generator(nn.Module):
|
|||||||
|
|
||||||
# self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(3, 64, kernel_size=7, padding=0, bias=False),
|
# self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(3, 64, kernel_size=7, padding=0, bias=False),
|
||||||
# nn.BatchNorm2d(64), activation)
|
# nn.BatchNorm2d(64), activation)
|
||||||
self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
self.first_layer = nn.Sequential(nn.ReflectionPad2d(1),
|
||||||
|
nn.Conv2d(3, 64, kernel_size=3, padding=0, bias=False),
|
||||||
nn.BatchNorm2d(64), activation)
|
nn.BatchNorm2d(64), activation)
|
||||||
|
# self.first_layer = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
|
||||||
|
# nn.BatchNorm2d(64), activation)
|
||||||
### downsample
|
### downsample
|
||||||
self.down1 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, groups=64, padding=1, stride=2),
|
self.down1 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, groups=64, padding=1, stride=2),
|
||||||
nn.Conv2d(64, 128, kernel_size=1, bias=False),
|
nn.Conv2d(64, 128, kernel_size=1, bias=False),
|
||||||
@@ -164,7 +167,9 @@ class Generator(nn.Module):
|
|||||||
DeConv(128,64,3),
|
DeConv(128,64,3),
|
||||||
nn.BatchNorm2d(64), activation
|
nn.BatchNorm2d(64), activation
|
||||||
)
|
)
|
||||||
self.last_layer = nn.Sequential(nn.Conv2d(64, 3, kernel_size=3, padding=1))
|
# self.last_layer = nn.Sequential(nn.Conv2d(64, 3, kernel_size=3, padding=1))
|
||||||
|
self.last_layer = nn.Sequential(nn.ReflectionPad2d(1),
|
||||||
|
nn.Conv2d(64, 3, kernel_size=3, padding=0))
|
||||||
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
# self.last_layer = nn.Sequential(nn.ReflectionPad2d(3),
|
||||||
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
# nn.Conv2d(64, 3, kernel_size=7, padding=0))
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
# Created Date: Sunday February 6th 2022
|
# Created Date: Sunday February 6th 2022
|
||||||
# Author: Chen Xuanhong
|
# Author: Chen Xuanhong
|
||||||
# Email: chenxuanhongzju@outlook.com
|
# Email: chenxuanhongzju@outlook.com
|
||||||
# Last Modified: Tuesday, 15th February 2022 1:35:41 am
|
# Last Modified: Tuesday, 15th February 2022 1:50:19 am
|
||||||
# Modified By: Chen Xuanhong
|
# Modified By: Chen Xuanhong
|
||||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||||
#############################################################
|
#############################################################
|
||||||
@@ -56,7 +56,7 @@ class InfiniteSampler(torch.utils.data.Sampler):
|
|||||||
|
|
||||||
class data_prefetcher():
|
class data_prefetcher():
|
||||||
def __init__(self, loader, cur_gpu):
|
def __init__(self, loader, cur_gpu):
|
||||||
torch.cuda.set_device(cur_gpu)
|
torch.cuda.set_device(cur_gpu) # must add this line to avoid excessive use of GPU 0 by the prefetcher
|
||||||
self.loader = loader
|
self.loader = loader
|
||||||
self.dataiter = iter(loader)
|
self.dataiter = iter(loader)
|
||||||
self.stream = torch.cuda.Stream(device=cur_gpu)
|
self.stream = torch.cuda.Stream(device=cur_gpu)
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ dataloader: VGGFace2HQ_multigpu
|
|||||||
dataset_name: vggface2_hq
|
dataset_name: vggface2_hq
|
||||||
dataset_params:
|
dataset_params:
|
||||||
random_seed: 1234
|
random_seed: 1234
|
||||||
dataloader_workers: 4
|
dataloader_workers: 6
|
||||||
|
|
||||||
eval_dataloader: DIV2K_hdf5
|
eval_dataloader: DIV2K_hdf5
|
||||||
eval_dataset_name: DF2K_H5_Eval
|
eval_dataset_name: DF2K_H5_Eval
|
||||||
|
|||||||
Reference in New Issue
Block a user