Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit d751f5c

Browse files
authoredJun 6, 2024
feat: support fit() in GeminiTextGenerator (#758)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes internal #343765747🦕
1 parent e452203 commit d751f5c

File tree

2 files changed

+80
-0
lines changed

2 files changed

+80
-0
lines changed
 

‎bigframes/ml/llm.py

+53
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,8 @@ class GeminiTextGenerator(base.BaseEstimator):
571571
Connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
572572
If None, use default connection in session context. BigQuery DataFrame will try to create the connection and attach
573573
permission if the connection isn't fully set up.
574+
max_iterations (Optional[int], Default to 300):
575+
The number of steps to run when performing supervised tuning.
574576
"""
575577

576578
def __init__(
@@ -581,9 +583,11 @@ def __init__(
581583
] = "gemini-pro",
582584
session: Optional[bigframes.Session] = None,
583585
connection_name: Optional[str] = None,
586+
max_iterations: int = 300,
584587
):
585588
self.model_name = model_name
586589
self.session = session or bpd.get_global_session()
590+
self.max_iterations = max_iterations
587591
self._bq_connection_manager = self.session.bqconnectionmanager
588592

589593
connection_name = connection_name or self.session._bq_connection
@@ -647,6 +651,55 @@ def _from_bq(
647651
model._bqml_model = core.BqmlModel(session, bq_model)
648652
return model
649653

654+
@property
655+
def _bqml_options(self) -> dict:
656+
"""The model options as they will be set for BQML"""
657+
options = {
658+
"max_iterations": self.max_iterations,
659+
"data_split_method": "NO_SPLIT",
660+
}
661+
return options
662+
663+
def fit(
664+
self,
665+
X: Union[bpd.DataFrame, bpd.Series],
666+
y: Union[bpd.DataFrame, bpd.Series],
667+
) -> GeminiTextGenerator:
668+
"""Fine tune GeminiTextGenerator model. Only support "gemini-pro" model for now.
669+
670+
.. note::
671+
672+
This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
673+
Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
674+
and might have limited support. For more information, see the launch stage descriptions
675+
(https://cloud.google.com/products#product-launch-stages).
676+
677+
Args:
678+
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
679+
DataFrame of shape (n_samples, n_features). Training data.
680+
y (bigframes.dataframe.DataFrame or bigframes.series.Series:
681+
Training labels.
682+
683+
Returns:
684+
GeminiTextGenerator: Fitted estimator.
685+
"""
686+
if self._bqml_model.model_name.startswith("gemini-1.5"):
687+
raise NotImplementedError("Fit is not supported for gemini-1.5 model.")
688+
689+
X, y = utils.convert_to_dataframe(X, y)
690+
691+
options = self._bqml_options
692+
options["endpoint"] = "gemini-1.0-pro-002"
693+
options["prompt_col"] = X.columns.tolist()[0]
694+
695+
self._bqml_model = self._bqml_model_factory.create_llm_remote_model(
696+
X,
697+
y,
698+
options=options,
699+
connection_name=self.connection_name,
700+
)
701+
return self
702+
650703
def predict(
651704
self,
652705
X: Union[bpd.DataFrame, bpd.Series],

‎tests/system/load/test_llm.py

+27
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,30 @@ def test_llm_palm_score_params(llm_fine_tune_df_default_index):
9999
"evaluation_status",
100100
]
101101
assert all(col in score_result_col for col in expected_col)
102+
103+
104+
@pytest.mark.flaky(retries=2)
105+
def test_llm_gemini_configure_fit(llm_fine_tune_df_default_index, llm_remote_text_df):
106+
model = bigframes.ml.llm.GeminiTextGenerator(
107+
model_name="gemini-pro", max_iterations=1
108+
)
109+
110+
X_train = llm_fine_tune_df_default_index[["prompt"]]
111+
y_train = llm_fine_tune_df_default_index[["label"]]
112+
model.fit(X_train, y_train)
113+
114+
assert model is not None
115+
116+
df = model.predict(
117+
llm_remote_text_df["prompt"],
118+
temperature=0.5,
119+
max_output_tokens=100,
120+
top_k=20,
121+
top_p=0.5,
122+
).to_pandas()
123+
assert df.shape == (3, 4)
124+
assert "ml_generate_text_llm_result" in df.columns
125+
series = df["ml_generate_text_llm_result"]
126+
assert all(series.str.len() == 1)
127+
128+
# TODO(ashleyxu b/335492787): After bqml rolled out version control: save, load, check parameters to ensure configuration was kept

0 commit comments

Comments
 (0)
Failed to load comments.