104 lines
2.7 KiB
Lua
104 lines
2.7 KiB
Lua
--[[
|
|
This data loader is a modified version of the one from dcgan.torch
|
|
(see https://github.com/soumith/dcgan.torch/blob/master/data/data.lua).
|
|
|
|
Copyright (c) 2016, Deepak Pathak [See LICENSE file for details]
|
|
]]--
|
|
|
|
local Threads = require 'threads'
|
|
Threads.serialization('threads.sharedserialize')
|
|
|
|
local data = {}
|
|
|
|
local result = {}
|
|
local unpack = unpack and unpack or table.unpack
|
|
|
|
function data.new(n, opt_)
|
|
opt_ = opt_ or {}
|
|
local self = {}
|
|
for k,v in pairs(data) do
|
|
self[k] = v
|
|
end
|
|
|
|
local donkey_file = 'donkey_folder.lua'
|
|
-- print('n..' .. n)
|
|
if n > 0 then
|
|
local options = opt_
|
|
self.threads = Threads(n,
|
|
function() require 'torch' end,
|
|
function(idx)
|
|
opt = options
|
|
tid = idx
|
|
local seed = (opt.manualSeed and opt.manualSeed or 0) + idx
|
|
torch.manualSeed(seed)
|
|
torch.setnumthreads(1)
|
|
print(string.format('Starting donkey with id: %d seed: %d', tid, seed))
|
|
assert(options, 'options not found')
|
|
assert(opt, 'opt not given')
|
|
print(opt)
|
|
paths.dofile(donkey_file)
|
|
end
|
|
|
|
)
|
|
else
|
|
if donkey_file then paths.dofile(donkey_file) end
|
|
-- print('empty threads')
|
|
self.threads = {}
|
|
function self.threads:addjob(f1, f2) f2(f1()) end
|
|
function self.threads:dojob() end
|
|
function self.threads:synchronize() end
|
|
end
|
|
|
|
local nSamples = 0
|
|
self.threads:addjob(function() return trainLoader:size() end,
|
|
function(c) nSamples = c end)
|
|
self.threads:synchronize()
|
|
self._size = nSamples
|
|
|
|
for i = 1, n do
|
|
self.threads:addjob(self._getFromThreads,
|
|
self._pushResult)
|
|
end
|
|
-- print(self.threads)
|
|
return self
|
|
end
|
|
|
|
function data._getFromThreads()
|
|
assert(opt.batchSize, 'opt.batchSize not found')
|
|
return trainLoader:sample(opt.batchSize)
|
|
end
|
|
|
|
function data._pushResult(...)
|
|
local res = {...}
|
|
if res == nil then
|
|
self.threads:synchronize()
|
|
end
|
|
result[1] = res
|
|
end
|
|
|
|
|
|
|
|
function data:getBatch()
|
|
-- queue another job
|
|
self.threads:addjob(self._getFromThreads, self._pushResult)
|
|
self.threads:dojob()
|
|
local res = result[1]
|
|
|
|
img_data = res[1]
|
|
img_paths = res[3]
|
|
|
|
result[1] = nil
|
|
if torch.type(img_data) == 'table' then
|
|
img_data = unpack(img_data)
|
|
end
|
|
|
|
|
|
return img_data, img_paths
|
|
end
|
|
|
|
function data:size()
|
|
return self._size
|
|
end
|
|
|
|
return data
|