Merger: fix load time of xseg if it has no model files
This commit is contained in:
+11
-4
@@ -41,7 +41,6 @@ class XSegNet(object):
|
||||
self.model_weights = self.model.get_weights()
|
||||
|
||||
model_name = f'{name}_{resolution}'
|
||||
|
||||
self.model_filename_list = [ [self.model, f'{model_name}.npy'] ]
|
||||
|
||||
if training:
|
||||
@@ -59,6 +58,7 @@ class XSegNet(object):
|
||||
return nn.tf_sess.run ( [pred], feed_dict={self.input_t :input_np})[0]
|
||||
self.net_run = net_run
|
||||
|
||||
self.initialized = True
|
||||
# Loading/initializing all models/optimizers weights
|
||||
for model, filename in self.model_filename_list:
|
||||
do_init = not load_weights
|
||||
@@ -66,12 +66,16 @@ class XSegNet(object):
|
||||
if not do_init:
|
||||
model_file_path = self.weights_file_root / filename
|
||||
do_init = not model.load_weights( model_file_path )
|
||||
if do_init and raise_on_no_model_files:
|
||||
raise Exception(f'{model_file_path} does not exists.')
|
||||
if do_init:
|
||||
if raise_on_no_model_files:
|
||||
raise Exception(f'{model_file_path} does not exists.')
|
||||
if not training:
|
||||
self.initialized = False
|
||||
break
|
||||
|
||||
if do_init:
|
||||
model.init_weights()
|
||||
|
||||
|
||||
def get_resolution(self):
|
||||
return self.resolution
|
||||
|
||||
@@ -86,6 +90,9 @@ class XSegNet(object):
|
||||
model.save_weights( self.weights_file_root / filename )
|
||||
|
||||
def extract (self, input_image):
|
||||
if not self.initialized:
|
||||
return 0.5*np.ones ( (self.resolution, self.resolution, 1), nn.floatx.as_numpy_dtype )
|
||||
|
||||
input_shape_len = len(input_image.shape)
|
||||
if input_shape_len == 3:
|
||||
input_image = input_image[None,...]
|
||||
|
||||
Reference in New Issue
Block a user