Files
disrupting-deepfakes/cyclegan/data/data.lua
T
Nataniel Ruiz a3fe19383a cyclegan
2019-12-25 17:21:03 -04:00

104 lines
2.7 KiB
Lua

--[[
This data loader is a modified version of the one from dcgan.torch
(see https://github.com/soumith/dcgan.torch/blob/master/data/data.lua).
Copyright (c) 2016, Deepak Pathak [See LICENSE file for details]
]]--
local Threads = require 'threads'
Threads.serialization('threads.sharedserialize')
local data = {}
local result = {}
local unpack = unpack and unpack or table.unpack
function data.new(n, opt_)
opt_ = opt_ or {}
local self = {}
for k,v in pairs(data) do
self[k] = v
end
local donkey_file = 'donkey_folder.lua'
-- print('n..' .. n)
if n > 0 then
local options = opt_
self.threads = Threads(n,
function() require 'torch' end,
function(idx)
opt = options
tid = idx
local seed = (opt.manualSeed and opt.manualSeed or 0) + idx
torch.manualSeed(seed)
torch.setnumthreads(1)
print(string.format('Starting donkey with id: %d seed: %d', tid, seed))
assert(options, 'options not found')
assert(opt, 'opt not given')
print(opt)
paths.dofile(donkey_file)
end
)
else
if donkey_file then paths.dofile(donkey_file) end
-- print('empty threads')
self.threads = {}
function self.threads:addjob(f1, f2) f2(f1()) end
function self.threads:dojob() end
function self.threads:synchronize() end
end
local nSamples = 0
self.threads:addjob(function() return trainLoader:size() end,
function(c) nSamples = c end)
self.threads:synchronize()
self._size = nSamples
for i = 1, n do
self.threads:addjob(self._getFromThreads,
self._pushResult)
end
-- print(self.threads)
return self
end
function data._getFromThreads()
assert(opt.batchSize, 'opt.batchSize not found')
return trainLoader:sample(opt.batchSize)
end
function data._pushResult(...)
local res = {...}
if res == nil then
self.threads:synchronize()
end
result[1] = res
end
function data:getBatch()
-- queue another job
self.threads:addjob(self._getFromThreads, self._pushResult)
self.threads:dojob()
local res = result[1]
img_data = res[1]
img_paths = res[3]
result[1] = nil
if torch.type(img_data) == 'table' then
img_data = unpack(img_data)
end
return img_data, img_paths
end
function data:size()
return self._size
end
return data