28 lines
925 B
Lua
28 lines
925 B
Lua
-- define nn module for VGG postprocessing
|
|
local VGG_postprocess, parent = torch.class('nn.VGG_postprocess', 'nn.Module')
|
|
|
|
function VGG_postprocess:__init()
|
|
parent.__init(self)
|
|
end
|
|
|
|
function VGG_postprocess:updateOutput(input)
|
|
self.output = input:add(1):mul(127.5)
|
|
-- print(self.output:max(), self.output:min())
|
|
if self.output:max() > 255 or self.output:min() < 0 then
|
|
print(self.output:min(), self.output:max())
|
|
end
|
|
-- assert(self.output:min()>=0,"badly scaled inputs")
|
|
-- assert(self.output:max()<=255,"badly scaled inputs")
|
|
|
|
local mean_pixel = torch.FloatTensor({103.939, 116.779, 123.68})
|
|
mean_pixel = mean_pixel:reshape(1,3,1,1)
|
|
mean_pixel = mean_pixel:repeatTensor(input:size(1), 1, input:size(3), input:size(4)):cuda()
|
|
self.output:add(-1, mean_pixel)
|
|
return self.output
|
|
end
|
|
|
|
function VGG_postprocess:updateGradInput(input, gradOutput)
|
|
self.gradInput = gradOutput:div(127.5)
|
|
return self.gradInput
|
|
end
|