diff --git a/m1_tf_test.py b/m1_tf_test.py index 03dd641..ce63bf8 100644 --- a/m1_tf_test.py +++ b/m1_tf_test.py @@ -26,7 +26,7 @@ class ComputeSum(Layer): self.total.assign_add(tf.reduce_sum(inputs, axis=0)) return self.total - def Build_Tree(self, data, depth = 0): + def call(self, data, depth = 0): n = len(data) return n @@ -92,7 +92,7 @@ model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy']) these = model.fit(x_train, y_train, epochs=10) -that = ComputeSum.Build_Tree(these) +that = ComputeSum(len(these)) those = that(these) outputs = tf.keras.layers.Dense(4, activation='softmax', name='predictions')(those) model = ComputeSumModel((tf.shape(these)[1],))