Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 6bc6a41

Browse files
authoredAug 19, 2024
feat: add llm.TextEmbeddingGenerator to support new embedding models (#905)
* feat: add llm.TextEmbeddingGenerator to support new embedding models * fix docs
1 parent 92fdb93 commit 6bc6a41

File tree

3 files changed

+207
-3
lines changed

3 files changed

+207
-3
lines changed
 

‎bigframes/ml/llm.py

+163-3
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,18 @@
4040

4141
_EMBEDDING_GENERATOR_GECKO_ENDPOINT = "textembedding-gecko"
4242
_EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT = "textembedding-gecko-multilingual"
43-
_EMBEDDING_GENERATOR_ENDPOINTS = (
43+
_PALM2_EMBEDDING_GENERATOR_ENDPOINTS = (
4444
_EMBEDDING_GENERATOR_GECKO_ENDPOINT,
4545
_EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT,
4646
)
4747

48+
_TEXT_EMBEDDING_004_ENDPOINT = "text-embedding-004"
49+
_TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT = "text-multilingual-embedding-002"
50+
_TEXT_EMBEDDING_ENDPOINTS = (
51+
_TEXT_EMBEDDING_004_ENDPOINT,
52+
_TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT,
53+
)
54+
4855
_GEMINI_PRO_ENDPOINT = "gemini-pro"
4956
_GEMINI_1P5_PRO_PREVIEW_ENDPOINT = "gemini-1.5-pro-preview-0514"
5057
_GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT = "gemini-1.5-flash-preview-0514"
@@ -57,6 +64,7 @@
5764

5865
_ML_GENERATE_TEXT_STATUS = "ml_generate_text_status"
5966
_ML_EMBED_TEXT_STATUS = "ml_embed_text_status"
67+
_ML_GENERATE_EMBEDDING_STATUS = "ml_generate_embedding_status"
6068

6169

6270
@log_adapter.class_logger
@@ -387,6 +395,10 @@ def to_gbq(self, model_name: str, replace: bool = False) -> PaLM2TextGenerator:
387395
class PaLM2TextEmbeddingGenerator(base.BaseEstimator):
388396
"""PaLM2 text embedding generator LLM model.
389397
398+
.. note::
399+
Models in this class are outdated and going to be deprecated. To use the most updated text embedding models, go to the TextEmbeddingGenerator class.
400+
401+
390402
Args:
391403
model_name (str, Default to "textembedding-gecko"):
392404
The model for text embedding. “textembedding-gecko” returns model embeddings for text inputs.
@@ -447,9 +459,9 @@ def _create_bqml_model(self):
447459
iam_role="aiplatform.user",
448460
)
449461

450-
if self.model_name not in _EMBEDDING_GENERATOR_ENDPOINTS:
462+
if self.model_name not in _PALM2_EMBEDDING_GENERATOR_ENDPOINTS:
451463
raise ValueError(
452-
f"Model name {self.model_name} is not supported. We only support {', '.join(_EMBEDDING_GENERATOR_ENDPOINTS)}."
464+
f"Model name {self.model_name} is not supported. We only support {', '.join(_PALM2_EMBEDDING_GENERATOR_ENDPOINTS)}."
453465
)
454466

455467
endpoint = (
@@ -551,6 +563,154 @@ def to_gbq(
551563
return new_model.session.read_gbq_model(model_name)
552564

553565

566+
@log_adapter.class_logger
567+
class TextEmbeddingGenerator(base.BaseEstimator):
568+
"""Text embedding generator LLM model.
569+
570+
Args:
571+
model_name (str, Default to "text-embedding-004"):
572+
The model for text embedding. Possible values are "text-embedding-004" or "text-multilingual-embedding-002".
573+
text-embedding models returns model embeddings for text inputs.
574+
text-multilingual-embedding models returns model embeddings for text inputs which support over 100 languages.
575+
Default to "text-embedding-004".
576+
session (bigframes.Session or None):
577+
BQ session to create the model. If None, use the global default session.
578+
connection_name (str or None):
579+
Connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
580+
If None, use default connection in session context.
581+
"""
582+
583+
def __init__(
584+
self,
585+
*,
586+
model_name: Literal[
587+
"text-embedding-004", "text-multilingual-embedding-002"
588+
] = "text-embedding-004",
589+
session: Optional[bigframes.Session] = None,
590+
connection_name: Optional[str] = None,
591+
):
592+
self.model_name = model_name
593+
self.session = session or bpd.get_global_session()
594+
self._bq_connection_manager = self.session.bqconnectionmanager
595+
596+
connection_name = connection_name or self.session._bq_connection
597+
self.connection_name = clients.resolve_full_bq_connection_name(
598+
connection_name,
599+
default_project=self.session._project,
600+
default_location=self.session._location,
601+
)
602+
603+
self._bqml_model_factory = globals.bqml_model_factory()
604+
self._bqml_model: core.BqmlModel = self._create_bqml_model()
605+
606+
def _create_bqml_model(self):
607+
# Parse and create connection if needed.
608+
if not self.connection_name:
609+
raise ValueError(
610+
"Must provide connection_name, either in constructor or through session options."
611+
)
612+
613+
if self._bq_connection_manager:
614+
connection_name_parts = self.connection_name.split(".")
615+
if len(connection_name_parts) != 3:
616+
raise ValueError(
617+
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
618+
)
619+
self._bq_connection_manager.create_bq_connection(
620+
project_id=connection_name_parts[0],
621+
location=connection_name_parts[1],
622+
connection_id=connection_name_parts[2],
623+
iam_role="aiplatform.user",
624+
)
625+
626+
if self.model_name not in _TEXT_EMBEDDING_ENDPOINTS:
627+
raise ValueError(
628+
f"Model name {self.model_name} is not supported. We only support {', '.join(_TEXT_EMBEDDING_ENDPOINTS)}."
629+
)
630+
631+
options = {
632+
"endpoint": self.model_name,
633+
}
634+
return self._bqml_model_factory.create_remote_model(
635+
session=self.session, connection_name=self.connection_name, options=options
636+
)
637+
638+
@classmethod
639+
def _from_bq(
640+
cls, session: bigframes.Session, bq_model: bigquery.Model
641+
) -> TextEmbeddingGenerator:
642+
assert bq_model.model_type == "MODEL_TYPE_UNSPECIFIED"
643+
assert "remoteModelInfo" in bq_model._properties
644+
assert "endpoint" in bq_model._properties["remoteModelInfo"]
645+
assert "connection" in bq_model._properties["remoteModelInfo"]
646+
647+
# Parse the remote model endpoint
648+
bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"]
649+
model_connection = bq_model._properties["remoteModelInfo"]["connection"]
650+
model_endpoint = bqml_endpoint.split("/")[-1]
651+
652+
model = cls(
653+
session=session,
654+
model_name=model_endpoint, # type: ignore
655+
connection_name=model_connection,
656+
)
657+
658+
model._bqml_model = core.BqmlModel(session, bq_model)
659+
return model
660+
661+
def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame:
662+
"""Predict the result from input DataFrame.
663+
664+
Args:
665+
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
666+
Input DataFrame, which needs to contain a column with name "content". Only the column will be used as input. Content can include preamble, questions, suggestions, instructions, or examples.
667+
668+
Returns:
669+
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
670+
"""
671+
672+
# Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
673+
(X,) = utils.convert_to_dataframe(X)
674+
675+
if len(X.columns) != 1:
676+
raise ValueError(
677+
f"Only support one column as input. {constants.FEEDBACK_LINK}"
678+
)
679+
680+
# BQML identified the column by name
681+
col_label = cast(blocks.Label, X.columns[0])
682+
X = X.rename(columns={col_label: "content"})
683+
684+
options = {
685+
"flatten_json_output": True,
686+
}
687+
688+
df = self._bqml_model.generate_embedding(X, options)
689+
690+
if (df[_ML_GENERATE_EMBEDDING_STATUS] != "").any():
691+
warnings.warn(
692+
f"Some predictions failed. Check column {_ML_GENERATE_EMBEDDING_STATUS} for detailed status. You may want to filter the failed rows and retry.",
693+
RuntimeWarning,
694+
)
695+
696+
return df
697+
698+
def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerator:
699+
"""Save the model to BigQuery.
700+
701+
Args:
702+
model_name (str):
703+
The name of the model.
704+
replace (bool, default False):
705+
Determine whether to replace if the model already exists. Default to False.
706+
707+
Returns:
708+
TextEmbeddingGenerator: Saved model."""
709+
710+
new_model = self._bqml_model.copy(model_name, replace)
711+
return new_model.session.read_gbq_model(model_name)
712+
713+
554714
@log_adapter.class_logger
555715
class GeminiTextGenerator(base.BaseEstimator):
556716
"""Gemini text generator LLM model.

‎bigframes/ml/loader.py

+3
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
llm._GEMINI_PRO_ENDPOINT: llm.GeminiTextGenerator,
6464
llm._GEMINI_1P5_PRO_PREVIEW_ENDPOINT: llm.GeminiTextGenerator,
6565
llm._GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT: llm.GeminiTextGenerator,
66+
llm._TEXT_EMBEDDING_004_ENDPOINT: llm.TextEmbeddingGenerator,
67+
llm._TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT: llm.TextEmbeddingGenerator,
6668
}
6769
)
6870

