202 lines
7.2 KiB
Lua
202 lines
7.2 KiB
Lua
local class = require 'class'
|
|
require 'models.base_model'
|
|
require 'models.architectures'
|
|
require 'util.image_pool'
|
|
util = paths.dofile('../util/util.lua')
|
|
content = paths.dofile('../util/content_loss.lua')
|
|
|
|
ContentGANModel = class('ContentGANModel', 'BaseModel')
|
|
|
|
function ContentGANModel:__init(conf)
|
|
BaseModel.__init(self, conf)
|
|
conf = conf or {}
|
|
end
|
|
|
|
function ContentGANModel:model_name()
|
|
return 'ContentGANModel'
|
|
end
|
|
|
|
function ContentGANModel:InitializeStates()
|
|
local optimState = {learningRate=opt.lr, beta1=opt.beta1,}
|
|
return optimState
|
|
end
|
|
-- Defines models and networks
|
|
function ContentGANModel:Initialize(opt)
|
|
if opt.test == 0 then
|
|
self.fakePool = ImagePool(opt.pool_size)
|
|
end
|
|
-- define tensors
|
|
self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
|
|
self.fake_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
|
|
self.real_B = self.fake_B:clone() --torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
|
|
|
|
-- load/define models
|
|
self.criterionGAN = nn.MSECriterion()
|
|
self.criterionContent = nn.AbsCriterion()
|
|
self.contentFunc = content.defineContent(opt.content_loss, opt.layer_name)
|
|
self.netG, self.netD = nil, nil
|
|
if opt.continue_train == 1 then
|
|
if opt.which_epoch then -- which_epoch option exists in test mode
|
|
self.netG = util.load_test_model('G_A', opt)
|
|
self.netD = util.load_test_model('D_A', opt)
|
|
else
|
|
self.netG = util.load_model('G_A', opt)
|
|
self.netD = util.load_model('D_A', opt)
|
|
end
|
|
else
|
|
self.netG = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG)
|
|
print('netG...', self.netG)
|
|
self.netD = defineD(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, false)
|
|
print('netD...', self.netD)
|
|
end
|
|
-- define real/fake labels
|
|
netD_output_size = self.netD:forward(self.real_A):size()
|
|
self.fake_label = torch.Tensor(netD_output_size):fill(0.0)
|
|
self.real_label = torch.Tensor(netD_output_size):fill(1.0) -- no soft smoothing
|
|
self.optimStateD = self:InitializeStates()
|
|
self.optimStateG = self:InitializeStates()
|
|
self:RefreshParameters()
|
|
print('---------- # Learnable Parameters --------------')
|
|
print(('G = %d'):format(self.parametersG:size(1)))
|
|
print(('D = %d'):format(self.parametersD:size(1)))
|
|
print('------------------------------------------------')
|
|
-- os.exit()
|
|
end
|
|
|
|
-- Runs the forward pass of the network and
|
|
-- saves the result to member variables of the class
|
|
function ContentGANModel:Forward(input, opt)
|
|
if opt.which_direction == 'BtoA' then
|
|
local temp = input.real_A
|
|
input.real_A = input.real_B
|
|
input.real_B = temp
|
|
end
|
|
|
|
self.real_A:copy(input.real_A)
|
|
self.real_B:copy(input.real_B)
|
|
self.fake_B = self.netG:forward(self.real_A):clone()
|
|
-- output = {self.fake_B}
|
|
output = {}
|
|
-- if opt.test == 1 then
|
|
|
|
-- end
|
|
return output
|
|
end
|
|
|
|
-- create closure to evaluate f(X) and df/dX of discriminator
|
|
function ContentGANModel:fDx_basic(x, gradParams, netD, netG,
|
|
real_target, fake_target, opt)
|
|
util.BiasZero(netD)
|
|
util.BiasZero(netG)
|
|
gradParams:zero()
|
|
|
|
local errD_real, errD_rec, errD_fake, errD = 0, 0, 0, 0
|
|
-- Real log(D_A(B))
|
|
local output = netD:forward(real_target)
|
|
errD_real = self.criterionGAN:forward(output, self.real_label)
|
|
df_do = self.criterionGAN:backward(output, self.real_label)
|
|
netD:backward(real_target, df_do)
|
|
|
|
-- Fake + log(1 - D_A(G_A(A)))
|
|
output = netD:forward(fake_target)
|
|
errD_fake = self.criterionGAN:forward(output, self.fake_label)
|
|
df_do = self.criterionGAN:backward(output, self.fake_label)
|
|
netD:backward(fake_target, df_do)
|
|
errD = (errD_real + errD_fake) / 2.0
|
|
-- print('errD', errD
|
|
return errD, gradParams
|
|
end
|
|
|
|
|
|
function ContentGANModel:fDx(x, opt)
|
|
fake_B = self.fakePool:Query(self.fake_B)
|
|
self.errD, gradParams = self:fDx_basic(x, self.gradparametersD, self.netD, self.netG,
|
|
self.real_B, fake_B, opt)
|
|
return self.errD, gradParams
|
|
end
|
|
|
|
function ContentGANModel:fGx_basic(x, netG_source, netD_source, real_source, real_target, fake_target,
|
|
gradParametersG_source, opt)
|
|
util.BiasZero(netD_source)
|
|
util.BiasZero(netG_source)
|
|
gradParametersG_source:zero()
|
|
-- GAN loss
|
|
-- local df_d_GAN = torch.zeros(fake_target:size())
|
|
-- local errGAN = 0
|
|
-- local errRec = 0
|
|
--- Domain GAN loss: D_A(G_A(A))
|
|
local output = netD_source.output -- [hack] forward was already executed in fDx, so save computation netD_source:forward(fake_B) ---
|
|
local errGAN = self.criterionGAN:forward(output, self.real_label)
|
|
local df_do = self.criterionGAN:backward(output, self.real_label)
|
|
local df_d_GAN = netD_source:updateGradInput(fake_target, df_do) ---:narrow(2,fake_AB:size(2)-output_nc+1, output_nc)
|
|
|
|
-- content loss
|
|
-- print('content_loss', opt.content_loss)
|
|
-- function content.lossUpdate(criterionContent, real_source, fake_target, contentFunc, loss_type, weight)
|
|
local errContent, df_d_content = content.lossUpdate(self.criterionContent, real_source, fake_target, self.contentFunc, opt.content_loss, opt.lambda_A)
|
|
netG_source:forward(real_source)
|
|
netG_source:backward(real_source, df_d_GAN + df_d_content)
|
|
-- print('errD', errGAN)
|
|
return gradParametersG_source, errGAN, errContent
|
|
end
|
|
|
|
function ContentGANModel:fGx(x, opt)
|
|
self.gradparametersG, self.errG, self.errCont =
|
|
self:fGx_basic(x, self.netG, self.netD,
|
|
self.real_A, self.real_B, self.fake_B,
|
|
self.gradparametersG, opt)
|
|
return self.errG, self.gradparametersG
|
|
end
|
|
|
|
function ContentGANModel:OptimizeParameters(opt)
|
|
local fDx = function(x) return self:fDx(x, opt) end
|
|
local fGx = function(x) return self:fGx(x, opt) end
|
|
optim.adam(fDx, self.parametersD, self.optimStateD)
|
|
optim.adam(fGx, self.parametersG, self.optimStateG)
|
|
end
|
|
|
|
function ContentGANModel:RefreshParameters()
|
|
self.parametersD, self.gradparametersD = nil, nil -- nil them to avoid spiking memory
|
|
self.parametersG, self.gradparametersG = nil, nil
|
|
-- define parameters of optimization
|
|
self.parametersG, self.gradparametersG = self.netG:getParameters()
|
|
self.parametersD, self.gradparametersD = self.netD:getParameters()
|
|
end
|
|
|
|
function ContentGANModel:Save(prefix, opt)
|
|
util.save_model(self.netG, prefix .. '_net_G_A.t7', 1.0)
|
|
util.save_model(self.netD, prefix .. '_net_D_A.t7', 1.0)
|
|
end
|
|
|
|
function ContentGANModel:GetCurrentErrorDescription()
|
|
description = ('G: %.4f D: %.4f Content: %.4f'):format(self.errG and self.errG or -1,
|
|
self.errD and self.errD or -1,
|
|
self.errCont and self.errCont or -1)
|
|
return description
|
|
end
|
|
|
|
|
|
function ContentGANModel:GetCurrentErrors()
|
|
local errors = {errG=self.errG and self.errG or -1, errD=self.errD and self.errD or -1,
|
|
errCont=self.errCont and self.errCont or -1}
|
|
return errors
|
|
end
|
|
|
|
-- returns a string that describes the display plot configuration
|
|
function ContentGANModel:DisplayPlot(opt)
|
|
return 'errG,errD,errCont'
|
|
end
|
|
|
|
|
|
function ContentGANModel:GetCurrentVisuals(opt, size)
|
|
if not size then
|
|
size = opt.display_winsize
|
|
end
|
|
|
|
local visuals = {}
|
|
table.insert(visuals, {img=self.real_A, label='real_A'})
|
|
table.insert(visuals, {img=self.fake_B, label='fake_B'})
|
|
table.insert(visuals, {img=self.real_B, label='real_B'})
|
|
return visuals
|
|
end
|