Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit e0f1ab0

Browse files
authoredFeb 29, 2024
feat: add TextEmbedding model version support (#394)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent 1726588 commit e0f1ab0

File tree

5 files changed

+61
-5
lines changed

5 files changed

+61
-5
lines changed
 

‎bigframes/ml/llm.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,9 @@ class PaLM2TextEmbeddingGenerator(base.Predictor):
266266
The model for text embedding. “textembedding-gecko” returns model embeddings for text inputs.
267267
"textembedding-gecko-multilingual" returns model embeddings for text inputs which support over 100 languages
268268
Default to "textembedding-gecko".
269+
version (str or None):
270+
Model version. Accepted values are "001", "002", "003", "latest" etc. Will use the default version if unset.
271+
See https://cloud.google.com/vertex-ai/docs/generative-ai/learn/model-versioning for details.
269272
session (bigframes.Session or None):
270273
BQ session to create the model. If None, use the global default session.
271274
connection_name (str or None):
@@ -279,10 +282,12 @@ def __init__(
279282
model_name: Literal[
280283
"textembedding-gecko", "textembedding-gecko-multilingual"
281284
] = "textembedding-gecko",
285+
version: Optional[str] = None,
282286
session: Optional[bigframes.Session] = None,
283287
connection_name: Optional[str] = None,
284288
):
285289
self.model_name = model_name
290+
self.version = version
286291
self.session = session or bpd.get_global_session()
287292
self._bq_connection_manager = clients.BqConnectionManager(
288293
self.session.bqconnectionclient, self.session.resourcemanagerclient
@@ -321,8 +326,11 @@ def _create_bqml_model(self):
321326
f"Model name {self.model_name} is not supported. We only support {', '.join(_EMBEDDING_GENERATOR_ENDPOINTS)}."
322327
)
323328

329+
endpoint = (
330+
self.model_name + "@" + self.version if self.version else self.model_name
331+
)
324332
options = {
325-
"endpoint": self.model_name,
333+
"endpoint": endpoint,
326334
}
327335
return self._bqml_model_factory.create_remote_model(
328336
session=self.session, connection_name=self.connection_name, options=options
@@ -342,8 +350,14 @@ def _from_bq(
342350
model_connection = model._properties["remoteModelInfo"]["connection"]
343351
model_endpoint = bqml_endpoint.split("/")[-1]
344352

353+
model_name, version = utils.parse_model_endpoint(model_endpoint)
354+
345355
embedding_generator_model = cls(
346-
session=session, model_name=model_endpoint, connection_name=model_connection
356+
session=session,
357+
# str to literals
358+
model_name=model_name, # type: ignore
359+
version=version,
360+
connection_name=model_connection,
347361
)
348362
embedding_generator_model._bqml_model = core.BqmlModel(session, model)
349363
return embedding_generator_model

‎bigframes/ml/loader.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
linear_model,
3131
llm,
3232
pipeline,
33+
utils,
3334
)
3435

3536
_BQML_MODEL_TYPE_MAPPING = MappingProxyType(
@@ -106,8 +107,10 @@ def _model_from_bq(session: bigframes.Session, bq_model: bigquery.Model):
106107
):
107108
# Parse the remote model endpoint
108109
bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"]
109-
endpoint_model = bqml_endpoint.split("/")[-1]
110-
return _BQML_ENDPOINT_TYPE_MAPPING[endpoint_model]._from_bq( # type: ignore
110+
model_endpoint = bqml_endpoint.split("/")[-1]
111+
model_name, _ = utils.parse_model_endpoint(model_endpoint)
112+
113+
return _BQML_ENDPOINT_TYPE_MAPPING[model_name]._from_bq( # type: ignore
111114
session=session, model=bq_model
112115
)
113116

‎bigframes/ml/utils.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import typing
16-
from typing import Iterable, Union
16+
from typing import Iterable, Optional, Union
1717

1818
import bigframes.constants as constants
1919
from bigframes.core import blocks
@@ -56,3 +56,16 @@ def _convert_to_series(frame: ArrayType) -> bpd.Series:
5656
raise ValueError(
5757
f"Unsupported type {type(frame)} to convert to Series. {constants.FEEDBACK_LINK}"
5858
)
59+
60+
61+
def parse_model_endpoint(model_endpoint: str) -> tuple[str, Optional[str]]:
62+
"""Parse model endpoint string to model_name and version."""
63+
model_name = model_endpoint
64+
version = None
65+
66+
at_idx = model_endpoint.find("@")
67+
if at_idx != -1:
68+
version = model_endpoint[at_idx + 1 :]
69+
model_name = model_endpoint[:at_idx]
70+
71+
return model_name, version

‎tests/system/small/ml/conftest.py

+9
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,15 @@ def palm2_embedding_generator_model(
256256
)
257257

258258

259+
@pytest.fixture(scope="session")
260+
def palm2_embedding_generator_model_002(
261+
session, bq_connection
262+
) -> llm.PaLM2TextEmbeddingGenerator:
263+
return llm.PaLM2TextEmbeddingGenerator(
264+
version="002", session=session, connection_name=bq_connection
265+
)
266+
267+
259268
@pytest.fixture(scope="session")
260269
def palm2_embedding_generator_multilingual_model(
261270
session, bq_connection

‎tests/system/small/ml/test_llm.py

+17
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,23 @@ def test_create_embedding_generator_model(
194194
assert reloaded_model.connection_name == bq_connection
195195

196196

197+
def test_create_embedding_generator_model_002(
198+
palm2_embedding_generator_model_002, dataset_id, bq_connection
199+
):
200+
# Model creation doesn't return error
201+
assert palm2_embedding_generator_model_002 is not None
202+
assert palm2_embedding_generator_model_002._bqml_model is not None
203+
204+
# save, load to ensure configuration was kept
205+
reloaded_model = palm2_embedding_generator_model_002.to_gbq(
206+
f"{dataset_id}.temp_embedding_model", replace=True
207+
)
208+
assert f"{dataset_id}.temp_embedding_model" == reloaded_model._bqml_model.model_name
209+
assert reloaded_model.model_name == "textembedding-gecko"
210+
assert reloaded_model.version == "002"
211+
assert reloaded_model.connection_name == bq_connection
212+
213+
197214
def test_create_embedding_generator_multilingual_model(
198215
palm2_embedding_generator_multilingual_model,
199216
dataset_id,

0 commit comments

Comments
 (0)
Failed to load comments.