diff --git a/m1_tf_test.py b/m1_tf_test.py index aab7601..306a038 100644 --- a/m1_tf_test.py +++ b/m1_tf_test.py @@ -31,13 +31,12 @@ class ComputeSum(Layer): n = len(data) return n - def ComputeSumModel(input_shape): - inputs = Input(shape = input_shape) - outputs = ComputeSum(input_shape[0])(inputs) - - model = tf.keras.Model(inputs = inputs, outputs = outputs) +def ComputeSumModel(input_shape): + inputs = Input(shape = input_shape) + outputs = ComputeSum(input_shape[0])(inputs) + model = tf.keras.Model(inputs = inputs, outputs = outputs) - return model + return model def xatoi(Str): @@ -71,7 +70,7 @@ def xatoi(Str): return base * sign inputs = tf.keras.Input(shape=(xatoi(sys.argv[1]),), name="digits") -model = tf.keras.models.load_model('model') + mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train, x_test = x_train / randrange(255+1), x_test / randrange(255+1) @@ -93,8 +92,10 @@ model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy']) these = model.fit(x_train, y_train, epochs=10).history that = ComputeSum(len(these)) those = that(these) +print(those) model = ComputeSumModel((tf.shape(these)[1],)) +print(model(these)) model = tf.keras.Model(inputs=inputs, outputs=outputs) model.build() -model.save('model') +model.save(these) model.summary()