Files
disrupting-deepfakes/cyclegan/util/VGG_preprocess.lua
T
Nataniel Ruiz a3fe19383a cyclegan
2019-12-25 17:21:03 -04:00

28 lines
925 B
Lua

-- define nn module for VGG postprocessing
local VGG_postprocess, parent = torch.class('nn.VGG_postprocess', 'nn.Module')
function VGG_postprocess:__init()
parent.__init(self)
end
function VGG_postprocess:updateOutput(input)
self.output = input:add(1):mul(127.5)
-- print(self.output:max(), self.output:min())
if self.output:max() > 255 or self.output:min() < 0 then
print(self.output:min(), self.output:max())
end
-- assert(self.output:min()>=0,"badly scaled inputs")
-- assert(self.output:max()<=255,"badly scaled inputs")
local mean_pixel = torch.FloatTensor({103.939, 116.779, 123.68})
mean_pixel = mean_pixel:reshape(1,3,1,1)
mean_pixel = mean_pixel:repeatTensor(input:size(1), 1, input:size(3), input:size(4)):cuda()
self.output:add(-1, mean_pixel)
return self.output
end
function VGG_postprocess:updateGradInput(input, gradOutput)
self.gradInput = gradOutput:div(127.5)
return self.gradInput
end