80 lines
2.1 KiB
Lua
80 lines
2.1 KiB
Lua
local class = require 'class'
|
|
require 'models.base_model'
|
|
require 'models.architectures'
|
|
require 'util.image_pool'
|
|
|
|
util = paths.dofile('../util/util.lua')
|
|
OneDirectionTestModel = class('OneDirectionTestModel', 'BaseModel')
|
|
|
|
function OneDirectionTestModel:__init(conf)
|
|
BaseModel.__init(self, conf)
|
|
conf = conf or {}
|
|
end
|
|
|
|
function OneDirectionTestModel:model_name()
|
|
return 'OneDirectionTestModel'
|
|
end
|
|
|
|
-- Defines models and networks
|
|
function OneDirectionTestModel:Initialize(opt)
|
|
-- define tensors
|
|
self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
|
|
|
|
-- load/define models
|
|
self.netG_A = util.load_test_model('G', opt)
|
|
|
|
-- setup optnet to save a bit of memory
|
|
if opt.use_optnet == 1 then
|
|
local optnet = require 'optnet'
|
|
local sample_input = torch.randn(1, opt.input_nc, 2, 2)
|
|
optnet.optimizeMemory(self.netG_A, sample_input, {inplace=true, reuseBuffers=true})
|
|
end
|
|
|
|
self:RefreshParameters()
|
|
|
|
print('---------- # Learnable Parameters --------------')
|
|
print(('G_A = %d'):format(self.parametersG_A:size(1)))
|
|
print('------------------------------------------------')
|
|
end
|
|
|
|
-- Runs the forward pass of the network and
|
|
-- saves the result to member variables of the class
|
|
function OneDirectionTestModel:Forward(input, opt)
|
|
if opt.which_direction == 'BtoA' then
|
|
input.real_A = input.real_B:clone()
|
|
end
|
|
|
|
self.real_A = input.real_A:clone()
|
|
if opt.gpu > 0 then
|
|
self.real_A = self.real_A:cuda()
|
|
end
|
|
|
|
self.fake_B = self.netG_A:forward(self.real_A):clone()
|
|
end
|
|
|
|
function OneDirectionTestModel:RefreshParameters()
|
|
self.parametersG_A, self.gradparametersG_A = nil, nil
|
|
self.parametersG_A, self.gradparametersG_A = self.netG_A:getParameters()
|
|
end
|
|
|
|
|
|
local function MakeIm3(im)
|
|
if im:size(2) == 1 then
|
|
local im3 = torch.repeatTensor(im, 1,3,1,1)
|
|
return im3
|
|
else
|
|
return im
|
|
end
|
|
end
|
|
|
|
function OneDirectionTestModel:GetCurrentVisuals(opt, size)
|
|
if not size then
|
|
size = opt.display_winsize
|
|
end
|
|
|
|
local visuals = {}
|
|
table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'})
|
|
table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'})
|
|
return visuals
|
|
end
|