前往 ai.google.dev 查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 |
我们展示了 CodeGemma,它是一系列基于 Google DeepMind 的 Gemma 模型(Gemma Team 等,2024 年)。 CodeGemma 是一系列先进的轻量级开放模型,采用与 Gemini 模型相同的研究和技术构建而成。
接着 Gemma 预训练模型,我们使用 5000 到 1,000 亿个主要代码词元对 CodeGemma 模型进行了进一步的训练, 与 Gemma 模型系列相同的架构。因此,CodeGemma 模型在 生成任务和生成任务, 理解能力和推理能力。
CodeGemma 有 3 个变体:
- 70 亿个代码的预训练模型
- 70 亿指令调优的代码模型
- 一种 2B 模型,专为代码填充和开放式生成而训练。
本指南将引导您将 CodeGemma 模型与 Flax 搭配使用来执行代码补全任务。
设置
1. 设置对 CodeGemma 的 Kaggle 访问权限
要完成本教程,您首先需要按照 Gemma 设置中的设置说明进行操作,了解如何执行以下操作:
- 在 kaggle.com 上访问 CodeGemma。
- 选择具有足够资源(T4 GPU 内存不足,请改用 TPU v2)的 Colab 运行时来运行 CodeGemma 模型。
- 生成并配置 Kaggle 用户名和 API 密钥。
完成 Gemma 设置后,请继续执行下一部分,您将为 Colab 环境设置环境变量。
2. 设置环境变量
为 KAGGLE_USERNAME
和 KAGGLE_KEY
设置环境变量。当看到“是否授予访问权限?”的提示时,消息,请同意提供私密访问。
import os
from google.colab import userdata # `userdata` is a Colab API.
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
3. 安装 gemma
库
目前免费的 Colab 硬件加速功能不足以运行此笔记本。如果您使用的是 Colab 随用随付或 Colab Pro,请点击修改 >笔记本设置 >选择 A100 GPU >点击保存即可启用硬件加速。
接下来,您需要从 github.com/google-deepmind/gemma
安装 Google DeepMind gemma
库。如果您收到有关“pip 的依赖项解析器”的错误,通常可以忽略它。
pip install -q git+https://github.com/google-deepmind/gemma.git
4. 导入库
此笔记本使用 Gemma(使用 Flax 构建其神经网络层)和 SentencePiece(用于标记化)。
import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm
加载 CodeGemma 模型
使用 kagglehub.model_download
加载 CodeGemma 模型,该方法接受三个参数:
handle
:Kaggle 的模型句柄path
:(可选字符串)本地路径force_download
:(可选布尔值)强制重新下载模型
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub
GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7) Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download... 100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s] Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3
检查模型权重和标记生成器的位置,然后设置路径变量。标记生成器目录位于下载模型的主目录中,而模型权重则位于子目录中。例如:
spm.model
标记生成器文件将位于/LOCAL/PATH/TO/codegemma/flax/2b-pt/3
中- 模型检查点将在
/LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
中
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model
执行采样/推理
使用 gemma.params.load_and_format_params
方法加载 CodeGemma 模型检查点并设置其格式:
params = params_lib.load_and_format_params(CKPT_PATH)
加载使用 sentencepiece.SentencePieceProcessor
构造的 CodeGemma 标记生成器:
vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True
如需从 CodeGemma 模型检查点自动加载正确配置,请使用 gemma.transformer.TransformerConfig
。cache_size
参数是 CodeGemma Transformer
缓存中的时间步长。然后,使用 gemma.transformer.Transformer
(继承自 flax.linen.Module
)将 CodeGemma 模型实例化为 model_2b
。
transformer_config = transformer_lib.TransformerConfig.from_params(
params,
cache_size=1024
)
transformer = transformer_lib.Transformer(config=transformer_config)
使用 gemma.sampler.Sampler
创建 sampler
。它使用 CodeGemma 模型检查点和标记生成器。
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer']
)
创建一些变量来表示中间填充 (fim) 令牌,并创建一些辅助函数来设置提示和生成的输出的格式。
例如,我们来看看以下代码:
def function(string):
assert function('asdf') == 'fdsa'
我们想要填充 function
,以便断言包含 True
。在这种情况下,前缀为:
"def function(string):\n"
后缀为:
"assert function('asdf') == 'fdsa'"
然后,我们将以下内容的格式设置为 PREFIX-SUFFIX-MIDDLE(需要填写的中间部分始终在提示的末尾):
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"
def format_completion_prompt(before, after):
print(f"\nORIGINAL PROMPT:\n{before}{after}")
prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
return prompt
def format_generated_output(before, after, output):
print(f"\nGENERATED OUTPUT:\n{repr(output)}")
formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
return formatted_output
创建提示并进行推理。指定前缀 before
文本和后缀 after
文本,并使用辅助函数 format_completion prompt
生成格式化提示。
您可以调整 total_generation_steps
(生成响应时执行的步骤数 - 此示例使用 100
保留主机内存)。
before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: def function(string): assert function('asdf') == 'fdsa' FORMATTED PROMPT: "<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>" GENERATED OUTPUT: ' return string[::-1]\n\n<|file_separator|>' FILL-IN COMPLETION: def function(string): return string[::-1] assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n sys.exit(0)"""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import if __name__ == "__main__": sys.exit(0) FORMATTED PROMPT: '<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n sys.exit(0)<|fim_middle|>' GENERATED OUTPUT: 'sys\n<|file_separator|>' FILL-IN COMPLETION: import sys if __name__ == "__main__": sys.exit(0)
before = """import numpy as np
def reflect(matrix):
# horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)
output = sampler(
[prompt],
total_generation_steps=100,
).text
formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT: import numpy as np def reflect(matrix): # horizontally reflect a matrix FORMATTED PROMPT: '<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>' GENERATED OUTPUT: ' return np.flip(matrix, axis=1)\n<|file_separator|>' FILL-IN COMPLETION: import numpy as np def reflect(matrix): # horizontally reflect a matrix return np.flip(matrix, axis=1)
了解详情
- 您可以在 GitHub 上详细了解 Google DeepMind 的
gemma
库,该库包含您在本教程中使用的模块的文档字符串,例如gemma.params
。gemma.transformer
和gemma.sampler
。 - 以下库有自己的文档网站:core JAX、Flax 和 Orbax。
- 如需查看
sentencepiece
标记生成器/detokenizer 文档,请查看 Google 的sentencepiece
GitHub 代码库。 - 如需查看
kagglehub
文档,请参阅 Kaggle 的kagglehub
GitHub 代码库中的README.md
。 - 了解如何将 Gemma 模型与 Google Cloud Vertex AI 搭配使用。
- 如果您使用的是 Google Cloud TPU(v3-8 及更高版本),请务必也更新到最新的
jax[tpu]
软件包 (!pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
),重启运行时,并检查jax
和jaxlib
版本是否匹配 (!pip list | grep jax
)。这可以防止因jaxlib
和jax
版本不匹配而出现的RuntimeError
。如需详细了解 JAX 安装说明,请参阅 JAX 文档。