مدل های جما را در Keras با استفاده از LoRA تنظیم کنید

مشاهده در ai.google.dev در Google Colab اجرا شود در Vertex AI باز کنید مشاهده منبع در GitHub

نمای کلی

Gemma خانواده ای از مدل های سبک وزن و مدرن است که از همان تحقیقات و فناوری استفاده شده برای ایجاد مدل های Gemini ساخته شده است.

نشان داده شده است که مدل های زبان بزرگ (LLM) مانند Gemma در انواع وظایف NLP موثر هستند. یک LLM ابتدا بر روی مجموعه بزرگی از متن به صورت خود نظارتی از قبل آموزش داده می شود. پیش‌آموزش به LLMها کمک می‌کند تا دانش عمومی، مانند روابط آماری بین کلمات را بیاموزند. سپس یک LLM را می توان با داده های دامنه خاص برای انجام وظایف پایین دستی (مانند تجزیه و تحلیل احساسات) تنظیم کرد.

LLMها از نظر اندازه بسیار بزرگ هستند (پارامترها در حد میلیاردها). تنظیم دقیق کامل (که تمام پارامترهای مدل را به روز می کند) برای اکثر برنامه ها مورد نیاز نیست زیرا مجموعه داده های تنظیم دقیق معمولی نسبتاً کوچکتر از مجموعه داده های قبل از آموزش هستند.

انطباق با رتبه پایین (LoRA) یک تکنیک تنظیم دقیق است که تعداد پارامترهای قابل آموزش برای کارهای پایین دستی را با انجماد وزن های مدل و درج تعداد کمتری وزنه های جدید در مدل به میزان زیادی کاهش می دهد. این باعث می شود که آموزش با LoRA بسیار سریعتر و حافظه کارآمدتر باشد، و وزن مدل کوچکتر (چند صد مگابایت) تولید می شود، همه اینها با حفظ کیفیت خروجی های مدل.

این آموزش شما را با استفاده از KerasNLP برای انجام تنظیم دقیق LoRA بر روی یک مدل Gemma 2B با استفاده از مجموعه داده Databricks Dolly 15k راهنمایی می کند. این مجموعه داده شامل 15000 جفت اعلان / پاسخ با کیفیت بالا است که به طور خاص برای تنظیم دقیق LLM طراحی شده است.

راه اندازی

به Gemma دسترسی پیدا کنید

برای تکمیل این آموزش، ابتدا باید دستورالعمل‌های راه‌اندازی را در Gemma setup تکمیل کنید. دستورالعمل های راه اندازی Gemma به شما نشان می دهد که چگونه کارهای زیر را انجام دهید:

  • در kaggle.com به Gemma دسترسی پیدا کنید.
  • یک زمان اجرا Colab با منابع کافی برای اجرای مدل Gemma 2B انتخاب کنید.
  • نام کاربری و کلید API Kaggle را ایجاد و پیکربندی کنید.

پس از تکمیل تنظیمات Gemma، به بخش بعدی بروید، جایی که متغیرهای محیطی را برای محیط Colab خود تنظیم خواهید کرد.

زمان اجرا را انتخاب کنید

برای تکمیل این آموزش، باید یک زمان اجرا Colab با منابع کافی برای اجرای مدل Gemma داشته باشید. در این مورد، می توانید از یک GPU T4 استفاده کنید:

  1. در سمت راست بالای پنجره Colab، ▾ ( گزینه های اتصال اضافی ) را انتخاب کنید.
  2. تغییر نوع زمان اجرا را انتخاب کنید.
  3. در بخش شتاب دهنده سخت افزار ، GPU T4 را انتخاب کنید.

کلید API خود را پیکربندی کنید

برای استفاده از Gemma، باید نام کاربری Kaggle و یک کلید Kaggle API ارائه دهید.

برای ایجاد یک کلید Kaggle API، به تب Account پروفایل کاربری Kaggle خود بروید و Create New Token را انتخاب کنید. با این کار دانلود فایل kaggle.json حاوی اطلاعات کاربری API شما راه اندازی می شود.

در Colab، Secrets (🔑) را در قسمت سمت چپ انتخاب کنید و نام کاربری Kaggle و کلید Kaggle API را اضافه کنید. نام کاربری خود را با نام KAGGLE_USERNAME و کلید API خود را با نام KAGGLE_KEY ذخیره کنید.

تنظیم متغیرهای محیطی

متغیرهای محیطی را برای KAGGLE_USERNAME و KAGGLE_KEY تنظیم کنید.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

وابستگی ها را نصب کنید

Keras، KerasNLP و سایر وابستگی ها را نصب کنید.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"

یک باطن انتخاب کنید

