Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 21b2188

Browse files
authoredMar 20, 2024
feat: add params for LinearRegression model (#464)
* feat: add params for LinearRegression model * fix tests * update docs
1 parent fb5d83b commit 21b2188

File tree

5 files changed

+67
-25
lines changed

5 files changed

+67
-25
lines changed
 

‎bigframes/ml/linear_model.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,25 @@ def __init__(
6161
"auto_strategy", "batch_gradient_descent", "normal_equation"
6262
] = "normal_equation",
6363
fit_intercept: bool = True,
64+
l1_reg: Optional[float] = None,
6465
l2_reg: float = 0.0,
6566
max_iterations: int = 20,
67+
warm_start: bool = False,
68+
learn_rate: Optional[float] = None,
6669
learn_rate_strategy: Literal["line_search", "constant"] = "line_search",
6770
early_stop: bool = True,
6871
min_rel_progress: float = 0.01,
69-
ls_init_learn_rate: float = 0.1,
72+
ls_init_learn_rate: Optional[float] = None,
7073
calculate_p_values: bool = False,
7174
enable_global_explain: bool = False,
7275
):
7376
self.optimize_strategy = optimize_strategy
7477
self.fit_intercept = fit_intercept
78+
self.l1_reg = l1_reg
7579
self.l2_reg = l2_reg
7680
self.max_iterations = max_iterations
81+
self.warm_start = warm_start
82+
self.learn_rate = learn_rate
7783
self.learn_rate_strategy = learn_rate_strategy
7884
self.early_stop = early_stop
7985
self.min_rel_progress = min_rel_progress
@@ -99,17 +105,21 @@ def _from_bq(
99105
for bf_param, bf_value in dummy_linear.__dict__.items():
100106
bqml_param = _BQML_PARAMS_MAPPING.get(bf_param)
101107
if bqml_param in last_fitting:
102-
kwargs[bf_param] = type(bf_value)(last_fitting[bqml_param])
108+
# Convert types
109+
kwargs[bf_param] = (
110+
float(last_fitting[bqml_param])
111+
if bf_param in ["l1_reg", "learn_rate", "ls_init_learn_rate"]
112+
else type(bf_value)(last_fitting[bqml_param])
113+
)
103114

104115
new_linear_regression = cls(**kwargs)
105116
new_linear_regression._bqml_model = core.BqmlModel(session, model)
106117
return new_linear_regression
107118

108119
@property
109-
def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
120+
def _bqml_options(self) -> dict:
110121
"""The model options as they will be set for BQML"""
111-
# TODO: Support l1_reg, warm_start, and learn_rate with error catching.
112-
return {
122+
options = {
113123
"model_type": "LINEAR_REG",
114124
"data_split_method": "NO_SPLIT",
115125
"optimize_strategy": self.optimize_strategy,
@@ -119,10 +129,20 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
119129
"learn_rate_strategy": self.learn_rate_strategy,
120130
"early_stop": self.early_stop,
121131
"min_rel_progress": self.min_rel_progress,
122-
"ls_init_learn_rate": self.ls_init_learn_rate,
123132
"calculate_p_values": self.calculate_p_values,
124133
"enable_global_explain": self.enable_global_explain,
125134
}
135+
if self.l1_reg is not None:
136+
options["l1_reg"] = self.l1_reg
137+
if self.learn_rate is not None:
138+
options["learn_rate"] = self.learn_rate
139+
if self.ls_init_learn_rate is not None:
140+
options["ls_init_learn_rate"] = self.ls_init_learn_rate
141+
# Even presenting warm_start returns error for NORMAL_EQUATION optimizer
142+
if self.warm_start is True:
143+
options["warm_start"] = self.warm_start
144+
145+
return options
126146

127147
def _fit(
128148
self,

‎bigframes/ml/sql.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def encode_value(self, v: Union[str, int, float, Iterable[str]]) -> str:
3838
inner = ", ".join([self.encode_value(x) for x in v])
3939
return f"[{inner}]"
4040
else:
41-
raise ValueError(f"Unexpected value type. {constants.FEEDBACK_LINK}")
41+
raise ValueError(
42+
f"Unexpected value type {type(v)}. {constants.FEEDBACK_LINK}"
43+
)
4244

4345
def build_parameters(self, **kwargs: Union[str, int, float, Iterable[str]]) -> str:
4446
"""Encode a dict of values into a formatted Iterable of key-value pairs for SQL"""

‎tests/system/large/ml/test_linear_model.py

+28-14
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,11 @@ def test_linear_regression_configure_fit_score(penguins_df_default_index, datase
6060
assert reloaded_model.calculate_p_values is False
6161
assert reloaded_model.early_stop is True
6262
assert reloaded_model.enable_global_explain is False
63+
assert reloaded_model.l1_reg is None
6364
assert reloaded_model.l2_reg == 0.0
65+
assert reloaded_model.learn_rate is None
6466
assert reloaded_model.learn_rate_strategy == "line_search"
65-
assert reloaded_model.ls_init_learn_rate == 0.1
67+
assert reloaded_model.ls_init_learn_rate is None
6668
assert reloaded_model.max_iterations == 20
6769
assert reloaded_model.min_rel_progress == 0.01
6870

@@ -71,7 +73,14 @@ def test_linear_regression_customized_params_fit_score(
7173
penguins_df_default_index, dataset_id
7274
):
7375
model = bigframes.ml.linear_model.LinearRegression(
74-
fit_intercept=False, l2_reg=0.1, min_rel_progress=0.01
76+
fit_intercept=False,
77+
l2_reg=0.2,
78+
min_rel_progress=0.02,
79+
l1_reg=0.2,
80+
max_iterations=30,
81+
optimize_strategy="batch_gradient_descent",
82+
learn_rate_strategy="constant",
83+
learn_rate=0.2,
7584
)
7685

7786
df = penguins_df_default_index.dropna()
@@ -92,12 +101,12 @@ def test_linear_regression_customized_params_fit_score(
92101
result = model.score(X_train, y_train).to_pandas()
93102
expected = pd.DataFrame(
94103
{
95-
"mean_absolute_error": [226.108411],
96-
"mean_squared_error": [80459.668456],
97-
"mean_squared_log_error": [0.00497],
98-
"median_absolute_error": [171.618872],
99-
"r2_score": [0.875415],
100-
"explained_variance": [0.875417],
104+
"mean_absolute_error": [240],
105+
"mean_squared_error": [91197],
106+
"mean_squared_log_error": [0.00573],
107+
"median_absolute_error": [197],
108+
"r2_score": [0.858],
109+
"explained_variance": [0.8588],
101110
},
102111
dtype="Float64",
103112
)
@@ -109,16 +118,21 @@ def test_linear_regression_customized_params_fit_score(
109118
assert (
110119
f"{dataset_id}.temp_configured_model" in reloaded_model._bqml_model.model_name
111120
)
112-
assert reloaded_model.optimize_strategy == "NORMAL_EQUATION"
121+
assert reloaded_model.optimize_strategy == "BATCH_GRADIENT_DESCENT"
113122
assert reloaded_model.fit_intercept is False
114123
assert reloaded_model.calculate_p_values is False
115124
assert reloaded_model.early_stop is True
116125
assert reloaded_model.enable_global_explain is False
117-
assert reloaded_model.l2_reg == 0.1
118-
assert reloaded_model.learn_rate_strategy == "line_search"
119-
assert reloaded_model.ls_init_learn_rate == 0.1
120-
assert reloaded_model.max_iterations == 20
121-
assert reloaded_model.min_rel_progress == 0.01
126+
assert reloaded_model.l1_reg == 0.2
127+
assert reloaded_model.l2_reg == 0.2
128+
assert reloaded_model.ls_init_learn_rate is None
129+
assert reloaded_model.max_iterations == 30
130+
assert reloaded_model.min_rel_progress == 0.02
131+
assert reloaded_model.learn_rate_strategy == "CONSTANT"
132+
assert reloaded_model.learn_rate == 0.2
133+
134+
135+
# TODO(garrettwu): add tests for param warm_start. Requires a trained model.
122136

123137

124138
def test_logistic_regression_configure_fit_score(penguins_df_default_index, dataset_id):

‎tests/unit/ml/test_golden_sql.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_linear_regression_default_fit(
105105
model.fit(mock_X, mock_y)
106106

107107
mock_session._start_query_ml_ddl.assert_called_once_with(
108-
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="normal_equation",\n fit_intercept=True,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n early_stop=True,\n min_rel_progress=0.01,\n ls_init_learn_rate=0.1,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
108+
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="normal_equation",\n fit_intercept=True,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n early_stop=True,\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
109109
)
110110

111111

@@ -115,7 +115,7 @@ def test_linear_regression_params_fit(bqml_model_factory, mock_session, mock_X,
115115
model.fit(mock_X, mock_y)
116116

117117
mock_session._start_query_ml_ddl.assert_called_once_with(
118-
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="normal_equation",\n fit_intercept=False,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n early_stop=True,\n min_rel_progress=0.01,\n ls_init_learn_rate=0.1,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
118+
'CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type="LINEAR_REG",\n data_split_method="NO_SPLIT",\n optimize_strategy="normal_equation",\n fit_intercept=False,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy="line_search",\n early_stop=True,\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=["input_column_label"])\nAS input_X_y_sql'
119119
)
120120

121121

‎third_party/bigframes_vendored/sklearn/linear_model/_base.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,24 @@ class LinearRegression(RegressorMixin, LinearModel):
7171
Default ``True``. Whether to calculate the intercept for this
7272
model. If set to False, no intercept will be used in calculations
7373
(i.e. data is expected to be centered).
74+
l1_reg (float or None, default None):
75+
The amount of L1 regularization applied. Default to None. Can't be set in "normal_equation" mode. If unset, value 0 is used.
7476
l2_reg (float, default 0.0):
7577
The amount of L2 regularization applied. Default to 0.
7678
max_iterations (int, default 20):
7779
The maximum number of training iterations or steps. Default to 20.
80+
warm_start (bool, default False):
81+
Determines whether to train a model with new training data, new model options, or both. Unless you explicitly override them, the initial options used to train the model are used for the warm start run. Default to False.
82+
learn_rate (float or None, default None):
83+
The learn rate for gradient descent when learn_rate_strategy='constant'. If unset, value 0.1 is used. If learn_rate_strategy='line_search', an error is returned.
7884
learn_rate_strategy (str, default "line_search"):
7985
The strategy for specifying the learning rate during training. Default to "line_search".
8086
early_stop (bool, default True):
8187
Whether training should stop after the first iteration in which the relative loss improvement is less than the value specified for min_rel_progress. Default to True.
8288
min_rel_progress (float, default 0.01):
8389
The minimum relative loss improvement that is necessary to continue training when EARLY_STOP is set to true. For example, a value of 0.01 specifies that each iteration must reduce the loss by 1% for training to continue. Default to 0.01.
84-
ls_init_learn_rate (float, default 0.1):
85-
Sets the initial learning rate that learn_rate_strategy='line_search' uses. This option can only be used if line_search is specified. Default to 0.1.
90+
ls_init_learn_rate (float or None, default None):
91+
Sets the initial learning rate that learn_rate_strategy='line_search' uses. This option can only be used if line_search is specified. If unset, value 0.1 is used.
8692
calculate_p_values (bool, default False):
8793
Specifies whether to compute p-values and standard errors during training. Default to False.
8894
enable_global_explain (bool, default False):

0 commit comments

Comments
 (0)
Failed to load comments.