Update m1_tf_test.py

This commit is contained in:
Smiril
2022-07-11 10:06:29 +02:00
committed by GitHub
parent 8daacea01f
commit 4464f660e9
+4 -4
View File
@@ -30,9 +30,9 @@ class ComputeSum(Layer):
n = len(data)
return n
def ComputeSumModel(input_shape):
inputs = Input(shape = input_shape)
outputs = ComputeSum(input_shape[0])(inputs)
def ComputeSumModel(input_shape):
inputs = Input(shape = input_shape)
outputs = ComputeSum(input_shape[0])(inputs)
model = tf.keras.Model(inputs = inputs, outputs = outputs)
@@ -91,7 +91,7 @@ loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer='adam',
loss=loss_fn,
metrics=['accuracy'])
these = model.fit(x_train, y_train, epochs=10)
these = model.fit(x_train, y_train, epochs=10).history
that = ComputeSum(len(these))
those = that(these)
outputs = tf.keras.layers.Dense(4, activation='softmax', name='predictions')(those)