Посмотреть на TensorFlow.org | Запустить в Google Colab | Посмотреть на GitHub | Скачать блокнот |
Обзор
В этом руководстве представлен список лучших практик по написанию кода с использованием TensorFlow 2 (TF2). Оно написано для пользователей, которые недавно перешли с TensorFlow 1 (TF1). Обратитесь к разделу руководства по переносу для получения дополнительной информации о переносе кода TF1 в TF2.
Настраивать
Импортируйте TensorFlow и другие зависимости для примеров в этом руководстве.
import tensorflow as tf
import tensorflow_datasets as tfds
Рекомендации для идиоматического TensorFlow 2
Рефакторинг вашего кода в более мелкие модули
Хорошей практикой является реорганизация вашего кода в более мелкие функции, которые вызываются по мере необходимости. Для лучшей производительности вы должны попытаться декорировать самые большие блоки вычислений, которые вы можете в tf.function
(обратите внимание, что вложенные функции python, вызываемые tf.function
, не требуют своих собственных отдельных украшений, если только вы не хотите использовать разные jit_compile
настройки для tf.function
). В зависимости от варианта использования это может быть несколько этапов обучения или даже весь цикл обучения. Для вариантов использования логического вывода это может быть прямой проход одной модели.
Отрегулируйте скорость обучения по умолчанию для некоторых tf.keras.optimizer
s
Некоторые оптимизаторы Keras имеют разную скорость обучения в TF2. Если вы видите изменение поведения сходимости для ваших моделей, проверьте скорость обучения по умолчанию.
Нет никаких изменений для optimizers.SGD
, optimizers.Adam
или optimizers.RMSprop
.
Изменены следующие скорости обучения по умолчанию:
-
optimizers.Adagrad
с0.01
до0.001
-
optimizers.Adadelta
с1.0
до0.001
-
optimizers.Adamax
с0.002
до0.001
-
optimizers.Nadam
с0.002
на0.001
Используйте tf.Module
s и Keras для управления переменными.
tf.Module
s и tf.keras.layers.Layer
s предлагают удобные variables
и свойства trainable_variables
, которые рекурсивно собирают все зависимые переменные. Это упрощает управление переменными локально там, где они используются.
Слои/модели Keras наследуются от tf.train.Checkpointable
и интегрированы с @tf.function
, что позволяет напрямую создавать контрольные точки или экспортировать SavedModels из объектов Keras. Вам не обязательно использовать API Model.fit
, чтобы воспользоваться преимуществами этих интеграций.
Прочтите раздел о переносе обучения и тонкой настройке в руководстве по Keras, чтобы узнать, как собрать подмножество соответствующих переменных с помощью Keras.
Объедините tf.data.Dataset
s и tf.function
Пакет наборов данных TensorFlow ( tfds
) содержит утилиты для загрузки предопределенных наборов данных в виде объектов tf.data.Dataset
. В этом примере вы можете загрузить набор данных MNIST с помощью tfds
:
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
Затем подготовьте данные для обучения:
- Измените масштаб каждого изображения.
- Перетасуйте порядок примеров.
- Соберите партии изображений и этикеток.
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
Чтобы пример был коротким, обрежьте набор данных, чтобы он возвращал только 5 пакетов:
train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)
STEPS_PER_EPOCH = 5
train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
2021-12-08 17:15:01.637157: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Используйте обычную итерацию Python для перебора обучающих данных, которые помещаются в памяти. В противном случае tf.data.Dataset
— лучший способ потоковой передачи обучающих данных с диска. Наборы данных являются итерируемыми (а не итераторами) и работают точно так же, как другие итерируемые объекты Python при активном выполнении. Вы можете в полной мере использовать функции асинхронной предварительной выборки/потоковой передачи набора данных, обернув свой код в tf.function
, который заменяет итерацию Python эквивалентными операциями графа с использованием AutoGraph.
@tf.function
def train(model, dataset, optimizer):
for x, y in dataset:
with tf.GradientTape() as tape:
# training=True is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
prediction = model(x, training=True)
loss = loss_fn(prediction, y)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
Если вы используете Model.fit
API, вам не придется беспокоиться об итерации набора данных.
model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)
Используйте обучающие циклы Keras
Если вам не нужен низкоуровневый контроль над процессом обучения, рекомендуется использовать встроенные в Keras методы fit
, evaluate
и predict
. Эти методы обеспечивают единый интерфейс для обучения модели независимо от реализации (последовательной, функциональной или подклассовой).
К преимуществам этих методов относятся:
- Они принимают массивы Numpy, генераторы Python и
tf.data.Datasets
. - Они автоматически применяют регуляризацию и потери активации.
- Они поддерживают
tf.distribute
где обучающий код остается неизменным независимо от конфигурации оборудования . - Они поддерживают произвольные callables как потери и метрики.
- Они поддерживают обратные вызовы, такие как
tf.keras.callbacks.TensorBoard
, и пользовательские обратные вызовы. - Они производительны, автоматически используют графики TensorFlow.
Вот пример обучения модели с использованием Dataset
. Подробнее о том, как это работает, читайте в руководствах .
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
# Model is the full model w/o custom layers
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)
print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5 5/5 [==============================] - 9s 7ms/step - loss: 1.5762 - accuracy: 0.4938 Epoch 2/5 2021-12-08 17:15:11.145429: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 6ms/step - loss: 0.5087 - accuracy: 0.8969 Epoch 3/5 2021-12-08 17:15:11.559374: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 2s 5ms/step - loss: 0.3348 - accuracy: 0.9469 Epoch 4/5 2021-12-08 17:15:13.860407: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 5ms/step - loss: 0.2445 - accuracy: 0.9688 Epoch 5/5 2021-12-08 17:15:14.269850: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 6ms/step - loss: 0.2006 - accuracy: 0.9719 2021-12-08 17:15:14.717552: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 1s 4ms/step - loss: 1.4553 - accuracy: 0.5781 Loss 1.4552843570709229, Accuracy 0.578125 2021-12-08 17:15:15.862684: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Настройте обучение и напишите свой собственный цикл
Если вам подходят модели Keras, но вам нужно больше гибкости и контроля над этапом обучения или внешними циклами обучения, вы можете реализовать свои собственные этапы обучения или даже целые циклы обучения. См. руководство Keras по настройке fit
, чтобы узнать больше.
Вы также можете реализовать многие вещи как tf.keras.callbacks.Callback
.
Этот метод обладает многими преимуществами, упомянутыми ранее , но дает вам контроль над шагом поезда и даже над внешним циклом.
Стандартный цикл обучения состоит из трех шагов:
- Повторите генератор Python или
tf.data.Dataset
, чтобы получить партии примеров. - Используйте
tf.GradientTape
для сбора градиентов. - Используйте один из
tf.keras.optimizers
, чтобы применить обновления веса к переменным модели.
Помните:
- Всегда включайте
training
аргумент в методcall
подклассов слоев и моделей. - Обязательно вызовите модель с правильно установленным
training
аргументом. - В зависимости от использования переменные модели могут не существовать до тех пор, пока модель не будет запущена на пакете данных.
- Вам нужно вручную обрабатывать такие вещи, как потери регуляризации для модели.
Нет необходимости запускать инициализаторы переменных или добавлять зависимости ручного управления. tf.function
обрабатывает автоматические зависимости управления и инициализацию переменных при создании для вас.
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
for epoch in range(NUM_EPOCHS):
for inputs, labels in train_data:
train_step(inputs, labels)
print("Finished epoch", epoch)
2021-12-08 17:15:16.714849: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 0 2021-12-08 17:15:17.097043: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 1 2021-12-08 17:15:17.502480: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 2 2021-12-08 17:15:17.873701: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 3 Finished epoch 4 2021-12-08 17:15:18.344196: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Воспользуйтесь преимуществами tf.function
с потоком управления Python
tf.function
предоставляет способ преобразования потока управления, зависящего от данных, в эквиваленты режима графа, такие как tf.cond
и tf.while_loop
.
Одним из распространенных мест, где появляется поток управления, зависящий от данных, являются модели последовательности. tf.keras.layers.RNN
оборачивает ячейку RNN, позволяя вам статически или динамически развернуть повторение. Например, вы можете повторно реализовать динамическое развертывание следующим образом.
class DynamicRNN(tf.keras.Model):
def __init__(self, rnn_cell):
super(DynamicRNN, self).__init__(self)
self.cell = rnn_cell
@tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 3])])
def call(self, input_data):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
timesteps = tf.shape(input_data)[0]
batch_size = tf.shape(input_data)[1]
outputs = tf.TensorArray(tf.float32, timesteps)
state = self.cell.get_initial_state(batch_size = batch_size, dtype=tf.float32)
for i in tf.range(timesteps):
output, state = self.cell(input_data[i], state)
outputs = outputs.write(i, output)
return tf.transpose(outputs.stack(), [1, 0, 2]), state
lstm_cell = tf.keras.layers.LSTMCell(units = 13)
my_rnn = DynamicRNN(lstm_cell)
outputs, state = my_rnn(tf.random.normal(shape=[10,20,3]))
print(outputs.shape)
(10, 20, 13)
Прочтите руководство по tf.function
для получения дополнительной информации.
Метрики нового стиля и потери
Метрики и потери — это как объекты, которые охотно работают, так и в tf.function
.
Объект потери доступен для вызова и ожидает ( y_true
, y_pred
) в качестве аргументов:
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815
Используйте метрики для сбора и отображения данных
Вы можете использовать tf.metrics
для агрегирования данных и tf.summary
для регистрации сводок и перенаправления их автору с помощью менеджера контекста. Сводки передаются непосредственно автору, что означает, что вы должны указать значение step
на сайте вызова.
summary_writer = tf.summary.create_file_writer('/tmp/summaries')
with summary_writer.as_default():
tf.summary.scalar('loss', 0.1, step=42)
Используйте tf.metrics
для агрегирования данных перед их записью в виде сводок. Метрики имеют состояние; они накапливают значения и возвращают совокупный результат, когда вы вызываете метод result
(например, Mean.result
). Очистите накопленные значения с помощью Model.reset_states
.
def train(model, optimizer, dataset, log_freq=10):
avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
for images, labels in dataset:
loss = train_step(model, optimizer, images, labels)
avg_loss.update_state(loss)
if tf.equal(optimizer.iterations % log_freq, 0):
tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)
avg_loss.reset_states()
def test(model, test_x, test_y, step_num):
# training=False is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
loss = loss_fn(model(test_x, training=False), test_y)
tf.summary.scalar('loss', loss, step=step_num)
train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')
test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test')
with train_summary_writer.as_default():
train(model, optimizer, dataset)
with test_summary_writer.as_default():
test(model, test_x, test_y, optimizer.iterations)
Визуализируйте сгенерированные сводки, указав TensorBoard на каталог сводных журналов:
tensorboard --logdir /tmp/summaries
Используйте API tf.summary
для записи сводных данных для визуализации в TensorBoard. Для получения дополнительной информации прочитайте руководство по tf.summary
.
# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# Update the metrics
loss_metric.update_state(total_loss)
accuracy_metric.update_state(labels, predictions)
for epoch in range(NUM_EPOCHS):
# Reset the metrics
loss_metric.reset_states()
accuracy_metric.reset_states()
for inputs, labels in train_data:
train_step(inputs, labels)
# Get the metric results
mean_loss=loss_metric.result()
mean_accuracy = accuracy_metric.result()
print('Epoch: ', epoch)
print(' loss: {:.3f}'.format(mean_loss))
print(' accuracy: {:.3f}'.format(mean_accuracy))
2021-12-08 17:15:19.339736: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 0 loss: 0.142 accuracy: 0.991 2021-12-08 17:15:19.781743: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 1 loss: 0.125 accuracy: 0.997 2021-12-08 17:15:20.219033: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 2 loss: 0.110 accuracy: 0.997 2021-12-08 17:15:20.598085: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 3 loss: 0.099 accuracy: 0.997 Epoch: 4 loss: 0.085 accuracy: 1.000 2021-12-08 17:15:20.981787: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Имена метрик Keras
Модели Keras последовательны в отношении обработки имен метрик. Когда вы передаете строку в списке метрик, именно эта строка используется в качестве name
метрики. Эти имена видны в объекте истории, возвращаемом model.fit
, и в журналах, передаваемых keras.callbacks
. устанавливается на строку, которую вы передали в списке метрик.
model.compile(
optimizer = tf.keras.optimizers.Adam(0.001),
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 5ms/step - loss: 0.0963 - acc: 0.9969 - accuracy: 0.9969 - my_accuracy: 0.9969 2021-12-08 17:15:21.942940: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])
Отладка
Используйте активное выполнение для пошагового запуска кода для проверки форм, типов данных и значений. Некоторые API, такие как tf.function
, tf.keras
и т. д., предназначены для использования выполнения Graph для повышения производительности и переносимости. При отладке используйте tf.config.run_functions_eagerly(True)
, чтобы использовать активное выполнение внутри этого кода.
Например:
@tf.function
def f(x):
if x > 0:
import pdb
pdb.set_trace()
x = x + 1
return x
tf.config.run_functions_eagerly(True)
f(tf.constant(1))
>>> f()
-> x = x + 1
(Pdb) l
6 @tf.function
7 def f(x):
8 if x > 0:
9 import pdb
10 pdb.set_trace()
11 -> x = x + 1
12 return x
13
14 tf.config.run_functions_eagerly(True)
15 f(tf.constant(1))
[EOF]
Это также работает внутри моделей Keras и других API, которые поддерживают активное выполнение:
class CustomModel(tf.keras.models.Model):
@tf.function
def call(self, input_data):
if tf.reduce_mean(input_data) > 0:
return input_data
else:
import pdb
pdb.set_trace()
return input_data // 2
tf.config.run_functions_eagerly(True)
model = CustomModel()
model(tf.constant([-2, -4]))
>>> call()
-> return input_data // 2
(Pdb) l
10 if tf.reduce_mean(input_data) > 0:
11 return input_data
12 else:
13 import pdb
14 pdb.set_trace()
15 -> return input_data // 2
16
17
18 tf.config.run_functions_eagerly(True)
19 model = CustomModel()
20 model(tf.constant([-2, -4]))
Примечания:
Методы
tf.keras.Model
, такие какfit
,evaluate
иtf.function
, выполняются как графики сpredict
под капотом.При использовании
tf.keras.Model.compile
установитеrun_eagerly = True
, чтобы запретить перенос логикиModel
вtf.function
.Используйте
tf.data.experimental.enable_debug_mode
, чтобы включить режим отладки дляtf.data
. Подробнее читайте в документации по API .
Не держите tf.Tensors
в своих объектах
Эти тензорные объекты могут быть созданы либо в tf.function
, либо в активном контексте, и эти тензоры ведут себя по-разному. Всегда используйте tf.Tensor
только для промежуточных значений.
Чтобы отслеживать состояние, используйте tf.Variable
, поскольку они всегда могут использоваться в обоих контекстах. Прочтите руководство по tf.Variable
, чтобы узнать больше.
Ресурсы и дополнительная литература
Прочтите руководства и учебные пособия по TF2, чтобы узнать больше о том, как использовать TF2.
Если вы ранее использовали TF1.x, настоятельно рекомендуется перенести свой код на TF2. Прочтите руководства по миграции, чтобы узнать больше.