fix fanseg
This commit is contained in:
@@ -3,6 +3,7 @@ import os
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
from nnlib import nnlib
|
||||
from interact import interact as io
|
||||
|
||||
class FANSegmentator(object):
|
||||
def __init__ (self, resolution, face_type_str, load_weights=True, weights_file_root=None):
|
||||
@@ -19,7 +20,15 @@ class FANSegmentator(object):
|
||||
|
||||
if load_weights:
|
||||
self.model.load_weights (str(self.weights_path))
|
||||
|
||||
else:
|
||||
io.log_info ("Initializing CA weights...")
|
||||
conv_weights_list = []
|
||||
for layer in self.model.layers:
|
||||
if type(layer) == Conv2D:
|
||||
conv_weights_list += [layer.weights[0]] # Conv2D kernel_weights
|
||||
CAInitializerMP(conv_weights_list)
|
||||
self.model.compile(loss='mse', optimizer=Adam(tf_cpu_mode=2))
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
@@ -43,7 +52,6 @@ class FANSegmentator(object):
|
||||
x = FANSegmentator.EncFlow(ngf=ngf)(x)
|
||||
x = FANSegmentator.DecFlow(ngf=ngf)(x)
|
||||
model = Model(inp,x)
|
||||
model.compile (loss='mse', optimizer=Adam(tf_cpu_mode=2) )
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user