Transfer pembelajaran dan penyempurnaan

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Mempersiapkan

import numpy as np
import tensorflow as tf
from tensorflow import keras

pengantar

Pembelajaran Transfer terdiri dari mengambil fitur belajar pada satu masalah, dan memanfaatkan mereka pada baru, masalah yang sama. Misalnya, fitur dari model yang telah belajar mengidentifikasi rakun mungkin berguna untuk memulai model yang dimaksudkan untuk mengidentifikasi tanuki.

Pembelajaran transfer biasanya dilakukan untuk tugas-tugas di mana kumpulan data Anda memiliki terlalu sedikit data untuk melatih model skala penuh dari awal.

Inkarnasi pembelajaran transfer yang paling umum dalam konteks pembelajaran mendalam adalah alur kerja berikut:

  1. Ambil lapisan dari model yang telah dilatih sebelumnya.
  2. Bekukan mereka, untuk menghindari penghancuran informasi apa pun yang dikandungnya selama putaran pelatihan di masa mendatang.
  3. Tambahkan beberapa lapisan baru yang dapat dilatih di atas lapisan beku. Mereka akan belajar mengubah fitur lama menjadi prediksi pada kumpulan data baru.
  4. Latih lapisan baru pada kumpulan data Anda.

Sebuah terakhir, langkah opsional, adalah fine-tuning, yang terdiri dari unfreezing seluruh model yang Anda diperoleh di atas (atau bagian dari itu), dan-pelatihan ulang pada data baru dengan tingkat belajar yang sangat rendah. Ini berpotensi mencapai peningkatan yang berarti, dengan secara bertahap mengadaptasi fitur yang telah dilatih sebelumnya ke data baru.

Pertama, kita akan pergi ke Keras trainable API secara rinci, yang mendasari paling Transfer belajar & fine-tuning alur kerja.

Kemudian, kami akan mendemonstrasikan alur kerja tipikal dengan mengambil model yang telah dilatih sebelumnya pada kumpulan data ImageNet, dan melatihnya kembali pada kumpulan data klasifikasi "kucing vs anjing" Kaggle.

Ini diadaptasi dari Jauh Belajar dengan Python dan 2016 posting blog "membangun model klasifikasi citra yang kuat menggunakan sangat sedikit data" .

Lapisan beku: memahami trainable atribut

Lapisan & model memiliki tiga atribut bobot:

  • weights adalah daftar semua variabel bobot dari lapisan.
  • trainable_weights adalah daftar orang-orang yang dimaksudkan untuk diperbarui (melalui gradient descent) untuk meminimalkan kerugian selama pelatihan.
  • non_trainable_weights adalah daftar orang-orang yang tidak dimaksudkan untuk dilatih. Biasanya mereka diperbarui oleh model selama forward pass.

Contoh: Dense lapisan memiliki 2 bobot dilatih (kernel & bias)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

Secara umum, semua beban adalah beban yang bisa dilatih. Satu-satunya built-in lapisan yang memiliki bobot non-dilatih adalah BatchNormalization lapisan. Ini menggunakan bobot yang tidak dapat dilatih untuk melacak rata-rata dan varians inputnya selama pelatihan. Untuk mempelajari cara menggunakan beban non-dilatih di lapisan kustom Anda sendiri, lihat panduan untuk menulis lapisan baru dari awal .

Contoh: BatchNormalization lapisan memiliki 2 bobot dilatih dan 2 bobot non-dilatih

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

Lapisan & model juga fitur atribut boolean trainable . Nilainya bisa diubah. Pengaturan layer.trainable ke False bergerak semua bobot layer dari dilatih untuk non-dilatih. Ini disebut "pembekuan" lapisan: keadaan lapisan beku tidak akan diperbarui selama pelatihan (baik ketika pelatihan dengan fit() atau ketika pelatihan dengan lingkaran kustom yang mengandalkan trainable_weights untuk menerapkan update gradien).

Contoh: pengaturan trainable untuk False

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

Saat bobot yang dapat dilatih menjadi tidak dapat dilatih, nilainya tidak lagi diperbarui selama pelatihan.

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 1s 640ms/step - loss: 0.0945

