Files
disrupting-deepfakes/cyclegan/data/aligned_data_loader.lua
T
Nataniel Ruiz a3fe19383a cyclegan
2019-12-25 17:21:03 -04:00

45 lines
1.4 KiB
Lua

--------------------------------------------------------------------------------
-- Subclass of BaseDataLoader that provides data from two datasets.
-- The samples from the datasets are aligned
-- The datasets are of the same size
--------------------------------------------------------------------------------
require 'data.base_data_loader'
local class = require 'class'
data_util = paths.dofile('data_util.lua')
AlignedDataLoader = class('AlignedDataLoader', 'BaseDataLoader')
function AlignedDataLoader:__init(conf)
BaseDataLoader.__init(self, conf)
conf = conf or {}
end
function AlignedDataLoader:name()
return 'AlignedDataLoader'
end
function AlignedDataLoader:Initialize(opt)
opt.align_data = 1
self.idx_A = {1, opt.input_nc}
self.idx_B = {opt.input_nc+1, opt.input_nc+opt.output_nc}
local nc = 3--opt.input_nc + opt.output_nc
self.data = data_util.load_dataset('', opt, nc)
end
-- actually fetches the data
-- |return|: a table of two tables, each corresponding to
-- the batch for dataset A and dataset B
function AlignedDataLoader:LoadBatchForAllDatasets()
local batch_data, path = self.data:getBatch()
local batchA = batch_data[{ {}, self.idx_A, {}, {} }]
local batchB = batch_data[{ {}, self.idx_B, {}, {} }]
return batchA, batchB, path, path
end
-- returns the size of each dataset
function AlignedDataLoader:size(dataset)
return self.data:size()
end