Keras یک API یادگیری عمیق چند چارچوبی و سطح بالا است که برای سادگی و سهولت استفاده طراحی شده است. با استفاده از Keras 3، می‌توانید گردش‌های کاری را روی یکی از سه Backend اجرا کنید: TensorFlow، JAX یا PyTorch.

برای این آموزش، Backend را برای JAX پیکربندی کنید.

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

بسته های وارداتی

Keras و KerasNLP را وارد کنید.

import keras
import keras_nlp

بارگذاری مجموعه داده

wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
--2024-07-31 01:56:39--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 18.164.174.23, 18.164.174.17, 18.164.174.55, ...
Connecting to huggingface.co (huggingface.co)|18.164.174.23|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7 [following]
--2024-07-31 01:56:39--  https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 18.154.206.4, 18.154.206.17, 18.154.206.28, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|18.154.206.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13085339 (12M) [text/plain]
Saving to: ‘databricks-dolly-15k.jsonl’

databricks-dolly-15 100%[===================>]  12.48M  73.7MB/s    in 0.2s    

2024-07-31 01:56:40 (73.7 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]

داده ها را از قبل پردازش کنید. این آموزش از زیر مجموعه ای از 1000 مثال آموزشی برای اجرای سریعتر نوت بوک استفاده می کند. استفاده از داده های آموزشی بیشتر را برای تنظیم دقیق با کیفیت بالاتر در نظر بگیرید.

import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

مدل بارگذاری

KerasNLP پیاده سازی بسیاری از معماری های مدل محبوب را ارائه می دهد. در این آموزش، یک مدل با استفاده از GemmaCausalLM ، یک مدل Gemma سرتاسر برای مدل‌سازی زبان علی ایجاد می‌کنید. یک مدل زبان علی، نشانه بعدی را بر اساس نشانه های قبلی پیش بینی می کند.

مدل را با استفاده از متد from_preset ایجاد کنید:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

متد from_preset مدل را از یک معماری و وزن از پیش تعیین شده نمونه سازی می کند. در کد بالا، رشته "gemma2_2b_en" معماری از پیش تعیین شده را مشخص می کند - یک مدل Gemma با 2 میلیارد پارامتر.

استنتاج قبل از تنظیم دقیق

در این بخش، مدل را با اعلان های مختلف پرس و جو می کنید تا ببینید چگونه پاسخ می دهد.

درخواست سفر اروپا

برای پیشنهادات در مورد اقداماتی که در سفر به اروپا باید انجام دهید، مدل را جویا شوید.

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
If you have any special needs, you should contact the embassy of the country that you are visiting.
You should contact the embassy of the country that I will be visiting.

What are my responsibilities when I go on a trip?

Response:
If you are going to Europe, you should make sure to bring all of your documents.
If you are going to Europe, make sure that you have all of your documents.

When do you travel abroad?

Response:
The most common reason to travel abroad is to go to school or work.
The most common reason to travel abroad is to work.

How can I get a visa to Europe?

Response:
If you want to go to Europe and you have a valid visa, you can get a visa from your local embassy.
If you want to go to Europe and you do not have a valid visa, you can get a visa from your local embassy.

When should I go to Europe?

Response:
You should go to Europe when the weather is nice.
You should go to Europe when the weather is bad.

How can I make a reservation for a trip?

این مدل با نکات کلی در مورد نحوه برنامه ریزی یک سفر پاسخ می دهد.

درخواست فتوسنتز ELI5

از مدل بخواهید که فتوسنتز را با عباراتی به اندازه کافی ساده برای درک یک کودک 5 ساله توضیح دهد.

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Plants need water, air, sunlight, and carbon dioxide. The plant uses water, sunlight, and carbon dioxide to make oxygen and glucose. The process is also known as photosynthesis.

Instruction:
What is the process of photosynthesis in a plant's cells? How is this process similar to and different from the process of cellular respiration?

Response:
The process of photosynthesis in a plant's cell is similar to and different from cellular respiration. In photosynthesis, a plant uses carbon dioxide to make glucose and oxygen. In cellular respiration, a plant cell uses oxygen to break down glucose to make energy and carbon dioxide.

Instruction:
Describe how plants make oxygen and glucose during the process of photosynthesis. Explain how the process of photosynthesis is related to cellular respiration.

Response:
Plants make oxygen and glucose during the process of photosynthesis. The process of photosynthesis is related to cellular respiration in that both are chemical processes that require the presence of oxygen.

Instruction:
How does photosynthesis occur in the cells of a plant? What is the purpose for each part of the cell?

Response:
Photosynthesis occurs in the cells of a plant. The purpose of