Jangan bingung layer.trainable atribut dengan argumen training di layer.__call__() (yang mengontrol apakah lapisan harus berjalan ke depan lulus dalam modus inferensi atau mode pelatihan). Untuk informasi lebih lanjut, lihat Keras FAQ .

Rekursif pengaturan dari trainable atribut

Jika Anda menetapkan trainable = False pada model atau pada setiap lapisan yang memiliki sub-lapisan, semua anak lapisan menjadi non-dilatih juga.

Contoh:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

Alur kerja transfer-learning yang khas

Ini mengarahkan kita pada bagaimana alur kerja transfer learning yang khas dapat diterapkan di Keras:

  1. Buat instance model dasar dan muat beban yang telah dilatih sebelumnya ke dalamnya.
  2. Membekukan semua lapisan dalam model dasar dengan menetapkan trainable = False .
  3. Buat model baru di atas output dari satu (atau beberapa) lapisan dari model dasar.
  4. Latih model baru Anda pada kumpulan data baru Anda.

Perhatikan bahwa alternatif, alur kerja yang lebih ringan juga dapat berupa:

  1. Buat instance model dasar dan muat beban yang telah dilatih sebelumnya ke dalamnya.
  2. Jalankan dataset baru Anda melaluinya dan catat output dari satu (atau beberapa) layer dari model dasar. Ini disebut ekstraksi fitur.
  3. Gunakan output itu sebagai data input untuk model baru yang lebih kecil.

Keuntungan utama dari alur kerja kedua itu adalah Anda hanya menjalankan model dasar sekali pada data Anda, bukan sekali per periode pelatihan. Jadi jauh lebih cepat & lebih murah.

Masalah dengan alur kerja kedua itu, bagaimanapun, adalah bahwa itu tidak memungkinkan Anda untuk secara dinamis mengubah data input model baru Anda selama pelatihan, yang diperlukan saat melakukan augmentasi data, misalnya. Pembelajaran transfer biasanya digunakan untuk tugas-tugas ketika kumpulan data baru Anda memiliki terlalu sedikit data untuk melatih model skala penuh dari awal, dan dalam skenario seperti itu, augmentasi data sangat penting. Jadi berikut ini, kita akan fokus pada alur kerja pertama.

Inilah yang terlihat seperti alur kerja pertama di Keras:

Pertama, buat contoh model dasar dengan bobot yang telah dilatih sebelumnya.

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

Kemudian, bekukan model dasar.

base_model.trainable = False

Buat model baru di atas.

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

Latih model pada data baru.

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

Mencari setelan

Setelah model Anda menyatu pada data baru, Anda dapat mencoba untuk mencairkan semua atau sebagian dari model dasar dan melatih kembali seluruh model secara menyeluruh dengan tingkat pembelajaran yang sangat rendah.

Ini adalah langkah terakhir opsional yang berpotensi memberi Anda peningkatan bertahap. Ini juga berpotensi menyebabkan overfitting cepat - ingatlah itu.

Hal ini penting untuk hanya melakukan langkah ini setelah model dengan lapisan beku telah dilatih untuk konvergensi. Jika Anda mencampur lapisan yang dapat dilatih secara acak dengan lapisan yang dapat dilatih yang menyimpan fitur yang telah dilatih sebelumnya, lapisan yang diinisialisasi secara acak akan menyebabkan pembaruan gradien yang sangat besar selama pelatihan, yang akan menghancurkan fitur yang telah dilatih sebelumnya.

Penting juga untuk menggunakan tingkat pembelajaran yang sangat rendah pada tahap ini, karena Anda melatih model yang jauh lebih besar daripada di putaran pertama pelatihan, pada kumpulan data yang biasanya sangat kecil. Akibatnya, Anda berisiko mengalami overfitting dengan sangat cepat jika menerapkan pembaruan bobot yang besar. Di sini, Anda hanya ingin menyesuaikan kembali bobot yang telah dilatih sebelumnya secara bertahap.

Ini adalah bagaimana menerapkan fine-tuning dari seluruh model dasar:

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

Catatan penting tentang compile() dan trainable

Memanggil compile() pada model dimaksudkan untuk "membekukan" perilaku model itu. Ini berarti bahwa trainable nilai atribut pada saat model dikompilasi harus dipertahankan sepanjang masa model yang, sampai compile dipanggil lagi. Oleh karena itu, jika Anda mengubah trainable nilai, pastikan untuk panggilan compile() lagi pada model Anda agar perubahan diperhitungkan.

