Skip to content

Commit

Permalink
add generation_config and safety_settings to google cloud multimodal …
Browse files Browse the repository at this point in the history
…model operators (#40126)
  • Loading branch information
CYarros10 committed Jun 14, 2024
1 parent f0bae33 commit e2b8f68
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 7 deletions.
18 changes: 15 additions & 3 deletions airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def prompt_multimodal_model(
self,
prompt: str,
location: str,
generation_config: dict | None = None,
safety_settings: dict | None = None,
pretrained_model: str = "gemini-pro",
project_id: str = PROVIDE_PROJECT_ID,
) -> str:
Expand All @@ -149,17 +151,21 @@ def prompt_multimodal_model(
:param prompt: Required. Inputs or queries that a user or a program gives
to the Multi-modal model, in order to elicit a specific response.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param generation_config: Optional. Generation configuration settings.
:param safety_settings: Optional. Per request settings for blocking unsafe content.
:param pretrained_model: By default uses the pre-trained model `gemini-pro`,
supporting prompts with text-only input, including natural language
tasks, multi-turn text and code chat, and code generation. It can
output text and code.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
"""
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())

model = self.get_generative_model(pretrained_model)
response = model.generate_content(prompt)
response = model.generate_content(
contents=[prompt], generation_config=generation_config, safety_settings=safety_settings
)

return response.text

Expand All @@ -170,6 +176,8 @@ def prompt_multimodal_model_with_media(
location: str,
media_gcs_path: str,
mime_type: str,
generation_config: dict | None = None,
safety_settings: dict | None = None,
pretrained_model: str = "gemini-pro-vision",
project_id: str = PROVIDE_PROJECT_ID,
) -> str:
Expand All @@ -178,6 +186,8 @@ def prompt_multimodal_model_with_media(
:param prompt: Required. Inputs or queries that a user or a program gives
to the Multi-modal model, in order to elicit a specific response.
:param generation_config: Optional. Generation configuration settings.
:param safety_settings: Optional. Per request settings for blocking unsafe content.
:param pretrained_model: By default uses the pre-trained model `gemini-pro-vision`,
supporting prompts with text-only input, including natural language
tasks, multi-turn text and code chat, and code generation. It can
Expand All @@ -192,6 +202,8 @@ def prompt_multimodal_model_with_media(

model = self.get_generative_model(pretrained_model)
part = self.get_generative_model_part(media_gcs_path, mime_type)
response = model.generate_content([prompt, part])
response = model.generate_content(
contents=[prompt, part], generation_config=generation_config, safety_settings=safety_settings
)

return response.text
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ class PromptMultimodalModelOperator(GoogleCloudBaseOperator):
service belongs to (templated).
:param prompt: Required. Inputs or queries that a user or a program gives
to the Multi-modal model, in order to elicit a specific response (templated).
:param generation_config: Optional. Generation configuration settings.
:param safety_settings: Optional. Per request settings for blocking unsafe content.
:param pretrained_model: By default uses the pre-trained model `gemini-pro`,
supporting prompts with text-only input, including natural language
tasks, multi-turn text and code chat, and code generation. It can
Expand All @@ -210,6 +212,8 @@ def __init__(
project_id: str,
location: str,
prompt: str,
generation_config: dict | None = None,
safety_settings: dict | None = None,
pretrained_model: str = "gemini-pro",
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
Expand All @@ -219,6 +223,8 @@ def __init__(
self.project_id = project_id
self.location = location
self.prompt = prompt
self.generation_config = generation_config
self.safety_settings = safety_settings
self.pretrained_model = pretrained_model
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
Expand All @@ -232,6 +238,8 @@ def execute(self, context: Context):
project_id=self.project_id,
location=self.location,
prompt=self.prompt,
generation_config=self.generation_config,
safety_settings=self.safety_settings,
pretrained_model=self.pretrained_model,
)

Expand All @@ -251,6 +259,8 @@ class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator):
service belongs to (templated).
:param prompt: Required. Inputs or queries that a user or a program gives
to the Multi-modal model, in order to elicit a specific response (templated).
:param generation_config: Optional. Generation configuration settings.
:param safety_settings: Optional. Per request settings for blocking unsafe content.
:param pretrained_model: By default uses the pre-trained model `gemini-pro-vision`,
supporting prompts with text-only input, including natural language
tasks, multi-turn text and code chat, and code generation. It can
Expand Down Expand Up @@ -279,6 +289,8 @@ def __init__(
prompt: str,
media_gcs_path: str,
mime_type: str,
generation_config: dict | None = None,
safety_settings: dict | None = None,
pretrained_model: str = "gemini-pro-vision",
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
Expand All @@ -288,6 +300,8 @@ def __init__(
self.project_id = project_id
self.location = location
self.prompt = prompt
self.generation_config = generation_config
self.safety_settings = safety_settings
self.pretrained_model = pretrained_model
self.media_gcs_path = media_gcs_path
self.mime_type = mime_type
Expand All @@ -303,6 +317,8 @@ def execute(self, context: Context):
project_id=self.project_id,
location=self.location,
prompt=self.prompt,
generation_config=self.generation_config,
safety_settings=self.safety_settings,
pretrained_model=self.pretrained_model,
media_gcs_path=self.media_gcs_path,
mime_type=self.mime_type,
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ dependencies:
- google-api-python-client>=2.0.2
- google-auth>=2.29.0
- google-auth-httplib2>=0.0.1
- google-cloud-aiplatform>=1.42.1
- google-cloud-aiplatform>=1.54.0
- google-cloud-automl>=2.12.0
# google-cloud-bigquery version 3.21.0 introduced a performance enhancement in QueryJob.result(),
# which has led to backward compatibility issues
Expand Down
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@
"google-api-python-client>=2.0.2",
"google-auth-httplib2>=0.0.1",
"google-auth>=2.29.0",
"google-cloud-aiplatform>=1.42.1",
"google-cloud-aiplatform>=1.54.0",
"google-cloud-automl>=2.12.0",
"google-cloud-batch>=0.13.0",
"google-cloud-bigquery-datatransfer>=3.13.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

# For no Pydantic environment, we need to skip the tests
pytest.importorskip("google.cloud.aiplatform_v1")
vertexai = pytest.importorskip("vertexai.generative_models")
from vertexai.generative_models import HarmBlockThreshold, HarmCategory

from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import (
GenerativeModelHook,
Expand All @@ -45,6 +47,17 @@
TEST_TEXT_EMBEDDING_MODEL = ""

TEST_MULTIMODAL_PRETRAINED_MODEL = "gemini-pro"
TEST_SAFETY_SETTINGS = {
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
}
TEST_GENERATION_CONFIG = {
"max_output_tokens": TEST_MAX_OUTPUT_TOKENS,
"top_p": TEST_TOP_P,
"temperature": TEST_TEMPERATURE,
}

TEST_MULTIMODAL_VISION_MODEL = "gemini-pro-vision"
TEST_VISION_PROMPT = "In 10 words or less, describe this content."
Expand Down Expand Up @@ -104,10 +117,16 @@ def test_prompt_multimodal_model(self, mock_model) -> None:
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=TEST_PROMPT,
generation_config=TEST_GENERATION_CONFIG,
safety_settings=TEST_SAFETY_SETTINGS,
pretrained_model=TEST_MULTIMODAL_PRETRAINED_MODEL,
)
mock_model.assert_called_once_with(TEST_MULTIMODAL_PRETRAINED_MODEL)
mock_model.return_value.generate_content.assert_called_once_with(TEST_PROMPT)
mock_model.return_value.generate_content.assert_called_once_with(
contents=[TEST_PROMPT],
generation_config=TEST_GENERATION_CONFIG,
safety_settings=TEST_SAFETY_SETTINGS,
)

@mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_generative_model_part"))
@mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_generative_model"))
Expand All @@ -116,6 +135,8 @@ def test_prompt_multimodal_model_with_media(self, mock_model, mock_part) -> None
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=TEST_VISION_PROMPT,
generation_config=TEST_GENERATION_CONFIG,
safety_settings=TEST_SAFETY_SETTINGS,
pretrained_model=TEST_MULTIMODAL_VISION_MODEL,
media_gcs_path=TEST_MEDIA_GCS_PATH,
mime_type=TEST_MIME_TYPE,
Expand All @@ -124,5 +145,7 @@ def test_prompt_multimodal_model_with_media(self, mock_model, mock_part) -> None
mock_part.assert_called_once_with(TEST_MEDIA_GCS_PATH, TEST_MIME_TYPE)

mock_model.return_value.generate_content.assert_called_once_with(
[TEST_VISION_PROMPT, mock_part.return_value]
contents=[TEST_VISION_PROMPT, mock_part.return_value],
generation_config=TEST_GENERATION_CONFIG,
safety_settings=TEST_SAFETY_SETTINGS,
)
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

# For no Pydantic environment, we need to skip the tests
pytest.importorskip("google.cloud.aiplatform_v1")
vertexai = pytest.importorskip("vertexai.generative_models")
from vertexai.generative_models import HarmBlockThreshold, HarmCategory

from airflow.providers.google.cloud.operators.vertex_ai.generative_model import (
GenerateTextEmbeddingsOperator,
Expand Down Expand Up @@ -112,12 +114,21 @@ class TestVertexAIPromptMultimodalModelOperator:
def test_execute(self, mock_hook):
prompt = "In 10 words or less, what is Apache Airflow?"
pretrained_model = "gemini-pro"
safety_settings = {
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
}
generation_config = {"max_output_tokens": 256, "top_p": 0.8, "temperature": 0.0}

op = PromptMultimodalModelOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=prompt,
generation_config=generation_config,
safety_settings=safety_settings,
pretrained_model=pretrained_model,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
Expand All @@ -131,6 +142,8 @@ def test_execute(self, mock_hook):
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=prompt,
generation_config=generation_config,
safety_settings=safety_settings,
pretrained_model=pretrained_model,
)

Expand All @@ -142,12 +155,21 @@ def test_execute(self, mock_hook):
vision_prompt = "In 10 words or less, describe this content."
media_gcs_path = "gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg"
mime_type = "image/jpeg"
safety_settings = {
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
}
generation_config = {"max_output_tokens": 256, "top_p": 0.8, "temperature": 0.0}

op = PromptMultimodalModelWithMediaOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=vision_prompt,
generation_config=generation_config,
safety_settings=safety_settings,
pretrained_model=pretrained_model,
media_gcs_path=media_gcs_path,
mime_type=mime_type,
Expand All @@ -163,6 +185,8 @@ def test_execute(self, mock_hook):
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=vision_prompt,
generation_config=generation_config,
safety_settings=safety_settings,
pretrained_model=pretrained_model,
media_gcs_path=media_gcs_path,
mime_type=mime_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import os
from datetime import datetime

from vertexai.generative_models import HarmBlockThreshold, HarmCategory

from airflow.models.dag import DAG
from airflow.providers.google.cloud.operators.vertex_ai.generative_model import (
GenerateTextEmbeddingsOperator,
Expand All @@ -44,6 +46,13 @@
VISION_PROMPT = "In 10 words or less, describe this content."
MEDIA_GCS_PATH = "gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg"
MIME_TYPE = "image/jpeg"
GENERATION_CONFIG = {"max_output_tokens": 256, "top_p": 0.95, "temperature": 0.0}
SAFETY_SETTINGS = {
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
}

with DAG(
dag_id=DAG_ID,
Expand Down Expand Up @@ -79,6 +88,8 @@
project_id=PROJECT_ID,
location=REGION,
prompt=PROMPT,
generation_config=GENERATION_CONFIG,
safety_settings=SAFETY_SETTINGS,
pretrained_model=MULTIMODAL_MODEL,
)
# [END how_to_cloud_vertex_ai_prompt_multimodal_model_operator]
Expand All @@ -89,6 +100,8 @@
project_id=PROJECT_ID,
location=REGION,
prompt=VISION_PROMPT,
generation_config=GENERATION_CONFIG,
safety_settings=SAFETY_SETTINGS,
pretrained_model=MULTIMODAL_VISION_MODEL,
media_gcs_path=MEDIA_GCS_PATH,
mime_type=MIME_TYPE,
Expand Down

0 comments on commit e2b8f68

Please sign in to comment.