כוונון מבוזר עם Gemma באמצעות Keras

הצגה ב-ai.google.dev הפעלה ב-Google Colab הרצה ב-Kaggle פתיחה ב-Vertex AI הצגת המקור ב-GitHub

סקירה כללית

Gemma היא משפחה של מודלים פתוחים וקלים לשימוש, שנוצרו על סמך המחקר והטכנולוגיה ששימשו ליצירת המודלים של Google Gemini. אפשר לשפר את Gemma בהתאם לצרכים הספציפיים שלכם. עם זאת, מודלים גדולים של שפה, כמו Gemma, יכולים להיות גדולים מאוד, וחלק מהם לא יכולים להתאים למאיץ יחיד לצורך כוונון מדויק. במקרה כזה, יש שתי גישות כלליות לשיפור שלהן:

  1. 'כוונון יעיל בפרמטרים' (PEFT), שמטרתו לצמצם את גודל המודל היעיל על ידי ויתור על חלק מהרזולוציה. LoRA נכללת בקטגוריה הזו, ובמדריך ביצוע שיפורים ועדכונים למודלים של Gemma ב-Keras באמצעות LoRA מוסבר איך לבצע שיפורים ועדכונים למודל Gemma 2B‏ gemma_2b_en באמצעות LoRA באמצעות KerasNLP ב-GPU יחיד.
  2. כוונון מדויק מלא של הפרמטרים באמצעות מודל מקבילי. במקרה של מודל מקבילי, המשקלים של מודל אחד מופצים בין כמה מכשירים ומאפשרים התאמה רוחבית. מידע נוסף על הכשרות מבוזרות זמין במדריך הזה של Keeras.

במדריך הזה תלמדו איך להשתמש ב-Keras עם קצה עורפי של JAX כדי לשפר את המודל Gemma 7B באמצעות LoRA ואימון מבוזבז של מודל-מקביליות ביחידה לעיבוד נתונים (TPU) של Google. לתשומת ליבכם: אפשר להשבית את LoRA במדריך הזה כדי לבצע כוונון איטי ומדויק יותר של פרמטר מלא.

שימוש במאיצים

באופן טכני, אפשר להשתמש ב-TPU או ב-GPU במדריך הזה.

הערות על סביבות TPU

ל-Google יש 3 מוצרים שמספקים מערכות TPU:

  • Colab מספק TPU v2 בחינם, וזה מספיק למדריך הזה.
  • ב-Kaggle אפשר להשתמש ב-TPU v3 בחינם, והם מתאימים גם למדריך הזה.
  • Cloud TPU מציע TPU v3 ודורות חדשים יותר. אחת הדרכים להגדיר את האפשרות הזו היא:
    1. יצירת מכונה וירטואלית של TPU
    2. מגדירים העברה ליציאה אחרת של יציאת SSH ליציאה של שרת Jupyter הרצויה.
    3. מתקינים את Jupyter ומפעילים אותו במכונה הווירטואלית של ה-TPU, ואז מתחברים ל-Colab דרך 'התחברות לסביבת זמן ריצה מקומית'

הערות לגבי הגדרה של כמה מעבדי GPU

המדריך הזה מתמקד בתרחיש לדוגמה של TPU, אבל אם יש לכם מכונה עם כמה יחידות GPU, תוכלו להתאים אותו בקלות לצרכים שלכם.

אם אתם מעדיפים לעבוד דרך Colab, אפשר גם להקצות מכונה וירטואלית עם כמה מעבדי GPU ל-Colab ישירות דרך האפשרות 'התחברות למכונה וירטואלית מותאמת אישית של GCE' בתפריט Colab Connect.

כאן נתמקד בשימוש ב-TPU בחינם מ-Kaggle.

לפני שמתחילים

פרטי הכניסה ל-Kaggle

המודלים של Gemma מתארחים ב-Kaggle. כדי להשתמש ב-Gemma, צריך לבקש גישה ב-Kaggle:

  • נכנסים לחשבון או נרשמים באתר kaggle.com.
  • פותחים את כרטיס המודל של Gemma ובוחרים באפשרות 'בקשת גישה'
  • ממלאים את טופס ההסכמה ומאשרים את התנאים וההגבלות

לאחר מכן, כדי להשתמש ב-Kaggle API, יוצרים אסימון API:

  • פותחים את הגדרות Kaggle
  • בוחרים באפשרות Create New Token (יצירת אסימון חדש).
  • מתבצעת הורדה של קובץ kaggle.json. הוא מכיל את פרטי הכניסה שלכם ל-Kaggle

מריצים את התא הבא ומזינים את פרטי הכניסה ל-Kaggle כשמתבקשים.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

אם הפונקציה kagglehub.login() לא פועלת, אפשר להגדיר את KAGGLE_USERNAME ו-KAGGLE_KEY בסביבה.

התקנה

התקנת Keras ו-KerasNLP עם מודל Gemma.

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

הגדרת הקצה העורפי של Keras JAX

