Files
disrupting-deepfakes/cyclegan/test.lua
T
Nataniel Ruiz a3fe19383a cyclegan
2019-12-25 17:21:03 -04:00

143 lines
3.7 KiB
Lua

-- usage: DATA_ROOT=/path/to/data/ name=expt1 which_direction=BtoA th test.lua
--
-- code derived from https://github.com/soumith/dcgan.torch and https://github.com/phillipi/pix2pix
require 'image'
require 'nn'
require 'nngraph'
require 'models.architectures'
util = paths.dofile('util/util.lua')
options = require 'options'
opt = options.parse_options('test')
-- initialize torch GPU/CPU mode
if opt.gpu > 0 then
require 'cutorch'
require 'cunn'
cutorch.setDevice(opt.gpu)
print ("GPU Mode")
torch.setdefaulttensortype('torch.CudaTensor')
else
torch.setdefaulttensortype('torch.FloatTensor')
print ("CPU Mode")
end
-- setup visualization
visualizer = require 'util/visualizer'
function TableConcat(t1,t2)
for i=1,#t2 do
t1[#t1+1] = t2[i]
end
return t1
end
-- load data
local data_loader = nil
if opt.align_data > 0 then
require 'data.aligned_data_loader'
data_loader = AlignedDataLoader()
else
require 'data.unaligned_data_loader'
data_loader = UnalignedDataLoader()
end
print( "DataLoader " .. data_loader:name() .. " was created.")
data_loader:Initialize(opt)
if opt.how_many == 'all' then
opt.how_many = data_loader:size()
end
opt.how_many = math.min(opt.how_many, data_loader:size())
-- set batch/instance normalization
set_normalization(opt.norm)
-- load model
opt.continue_train = 1
-- define model
if opt.model == 'cycle_gan' then
require 'models.cycle_gan_model'
model = CycleGANModel()
elseif opt.model == 'one_direction_test' then
require 'models.one_direction_test_model'
model = OneDirectionTestModel()
elseif opt.model == 'pix2pix' then
require 'models.pix2pix_model'
model = Pix2PixModel()
elseif opt.model == 'bigan' then
require 'models.bigan_model'
model = BiGANModel()
elseif opt.model == 'content_gan' then
require 'models.content_gan_model'
model = ContentGANModel()
else
error('Please specify a correct model')
end
model:Initialize(opt)
local pathsA = {} -- paths to images A tested on
local pathsB = {} -- paths to images B tested on
local web_dir = paths.concat(opt.results_dir, opt.name .. '/' .. opt.which_epoch .. '_' .. opt.phase)
paths.mkdir(web_dir)
local image_dir = paths.concat(web_dir, 'images')
paths.mkdir(image_dir)
s1 = opt.fineSize
s2 = opt.fineSize / opt.aspect_ratio
visuals = {}
for n = 1, math.floor(opt.how_many) do
print('processing batch ' .. n)
local cur_dataA, cur_dataB, cur_pathsA, cur_pathsB = data_loader:GetNextBatch()
cur_pathsA = util.basename_batch(cur_pathsA)
cur_pathsB = util.basename_batch(cur_pathsB)
print('pathsA', cur_pathsA)
print('pathsB', cur_PathsB)
model:Forward({real_A=cur_dataA, real_B=cur_dataB}, opt)
visuals = model:GetCurrentVisuals(opt, opt.fineSize)
for i,visual in ipairs(visuals) do
if opt.resize_or_crop == 'scale_width' or opt.resize_or_crop == 'scale_height' then
s1 = nil
s2 = nil
end
visualizer.save_images(visual.img, paths.concat(image_dir, visual.label), {string.gsub(cur_pathsA[1],'.jpg','.png')}, s1, s2)
end
print('Saved images to: ', image_dir)
pathsA = TableConcat(pathsA, cur_pathsA)
pathsB = TableConcat(pathsB, cur_pathsB)
end
labels = {}
for i,visual in ipairs(visuals) do
table.insert(labels, visual.label)
end
-- make webpage
io.output(paths.concat(web_dir, 'index.html'))
io.write('<table style="text-align:center;">')
io.write('<tr><td> Image </td>')
for i = 1, #labels do
io.write('<td>' .. labels[i] .. '</td>')
end
io.write('</tr>')
for n = 1,math.floor(opt.how_many) do
io.write('<tr>')
io.write('<td>' .. tostring(n) .. '</td>')
for j = 1, #labels do
label = labels[j]
io.write('<td><img src="./images/' .. label .. '/' .. string.gsub(pathsA[n],'.jpg','.png') .. '"/></td>')
end
io.write('</tr>')
end
io.write('</table>')