45 lines
1.4 KiB
Lua
45 lines
1.4 KiB
Lua
--------------------------------------------------------------------------------
|
|
-- Subclass of BaseDataLoader that provides data from two datasets.
|
|
-- The samples from the datasets are aligned
|
|
-- The datasets are of the same size
|
|
--------------------------------------------------------------------------------
|
|
require 'data.base_data_loader'
|
|
|
|
local class = require 'class'
|
|
data_util = paths.dofile('data_util.lua')
|
|
|
|
AlignedDataLoader = class('AlignedDataLoader', 'BaseDataLoader')
|
|
|
|
function AlignedDataLoader:__init(conf)
|
|
BaseDataLoader.__init(self, conf)
|
|
conf = conf or {}
|
|
end
|
|
|
|
function AlignedDataLoader:name()
|
|
return 'AlignedDataLoader'
|
|
end
|
|
|
|
function AlignedDataLoader:Initialize(opt)
|
|
opt.align_data = 1
|
|
self.idx_A = {1, opt.input_nc}
|
|
self.idx_B = {opt.input_nc+1, opt.input_nc+opt.output_nc}
|
|
local nc = 3--opt.input_nc + opt.output_nc
|
|
self.data = data_util.load_dataset('', opt, nc)
|
|
end
|
|
|
|
-- actually fetches the data
|
|
-- |return|: a table of two tables, each corresponding to
|
|
-- the batch for dataset A and dataset B
|
|
function AlignedDataLoader:LoadBatchForAllDatasets()
|
|
local batch_data, path = self.data:getBatch()
|
|
local batchA = batch_data[{ {}, self.idx_A, {}, {} }]
|
|
local batchB = batch_data[{ {}, self.idx_B, {}, {} }]
|
|
|
|
return batchA, batchB, path, path
|
|
end
|
|
|
|
-- returns the size of each dataset
|
|
function AlignedDataLoader:size(dataset)
|
|
return self.data:size()
|
|
end
|