From 4464f660e96902f0a23d702ea5cb5c0e8d2e529c Mon Sep 17 00:00:00 2001 From: Smiril Date: Mon, 11 Jul 2022 10:06:29 +0200 Subject: [PATCH] Update m1_tf_test.py --- m1_tf_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/m1_tf_test.py b/m1_tf_test.py index ce63bf8..a1cc884 100644 --- a/m1_tf_test.py +++ b/m1_tf_test.py @@ -30,9 +30,9 @@ class ComputeSum(Layer): n = len(data) return n -def ComputeSumModel(input_shape): - inputs = Input(shape = input_shape) - outputs = ComputeSum(input_shape[0])(inputs) + def ComputeSumModel(input_shape): + inputs = Input(shape = input_shape) + outputs = ComputeSum(input_shape[0])(inputs) model = tf.keras.Model(inputs = inputs, outputs = outputs) @@ -91,7 +91,7 @@ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy']) -these = model.fit(x_train, y_train, epochs=10) +these = model.fit(x_train, y_train, epochs=10).history that = ComputeSum(len(these)) those = that(these) outputs = tf.keras.layers.Dense(4, activation='softmax', name='predictions')(those)