remove lr_dropout for plaidml backend
This commit is contained in:
@@ -141,8 +141,9 @@ class Quick96Model(ModelBase):
|
||||
self.CA_conv_weights_list += [layer.weights[0]] #- is Conv2D kernel_weights
|
||||
|
||||
if self.is_training_mode:
|
||||
self.src_dst_opt = RMSprop(lr=2e-4, lr_dropout=0.3)
|
||||
self.src_dst_mask_opt = RMSprop(lr=2e-4, lr_dropout=0.3)
|
||||
lr_dropout = 0.3 if nnlib.device.backend != 'plaidML' else 0.0
|
||||
self.src_dst_opt = RMSprop(lr=2e-4, lr_dropout=lr_dropout)
|
||||
self.src_dst_mask_opt = RMSprop(lr=2e-4, lr_dropout=lr_dropout)
|
||||
|
||||
target_src_masked = self.model.target_src*self.model.target_srcm
|
||||
target_dst_masked = self.model.target_dst*self.model.target_dstm
|
||||
|
||||
Reference in New Issue
Block a user