From 5a8346a48b0fdd6cf239afea3d4652f59646911c Mon Sep 17 00:00:00 2001 From: Smiril Date: Mon, 11 Jul 2022 10:19:31 +0200 Subject: [PATCH] Update m1_tf_test.py --- m1_tf_test.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/m1_tf_test.py b/m1_tf_test.py index a1cc884..ba12c0c 100644 --- a/m1_tf_test.py +++ b/m1_tf_test.py @@ -10,6 +10,7 @@ import sys import tensorflow as tf from tensorflow import keras from tensorflow.keras.layers import Layer, Input +from tensorflow.keras.models import LSTM tf.__version__ tf.config.list_physical_devices() from random import randrange @@ -30,9 +31,9 @@ class ComputeSum(Layer): n = len(data) return n - def ComputeSumModel(input_shape): - inputs = Input(shape = input_shape) - outputs = ComputeSum(input_shape[0])(inputs) +def ComputeSumModel(input_shape): + inputs = Input(shape = input_shape) + outputs = ComputeSum(input_shape[0])(inputs) model = tf.keras.Model(inputs = inputs, outputs = outputs) @@ -88,6 +89,7 @@ model = tf.keras.models.Sequential([ tf.keras.layers.Dropout(0.2) ]) loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) +model.add(LSTM(32, input_shape=(1024, ))) model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])