Catatan penting tentang BatchNormalization lapisan

Banyak model gambar mengandung BatchNormalization lapisan. Lapisan itu adalah kasus khusus pada setiap hitungan yang bisa dibayangkan. Berikut adalah beberapa hal yang perlu diingat.

  • BatchNormalization berisi 2 bobot non-dilatih yang mendapatkan update selama pelatihan. Ini adalah variabel yang melacak mean dan varians dari input.
  • Ketika Anda menetapkan bn_layer.trainable = False , yang BatchNormalization lapisan akan berjalan dalam mode inferensi, dan tidak akan memperbarui berarti & varians statistiknya. Ini bukan kasus untuk lapisan lain pada umumnya, seperti kemampuan dilatihnya berat & inferensi / mode pelatihan adalah dua konsep ortogonal . Tapi dua terikat dalam kasus BatchNormalization lapisan.
  • Ketika Anda mencairkan model yang berisi BatchNormalization lapisan untuk melakukan fine-tuning, Anda harus menjaga BatchNormalization lapisan dalam mode inferensi dengan melewati training=False saat memanggil model dasar. Jika tidak, pembaruan yang diterapkan pada bobot yang tidak dapat dilatih akan tiba-tiba menghancurkan apa yang telah dipelajari model.

Anda akan melihat pola ini beraksi dalam contoh ujung ke ujung di akhir panduan ini.

Transfer pembelajaran & penyetelan dengan loop pelatihan khusus

Jika bukan fit() , Anda menggunakan lingkaran pelatihan tingkat rendah Anda sendiri, alur kerja tetap pada dasarnya sama. Anda harus berhati-hati untuk hanya memperhitungkan daftar model.trainable_weights ketika menerapkan pembaruan gradien:

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

Begitu juga untuk fine-tuning.

Contoh ujung ke ujung: menyempurnakan model klasifikasi gambar pada kumpulan data kucing vs. anjing

Untuk memperkuat konsep-konsep ini, mari memandu Anda melalui contoh pembelajaran transfer & penyetelan ujung-ke-ujung yang konkret. Kami akan memuat model Xception, yang telah dilatih sebelumnya di ImageNet, dan menggunakannya pada kumpulan data klasifikasi "kucing vs. anjing" Kaggle.

Mendapatkan data

Pertama, mari kita ambil dataset kucing vs. anjing menggunakan TFDS. Jika Anda memiliki dataset Anda sendiri, Anda mungkin ingin menggunakan utilitas tf.keras.preprocessing.image_dataset_from_directory untuk menghasilkan sejenis dataset berlabel objek dari serangkaian gambar pada disk ke dalam folder kelas khusus.

Pembelajaran transfer paling berguna saat bekerja dengan kumpulan data yang sangat kecil. Untuk menjaga agar kumpulan data kami tetap kecil, kami akan menggunakan 40% dari data pelatihan asli (25.000 gambar) untuk pelatihan, 10% untuk validasi, dan 10% untuk pengujian.

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

Ini adalah 9 gambar pertama dalam set data pelatihan -- seperti yang Anda lihat, semuanya memiliki ukuran yang berbeda.

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

Kita juga dapat melihat bahwa label 1 adalah "anjing" dan label 0 adalah "kucing".

Standarisasi data

Gambar mentah kami memiliki berbagai ukuran. Selain itu, setiap piksel terdiri dari 3 nilai integer antara 0 dan 255 (nilai level RGB). Ini tidak cocok untuk memberi makan jaringan saraf. Kita perlu melakukan 2 hal:

  • Standarisasi ke ukuran gambar tetap. Kami memilih 150x150.
  • Nilai-nilai pixel menormalkan antara -1 dan 1. Kami akan melakukan ini dengan menggunakan Normalization lapisan sebagai bagian dari model itu sendiri.