@@ -84,6 +86,7 @@ def from_bq(
8486
imported.XGBoostModel,
8587
llm.PaLM2TextGenerator,
8688
llm.PaLM2TextEmbeddingGenerator,
89+
llm.TextEmbeddingGenerator,
8790
pipeline.Pipeline,
8891
compose.ColumnTransformer,
8992
preprocessing.PreprocessingType,

‎tests/system/small/ml/test_llm.py

+41
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,47 @@ def test_embedding_generator_predict_series_success(
304304
assert len(value) == 768
305305

306306

307+
@pytest.mark.parametrize(
308+
"model_name",
309+
("text-embedding-004", "text-multilingual-embedding-002"),
310+
)
311+
def test_create_load_text_embedding_generator_model(
312+
dataset_id, model_name, session, bq_connection
313+
):
314+
text_embedding_model = llm.TextEmbeddingGenerator(
315+
model_name=model_name, connection_name=bq_connection, session=session
316+
)
317+
assert text_embedding_model is not None
318+
assert text_embedding_model._bqml_model is not None
319+
320+
# save, load to ensure configuration was kept
321+
reloaded_model = text_embedding_model.to_gbq(
322+
f"{dataset_id}.temp_text_model", replace=True
323+
)
324+
assert f"{dataset_id}.temp_text_model" == reloaded_model._bqml_model.model_name
325+
assert reloaded_model.connection_name == bq_connection
326+
assert reloaded_model.model_name == model_name
327+
328+
329+
@pytest.mark.parametrize(
330+
"model_name",
331+
("text-embedding-004", "text-multilingual-embedding-002"),
332+
)
333+
@pytest.mark.flaky(retries=2)
334+
def test_gemini_text_embedding_generator_predict_default_params_success(
335+
llm_text_df, model_name, session, bq_connection
336+
):
337+
text_embedding_model = llm.TextEmbeddingGenerator(
338+
model_name=model_name, connection_name=bq_connection, session=session
339+
)
340+
df = text_embedding_model.predict(llm_text_df).to_pandas()
341+
assert df.shape == (3, 4)
342+
assert "ml_generate_embedding_result" in df.columns
343+
series = df["ml_generate_embedding_result"]
344+
value = series[0]
345+
assert len(value) == 768
346+
347+
307348
@pytest.mark.parametrize(
308349
"model_name",
309350
("gemini-pro", "gemini-1.5-pro-preview-0514", "gemini-1.5-flash-preview-0514"),

0 commit comments

Comments
 (0)
Failed to load comments.