Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 56cbd3b

Browse files
authoredMay 30, 2024
feat: add GeminiText 1.5 Preview models (#737)
* feat: add gemini 1.5 preview models * tests
1 parent e5a2992 commit 56cbd3b

File tree

4 files changed

+68
-20
lines changed

4 files changed

+68
-20
lines changed
 

‎bigframes/ml/llm.py

+33-8
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@
4646
)
4747

4848
_GEMINI_PRO_ENDPOINT = "gemini-pro"
49+
_GEMINI_1P5_PRO_PREVIEW_ENDPOINT = "gemini-1.5-pro-preview-0514"
50+
_GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT = "gemini-1.5-flash-preview-0514"
51+
_GEMINI_ENDPOINTS = (
52+
_GEMINI_PRO_ENDPOINT,
53+
_GEMINI_1P5_PRO_PREVIEW_ENDPOINT,
54+
_GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT,
55+
)
56+
4957

5058
_ML_GENERATE_TEXT_STATUS = "ml_generate_text_status"
5159
_ML_EMBED_TEXT_STATUS = "ml_embed_text_status"
@@ -547,13 +555,16 @@ def to_gbq(
547555
class GeminiTextGenerator(base.BaseEstimator):
548556
"""Gemini text generator LLM model.
549557
550-
.. note::
551-
This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
552-
Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
553-
and might have limited support. For more information, see the launch stage descriptions
554-
(https://cloud.google.com/products#product-launch-stages).
555-
556558
Args:
559+
model_name (str, Default to "gemini-pro"):
560+
The model for natural language tasks. Accepted values are "gemini-pro", "gemini-1.5-pro-preview-0514" and "gemini-1.5-flash-preview-0514". Default to "gemini-pro".
561+
562+
.. note::
563+
"gemini-1.5-pro-preview-0514" and "gemini-1.5-flash-preview-0514" is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
564+
Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
565+
and might have limited support. For more information, see the launch stage descriptions
566+
(https://cloud.google.com/products#product-launch-stages).
567+
557568
session (bigframes.Session or None):
558569
BQ session to create the model. If None, use the global default session.
559570
connection_name (str or None):
@@ -565,9 +576,13 @@ class GeminiTextGenerator(base.BaseEstimator):
565576
def __init__(
566577
self,
567578
*,
579+
model_name: Literal[
580+
"gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"
581+
] = "gemini-pro",
568582
session: Optional[bigframes.Session] = None,
569583
connection_name: Optional[str] = None,
570584
):
585+
self.model_name = model_name
571586
self.session = session or bpd.get_global_session()
572587
self._bq_connection_manager = self.session.bqconnectionmanager
573588

@@ -601,7 +616,12 @@ def _create_bqml_model(self):
601616
iam_role="aiplatform.user",
602617
)
603618

604-
options = {"endpoint": _GEMINI_PRO_ENDPOINT}
619+
if self.model_name not in _GEMINI_ENDPOINTS:
620+
raise ValueError(
621+
f"Model name {self.model_name} is not supported. We only support {', '.join(_GEMINI_ENDPOINTS)}."
622+
)
623+
624+
options = {"endpoint": self.model_name}
605625

606626
return self._bqml_model_factory.create_remote_model(
607627
session=self.session, connection_name=self.connection_name, options=options
@@ -613,12 +633,17 @@ def _from_bq(
613633
) -> GeminiTextGenerator:
614634
assert bq_model.model_type == "MODEL_TYPE_UNSPECIFIED"
615635
assert "remoteModelInfo" in bq_model._properties
636+
assert "endpoint" in bq_model._properties["remoteModelInfo"]
616637
assert "connection" in bq_model._properties["remoteModelInfo"]
617638

618639
# Parse the remote model endpoint
640+
bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"]
619641
model_connection = bq_model._properties["remoteModelInfo"]["connection"]
642+
model_endpoint = bqml_endpoint.split("/")[-1]
620643

621-
model = cls(session=session, connection_name=model_connection)
644+
model = cls(
645+
model_name=model_endpoint, session=session, connection_name=model_connection
646+
)
622647
model._bqml_model = core.BqmlModel(session, bq_model)
623648
return model
624649

‎bigframes/ml/loader.py

+2
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
llm._EMBEDDING_GENERATOR_GECKO_ENDPOINT: llm.PaLM2TextEmbeddingGenerator,
6262
llm._EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT: llm.PaLM2TextEmbeddingGenerator,
6363
llm._GEMINI_PRO_ENDPOINT: llm.GeminiTextGenerator,
64+
llm._GEMINI_1P5_PRO_PREVIEW_ENDPOINT: llm.GeminiTextGenerator,
65+
llm._GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT: llm.GeminiTextGenerator,
6466
}
6567
)
6668

‎tests/system/small/ml/conftest.py

-5
Original file line numberDiff line numberDiff line change
@@ -275,11 +275,6 @@ def palm2_embedding_generator_multilingual_model(
275275
)
276276

277277

278-
@pytest.fixture(scope="session")
279-
def gemini_text_generator_model(session, bq_connection) -> llm.GeminiTextGenerator:
280-
return llm.GeminiTextGenerator(session=session, connection_name=bq_connection)
281-
282-
283278
@pytest.fixture(scope="session")
284279
def linear_remote_model_params() -> dict:
285280
# Pre-deployed endpoint of linear reg model in Vertex.

‎tests/system/small/ml/test_llm.py

+33-7
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,16 @@ def test_embedding_generator_predict_series_success(
303303
assert len(value) == 768
304304

305305

306-
def test_create_gemini_text_generator_model(
307-
gemini_text_generator_model, dataset_id, bq_connection
306+
@pytest.mark.parametrize(
307+
"model_name",
308+
("gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"),
309+
)
310+
def test_create_load_gemini_text_generator_model(
311+
dataset_id, model_name, session, bq_connection
308312
):
309-
# Model creation doesn't return error
313+
gemini_text_generator_model = llm.GeminiTextGenerator(
314+
model_name=model_name, connection_name=bq_connection, session=session
315+
)
310316
assert gemini_text_generator_model is not None
311317
assert gemini_text_generator_model._bqml_model is not None
312318

@@ -316,23 +322,43 @@ def test_create_gemini_text_generator_model(
316322
)
317323
assert f"{dataset_id}.temp_text_model" == reloaded_model._bqml_model.model_name
318324
assert reloaded_model.connection_name == bq_connection
319-
320-
325+
assert reloaded_model.model_name == model_name
326+
327+
328+
@pytest.mark.parametrize(
329+
"model_name",
330+
(
331+
"gemini-pro",
332+
"gemini-1.5-pro-preview-0514",
333+
# TODO(garrrettwu): enable when cl/637028077 is in prod.
334+
# "gemini-1.5-flash-preview-0514"
335+
),
336+
)
321337
@pytest.mark.flaky(retries=2)
322338
def test_gemini_text_generator_predict_default_params_success(
323-
gemini_text_generator_model, llm_text_df
339+
llm_text_df, model_name, session, bq_connection
324340
):
341+
gemini_text_generator_model = llm.GeminiTextGenerator(
342+
model_name=model_name, connection_name=bq_connection, session=session
343+
)
325344
df = gemini_text_generator_model.predict(llm_text_df).to_pandas()
326345
assert df.shape == (3, 4)
327346
assert "ml_generate_text_llm_result" in df.columns
328347
series = df["ml_generate_text_llm_result"]
329348
assert all(series.str.len() > 20)
330349

331350

351+
@pytest.mark.parametrize(
352+
"model_name",
353+
("gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"),
354+
)
332355
@pytest.mark.flaky(retries=2)
333356
def test_gemini_text_generator_predict_with_params_success(
334-
gemini_text_generator_model, llm_text_df
357+
llm_text_df, model_name, session, bq_connection
335358
):
359+
gemini_text_generator_model = llm.GeminiTextGenerator(
360+
model_name=model_name, connection_name=bq_connection, session=session
361+
)
336362
df = gemini_text_generator_model.predict(
337363
llm_text_df, temperature=0.5, max_output_tokens=100, top_k=20, top_p=0.5
338364
).to_pandas()

0 commit comments

Comments
 (0)
Failed to load comments.