使用 JAX 和 Flax 通过 CodeGemma 进行推断

前往 ai.google.dev 查看 在 Google Colab 中运行 在 GitHub 上查看源代码

我们展示了 CodeGemma,它是一系列基于 Google DeepMind 的 Gemma 模型(Gemma Team 等,2024 年)。 CodeGemma 是一系列先进的轻量级开放模型,采用与 Gemini 模型相同的研究和技术构建而成。

接着 Gemma 预训练模型,我们使用 5000 到 1,000 亿个主要代码词元对 CodeGemma 模型进行了进一步的训练, 与 Gemma 模型系列相同的架构。因此,CodeGemma 模型在 生成任务和生成任务, 理解能力和推理能力。

CodeGemma 有 3 个变体:

  • 70 亿个代码的预训练模型
  • 70 亿指令调优的代码模型
  • 一种 2B 模型,专为代码填充和开放式生成而训练。

本指南将引导您将 CodeGemma 模型与 Flax 搭配使用来执行代码补全任务。

设置

1. 设置对 CodeGemma 的 Kaggle 访问权限

要完成本教程,您首先需要按照 Gemma 设置中的设置说明进行操作,了解如何执行以下操作:

  • kaggle.com 上访问 CodeGemma。
  • 选择具有足够资源(T4 GPU 内存不足,请改用 TPU v2)的 Colab 运行时来运行 CodeGemma 模型。
  • 生成并配置 Kaggle 用户名和 API 密钥。

完成 Gemma 设置后,请继续执行下一部分,您将为 Colab 环境设置环境变量。

2. 设置环境变量

KAGGLE_USERNAMEKAGGLE_KEY 设置环境变量。当看到“是否授予访问权限?”的提示时,消息,请同意提供私密访问。

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

3. 安装 gemma

目前免费的 Colab 硬件加速功能不足以运行此笔记本。如果您使用的是 Colab 随用随付或 Colab Pro,请点击修改 >笔记本设置 >选择 A100 GPU >点击保存即可启用硬件加速。

接下来,您需要从 github.com/google-deepmind/gemma 安装 Google DeepMind gemma 库。如果您收到有关“pip 的依赖项解析器”的错误,通常可以忽略它。

pip install -q git+https://github.com/google-deepmind/gemma.git

4. 导入库

此笔记本使用 Gemma(使用 Flax 构建其神经网络层)和 SentencePiece(用于标记化)。

import os
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

加载 CodeGemma 模型

使用 kagglehub.model_download 加载 CodeGemma 模型,该方法接受三个参数:

  • handle:Kaggle 的模型句柄
  • path:(可选字符串)本地路径
  • force_download:(可选布尔值)强制重新下载模型
GEMMA_VARIANT = '2b-pt' # @param ['2b-pt', '7b-it', '7b-pt', '1.1-2b-pt', '1.1-7b-it'] {type:"string"}
import kagglehub

GEMMA_PATH = kagglehub.model_download(f'google/codegemma/flax/{GEMMA_VARIANT}')
Warning: Looks like you're using an outdated `kagglehub` version, please consider updating (latest version: 0.2.7)
Downloading from https://www.kaggle.com/api/v1/models/google/codegemma/flax/2b-pt/3/download...
100%|██████████| 3.67G/3.67G [00:22<00:00, 173MB/s]
Extracting model files...
print('GEMMA_PATH:', GEMMA_PATH)
GEMMA_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3

检查模型权重和标记生成器的位置,然后设置路径变量。标记生成器目录位于下载模型的主目录中,而模型权重则位于子目录中。例如:

  • spm.model 标记生成器文件将位于 /LOCAL/PATH/TO/codegemma/flax/2b-pt/3
  • 模型检查点将在 /LOCAL/PATH/TO/codegemma/flax/2b-pt/3/2b-pt
CKPT_PATH = os.path.join(GEMMA_PATH, GEMMA_VARIANT[-5:])
TOKENIZER_PATH = os.path.join(GEMMA_PATH, 'spm.model')
print('CKPT_PATH:', CKPT_PATH)
print('TOKENIZER_PATH:', TOKENIZER_PATH)
CKPT_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/2b-pt
TOKENIZER_PATH: /root/.cache/kagglehub/models/google/codegemma/flax/2b-pt/3/spm.model

执行采样/推理