Secara umum, ini adalah praktik yang baik untuk mengembangkan model yang mengambil data mentah sebagai input, sebagai lawan model yang mengambil data yang sudah diproses sebelumnya. Alasannya adalah, jika model Anda mengharapkan data yang telah diproses sebelumnya, setiap kali Anda mengekspor model Anda untuk menggunakannya di tempat lain (di browser web, di aplikasi seluler), Anda harus mengimplementasikan kembali pipeline prapemrosesan yang sama persis. Ini menjadi sangat rumit dengan sangat cepat. Jadi kita harus melakukan pra-pemrosesan sesedikit mungkin sebelum mencapai model.

Di sini, kita akan melakukan pengubahan ukuran gambar di saluran data (karena jaringan saraf dalam hanya dapat memproses kumpulan data yang berdekatan), dan kita akan melakukan penskalaan nilai input sebagai bagian dari model, saat kita membuatnya.

Mari kita ubah ukuran gambar menjadi 150x150:

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

Selain itu, mari kita mengelompokkan data dan menggunakan caching & prefetching untuk mengoptimalkan kecepatan pemuatan.

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

Menggunakan augmentasi data acak

Jika Anda tidak memiliki kumpulan data gambar yang besar, merupakan praktik yang baik untuk memperkenalkan keragaman sampel secara artifisial dengan menerapkan transformasi acak namun realistis pada gambar pelatihan, seperti pembalikan horizontal acak atau rotasi acak kecil. Ini membantu mengekspos model ke berbagai aspek data pelatihan sambil memperlambat overfitting.

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [layers.RandomFlip("horizontal"), layers.RandomRotation(0.1),]
)

Mari kita visualisasikan seperti apa gambar pertama dari kumpulan pertama setelah berbagai transformasi acak:

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")
2021-09-01 18:45:34.772284: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

Membangun model

Sekarang mari kita membangun sebuah model yang mengikuti cetak biru yang telah kita jelaskan sebelumnya.

Perhatikan bahwa:

  • Kami menambahkan Rescaling lapisan untuk nilai input skala (awalnya di [0, 255] range) ke [-1, 1] jangkauan.
  • Kami menambahkan Dropout lapisan sebelum lapisan klasifikasi, untuk regularisasi.
  • Kami pastikan untuk lulus training=False saat memanggil model dasar, sehingga berjalan dalam mode inferensi, sehingga statistik batchnorm tidak mendapatkan update bahkan setelah kami mencairkan model dasar untuk fine-tuning.
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83689472/83683744 [==============================] - 2s 0us/step
83697664/83683744 [==============================] - 2s 0us/step
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 2,049
Non-trainable params: 20,861,480
_________________________________________________________________

Latih lapisan atas

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
151/291 [==============>...............] - ETA: 3s - loss: 0.1979 - binary_accuracy: 0.9096
Corrupt JPEG data: 65 extraneous bytes before marker 0xd9
268/291 [==========================>...] - ETA: 1s - loss: 0.1663 - binary_accuracy: 0.9269
Corrupt JPEG data: 239 extraneous bytes before marker 0xd9
282/291 [============================>.] - ETA: 0s - loss: 0.1628 - binary_accuracy: 0.9284
Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9
Corrupt JPEG data: 228 extraneous bytes before marker 0xd9
291/291 [==============================] - ETA: 0s - loss: 0.1620 - binary_accuracy: 0.9286
Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9
291/291 [==============================] - 29s 63ms/step - loss: 0.1620 - binary_accuracy: 0.9286 - val_loss: 0.0814 - val_binary_accuracy: 0.9686
Epoch 2/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1178 - binary_accuracy: 0.9511 - val_loss: 0.0785 - val_binary_accuracy: 0.9695
Epoch 3/20
291/291 [==============================] - 9s 30ms/step - loss: 0.1121 - binary_accuracy: 0.9536 - val_loss: 0.0748 - val_binary_accuracy: 0.9712
Epoch 4/20
291/291 [==============================] - 9s 29ms/step - loss: 0.1082 - binary_accuracy: 0.9554 - val_loss: 0.0754 - val_binary_accuracy: 0.9703
Epoch 5/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1034 - binary_accuracy: 0.9570 - val_loss: 0.0721 - val_binary_accuracy: 0.9725
Epoch 6/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0975 - binary_accuracy: 0.9602 - val_loss: 0.0748 - val_binary_accuracy: 0.9699
Epoch 7/20
291/291 [==============================] - 9s 29ms/step - loss: 0.0989 - binary_accuracy: 0.9595 - val_loss: 0.0732 - val_binary_accuracy: 0.9716
Epoch 8/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1027 - binary_accuracy: 0.9566 - val_loss: 0.0787 - val_binary_accuracy: 0.9678
Epoch 9/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0959 - binary_accuracy: 0.9614 - val_loss: 0.0734 - val_binary_accuracy: 0.9729
Epoch 10/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0995 - binary_accuracy: 0.9588 - val_loss: 0.0717 - val_binary_accuracy: 0.9721
Epoch 11/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0957 - binary_accuracy: 0.9612 - val_loss: 0.0731 - val_binary_accuracy: 0.9725
Epoch 12/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0936 - binary_accuracy: 0.9622 - val_loss: 0.0751 - val_binary_accuracy: 0.9716
Epoch 13/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0965 - binary_accuracy: 0.9610 - val_loss: 0.0821 - val_binary_accuracy: 0.9695
Epoch 14/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0939 - binary_accuracy: 0.9618 - val_loss: 0.0742 - val_binary_accuracy: 0.9712
Epoch 15/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0974 - binary_accuracy: 0.9585 - val_loss: 0.0771 - val_binary_accuracy: 0.9712
Epoch 16/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9621 - val_loss: 0.0823 - val_binary_accuracy: 0.9699
Epoch 17/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9625 - val_loss: 0.0718 - val_binary_accuracy: 0.9708
Epoch 18/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0928 - binary_accuracy: 0.9616 - val_loss: 0.0738 - val_binary_accuracy: 0.9716
Epoch 19/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0922 - binary_accuracy: 0.9644 - val_loss: 0.0743 - val_binary_accuracy: 0.9716
Epoch 20/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0885 - binary_accuracy: 0.9635 - val_loss: 0.0745 - val_binary_accuracy: 0.9695
<keras.callbacks.History at 0x7f849a3b2950>

Lakukan putaran fine-tuning seluruh model

Terakhir, mari kita mencairkan model dasar dan melatih seluruh model dari ujung ke ujung dengan kecepatan belajar yang rendah.

Yang penting, meskipun model dasar menjadi dilatih, masih berjalan dalam modus kesimpulan karena kita melewati training=False ketika menyebutnya ketika kita membangun model. Ini berarti bahwa lapisan normalisasi batch di dalamnya tidak akan memperbarui statistik batchnya. Jika mereka melakukannya, mereka akan merusak representasi yang dipelajari oleh model sejauh ini.

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 20,809,001
Non-trainable params: 54,528
_________________________________________________________________
Epoch 1/10
291/291 [==============================] - 43s 131ms/step - loss: 0.0802 - binary_accuracy: 0.9692 - val_loss: 0.0580 - val_binary_accuracy: 0.9764
Epoch 2/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0542 - binary_accuracy: 0.9792 - val_loss: 0.0529 - val_binary_accuracy: 0.9764
Epoch 3/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0400 - binary_accuracy: 0.9832 - val_loss: 0.0510 - val_binary_accuracy: 0.9798
Epoch 4/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0313 - binary_accuracy: 0.9879 - val_loss: 0.0505 - val_binary_accuracy: 0.9819
Epoch 5/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0272 - binary_accuracy: 0.9904 - val_loss: 0.0485 - val_binary_accuracy: 0.9807
Epoch 6/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0284 - binary_accuracy: 0.9901 - val_loss: 0.0497 - val_binary_accuracy: 0.9824
Epoch 7/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0198 - binary_accuracy: 0.9937 - val_loss: 0.0530 - val_binary_accuracy: 0.9802
Epoch 8/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0173 - binary_accuracy: 0.9930 - val_loss: 0.0572 - val_binary_accuracy: 0.9819
Epoch 9/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0113 - binary_accuracy: 0.9958 - val_loss: 0.0555 - val_binary_accuracy: 0.9837
Epoch 10/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0091 - binary_accuracy: 0.9966 - val_loss: 0.0596 - val_binary_accuracy: 0.9832
<keras.callbacks.History at 0x7f83982d4cd0>

Setelah 10 epoch, fine-tuning memberi kami peningkatan yang bagus di sini.