Keras를 사용하여 Gemma를 사용한 분산 조정

ai.google.dev에서 보기 Google Colab에서 실행 Kaggle에서 실행하기 Vertex AI에서 열기 GitHub에서 소스 보기

개요

Gemma는 Google Gemini 모델을 만드는 데 사용되는 연구와 기술로 빌드된 최첨단 경량 개방형 모델군입니다. Gemma는 특정 니즈에 맞게 더욱 미세하게 조정할 수 있습니다. 하지만 Gemma와 같은 대규모 언어 모델은 크기가 매우 클 수 있으며 일부는 미세 조정을 위한 sing 가속기에 맞지 않을 수 있습니다. 이 경우 미세 조정하는 일반적인 두 가지 방법이 있습니다.

  1. 일부 충실도를 희생하여 유효 모델 크기를 줄이는 매개변수 효율적 미세 조정 (PEFT) LoRA가 이 카테고리에 속합니다. LoRA를 사용하여 Keras에서 Gemma 모델 세부 조정 튜토리얼에서는 단일 GPU에서 KerasNLP를 사용하여 LoRA로 Gemma 2B 모델 gemma_2b_en를 미세 조정하는 방법을 보여줍니다.
  2. 모델 병렬 처리를 사용한 전체 매개변수 미세 조정 모델 동시 로드는 단일 모델의 가중치를 여러 기기에 분산하고 가로 확장을 사용 설정합니다. 분산 학습에 대한 자세한 내용은 이 Keras 가이드를 참조하세요.

이 튜토리얼에서는 JAX 백엔드와 함께 Keras를 사용하여 LoRA로 Gemma 7B 모델을 미세 조정하고 Google의 Tensor Processing Unit (TPU)에서 모델-패럴리즘 분산 학습을 사용하는 방법을 안내합니다. 이 튜토리얼에서는 속도가 느리지만 더 정확한 전체 매개변수 조정을 위해 LoRA를 사용 중지할 수 있습니다.

가속기 사용

이 튜토리얼에서는 TPU 또는 GPU 중 무엇이든 사용할 수 있습니다.

TPU 환경에 관한 참고사항

Google에는 TPU를 제공하는 3가지 제품이 있습니다.

  • Colab은 TPU v2를 무료로 제공하므로 이 튜토리얼을 진행하기에 충분합니다.
  • Kaggle은 TPU v3를 무료로 제공하며 이 튜토리얼에서도 작동합니다.
  • Cloud TPU는 TPU v3 및 이후 버전을 제공합니다. 설정 방법은 다음과 같습니다.
    1. TPU VM 만들기
    2. 원하는 Jupyter 서버 포트에 SSH 포트 전달을 설정합니다.
    3. Jupyter를 설치하고 TPU VM에서 시작한 다음 '로컬 런타임에 연결'을 통해 Colab에 연결하세요.

멀티 GPU 설정에 관한 참고사항

이 튜토리얼에서는 TPU 사용 사례를 중점적으로 다루지만 다중 GPU 머신을 사용하는 경우 사용자의 필요에 맞게 쉽게 조정할 수 있습니다.

Colab을 통해 작업하려는 경우 Colab Connect 메뉴의 '맞춤 GCE VM에 연결'을 통해 Colab용 멀티 GPU VM을 직접 프로비저닝할 수도 있습니다.

여기서는 Kaggle의 무료 TPU 사용에 중점을 둡니다.

시작하기 전에

Kaggle 사용자 인증 정보

Gemma 모델은 Kaggle에서 호스팅됩니다. Gemma를 사용하려면 Kaggle에서 액세스 권한을 요청하세요.

  • kaggle.com에서 로그인 또는 등록
  • Gemma 모델 카드를 열고 '액세스 요청'을 선택합니다.
  • 동의 양식을 작성하고 이용약관에 동의합니다.

그런 다음 Kaggle API를 사용하기 위해 API 토큰을 만듭니다.

  • Kaggle 설정을 엽니다.
  • '새 토큰 만들기'를 선택합니다.
  • kaggle.json 파일이 다운로드됩니다. Kaggle 사용자 인증 정보가 포함되어 있습니다.

다음 셀을 실행하고 메시지가 표시되면 Kaggle 사용자 인증 정보를 입력합니다.

# If you are using Kaggle, you don't need to login again.
!pip install ipywidgets
import kagglehub

kagglehub.login()
VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

kagglehub.login()이 작동하지 않는 경우 환경에서 KAGGLE_USERNAME 및 KAGGLE_KEY를 설정하는 것도 방법입니다.

설치

Gemma 모델과 함께 Keras 및 KerasNLP를 설치합니다.

pip install -q -U keras-nlp
# Work around an import error with tensorflow-hub. The library is not used.
pip install -q -U tensorflow-hub
# Install tensorflow-cpu so tensorflow does not attempt to access the TPU.
pip install -q -U tensorflow-cpu tensorflow-text
# Install keras 3 last. See https://keras.io/getting_started for details.
pip install -q -U keras

Keras JAX 백엔드 설정

JAX를 가져오고 TPU에서 상태 검사를 실행합니다. Kaggle은 메모리가 각각 16GB인 TPU 코어 8개가 있는 TPUv3-8 기기를 제공합니다.

import jax

jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
import os

# The Keras 3 distribution API is only implemented for the JAX backend for now
os.environ["KERAS_BACKEND"] = "jax"
# Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation
# overhead
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

모델 로드

import keras
import keras_nlp

NVIDIA GPU의 혼합 정밀도 학습 관련 참고사항

NVIDIA GPU에서 학습할 때 혼합 정밀도 (keras.mixed_precision.set_global_policy('mixed_bfloat16'))를 사용하면 학습 품질에 미치는 영향이 최소화된 상태에서 학습 속도를 높일 수 있습니다. 대부분의 경우 메모리와 시간을 모두 절약할 수 있으므로 혼합 정밀도를 사용 설정하는 것이 좋습니다. 그러나 배치 크기가 작으면 메모리 사용량이 1.5배 늘어날 수 있습니다 (가중치가 절반 정밀도와 전체 정밀도로 두 번 로드됨).

추론의 경우 혼합 정밀도는 적용되지 않지만 절반 정밀도 (keras.config.set_floatx("bfloat16"))가 작동하여 메모리를 절약합니다.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

TPU에 분산된 가중치와 텐서가 포함된 모델을 로드하려면 먼저 새 DeviceMesh를 만듭니다. DeviceMesh는 분산 컴퓨팅용으로 구성된 하드웨어 기기 모음을 나타내며 통합 배포 API의 일부로 Keras 3에 도입되었습니다.

distribution API는 데이터 및 모델 병렬화를 지원하므로 여러 가속기와 호스트에서 딥 러닝 모델을 효율적으로 확장할 수 있습니다. 기본 프레임워크 (예: JAX)를 활용하여 단일 프로그램, 다중 데이터 (SPMD) 확장이라는 절차를 통해 샤딩 지시문에 따라 프로그램과 텐서를 배포합니다. 자세한 내용은 새로운 Keras 3 배포 API 가이드를 참고하세요.

# Create a device mesh with (1, 8) shape so that the weights are sharded across
# all 8 TPUs.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

distribution API의 LayoutMap는 문자열 키(예: 아래 token_embedding/embeddings)를 사용하여 가중치와 텐서를 샤딩하거나 복제하는 방법을 지정합니다. 이 문자열 키는 정규식으로 취급되어 텐서 경로와 일치합니다. 일치하는 텐서는 모델 크기 (8개 TPU)로 샤딩됩니다. 나머지는 완전히 복제됩니다.

model_dim = "model"

layout_map = keras.distribution.LayoutMap(device_mesh)

# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (model_dim, None)
# Regex to match against the query, key and value matrices in the decoder
# attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
    model_dim, None, None)

layout_map["decoder_block.*attention_output.*kernel"] = (
    model_dim, None, None)
layout_map["decoder_block.*ffw_gating.*kernel"] = (None, model_dim)
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, None)

ModelParallel를 사용하면 DeviceMesh의 모든 기기에서 모델 가중치 또는 활성화 텐서를 샤딩할 수 있습니다. 이 경우 일부 Gemma 7B 모델 가중치는 위에 정의된 layout_map에 따라 8개의 TPU 칩에 샤딩됩니다. 이제 분산된 방식으로 모델을 로드합니다.

model_parallel = keras.distribution.ModelParallel(
    layout_map=layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_7b_en/1' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.

이제 모델이 올바르게 파티션된지 확인합니다. decoder_block_1를 예로 들어 보겠습니다.

decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
  print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')
<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (24576, 3072)     PartitionSpec(None, 'model')

미세 조정 전 추론

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
'Best comedy movies in the 90s 1. The Naked Gun 2½: The Smell of Fear (1991) 2. Wayne’s World (1992) 3. The Naked Gun 33⅓: The Final Insult (1994)'

모델은 시청할 만한 90년대 코미디 영화 목록을 생성합니다. 이제 Gemma 모델을 미세 조정하여 출력 스타일을 변경합니다.

IMDB로 미세 조정

import tensorflow_datasets as tfds

imdb_train = tfds.load(
    "imdb_reviews",
    split="train",
    as_supervised=True,
    batch_size=2,
)
# Drop labels.
imdb_train = imdb_train.map(lambda x, y: x)

imdb_train.unbatch().take(1).get_single_element().numpy()
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s]
Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]
Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-train.tfrecord…
Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-test.tfrecord*…
Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]
Shuffling /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incompleteAJDUZT/imdb_reviews-unsupervised.t…
Dataset imdb_reviews downloaded and prepared to /root/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it."
# Use a subset of the dataset for faster training.
imdb_train = imdb_train.take(2000)

Low Rank Adaptation (LoRA)을 사용하여 미세 조정합니다. LoRA는 모델의 전체 가중치를 동결하고 더 적은 수의 새로운 학습 가능한 가중치를 모델에 삽입하여 다운스트림 작업의 학습 가능한 매개변수 수를 크게 줄이는 미세 조정 기법입니다. 기본적으로 LoRA는 더 큰 전체 가중치 행렬을 더 작은 하위 순위 행렬 AxB 2개로 재매개변수화하여 학습하며, 이 기법을 사용하면 학습 속도가 훨씬 빨라지고 메모리 효율이 높아집니다.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
# Fine-tune on the IMDb movie reviews dataset.

# Limit the input sequence length to 128 to control memory usage.
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.summary()
gemma_lm.fit(imdb_train, epochs=1)
/usr/local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:756: UserWarning: Some donated buffers were not usable: ShapedArray(float32[256000,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,384,256]), ShapedArray(float32[16,256,384]), ShapedArray(float32[384,24576]), ShapedArray(float32[384,24576]), ShapedArray(float32[24576,384]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer_donation.
  warnings.warn("Some donated buffers were not usable:"
2000/2000 ━━━━━━━━━━━━━━━━━━━━ 358s 163ms/step - loss: 2.7145 - sparse_categorical_accuracy: 0.4329
<keras.src.callbacks.history.History at 0x7e9cac7f41c0>

LoRA를 사용 설정하면 학습 가능한 매개변수 수가 70억 개에서 1, 100만 개로 크게 줄어듭니다.

미세 조정 후 추론

gemma_lm.generate("Best comedy movies in the 90s ", max_length=64)
"Best comedy movies in the 90s \n\nThis is the movie that made me want to be a director. It's a great movie, and it's still funny today. The acting is superb, the writing is excellent, the music is perfect for the movie, and the story is great."

미세 조정 후 모델은 영화 리뷰의 스타일을 학습했으며 이제 90년대 코미디 영화의 맥락에서 해당 스타일로 출력을 생성합니다.

다음 단계

이 튜토리얼에서는 KerasNLP JAX 백엔드를 사용하여 강력한 TPU에서 분산 방식으로 IMDb 데이터 세트의 Gemma 모델을 미세 조정하는 방법을 알아봤습니다. 다음은 다른 학습 주제에 대한 몇 가지 제안사항입니다.