Files
disrupting-deepfakes/cyclegan/models/one_direction_test_model.lua
T
Nataniel Ruiz a3fe19383a cyclegan
2019-12-25 17:21:03 -04:00

80 lines
2.1 KiB
Lua

local class = require 'class'
require 'models.base_model'
require 'models.architectures'
require 'util.image_pool'
util = paths.dofile('../util/util.lua')
OneDirectionTestModel = class('OneDirectionTestModel', 'BaseModel')
function OneDirectionTestModel:__init(conf)
BaseModel.__init(self, conf)
conf = conf or {}
end
function OneDirectionTestModel:model_name()
return 'OneDirectionTestModel'
end
-- Defines models and networks
function OneDirectionTestModel:Initialize(opt)
-- define tensors
self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
-- load/define models
self.netG_A = util.load_test_model('G', opt)
-- setup optnet to save a bit of memory
if opt.use_optnet == 1 then
local optnet = require 'optnet'
local sample_input = torch.randn(1, opt.input_nc, 2, 2)
optnet.optimizeMemory(self.netG_A, sample_input, {inplace=true, reuseBuffers=true})
end
self:RefreshParameters()
print('---------- # Learnable Parameters --------------')
print(('G_A = %d'):format(self.parametersG_A:size(1)))
print('------------------------------------------------')
end
-- Runs the forward pass of the network and
-- saves the result to member variables of the class
function OneDirectionTestModel:Forward(input, opt)
if opt.which_direction == 'BtoA' then
input.real_A = input.real_B:clone()
end
self.real_A = input.real_A:clone()
if opt.gpu > 0 then
self.real_A = self.real_A:cuda()
end
self.fake_B = self.netG_A:forward(self.real_A):clone()
end
function OneDirectionTestModel:RefreshParameters()
self.parametersG_A, self.gradparametersG_A = nil, nil
self.parametersG_A, self.gradparametersG_A = self.netG_A:getParameters()
end
local function MakeIm3(im)
if im:size(2) == 1 then
local im3 = torch.repeatTensor(im, 1,3,1,1)
return im3
else
return im
end
end
function OneDirectionTestModel:GetCurrentVisuals(opt, size)
if not size then
size = opt.display_winsize
end
local visuals = {}
table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'})
table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'})
return visuals
end