מייבאים את JAX ומריצים בדיקת תקינות ב-TPU. ב-Kaggle יש מכשירי TPUv3-8 עם 8 ליבות TPU בנפח זיכרון של 16GB כל אחת.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

טעינת מודל

import keras
import keras_nlp

הערות על אימון ברמת דיוק משולבת ב-GPU של NVIDIA

כשמתאמנים על מעבדי NVIDIA GPU, אפשר להשתמש ברמת דיוק משולבת (keras.mixed_precision.set_global_policy('mixed_bfloat16')) כדי לזרז את האימון תוך השפעה מינימלית על איכות האימון. ברוב המקרים, מומלץ להפעיל דיוק משולב כי זה חוסך גם זיכרון וגם זמן. עם זאת, חשוב לזכור שבקבוצות קטנות, הדבר עלול להגדיל את השימוש בזיכרון פי 1.5 (המשקלים יוטמעו פעמיים, ברמת דיוק של חצי ושל דיוק מלא).

להסקה, דיוק חצי (keras.config.set_floatx("bfloat16")) יפעל ויחסוך זיכרון, בעוד שדיוק מעורב לא רלוונטי.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

כדי לטעון את המודל עם המשקולות והטנזורים שמפוזרים בין מעבדי TPU, קודם צריך ליצור DeviceMesh חדש. DeviceMesh מייצג אוסף של מכשירי חומרה שהוגדרו לחישוב מבוזבז, והוא הוצג ב-Keras 3 כחלק מ-API המאוחד לחלוקה.

ה-API של ההפצה מאפשר מקבילה של נתונים ומודלים, ומאפשר התאמה לעומס (scaling) של מודלים של למידה עמוקה באופן יעיל בכמה מאיצים ומארחים. הוא משתמש ב-framework הבסיסי (למשל, JAX) כדי להפיץ את התוכנה ואת הפרמטרים tensors בהתאם להנחיות הפיצול (shard) באמצעות תהליך שנקרא תוכנית יחידה, הרחבת נתונים מרובים (SPMD). פרטים נוספים זמינים במדריך החדש של Keras 3 לספק API.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

השדה LayoutMap מ-distribution API מציין איך צריך לפצל או לשכפל את המשקלים והטנסורים, באמצעות מפתחות המחרוזות. לדוגמה, token_embedding/embeddings שבהמשך, שמטופלים כמו ביטוי רגולרי כדי להתאים לנתיבי הטנסורים. טינסורים מותאמים מחולקים לפי מאפייני מודל (8 TPUs), ואחרים ישוחזרו במלואם.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

באמצעות ModelParallel אפשר לפצל את משקולות המודל או את רכיבי ההפעלה של Tensors בין כל הסטיות ב-DeviceMesh. במקרה כזה, חלק ממשקלי המודל של Gemma 7B מחולקים ל-8 צ'יפים של TPU בהתאם ל-layout_map שהוגדר למעלה. עכשיו טוענים את המודל באופן מבוזר.

model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

עכשיו מוודאים שהמודל חולק למחיצות בצורה נכונה. ניקח לדוגמה את decoder_block_1.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

הסקת מסקנות לפני כוונון

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

המודל יוצר רשימה של סרטי קומדיה נהדרים משנות ה-90 של המאה ה-20. עכשיו אנחנו משפרים את המודל של Gemma כדי לשנות את סגנון הפלט.

שיפור נוסף באמצעות IMDB

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

ביצוע כוונון עדין באמצעות התאמה של דירוג נמוך (LoRA). LoRA היא טכניקה של כוונון עדין שמפחיתה באופן משמעותי את מספר הפרמטרים שניתן לאמן למשימות במורד הזרם. לשם כך, היא מקפיאה את המשקלים המלאים של המודל ומוסיפה למודל מספר קטן יותר של משקלים חדשים שניתן לאמן. בעיקרון, LoRA משנה את הפרמטרים של מטריצות המשקל המלאות הגדולות יותר באמצעות 2 מטריצות קטנות יותר בעלות דירוג נמוך יותר (AxB) כדי לאמן את המערכת. הטכניקה הזו מאפשרת לבצע אימון מהר יותר ויעיל יותר מבחינת שימוש בזיכרון.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

חשוב לזכור שהפעלת LoRA מפחיתה באופן משמעותי את מספר הפרמטרים שאפשר לאמן, מ-7 מיליארד ל-11 מיליון בלבד.

הסקת מסקנות אחרי כוונון עדין

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

אחרי כוונון מדויק, המודל למד את הסגנון של ביקורות על סרטים, והוא יוצר עכשיו פלט בסגנון הזה בהקשר של סרטי קומדיה משנות ה-90.

המאמרים הבאים

במדריך הזה למדתם איך להשתמש בקצה העורפי של KerasNLP JAX כדי לבצע שיפורים למודל Gemma במערך הנתונים של IMDb באופן מבוזבז על המעבדים החזקים מסוג TPU. ריכזנו כאן כמה הצעות לנושאים נוספים שאפשר ללמוד: