Réglage distribué avec Gemma à l'aide de Keras

Afficher sur ai.google.dev Exécuter dans Google Colab Exécuter dans Kaggle Ouvrir dans Vertex AI Consulter le code source sur GitHub

Présentation

Gemma est une famille de modèles ouverts légers et de pointe, élaborés à partir des recherches et des technologies utilisées pour créer des modèles Google Gemini. Gemma peut être affinée pour répondre à des besoins spécifiques. Mais les grands modèles de langage, tels que Gemma, peuvent être très volumineux et certains d'entre eux peuvent ne pas être compatibles avec un accélérateur de réglage. Dans ce cas, il existe deux approches générales pour les affiner:

  1. L'affinage efficace des paramètres (PEFT), qui cherche à réduire la taille effective du modèle en sacrifiant une partie de la fidélité. La LoRA appartient à cette catégorie. Le tutoriel Régler les modèles Gemma dans Keras avec LoRA explique comment affiner le modèle Gemma 2B gemma_2b_en avec LoRA en utilisant KerasNLP sur un seul GPU.
  2. Affinage complet des paramètres avec parallélisme des modèles Le parallélisme des modèles répartit les pondérations d'un seul modèle sur plusieurs appareils et active le scaling horizontal. Pour en savoir plus sur l'entraînement distribué, consultez ce guide Keras.

Ce tutoriel vous explique comment utiliser Keras avec un backend JAX pour affiner le modèle Gemma 7B à l'aide de la technologie LoRA ainsi que de l'entraînement distribué basé sur le parallisme du modèle sur le Tensor Processing Unit (TPU) de Google. Notez que la LoRA peut être désactivée dans ce tutoriel pour un réglage complet des paramètres plus lents mais plus précis.

Utiliser des accélérateurs

Techniquement, vous pouvez utiliser un TPU ou un GPU pour ce tutoriel.

Remarques sur les environnements TPU

Trois produits Google proposent des TPU:

  • Colab fournit sans frais TPU v2, ce qui est suffisant pour ce tutoriel.
  • Kaggle propose TPU v3 sans frais et fonctionne également pour ce tutoriel.
  • Cloud TPU est disponible pour les TPU v3 et les générations plus récentes. Vous pouvez le configurer de la manière suivante: <ph type="x-smartling-placeholder">
      </ph>
    1. créer une VM TPU ;
    2. Configurez le transfert de port SSH pour le port de serveur Jupyter que vous souhaitez utiliser.
    3. Installez Jupyter et démarrez-le sur la VM TPU, puis connectez-vous à Colab via "Se connecter à un environnement d'exécution local"

Remarques sur la configuration multi-GPU

Bien que ce tutoriel porte sur le cas d'utilisation des TPU, vous pouvez facilement l'adapter à vos besoins si vous disposez d'une machine multiGPU.

Si vous préférez travailler via Colab, vous pouvez également provisionner une VM multi-GPU pour Colab directement via "Se connecter à une VM GCE personnalisée". dans le menu Colab Connect.

Nous allons nous concentrer ici sur l'utilisation du TPU sans frais de Kaggle.

Avant de commencer

Identifiants Kaggle

Les modèles Gemma sont hébergés par Kaggle. Pour utiliser Gemma, demandez l'accès sur Kaggle:

  • Connectez-vous ou inscrivez-vous sur kaggle.com
  • Ouvrez la fiche du modèle Gemma, puis sélectionnez Demander l'accès.
  • Remplissez le formulaire de consentement et acceptez les conditions d'utilisation

Ensuite, pour utiliser l'API Kaggle, créez un jeton d'API:

  • Ouvrez les paramètres Kaggle.
  • Sélectionnez Create New Token (Créer un jeton).
  • Un fichier kaggle.json est téléchargé. Il contient vos identifiants Kaggle

Exécutez la cellule suivante et saisissez vos identifiants Kaggle lorsque vous y êtes invité.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Une autre méthode consiste à définir KAGGLE_USERNAME et KAGGLE_KEY dans votre environnement si kagglehub.login() ne fonctionne pas pour vous.

Installation

Installer Keras et KerasNLP avec le modèle Gemma

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

Configurer le backend JAX Keras

Importez JAX et effectuez une évaluation de l'intégrité sur TPU. Kaggle propose des appareils TPUv3-8 dotés de 8 cœurs de TPU avec chacun 16 Go de mémoire.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

Charger le modèle

import keras
import keras_nlp

Remarques sur l'entraînement de précision mixte sur les GPU NVIDIA

Lors de l'entraînement sur des GPU NVIDIA, la précision mixte (keras.mixed_precision.set_global_policy('mixed_bfloat16')) peut être utilisée pour accélérer l'entraînement avec un impact minimal sur la qualité. Dans la plupart des cas, nous vous recommandons d'activer la précision mixte, car elle permet d'économiser de la mémoire et du temps. Toutefois, sachez qu'avec des lots de petite taille, l'utilisation de la mémoire peut être multipliée par 1,5 (les pondérations seront chargées deux fois, avec une demi-précision et une précision totale).

Pour l'inférence, la demi-précision (keras.config.set_floatx("bfloat16")) fonctionne et permet d'économiser de la mémoire, tandis que la précision mixte n'est pas applicable.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Pour charger le modèle avec les pondérations et les Tensors répartis sur les TPU, commencez par créer un DeviceMesh. DeviceMesh représente un ensemble de périphériques configurés pour le calcul distribué et a été introduit dans Keras 3 dans le cadre de l'API de distribution unifiée.

L'API de distribution permet le parallélisme des données et des modèles, ce qui permet un scaling efficace des modèles de deep learning sur plusieurs accélérateurs et hôtes. Il exploite le framework sous-jacent (par exemple, JAX) pour distribuer le programme et les Tensors en fonction des directives de segmentation par le biais d'une procédure appelée expansion SPMD (Single Program, Multiple Data). Pour en savoir plus, consultez le nouveau guide de l'API de distribution Keras 3.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

LayoutMap de l'API de distribution spécifie comment les pondérations et les Tensors doivent être segmentés ou répliqués à l'aide des clés de chaîne (par exemple, token_embedding/embeddings ci-dessous), qui sont traitées comme des expressions régulières pour correspondre aux chemins d'accès des Tensors. Les Tensors mis en correspondance sont segmentés avec les dimensions du modèle (8 TPU). d'autres seront entièrement répliquées.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

ModelParallel vous permet de segmenter les pondérations du modèle ou les Tensors d'activation sur toutes les décimales de DeviceMesh. Dans ce cas, certaines pondérations du modèle Gemma 7B sont segmentées sur huit puces TPU conformément à la layout_map définie ci-dessus. À présent, chargez le modèle de manière distribuée.

model_parallel = keras.distribution.ModelParallel(
    device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

À présent, vérifiez que le modèle a été correctement partitionné. Prenons decoder_block_1 comme exemple.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

Inférence avant réglage

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

Le modèle génère une liste de grands films de comédie des années 90 à regarder. Nous allons maintenant affiner le modèle Gemma pour modifier le style de sortie.

Finaliser avec IMDB

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

Effectuez des réglages à l'aide de la fonctionnalité LoRA (Low Rank Adaptation). La LoRA est une technique d'affinage qui réduit considérablement le nombre de paramètres pouvant être entraînés pour les tâches en aval en gelant toutes les pondérations du modèle et en insérant un plus petit nombre de nouvelles pondérations pouvant être entraînées dans le modèle. En gros, LoRA reparamètre les matrices de pondération complète les plus volumineuses par 2 plus petites matrices de faible rang AxB à entraîner. Cette technique rend l'entraînement beaucoup plus rapide et plus efficace en mémoire.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

Notez que l'activation de la LoRA réduit considérablement le nombre de paramètres pouvant être entraînés, passant de 7 milliards à seulement 11 millions.

Inférence après réglage

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

Après l'affinage, le modèle a appris le style des critiques de films et génère maintenant des résultats dans ce style pour les comédies des années 90.

Étape suivante

Dans ce tutoriel, vous avez appris à utiliser le backend JAX KerasNLP pour affiner un modèle Gemma sur l'ensemble de données IMDb de manière distribuée sur les puissants TPU. Voici quelques suggestions d'autres points à retenir: