master
/ mnist.py

mnist.py @96fc089 raw · history · blame

import tensorflow as tf
import datetime
import os

def create_model():
  return tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
  ])

if __name__ ==  '__main__':
    (x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data(path=os.getcwd()+'/mnist.npz')
    x_train, x_test = x_train / 255.0, x_test / 255.0

    model = create_model()
    model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])

    log_dir="./results/tb_results/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

    model.fit(x=x_train, 
            y=y_train, 
            epochs=5, 
            validation_data=(x_test, y_test), 
            callbacks=[tensorboard_callback])