ดูใน ai.google.dev | เรียกใช้ใน Google Colab | ดูแหล่งที่มาใน GitHub |
เรานำเสนอ CodeGemma ซึ่งเป็นคอลเล็กชันของโมเดลโค้ดแบบเปิดซึ่งอิงตามโมเดล Gemma ของ Google DeepMind (Gemma Team et al., 2024) CodeGemma เป็นชุดโมเดลเปิดที่ทันสมัยและน้ำหนักเบา สร้างขึ้นจากการวิจัยและเทคโนโลยีเดียวกันกับที่ใช้ในการสร้างโมเดล Gemini
โมเดล CodeGemma จะได้รับการฝึกเพิ่มเติมจากโมเดลที่ฝึกไว้แล้วของ Gemma ด้วยโทเค็นโค้ดหลักมากกว่า 500 ถึง 1 แสนล้านโทเค็น โดยใช้ สถาปัตยกรรมแบบเดียวกับกลุ่มโมเดล Gemma ด้วยเหตุนี้ โมเดล CodeGemma จึงมีประสิทธิภาพของโค้ดที่ล้ำสมัยในทั้ง 2 โค้ดนี้ และงานการสร้าง ไปพร้อมๆ กับรักษาความแข็งแรง ทักษะการทำความเข้าใจและการให้เหตุผลในวงกว้าง
CodeGemma มี 3 ตัวแปร:
- โมเดลที่ฝึกด้วยโค้ด 7B ล่วงหน้า
- โมเดลโค้ดที่มีการปรับแต่งตามคำสั่ง 7B
- โมเดล 2B ที่ได้รับการฝึกมาโดยเฉพาะสำหรับการใส่ข้อมูลโค้ดและการสร้างแบบเปิดกว้าง
คู่มือนี้จะแนะนำให้คุณทราบเกี่ยวกับการใช้โมเดล CodeGemma ร่วมกับ Flax ในการจัดทำโค้ด
ตั้งค่า
1. ตั้งค่าการเข้าถึง Kaggle สำหรับ CodeGemma
หากต้องการจบบทแนะนำนี้ ก่อนอื่นคุณต้องทำตามวิธีการตั้งค่าที่การตั้งค่า Gemma ซึ่งแสดงวิธีดำเนินการต่อไปนี้
- รับสิทธิ์เข้าถึง CodeGemma ใน kaggle.com
- เลือกรันไทม์ของ Colab ที่มีทรัพยากรเพียงพอ (GPU T4 มีหน่วยความจำไม่เพียงพอ ใช้ TPU v2 แทน) เพื่อเรียกใช้โมเดล CodeGemma
- สร้างและกำหนดค่าชื่อผู้ใช้และคีย์ API ของ Kaggle
หลังจากตั้งค่า Gemma เสร็จแล้ว ให้ไปยังส่วนถัดไปซึ่งจะตั้งค่าตัวแปรสภาพแวดล้อมสำหรับสภาพแวดล้อม Colab
2. ตั้งค่าตัวแปรสภาพแวดล้อม
ตั้งค่าตัวแปรสภาพแวดล้อมสำหรับ KAGGLE_USERNAME
และ KAGGLE_KEY
เมื่อได้รับข้อความแจ้งผ่านข้อความ "ให้สิทธิ์เข้าถึงไหม" ยอมรับการเข้าถึงข้อมูลลับ
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. ติดตั้งไลบรารี gemma
การเร่งฮาร์ดแวร์ฟรีของ Colab ยังไม่เพียงพอในการเรียกใช้สมุดบันทึกนี้ หากคุณใช้ Colab Pay As You Go หรือ Colab Pro ให้คลิกแก้ไข > การตั้งค่าสมุดบันทึก > เลือก A100 GPU > บันทึกเพื่อเปิดใช้การเร่งฮาร์ดแวร์
ถัดไป คุณต้องติดตั้งไลบรารี Google DeepMind gemma
จาก github.com/google-deepmind/gemma
หากคุณได้รับข้อผิดพลาดเกี่ยวกับ "รีโซลเวอร์ทรัพยากร Dependency ของ PIP" โดยปกติแล้วคุณไม่ต้องสนใจ
pip install -q git+https://github.com/google-deepmind/gemma.git
4. นำเข้าไลบรารี
สมุดบันทึกนี้ใช้ Gemma (ซึ่งใช้ Flax ในการสร้างเลเยอร์โครงข่ายระบบประสาทเทียม) และ SentencePiece (สำหรับการแปลงข้อมูลเป็นโทเค็น)
import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm
โหลดโมเดล CodeGemma
โหลดโมเดล CodeGemma ด้วย kagglehub.model_download
ซึ่งมีอาร์กิวเมนต์ 3 อย่าง ดังนี้
handle
: แฮนเดิลโมเดลจาก Kagglepath
: (สตริงที่ไม่บังคับ) เส้นทางในเครื่องforce_download
: (บูลีนที่ไม่บังคับ) บังคับให้ดาวน์โหลดโมเดลอีกครั้ง
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7) Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download... 100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3
ตรวจสอบตำแหน่งของน้ำหนักโมเดลและเครื่องมือแปลงข้อมูลเป็นโทเค็น จากนั้นตั้งค่าตัวแปรเส้นทาง ไดเรกทอรีโทเคนไลซ์จะอยู่ในไดเรกทอรีหลักที่คุณดาวน์โหลดโมเดลไป ขณะที่น้ำหนักโมเดลจะอยู่ในไดเรกทอรีย่อย เช่น
- ไฟล์ Tokenizer ของ
spm.model
จะอยู่ใน/LOCAL/PATH/TO/codegemma/flax/2b-pt/3
- จุดตรวจสอบโมเดลจะอยู่ใน
/LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model
ทำการสุ่มตัวอย่าง/การอนุมาน
โหลดและจัดรูปแบบจุดตรวจสอบโมเดล CodeGemma ด้วยเมธอด gemma.params.load_and_format_params
params = params_lib.load_and_format_params(CKPT_PATH)
โหลดเครื่องมือแปลงข้อมูลเป็นโทเค็น CodeGemma ซึ่งสร้างขึ้นโดยใช้ sentencepiece.SentencePieceProcessor
:
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
หากต้องการโหลดการกําหนดค่าที่ถูกต้องโดยอัตโนมัติจากจุดตรวจสอบโมเดล CodeGemma ให้ใช้ gemma.transformer.TransformerConfig
อาร์กิวเมนต์ cache_size
คือจำนวนขั้นตอนในแคชของ CodeGemma Transformer
หลังจากนั้น ให้สร้างอินสแตนซ์โมเดล CodeGemma เป็น model_2b
ด้วย gemma.transformer.Transformer
(ซึ่งรับค่าจาก flax.linen.Module
)
transformer_config = transformer_lib.TransformerConfig.from_params(
params,
cache_size=1024
)
transformer = transformer_lib.Transformer(config=transformer_config)
สร้าง sampler
ด้วย gemma.sampler.Sampler
โดยจะใช้จุดตรวจสอบโมเดล CodeGemma และเครื่องมือแปลงข้อมูลเป็นโทเค็น
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer']
)
สร้างตัวแปรบางตัวเพื่อแสดงโทเค็นแบบเติมตรงกลาง (fim) และสร้างฟังก์ชันตัวช่วยบางอย่างเพื่อจัดรูปแบบพรอมต์และเอาต์พุตที่สร้างขึ้น
ลองดูโค้ดต่อไปนี้เป็นตัวอย่าง
def function(string):
assert function('asdf') == 'fdsa'
เราต้องการกรอก function
เพื่อให้การยืนยันระงับ True
ในกรณีนี้ คำนำหน้าจะเป็น
"def function(string):\n"
และคำต่อท้ายจะเป็น
"assert function('asdf') == 'fdsa'"
จากนั้นเราจะจัดรูปแบบพรอมต์นี้เป็นพรอมต์ PREFIX-SUFFIX-MIDDLE (ส่วนตรงกลางที่ต้องเติมจะแสดงที่ส่วนท้ายของข้อความแจ้งเสมอ):
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"
def format_completion_prompt(before, after):
print(f"\nORIGINAL PROMPT:\n{before}{after}")
prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
return prompt
def format_generated_output(before, after, output):
print(f"\nGENERATED OUTPUT:\n{repr(output)}")
formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
return formatted_output
สร้างพรอมต์และดำเนินการอนุมาน ระบุข้อความนำ before
และข้อความคำต่อท้าย after
แล้วสร้างพรอมต์ที่มีการจัดรูปแบบโดยใช้ฟังก์ชันตัวช่วย format_completion prompt
คุณสามารถปรับแต่ง total_generation_steps
(จำนวนขั้นตอนที่ดำเนินการเมื่อสร้างคำตอบ โดยตัวอย่างนี้ใช้ 100
เพื่อเก็บรักษาหน่วยความจำของโฮสต์)
before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: def function(string): assert function('asdf') == 'fdsa' FORMATTED PROMPT: "<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>" GENERATED OUTPUT: ' return string[::-1]\n\n<|file_separator|>' FILL-IN COMPLETION: def function(string): return string[::-1] assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n sys.exit(0)"""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import if __name__ == "__main__": sys.exit(0) FORMATTED PROMPT: '<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n sys.exit(0)<|fim_middle|>' GENERATED OUTPUT: 'sys\n<|file_separator|>' FILL-IN COMPLETION: import sys if __name__ == "__main__": sys.exit(0)
before = """import numpy as np
def reflect(matrix):
# horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import numpy as np def reflect(matrix): # horizontally reflect a matrix FORMATTED PROMPT: '<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>' GENERATED OUTPUT: ' return np.flip(matrix, axis=1)\n<|file_separator|>' FILL-IN COMPLETION: import numpy as np def reflect(matrix): # horizontally reflect a matrix return np.flip(matrix, axis=1)
ดูข้อมูลเพิ่มเติม
- คุณสามารถดูข้อมูลเพิ่มเติมเกี่ยวกับไลบรารี Google DeepMind
gemma
ใน GitHub ซึ่งมีเอกสารสตริงของโมดูลที่คุณใช้ในบทแนะนำนี้ เช่นgemma.params
gemma.transformer
และgemma.sampler
- ไลบรารีต่อไปนี้มีเว็บไซต์เอกสารประกอบของตนเอง ได้แก่ JAX หลัก, Flax และ Orbax
- ดูเอกสารประกอบเกี่ยวกับเครื่องมือแปลงข้อมูลเป็นโทเค็น/เครื่องมือถอดรหัสของ
sentencepiece
ได้ที่ที่เก็บsentencepiece
GitHub ของ Google - ดูเอกสารประกอบเกี่ยวกับ
kagglehub
ได้ที่README.md
ในที่เก็บ GitHub ของkagglehub
ของ Kaggle - ดูวิธีใช้โมเดล Gemma กับ Vertex AI ของ Google Cloud
- หากคุณใช้ Google Cloud TPU (v3-8 ขึ้นไป) โปรดอัปเดตเป็นแพ็กเกจ
jax[tpu]
ล่าสุด (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
) รีสตาร์ทรันไทม์ และตรวจสอบว่าเวอร์ชันjax
และjaxlib
ตรงกัน (!pip list | grep jax
) ซึ่งจะป้องกันRuntimeError
ที่อาจเกิดขึ้นเนื่องจากเวอร์ชันjaxlib
และjax
ไม่ตรงกัน ดูคำแนะนำในการติดตั้ง JAX เพิ่มเติมในเอกสาร JAX