پاسخ مدل حاوی کلماتی است که ممکن است برای کودک آسان نباشد مانند کلروفیل.

تنظیم دقیق LoRA

برای دریافت پاسخ‌های بهتر از مدل، مدل را با انطباق رتبه پایین (LoRA) با استفاده از مجموعه داده Databricks Dolly 15k تنظیم دقیق کنید.

رتبه LoRA ابعاد ماتریس های قابل آموزش را تعیین می کند که به وزن های اصلی LLM اضافه می شوند. بیان و دقت تنظیمات تنظیم دقیق را کنترل می کند.

رتبه بالاتر به این معنی است که تغییرات با جزئیات بیشتر امکان پذیر است، اما همچنین به معنی پارامترهای قابل آموزش بیشتر است. رتبه پایین تر به معنای سربار محاسباتی کمتر، اما به طور بالقوه انطباق دقیق کمتر است.

این آموزش از رتبه LoRA 4 استفاده می کند. در عمل، با یک رتبه نسبتاً کوچک (مانند 4، 8، 16) شروع کنید. این از نظر محاسباتی برای آزمایش کارآمد است. مدل خود را با این رتبه آموزش دهید و بهبود عملکرد را در کار خود ارزیابی کنید. به تدریج رتبه را در آزمایش‌های بعدی افزایش دهید و ببینید که آیا این کار عملکرد را بیشتر می‌کند یا خیر.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

توجه داشته باشید که فعال کردن LoRA تعداد پارامترهای قابل آموزش را به میزان قابل توجهی کاهش می دهد (از 2.6 میلیارد به 2.9 میلیون).

# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)
1000/1000 ━━━━━━━━━━━━━━━━━━━━ 923s 888ms/step - loss: 1.5586 - sparse_categorical_accuracy: 0.5251
<keras.src.callbacks.history.History at 0x799d04393c40>

نکته ای در مورد تنظیم دقیق ترکیبی در پردازنده های گرافیکی NVIDIA

دقت کامل برای تنظیم دقیق توصیه می شود. هنگام تنظیم دقیق پردازنده‌های گرافیکی NVIDIA، توجه داشته باشید که می‌توانید از دقت ترکیبی ( keras.mixed_precision.set_global_policy('mixed_bfloat16') ) برای سرعت بخشیدن به آموزش با حداقل تأثیر بر کیفیت آموزش استفاده کنید. تنظیم دقیق ترکیبی حافظه بیشتری مصرف می کند، بنابراین فقط در پردازنده های گرافیکی بزرگتر مفید است.

برای استنباط، نیم دقت ( keras.config.set_floatx("bfloat16") ) کار می کند و حافظه را ذخیره می کند در حالی که دقت ترکیبی قابل اعمال نیست.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

استنتاج پس از تنظیم دقیق

پس از تنظیم دقیق، پاسخ ها از دستورالعمل ارائه شده در اعلان پیروی می کنند.

درخواست سفر اروپا

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
When planning a trip to Europe, you should consider your budget, time and the places you want to visit. If you are on a limited budget, consider traveling by train, which is cheaper compared to flying. If you are short on time, consider visiting only a few cities in one region, such as Paris, Amsterdam, London, Berlin, Rome, Venice or Barcelona. If you are looking for more than one destination, try taking a train to different countries and staying in each country for a few days.

این مدل اکنون مکان هایی را برای بازدید در اروپا توصیه می کند.

درخواست فتوسنتز ELI5

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
The process of photosynthesis is a chemical reaction in plants that converts the energy of sunlight into chemical energy, which the plants can then use to grow and develop. During photosynthesis, a plant will absorb carbon dioxide (CO2) from the air and water from the soil and use the energy from the sun to produce oxygen (O2) and sugars (glucose) as a by-product.

این مدل اکنون فتوسنتز را به زبان ساده‌تر توضیح می‌دهد.

توجه داشته باشید که برای اهداف نمایشی، این آموزش مدل را در زیر مجموعه کوچکی از مجموعه داده فقط برای یک دوره و با مقدار رتبه LoRA پایین تنظیم می‌کند. برای دریافت پاسخ‌های بهتر از مدل تنظیم‌شده، می‌توانید موارد زیر را آزمایش کنید:

  1. افزایش اندازه مجموعه داده تنظیم دقیق
  2. آموزش مراحل بیشتر (دوران)
  3. تنظیم یک رتبه LoRA بالاتر
  4. اصلاح مقادیر فراپارامتر مانند learning_rate و weight_decay .

خلاصه و مراحل بعدی

این آموزش تنظیم دقیق LoRA را بر روی یک مدل Gemma با استفاده از KerasNLP پوشش می دهد. در ادامه اسناد زیر را بررسی کنید: