Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 9c106bd

Browse files
ashleyxuumilkshakeiii
andauthoredApr 18, 2024
feat: Add fine tuning fit() for Palm2TextGenerator (#616)
* feat: support list of numerics in pandas.cut (#580) An internal user encountered this missing overload * move the tests to load-testing * add predict tests * address comments * address comments --------- Co-authored-by: Henry Solberg <henry.j.solberg@gmail.com>
1 parent 8f9ece6 commit 9c106bd

File tree

6 files changed

+219
-2
lines changed

6 files changed

+219
-2
lines changed
 

‎bigframes/ml/core.py

+40
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,46 @@ def create_model(
321321

322322
return self._create_model_with_sql(session=session, sql=sql)
323323

324+
def create_llm_remote_model(
325+
self,
326+
X_train: bpd.DataFrame,
327+
y_train: bpd.DataFrame,
328+
connection_name: str,
329+
options: Mapping[str, Union[str, int, float, Iterable[str]]] = {},
330+
) -> BqmlModel:
331+
"""Create a session-temporary BQML model with the CREATE OR REPLACE MODEL statement
332+
333+
Args:
334+
X_train: features columns for training
335+
y_train: labels columns for training
336+
options: a dict of options to configure the model. Generates a BQML OPTIONS
337+
clause
338+
connection_name:
339+
a BQ connection to talk with Vertex AI, of the format <PROJECT_NUMBER>.<REGION>.<CONNECTION_NAME>. https://cloud.google.com/bigquery/docs/create-cloud-resource-connection
340+
341+
Returns: a BqmlModel, wrapping a trained model in BigQuery
342+
"""
343+
options = dict(options)
344+
# Cache dataframes to make sure base table is not a snapshot
345+
# cached dataframe creates a full copy, never uses snapshot
346+
input_data = X_train._cached(force=True).join(
347+
y_train._cached(force=True), how="outer"
348+
)
349+
options.update({"INPUT_LABEL_COLS": y_train.columns.tolist()})
350+
351+
session = X_train._session
352+
353+
model_ref = self._create_model_ref(session._anonymous_dataset)
354+
355+
sql = self._model_creation_sql_generator.create_llm_remote_model(
356+
source_df=input_data,
357+
model_ref=model_ref,
358+
options=options,
359+
connection_name=connection_name,
360+
)
361+
362+
return self._create_model_with_sql(session=session, sql=sql)
363+
324364
def create_time_series_model(
325365
self,
326366
X_train: bpd.DataFrame,

‎bigframes/ml/llm.py

+70-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
from bigframes.ml import base, core, globals, utils
2828
import bigframes.pandas as bpd
2929

30+
_BQML_PARAMS_MAPPING = {
31+
"max_iterations": "maxIterations",
32+
}
33+
3034
_TEXT_GENERATOR_BISON_ENDPOINT = "text-bison"
3135
_TEXT_GENERATOR_BISON_32K_ENDPOINT = "text-bison-32k"
3236
_TEXT_GENERATOR_ENDPOINTS = (
@@ -62,6 +66,8 @@ class PaLM2TextGenerator(base.BaseEstimator):
6266
Connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
6367
if None, use default connection in session context. BigQuery DataFrame will try to create the connection and attach
6468
permission if the connection isn't fully setup.
69+
max_iterations (Optional[int], Default to 300):
70+
The number of steps to run when performing supervised tuning.
6571
"""
6672

6773
def __init__(
@@ -70,9 +76,11 @@ def __init__(
7076
model_name: Literal["text-bison", "text-bison-32k"] = "text-bison",
7177
session: Optional[bigframes.Session] = None,
7278
connection_name: Optional[str] = None,
79+
max_iterations: int = 300,
7380
):
7481
self.model_name = model_name
7582
self.session = session or bpd.get_global_session()
83+
self.max_iterations = max_iterations
7684
self._bq_connection_manager = self.session.bqconnectionmanager
7785

7886
connection_name = connection_name or self.session._bq_connection
@@ -132,12 +140,73 @@ def _from_bq(
132140
model_connection = model._properties["remoteModelInfo"]["connection"]
133141
model_endpoint = bqml_endpoint.split("/")[-1]
134142

143+
# Get the optional params
144+
kwargs: dict = {}
145+
last_fitting = model.training_runs[-1]["trainingOptions"]
146+
147+
dummy_text_generator = cls()
148+
for bf_param, _ in dummy_text_generator.__dict__.items():
149+
bqml_param = _BQML_PARAMS_MAPPING.get(bf_param)
150+
if bqml_param in last_fitting:
151+
# Convert types
152+
if bf_param in ["max_iterations"]:
153+
kwargs[bf_param] = int(last_fitting[bqml_param])
154+
135155
text_generator_model = cls(
136-
session=session, model_name=model_endpoint, connection_name=model_connection
156+
**kwargs,
157+
session=session,
158+
model_name=model_endpoint,
159+
connection_name=model_connection,
137160
)
138161
text_generator_model._bqml_model = core.BqmlModel(session, model)
139162
return text_generator_model
140163

164+
@property
165+
def _bqml_options(self) -> dict:
166+
"""The model options as they will be set for BQML"""
167+
options = {
168+
"max_iterations": self.max_iterations,
169+
"data_split_method": "NO_SPLIT",
170+
}
171+
return options
172+
173+
def fit(
174+
self,
175+
X: Union[bpd.DataFrame, bpd.Series],
176+
y: Union[bpd.DataFrame, bpd.Series],
177+
) -> PaLM2TextGenerator:
178+
"""Fine tune PaLM2TextGenerator model.
179+
180+
.. note::
181+
182+
This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
183+
Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
184+
and might have limited support. For more information, see the launch stage descriptions
185+
(https://cloud.google.com/products#product-launch-stages).
186+
187+
Args:
188+
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
189+
DataFrame of shape (n_samples, n_features). Training data.
190+
y (bigframes.dataframe.DataFrame or bigframes.series.Series:
191+
Training labels.
192+
193+
Returns:
194+
PaLM2TextGenerator: Fitted Estimator.
195+
"""
196+
X, y = utils.convert_to_dataframe(X, y)
197+
198+
options = self._bqml_options
199+
options["endpoint"] = self.model_name + "@001"
200+
options["prompt_col"] = X.columns.tolist()[0]
201+
202+
self._bqml_model = self._bqml_model_factory.create_llm_remote_model(
203+
X,
204+
y,
205+
options=options,
206+
connection_name=self.connection_name,
207+
)
208+
return self
209+
141210
def predict(
142211
self,
143212
X: Union[bpd.DataFrame, bpd.Series],

‎bigframes/ml/sql.py

+17
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,23 @@ def create_model(
177177
parts.append(f"AS {source_sql}")
178178
return "\n".join(parts)
179179

180+
def create_llm_remote_model(
181+
self,
182+
source_df: bpd.DataFrame,
183+
connection_name: str,
184+
model_ref: google.cloud.bigquery.ModelReference,
185+
options: Mapping[str, Union[str, int, float, Iterable[str]]] = {},
186+
) -> str:
187+
"""Encode the CREATE OR REPLACE MODEL statement for BQML"""
188+
source_sql = source_df.sql
189+
190+
parts = [f"CREATE OR REPLACE MODEL {self._model_id_sql(model_ref)}"]
191+
parts.append(self.connection(connection_name))
192+
if options:
193+
parts.append(self.options(**options))
194+
parts.append(f"AS {source_sql}")
195+
return "\n".join(parts)
196+
180197
def create_remote_model(
181198
self,
182199
connection_name: str,

‎tests/system/load/test_llm.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pandas as pd
16+
import pytest
17+
18+
import bigframes.ml.llm
19+
20+
21+
@pytest.fixture(scope="session")
22+
def llm_fine_tune_df_default_index(
23+
session: bigframes.Session,
24+
) -> bigframes.dataframe.DataFrame:
25+
sql = """
26+
SELECT
27+
CONCAT("Please do sentiment analysis on the following text and only output a number from 0 to 5 where 0 means sadness, 1 means joy, 2 means love, 3 means anger, 4 means fear, and 5 means surprise. Text: ", text) as prompt,
28+
CAST(label AS STRING) as label
29+
FROM `llm_tuning.emotion_classification_train`
30+
"""
31+
return session.read_gbq(sql)
32+
33+
34+
@pytest.fixture(scope="session")
35+
def llm_remote_text_pandas_df():
36+
"""Additional data matching the penguins dataset, with a new index"""
37+
return pd.DataFrame(
38+
{
39+
"prompt": [
40+
"Please do sentiment analysis on the following text and only output a number from 0 to 5where 0 means sadness, 1 means joy, 2 means love, 3 means anger, 4 means fear, and 5 means surprise. Text: i feel beautifully emotional knowing that these women of whom i knew just a handful were holding me and my baba on our journey",
41+
"Please do sentiment analysis on the following text and only output a number from 0 to 5 where 0 means sadness, 1 means joy, 2 means love, 3 means anger, 4 means fear, and 5 means surprise. Text: i was feeling a little vain when i did this one",
42+
"Please do sentiment analysis on the following text and only output a number from 0 to 5 where 0 means sadness, 1 means joy, 2 means love, 3 means anger, 4 means fear, and 5 means surprise. Text: a father of children killed in an accident",
43+
],
44+
}
45+
)
46+
47+
48+
def test_llm_palm_configure_fit(
49+
llm_fine_tune_df_default_index, llm_remote_text_pandas_df
50+
):
51+
model = bigframes.ml.llm.PaLM2TextGenerator(
52+
model_name="text-bison", max_iterations=1
53+
)
54+
55+
df = llm_fine_tune_df_default_index.dropna()
56+
X_train = df[["prompt"]]
57+
y_train = df[["label"]]
58+
model.fit(X_train, y_train)
59+
60+
assert model is not None
61+
62+
df = model.predict(llm_remote_text_pandas_df).to_pandas()
63+
assert df.shape == (3, 4)
64+
assert "ml_generate_text_llm_result" in df.columns
65+
series = df["ml_generate_text_llm_result"]
66+
assert all(series.str.len() == 1)
67+
68+
# TODO(ashleyxu b/335492787): After bqml rolled out version control: save, load, check parameters to ensure configuration was kept

‎tests/system/small/ml/test_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 Google LLC
1+
# Copyright 2024 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

‎tests/unit/ml/test_sql.py

+23
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,29 @@ def test_create_model_transform_correct(
181181
)
182182

183183

184+
def test_create_llm_remote_model_correct(
185+
model_creation_sql_generator: ml_sql.ModelCreationSqlGenerator,
186+
mock_df: bpd.DataFrame,
187+
):
188+
sql = model_creation_sql_generator.create_llm_remote_model(
189+
source_df=mock_df,
190+
connection_name="my_project.us.my_connection",
191+
model_ref=bigquery.ModelReference.from_string(
192+
"test-proj._anonXYZ.create_remote_model"
193+
),
194+
options={"option_key1": "option_value1", "option_key2": 2},
195+
)
196+
assert (
197+
sql
198+
== """CREATE OR REPLACE MODEL `test-proj`.`_anonXYZ`.`create_remote_model`
199+
REMOTE WITH CONNECTION `my_project.us.my_connection`
200+
OPTIONS(
201+
option_key1="option_value1",
202+
option_key2=2)
203+
AS input_X_y_sql"""
204+
)
205+
206+
184207
def test_create_remote_model_correct(
185208
model_creation_sql_generator: ml_sql.ModelCreationSqlGenerator,
186209
):

0 commit comments

Comments
 (0)
Failed to load comments.