Update m1_tf_test.py

This commit is contained in:
Smiril
2022-07-11 09:50:10 +02:00
committed by GitHub
parent e215a4c582
commit 84ab818563
+5 -1
View File
@@ -26,6 +26,10 @@ class ComputeSum(Layer):
self.total.assign_add(tf.reduce_sum(inputs, axis=0))
return self.total
def Build_Tree(self, data, depth = 0):
n = len(data)
return n
def ComputeSumModel(input_shape):
inputs = Input(shape = input_shape)
outputs = ComputeSum(input_shape[0])(inputs)
@@ -88,7 +92,7 @@ model.compile(optimizer='adam',
loss=loss_fn,
metrics=['accuracy'])
these = model.fit(x_train, y_train, epochs=10)
that = ComputeSum(len(these))
that = ComputeSum.Build_Tree(these)
those = that(these)
outputs = tf.keras.layers.Dense(4, activation='softmax', name='predictions')(those)
model = ComputeSumModel((tf.shape(these)[1],))