使用 gemma.params.load_and_format_params 方法加载 CodeGemma 模型检查点并设置其格式:

params = params_lib.load_and_format_params(CKPT_PATH)

加载使用 sentencepiece.SentencePieceProcessor 构造的 CodeGemma 标记生成器:

vocab = spm.SentencePieceProcessor()
vocab.Load(TOKENIZER_PATH)
True

如需从 CodeGemma 模型检查点自动加载正确配置,请使用 gemma.transformer.TransformerConfigcache_size 参数是 CodeGemma Transformer 缓存中的时间步长。然后,使用 gemma.transformer.Transformer(继承自 flax.linen.Module)将 CodeGemma 模型实例化为 model_2b

transformer_config = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=1024
)

transformer = transformer_lib.Transformer(config=transformer_config)

使用 gemma.sampler.Sampler 创建 sampler。它使用 CodeGemma 模型检查点和标记生成器。

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    params=params['transformer']
)

创建一些变量来表示中间填充 (fim) 令牌,并创建一些辅助函数来设置提示和生成的输出的格式。

例如,我们来看看以下代码:

def function(string):
assert function('asdf') == 'fdsa'

我们想要填充 function,以便断言包含 True。在这种情况下,前缀为:

"def function(string):\n"

后缀为:

"assert function('asdf') == 'fdsa'"

然后,我们将以下内容的格式设置为 PREFIX-SUFFIX-MIDDLE(需要填写的中间部分始终在提示的末尾):

"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"
# In the context of a code editor,
# the cursor is the location where the text will be inserted
BEFORE_CURSOR = "<|fim_prefix|>"
AFTER_CURSOR = "<|fim_suffix|>"
AT_CURSOR = "<|fim_middle|>"
FILE_SEPARATOR = "<|file_separator|>"

def format_completion_prompt(before, after):
  print(f"\nORIGINAL PROMPT:\n{before}{after}")
  prompt = f"{BEFORE_CURSOR}{before}{AFTER_CURSOR}{after}{AT_CURSOR}"
  print(f"\nFORMATTED PROMPT:\n{repr(prompt)}")
  return prompt
def format_generated_output(before, after, output):
  print(f"\nGENERATED OUTPUT:\n{repr(output)}")
  formatted_output = f"{before}{output.replace(FILE_SEPARATOR, '')}{after}"
  print(f"\nFILL-IN COMPLETION:\n{formatted_output}")
  return formatted_output

创建提示并进行推理。指定前缀 before 文本和后缀 after 文本,并使用辅助函数 format_completion prompt 生成格式化提示。

您可以调整 total_generation_steps(生成响应时执行的步骤数 - 此示例使用 100 保留主机内存)。

before = "def function(string):\n"
after = "assert function('asdf') == 'fdsa'"
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
def function(string):
assert function('asdf') == 'fdsa'

FORMATTED PROMPT:
"<|fim_prefix|>def function(string):\n<|fim_suffix|>assert function('asdf') == 'fdsa'<|fim_middle|>"

GENERATED OUTPUT:
'    return string[::-1]\n\n<|file_separator|>'

FILL-IN COMPLETION:
def function(string):
    return string[::-1]

assert function('asdf') == 'fdsa'
before = "import "
after = """if __name__ == "__main__":\n    sys.exit(0)"""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import if __name__ == "__main__":
    sys.exit(0)

FORMATTED PROMPT:
'<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":\n    sys.exit(0)<|fim_middle|>'

GENERATED OUTPUT:
'sys\n<|file_separator|>'

FILL-IN COMPLETION:
import sys
if __name__ == "__main__":
    sys.exit(0)
before = """import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
"""
after = ""
prompt = format_completion_prompt(before, after)

output = sampler(
    [prompt],
    total_generation_steps=100,
    ).text

formatted_output = format_generated_output(before, after, output[0])
ORIGINAL PROMPT:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix


FORMATTED PROMPT:
'<|fim_prefix|>import numpy as np\ndef reflect(matrix):\n  # horizontally reflect a matrix\n<|fim_suffix|><|fim_middle|>'

GENERATED OUTPUT:
'  return np.flip(matrix, axis=1)\n<|file_separator|>'

FILL-IN COMPLETION:
import numpy as np
def reflect(matrix):
  # horizontally reflect a matrix
  return np.flip(matrix, axis=1)

了解详情