From 84ab81856323dc449ae8f77f079a8506697d8634 Mon Sep 17 00:00:00 2001 From: Smiril Date: Mon, 11 Jul 2022 09:50:10 +0200 Subject: [PATCH] Update m1_tf_test.py --- m1_tf_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/m1_tf_test.py b/m1_tf_test.py index b37c2cc..03dd641 100644 --- a/m1_tf_test.py +++ b/m1_tf_test.py @@ -26,6 +26,10 @@ class ComputeSum(Layer): self.total.assign_add(tf.reduce_sum(inputs, axis=0)) return self.total + def Build_Tree(self, data, depth = 0): + n = len(data) + return n + def ComputeSumModel(input_shape): inputs = Input(shape = input_shape) outputs = ComputeSum(input_shape[0])(inputs) @@ -88,7 +92,7 @@ model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy']) these = model.fit(x_train, y_train, epochs=10) -that = ComputeSum(len(these)) +that = ComputeSum.Build_Tree(these) those = that(these) outputs = tf.keras.layers.Dense(4, activation='softmax', name='predictions')(those) model = ComputeSumModel((tf.shape(these)[1],))