diff --git a/m1_tf_test.py b/m1_tf_test.py index ce63bf8..a1cc884 100644 --- a/m1_tf_test.py +++ b/m1_tf_test.py @@ -30,9 +30,9 @@ class ComputeSum(Layer): n = len(data) return n -def ComputeSumModel(input_shape): - inputs = Input(shape = input_shape) - outputs = ComputeSum(input_shape[0])(inputs) + def ComputeSumModel(input_shape): + inputs = Input(shape = input_shape) + outputs = ComputeSum(input_shape[0])(inputs) model = tf.keras.Model(inputs = inputs, outputs = outputs) @@ -91,7 +91,7 @@ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy']) -these = model.fit(x_train, y_train, epochs=10) +these = model.fit(x_train, y_train, epochs=10).history that = ComputeSum(len(these)) those = that(these) outputs = tf.keras.layers.Dense(4, activation='softmax', name='predictions')(those)