Update m1_tf_test.py
This commit is contained in:
@@ -26,7 +26,7 @@ class ComputeSum(Layer):
|
||||
self.total.assign_add(tf.reduce_sum(inputs, axis=0))
|
||||
return self.total
|
||||
|
||||
def Build_Tree(self, data, depth = 0):
|
||||
def call(self, data, depth = 0):
|
||||
n = len(data)
|
||||
return n
|
||||
|
||||
@@ -92,7 +92,7 @@ model.compile(optimizer='adam',
|
||||
loss=loss_fn,
|
||||
metrics=['accuracy'])
|
||||
these = model.fit(x_train, y_train, epochs=10)
|
||||
that = ComputeSum.Build_Tree(these)
|
||||
that = ComputeSum(len(these))
|
||||
those = that(these)
|
||||
outputs = tf.keras.layers.Dense(4, activation='softmax', name='predictions')(those)
|
||||
model = ComputeSumModel((tf.shape(these)[1],))
|
||||
|
||||
Reference in New Issue
Block a user