|
40 | 40 |
|
41 | 41 | _EMBEDDING_GENERATOR_GECKO_ENDPOINT = "textembedding-gecko"
|
42 | 42 | _EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT = "textembedding-gecko-multilingual"
|
43 |
| -_EMBEDDING_GENERATOR_ENDPOINTS = ( |
| 43 | +_PALM2_EMBEDDING_GENERATOR_ENDPOINTS = ( |
44 | 44 | _EMBEDDING_GENERATOR_GECKO_ENDPOINT,
|
45 | 45 | _EMBEDDING_GENERATOR_GECKO_MULTILINGUAL_ENDPOINT,
|
46 | 46 | )
|
47 | 47 |
|
| 48 | +_TEXT_EMBEDDING_004_ENDPOINT = "text-embedding-004" |
| 49 | +_TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT = "text-multilingual-embedding-002" |
| 50 | +_TEXT_EMBEDDING_ENDPOINTS = ( |
| 51 | + _TEXT_EMBEDDING_004_ENDPOINT, |
| 52 | + _TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT, |
| 53 | +) |
| 54 | + |
48 | 55 | _GEMINI_PRO_ENDPOINT = "gemini-pro"
|
49 | 56 | _GEMINI_1P5_PRO_PREVIEW_ENDPOINT = "gemini-1.5-pro-preview-0514"
|
50 | 57 | _GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT = "gemini-1.5-flash-preview-0514"
|
|
57 | 64 |
|
58 | 65 | _ML_GENERATE_TEXT_STATUS = "ml_generate_text_status"
|
59 | 66 | _ML_EMBED_TEXT_STATUS = "ml_embed_text_status"
|
| 67 | +_ML_GENERATE_EMBEDDING_STATUS = "ml_generate_embedding_status" |
60 | 68 |
|
61 | 69 |
|
62 | 70 | @log_adapter.class_logger
|
@@ -387,6 +395,10 @@ def to_gbq(self, model_name: str, replace: bool = False) -> PaLM2TextGenerator:
|
387 | 395 | class PaLM2TextEmbeddingGenerator(base.BaseEstimator):
|
388 | 396 | """PaLM2 text embedding generator LLM model.
|
389 | 397 |
|
| 398 | + .. note:: |
| 399 | + Models in this class are outdated and going to be deprecated. To use the most updated text embedding models, go to the TextEmbeddingGenerator class. |
| 400 | +
|
| 401 | +
|
390 | 402 | Args:
|
391 | 403 | model_name (str, Default to "textembedding-gecko"):
|
392 | 404 | The model for text embedding. “textembedding-gecko” returns model embeddings for text inputs.
|
@@ -447,9 +459,9 @@ def _create_bqml_model(self):
|
447 | 459 | iam_role="aiplatform.user",
|
448 | 460 | )
|
449 | 461 |
|
450 |
| - if self.model_name not in _EMBEDDING_GENERATOR_ENDPOINTS: |
| 462 | + if self.model_name not in _PALM2_EMBEDDING_GENERATOR_ENDPOINTS: |
451 | 463 | raise ValueError(
|
452 |
| - f"Model name {self.model_name} is not supported. We only support {', '.join(_EMBEDDING_GENERATOR_ENDPOINTS)}." |
| 464 | + f"Model name {self.model_name} is not supported. We only support {', '.join(_PALM2_EMBEDDING_GENERATOR_ENDPOINTS)}." |
453 | 465 | )
|
454 | 466 |
|
455 | 467 | endpoint = (
|
@@ -551,6 +563,154 @@ def to_gbq(
|
551 | 563 | return new_model.session.read_gbq_model(model_name)
|
552 | 564 |
|
553 | 565 |
|
| 566 | +@log_adapter.class_logger |
| 567 | +class TextEmbeddingGenerator(base.BaseEstimator): |
| 568 | + """Text embedding generator LLM model. |
| 569 | +
|
| 570 | + Args: |
| 571 | + model_name (str, Default to "text-embedding-004"): |
| 572 | + The model for text embedding. Possible values are "text-embedding-004" or "text-multilingual-embedding-002". |
| 573 | + text-embedding models returns model embeddings for text inputs. |
| 574 | + text-multilingual-embedding models returns model embeddings for text inputs which support over 100 languages. |
| 575 | + Default to "text-embedding-004". |
| 576 | + session (bigframes.Session or None): |
| 577 | + BQ session to create the model. If None, use the global default session. |
| 578 | + connection_name (str or None): |
| 579 | + Connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>. |
| 580 | + If None, use default connection in session context. |
| 581 | + """ |
| 582 | + |
| 583 | + def __init__( |
| 584 | + self, |
| 585 | + *, |
| 586 | + model_name: Literal[ |
| 587 | + "text-embedding-004", "text-multilingual-embedding-002" |
| 588 | + ] = "text-embedding-004", |
| 589 | + session: Optional[bigframes.Session] = None, |
| 590 | + connection_name: Optional[str] = None, |
| 591 | + ): |
| 592 | + self.model_name = model_name |
| 593 | + self.session = session or bpd.get_global_session() |
| 594 | + self._bq_connection_manager = self.session.bqconnectionmanager |
| 595 | + |
| 596 | + connection_name = connection_name or self.session._bq_connection |
| 597 | + self.connection_name = clients.resolve_full_bq_connection_name( |
| 598 | + connection_name, |
| 599 | + default_project=self.session._project, |
| 600 | + default_location=self.session._location, |
| 601 | + ) |
| 602 | + |
| 603 | + self._bqml_model_factory = globals.bqml_model_factory() |
| 604 | + self._bqml_model: core.BqmlModel = self._create_bqml_model() |
| 605 | + |
| 606 | + def _create_bqml_model(self): |
| 607 | + # Parse and create connection if needed. |
| 608 | + if not self.connection_name: |
| 609 | + raise ValueError( |
| 610 | + "Must provide connection_name, either in constructor or through session options." |
| 611 | + ) |
| 612 | + |
| 613 | + if self._bq_connection_manager: |
| 614 | + connection_name_parts = self.connection_name.split(".") |
| 615 | + if len(connection_name_parts) != 3: |
| 616 | + raise ValueError( |
| 617 | + f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}." |
| 618 | + ) |
| 619 | + self._bq_connection_manager.create_bq_connection( |
| 620 | + project_id=connection_name_parts[0], |
| 621 | + location=connection_name_parts[1], |
| 622 | + connection_id=connection_name_parts[2], |
| 623 | + iam_role="aiplatform.user", |
| 624 | + ) |
| 625 | + |
| 626 | + if self.model_name not in _TEXT_EMBEDDING_ENDPOINTS: |
| 627 | + raise ValueError( |
| 628 | + f"Model name {self.model_name} is not supported. We only support {', '.join(_TEXT_EMBEDDING_ENDPOINTS)}." |
| 629 | + ) |
| 630 | + |
| 631 | + options = { |
| 632 | + "endpoint": self.model_name, |
| 633 | + } |
| 634 | + return self._bqml_model_factory.create_remote_model( |
| 635 | + session=self.session, connection_name=self.connection_name, options=options |
| 636 | + ) |
| 637 | + |
| 638 | + @classmethod |
| 639 | + def _from_bq( |
| 640 | + cls, session: bigframes.Session, bq_model: bigquery.Model |
| 641 | + ) -> TextEmbeddingGenerator: |
| 642 | + assert bq_model.model_type == "MODEL_TYPE_UNSPECIFIED" |
| 643 | + assert "remoteModelInfo" in bq_model._properties |
| 644 | + assert "endpoint" in bq_model._properties["remoteModelInfo"] |
| 645 | + assert "connection" in bq_model._properties["remoteModelInfo"] |
| 646 | + |
| 647 | + # Parse the remote model endpoint |
| 648 | + bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"] |
| 649 | + model_connection = bq_model._properties["remoteModelInfo"]["connection"] |
| 650 | + model_endpoint = bqml_endpoint.split("/")[-1] |
| 651 | + |
| 652 | + model = cls( |
| 653 | + session=session, |
| 654 | + model_name=model_endpoint, # type: ignore |
| 655 | + connection_name=model_connection, |
| 656 | + ) |
| 657 | + |
| 658 | + model._bqml_model = core.BqmlModel(session, bq_model) |
| 659 | + return model |
| 660 | + |
| 661 | + def predict(self, X: Union[bpd.DataFrame, bpd.Series]) -> bpd.DataFrame: |
| 662 | + """Predict the result from input DataFrame. |
| 663 | +
|
| 664 | + Args: |
| 665 | + X (bigframes.dataframe.DataFrame or bigframes.series.Series): |
| 666 | + Input DataFrame, which needs to contain a column with name "content". Only the column will be used as input. Content can include preamble, questions, suggestions, instructions, or examples. |
| 667 | +
|
| 668 | + Returns: |
| 669 | + bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values. |
| 670 | + """ |
| 671 | + |
| 672 | + # Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models |
| 673 | + (X,) = utils.convert_to_dataframe(X) |
| 674 | + |
| 675 | + if len(X.columns) != 1: |
| 676 | + raise ValueError( |
| 677 | + f"Only support one column as input. {constants.FEEDBACK_LINK}" |
| 678 | + ) |
| 679 | + |
| 680 | + # BQML identified the column by name |
| 681 | + col_label = cast(blocks.Label, X.columns[0]) |
| 682 | + X = X.rename(columns={col_label: "content"}) |
| 683 | + |
| 684 | + options = { |
| 685 | + "flatten_json_output": True, |
| 686 | + } |
| 687 | + |
| 688 | + df = self._bqml_model.generate_embedding(X, options) |
| 689 | + |
| 690 | + if (df[_ML_GENERATE_EMBEDDING_STATUS] != "").any(): |
| 691 | + warnings.warn( |
| 692 | + f"Some predictions failed. Check column {_ML_GENERATE_EMBEDDING_STATUS} for detailed status. You may want to filter the failed rows and retry.", |
| 693 | + RuntimeWarning, |
| 694 | + ) |
| 695 | + |
| 696 | + return df |
| 697 | + |
| 698 | + def to_gbq(self, model_name: str, replace: bool = False) -> TextEmbeddingGenerator: |
| 699 | + """Save the model to BigQuery. |
| 700 | +
|
| 701 | + Args: |
| 702 | + model_name (str): |
| 703 | + The name of the model. |
| 704 | + replace (bool, default False): |
| 705 | + Determine whether to replace if the model already exists. Default to False. |
| 706 | +
|
| 707 | + Returns: |
| 708 | + TextEmbeddingGenerator: Saved model.""" |
| 709 | + |
| 710 | + new_model = self._bqml_model.copy(model_name, replace) |
| 711 | + return new_model.session.read_gbq_model(model_name) |
| 712 | + |
| 713 | + |
554 | 714 | @log_adapter.class_logger
|
555 | 715 | class GeminiTextGenerator(base.BaseEstimator):
|
556 | 716 | """Gemini text generator LLM model.
|
|
0 commit comments