View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
Overview
TensorFlow Estimators are supported in TensorFlow, and can be created from new and existing tf.keras
models. This tutorial contains a complete, minimal example of that process.
Setup
import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds
Create a simple Keras model.
In Keras, you assemble layers to build models. A model is (usually) a graph
of layers. The most common type of model is a stack of layers: the
tf.keras.Sequential
model.
To build a simple, fully-connected network (i.e. multi-layer perceptron):
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(16, activation='relu', input_shape=(4,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(3)
])
Compile the model and get a summary.
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer='adam')
model.summary()
Create an input function
Use the Datasets API to scale to large datasets or multi-device training.
Estimators need control of when and how their input pipeline is built. To allow this, they require an "Input function" or input_fn
. The Estimator
will call this function with no arguments. The input_fn
must return a tf.data.Dataset
.
def input_fn():
split = tfds.Split.TRAIN
dataset = tfds.load('iris', split=split, as_supervised=True)
dataset = dataset.map(lambda features, labels: ({'dense_input':features}, labels))
dataset = dataset.batch(32).repeat()
return dataset
Test out your input_fn
for features_batch, labels_batch in input_fn().take(1):
print(features_batch)
print(labels_batch)
Create an Estimator from the tf.keras model.
A tf.keras.Model
can be trained with the tf.estimator
API by converting the
model to an tf.estimator.Estimator
object with
tf.keras.estimator.model_to_estimator
.
import tempfile
model_dir = tempfile.mkdtemp()
keras_estimator = tf.keras.estimator.model_to_estimator(
keras_model=model, model_dir=model_dir)
Train and evaluate the estimator.
keras_estimator.train(input_fn=input_fn, steps=500)
eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10)
print('Eval result: {}'.format(eval_result))