Files
disrupting-deepfakes/cyclegan/data/donkey_folder.lua
T
Nataniel Ruiz a3fe19383a cyclegan
2019-12-25 17:21:03 -04:00

201 lines
6.0 KiB
Lua

--[[
This data loader is a modified version of the one from dcgan.torch
(see https://github.com/soumith/dcgan.torch/blob/master/data/donkey_folder.lua).
Copyright (c) 2016, Deepak Pathak [See LICENSE file for details]
Copyright (c) 2015-present, Facebook, Inc.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree. An additional grant
of patent rights can be found in the PATENTS file in the same directory.
]]--
require 'image'
paths.dofile('dataset.lua')
-- This file contains the data-loading logic and details.
-- It is run by each data-loader thread.
------------------------------------------
-------- COMMON CACHES and PATHS
-- Check for existence of opt.data
if opt.DATA_ROOT then
opt.data = paths.concat(opt.DATA_ROOT, opt.phase)
else
print(os.getenv('DATA_ROOT'))
opt.data = paths.concat(os.getenv('DATA_ROOT'), opt.phase)
end
if not paths.dirp(opt.data) then
error('Did not find directory: ' .. opt.data)
end
-- a cache file of the training metadata (if doesnt exist, will be created)
local cache_prefix = opt.data:gsub('/', '_')
os.execute(('mkdir -p %s'):format(opt.cache_dir))
local trainCache = paths.concat(opt.cache_dir, cache_prefix .. '_trainCache.t7')
--------------------------------------------------------------------------------------------
local input_nc = opt.nc -- input channels
local loadSize = {input_nc, opt.loadSize}
local sampleSize = {input_nc, opt.fineSize}
local function loadImage(path)
local input = image.load(path, 3, 'float')
local h = input:size(2)
local w = input:size(3)
local imA = image.crop(input, 0, 0, w/2, h)
imA = image.scale(imA, loadSize[2], loadSize[2])
local imB = image.crop(input, w/2, 0, w, h)
imB = image.scale(imB, loadSize[2], loadSize[2])
local perm = torch.LongTensor{3, 2, 1}
imA = imA:index(1, perm)
imA = imA:mul(2):add(-1)
imB = imB:index(1, perm)
imB = imB:mul(2):add(-1)
assert(imA:max()<=1,"A: badly scaled inputs")
assert(imA:min()>=-1,"A: badly scaled inputs")
assert(imB:max()<=1,"B: badly scaled inputs")
assert(imB:min()>=-1,"B: badly scaled inputs")
local oW = sampleSize[2]
local oH = sampleSize[2]
local iH = imA:size(2)
local iW = imA:size(3)
if iH~=oH then
h1 = math.ceil(torch.uniform(1e-2, iH-oH))
end
if iW~=oW then
w1 = math.ceil(torch.uniform(1e-2, iW-oW))
end
if iH ~= oH or iW ~= oW then
imA = image.crop(imA, w1, h1, w1 + oW, h1 + oH)
imB = image.crop(imB, w1, h1, w1 + oW, h1 + oH)
end
if opt.flip == 1 and torch.uniform() > 0.5 then
imA = image.hflip(imA)
imB = image.hflip(imB)
end
local concatenated = torch.cat(imA,imB,1)
return concatenated
end
local function loadSingleImage(path)
local im = image.load(path, input_nc, 'float')
if opt.resize_or_crop == 'resize_and_crop' then
im = image.scale(im, loadSize[2], loadSize[2])
end
if input_nc == 3 then
local perm = torch.LongTensor{3, 2, 1}
im = im:index(1, perm)--:mul(256.0): brg, rgb
im = im:mul(2):add(-1)
end
assert(im:max()<=1,"A: badly scaled inputs")
assert(im:min()>=-1,"A: badly scaled inputs")
local oW = sampleSize[2]
local oH = sampleSize[2]
local iH = im:size(2)
local iW = im:size(3)
if (opt.resize_or_crop == 'resize_and_crop' ) then
local h1, w1 = 0, 0
if iH~=oH then
h1 = math.ceil(torch.uniform(1e-2, iH-oH))
end
if iW~=oW then
w1 = math.ceil(torch.uniform(1e-2, iW-oW))
end
if iH ~= oH or iW ~= oW then
im = image.crop(im, w1, h1, w1 + oW, h1 + oH)
end
elseif (opt.resize_or_crop == 'combined') then
local sH = math.min(math.ceil(oH * torch.uniform(1+1e-2, 2.0-1e-2)), iH-1e-2)
local sW = math.min(math.ceil(oW * torch.uniform(1+1e-2, 2.0-1e-2)), iW-1e-2)
local h1 = math.ceil(torch.uniform(1e-2, iH-sH))
local w1 = math.ceil(torch.uniform(1e-2, iW-sW))
im = image.crop(im, w1, h1, w1 + sW, h1 + sH)
im = image.scale(im, oW, oH)
elseif (opt.resize_or_crop == 'crop') then
local w = math.min(math.min(oH, iH),iW)
w = math.floor(w/4)*4
local x = math.floor(torch.uniform(0, iW - w))
local y = math.floor(torch.uniform(0, iH - w))
im = image.crop(im, x, y, x+w, y+w)
elseif (opt.resize_or_crop == 'scale_width') then
w = oW
h = torch.floor(iH * oW/iW)
im = image.scale(im, w, h)
elseif (opt.resize_or_crop == 'scale_height') then
h = oH
w = torch.floor(iW * oH / iH)
im = image.scale(im, w, h)
end
if opt.flip == 1 and torch.uniform() > 0.5 then
im = image.hflip(im)
end
return im
end
-- channel-wise mean and std. Calculate or load them from disk later in the script.
local mean,std
--------------------------------------------------------------------------------
-- Hooks that are used for each image that is loaded
-- function to load the image, jitter it appropriately (random crops etc.)
local trainHook_singleimage = function(self, path)
collectgarbage()
-- print('load single image')
local im = loadSingleImage(path)
return im
end
-- function that loads images that have juxtaposition
-- of two images from two domains
local trainHook_doubleimage = function(self, path)
-- print('load double image')
collectgarbage()
local im = loadImage(path)
return im
end
if opt.align_data > 0 then
sample_nc = input_nc*2
trainHook = trainHook_doubleimage
else
sample_nc = input_nc
trainHook = trainHook_singleimage
end
trainLoader = dataLoader{
paths = {opt.data},
loadSize = {input_nc, loadSize[2], loadSize[2]},
sampleSize = {sample_nc, sampleSize[2], sampleSize[2]},
split = 100,
serial_batches = opt.serial_batches,
verbose = true
}
trainLoader.sampleHookTrain = trainHook
collectgarbage()
-- do some sanity checks on trainLoader
do
local class = trainLoader.imageClass
local nClasses = #trainLoader.classes
assert(class:max() <= nClasses, "class logic has error")
assert(class:min() >= 1, "class logic has error")
end