Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 2218c21

Browse files
authoredMay 10, 2024
feat: add Series.case_when() (#673)
* feat: add `Series.case_when()` * rename to ScalarOp * rename to exprs * add type annotations feat: add `DataFrame.__delitem__` (#673) docs: add logistic regression samples (#673)
1 parent 93416ed commit 2218c21

File tree

9 files changed

+311
-34
lines changed

9 files changed

+311
-34
lines changed
 

‎bigframes/core/__init__.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from dataclasses import dataclass
1717
import functools
1818
import io
19+
import itertools
1920
import typing
2021
from typing import Iterable, Sequence
2122

@@ -370,14 +371,16 @@ def unpivot(
370371
for col_id, input_ids in unpivot_columns:
371372
# row explode offset used to choose the input column
372373
# we use offset instead of label as labels are not necessarily unique
373-
cases = tuple(
374-
(
375-
ops.eq_op.as_expr(explode_offsets_id, ex.const(i)),
376-
ex.free_var(id_or_null)
377-
if (id_or_null is not None)
378-
else ex.const(None),
374+
cases = itertools.chain(
375+
*(
376+
(
377+
ops.eq_op.as_expr(explode_offsets_id, ex.const(i)),
378+
ex.free_var(id_or_null)
379+
if (id_or_null is not None)
380+
else ex.const(None),
381+
)
382+
for i, id_or_null in enumerate(input_ids)
379383
)
380-
for i, id_or_null in enumerate(input_ids)
381384
)
382385
col_expr = ops.case_when_op.as_expr(*cases)
383386
unpivot_exprs.append((col_expr, col_id))

‎bigframes/core/blocks.py

+9
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,15 @@ def apply_ternary_op(
803803
expr = op.as_expr(col_id_1, col_id_2, col_id_3)
804804
return self.project_expr(expr, result_label)
805805

806+
def apply_nary_op(
807+
self,
808+
columns: Iterable[str],
809+
op: ops.NaryOp,
810+
result_label: Label = None,
811+
) -> typing.Tuple[Block, str]:
812+
expr = op.as_expr(*columns)
813+
return self.project_expr(expr, result_label)
814+
806815
def multi_apply_window_op(
807816
self,
808817
columns: typing.Sequence[str],

‎bigframes/dataframe.py

+4
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,10 @@ def _repr_html_(self) -> str:
655655
html_string += f"[{row_count} rows x {column_count} columns in total]"
656656
return html_string
657657

658+
def __delitem__(self, key: str):
659+
df = self.drop(columns=[key])
660+
self._set_block(df._get_block())
661+
658662
def __setitem__(self, key: str, value: SingleItemValue):
659663
df = self._assign_single_item(key, value)
660664
self._set_block(df._get_block())

‎bigframes/operations/__init__.py

+25-26
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import dataclasses
1818
import functools
1919
import typing
20-
from typing import Tuple, Union
20+
from typing import Union
2121

2222
import numpy as np
2323
import pandas as pd
@@ -46,7 +46,7 @@ def order_preserving(self) -> bool:
4646

4747

4848
@dataclasses.dataclass(frozen=True)
49-
class NaryOp:
49+
class ScalarOp:
5050
@property
5151
def name(self) -> str:
5252
raise NotImplementedError("RowOp abstract base class has no implementation")
@@ -60,10 +60,30 @@ def order_preserving(self) -> bool:
6060
return False
6161

6262

63+
@dataclasses.dataclass(frozen=True)
64+
class NaryOp(ScalarOp):
65+
def as_expr(
66+
self,
67+
*exprs: Union[str | bigframes.core.expression.Expression],
68+
) -> bigframes.core.expression.Expression:
69+
import bigframes.core.expression
70+
71+
# Keep this in sync with output_type and compilers
72+
inputs: list[bigframes.core.expression.Expression] = []
73+
74+
for expr in exprs:
75+
inputs.append(_convert_expr_input(expr))
76+
77+
return bigframes.core.expression.OpExpression(
78+
self,
79+
tuple(inputs),
80+
)
81+
82+
6383
# These classes can be used to create simple ops that don't take local parameters
6484
# All is needed is a unique name, and to register an implementation in ibis_mappings.py
6585
@dataclasses.dataclass(frozen=True)
66-
class UnaryOp(NaryOp):
86+
class UnaryOp(ScalarOp):
6787
@property
6888
def arguments(self) -> int:
6989
return 1
@@ -79,7 +99,7 @@ def as_expr(
7999

80100

81101
@dataclasses.dataclass(frozen=True)
82-
class BinaryOp(NaryOp):
102+
class BinaryOp(ScalarOp):
83103
@property
84104
def arguments(self) -> int:
85105
return 2
@@ -101,7 +121,7 @@ def as_expr(
101121

102122

103123
@dataclasses.dataclass(frozen=True)
104-
class TernaryOp(NaryOp):
124+
class TernaryOp(ScalarOp):
105125
@property
106126
def arguments(self) -> int:
107127
return 3
@@ -655,27 +675,6 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
655675
output_expr_types,
656676
)
657677

658-
def as_expr(
659-
self,
660-
*case_output_pairs: Tuple[
661-
Union[str | bigframes.core.expression.Expression],
662-
Union[str | bigframes.core.expression.Expression],
663-
],
664-
) -> bigframes.core.expression.Expression:
665-
import bigframes.core.expression
666-
667-
# Keep this in sync with output_type and compilers
668-
inputs: list[bigframes.core.expression.Expression] = []
669-
670-
for case, output in case_output_pairs:
671-
inputs.append(_convert_expr_input(case))
672-
inputs.append(_convert_expr_input(output))
673-
674-
return bigframes.core.expression.OpExpression(
675-
self,
676-
tuple(inputs),
677-
)
678-
679678

680679
case_when_op = CaseWhenOp()
681680

‎bigframes/operations/base.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import typing
18+
from typing import List, Sequence
1819

1920
import bigframes_vendored.pandas.pandas._typing as vendored_pandas_typing
2021
import numpy
@@ -205,6 +206,21 @@ def _apply_binary_op(
205206
block, result_id = self._block.project_expr(expr, name)
206207
return series.Series(block.select_column(result_id))
207208

209+
def _apply_nary_op(
210+
self,
211+
op: ops.NaryOp,
212+
others: Sequence[typing.Union[series.Series, scalars.Scalar]],
213+
ignore_self=False,
214+
):
215+
"""Applies an n-ary operator to the series and others."""
216+
values, block = self._align_n(others, ignore_self=ignore_self)
217+
block, result_id = block.apply_nary_op(
218+
values,
219+
op,
220+
self._name,
221+
)
222+
return series.Series(block.select_column(result_id))
223+
208224
def _apply_binary_aggregation(
209225
self, other: series.Series, stat: agg_ops.BinaryAggregateOp
210226
) -> float:
@@ -226,8 +242,13 @@ def _align_n(
226242
self,
227243
others: typing.Sequence[typing.Union[series.Series, scalars.Scalar]],
228244
how="outer",
245+
ignore_self=False,
229246
) -> tuple[typing.Sequence[str], blocks.Block]:
230-
value_ids = [self._value_column]
247+
if ignore_self:
248+
value_ids: List[str] = []
249+
else:
250+
value_ids = [self._value_column]
251+
231252
block = self._block
232253
for other in others:
233254
if isinstance(other, series.Series):

‎bigframes/series.py

+19
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,25 @@ def between(self, left, right, inclusive="both"):
410410
self._apply_binary_op(right, right_op)
411411
)
412412

413+
def case_when(self, caselist) -> Series:
414+
return self._apply_nary_op(
415+
ops.case_when_op,
416+
tuple(
417+
itertools.chain(
418+
itertools.chain(*caselist),
419+
# Fallback to current value if no other matches.
420+
(
421+
# We make a Series with a constant value to avoid casts to
422+
# types other than boolean.
423+
Series(True, index=self.index, dtype=pandas.BooleanDtype()),
424+
self,
425+
),
426+
),
427+
),
428+
# Self is already included in "others".
429+
ignore_self=True,
430+
)
431+
413432
def cumsum(self) -> Series:
414433
return self._apply_window_op(
415434
agg_ops.sum_op, bigframes.core.window_spec.WindowSpec(following=0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""BigQuery DataFrames code samples for
16+
https://cloud.google.com/bigquery/docs/logistic-regression-prediction.
17+
"""
18+
19+
20+
def test_logistic_regression_prediction(random_model_id: str) -> None:
21+
your_model_id = random_model_id
22+
23+
# [START bigquery_dataframes_logistic_regression_prediction_examine]
24+
import bigframes.pandas as bpd
25+
26+
df = bpd.read_gbq(
27+
"bigquery-public-data.ml_datasets.census_adult_income",
28+
columns=(
29+
"age",
30+
"workclass",
31+
"marital_status",
32+
"education_num",
33+
"occupation",
34+
"hours_per_week",
35+
"income_bracket",
36+
"functional_weight",
37+
),
38+
max_results=100,
39+
)
40+
df.peek()
41+
# Output:
42+
# age workclass marital_status education_num occupation hours_per_week income_bracket functional_weight
43+
# 47 Local-gov Married-civ-spouse 13 Prof-specialty 40 >50K 198660
44+
# 56 Private Never-married 9 Adm-clerical 40 <=50K 85018
45+
# 40 Private Married-civ-spouse 12 Tech-support 40 >50K 285787
46+
# 34 Self-emp-inc Married-civ-spouse 9 Craft-repair 54 >50K 207668
47+
# 23 Private Married-civ-spouse 10 Handlers-cleaners 40 <=50K 40060
48+
# [END bigquery_dataframes_logistic_regression_prediction_examine]
49+
50+
# [START bigquery_dataframes_logistic_regression_prediction_prepare]
51+
import bigframes.pandas as bpd
52+
53+
input_data = bpd.read_gbq(
54+
"bigquery-public-data.ml_datasets.census_adult_income",
55+
columns=(
56+
"age",
57+
"workclass",
58+
"marital_status",
59+
"education_num",
60+
"occupation",
61+
"hours_per_week",
62+
"income_bracket",
63+
"functional_weight",
64+
),
65+
)
66+
input_data["dataframe"] = bpd.Series("training", index=input_data.index,).case_when(
67+
[
68+
(((input_data["functional_weight"] % 10) == 8), "evaluation"),
69+
(((input_data["functional_weight"] % 10) == 9), "prediction"),
70+
]
71+
)
72+
del input_data["functional_weight"]
73+
# [END bigquery_dataframes_logistic_regression_prediction_prepare]
74+
75+
# [START bigquery_dataframes_logistic_regression_prediction_create_model]
76+
import bigframes.ml.linear_model
77+
78+
# input_data is defined in an earlier step.
79+
training_data = input_data[input_data["dataframe"] == "training"]
80+
X = training_data.drop(columns=["income_bracket", "dataframe"])
81+
y = training_data["income_bracket"]
82+
83+
census_model = bigframes.ml.linear_model.LogisticRegression()
84+
census_model.fit(X, y)
85+
86+
census_model.to_gbq(
87+
your_model_id, # For example: "your-project.census.census_model"
88+
replace=True,
89+
)
90+
# [END bigquery_dataframes_logistic_regression_prediction_create_model]
91+
92+
# [START bigquery_dataframes_logistic_regression_prediction_evaluate_model]
93+
# Select model you'll use for predictions. `read_gbq_model` loads model
94+
# data from BigQuery, but you could also use the `census_model` object
95+
# from previous steps.
96+
census_model = bpd.read_gbq_model(
97+
your_model_id, # For example: "your-project.census.census_model"
98+
)
99+
100+
# input_data is defined in an earlier step.
101+
evaluation_data = input_data[input_data["dataframe"] == "evaluation"]
102+
X = evaluation_data.drop(columns=["income_bracket", "dataframe"])
103+
y = evaluation_data["income_bracket"]
104+
105+
# The score() method evaluates how the model performs compared to the
106+
# actual data. Output DataFrame matches that of ML.EVALUATE().
107+
score = census_model.score(X, y)
108+
score.peek()
109+
# Output:
110+
# precision recall accuracy f1_score log_loss roc_auc
111+
# 0 0.685764 0.536685 0.83819 0.602134 0.350417 0.882953
112+
# [END bigquery_dataframes_logistic_regression_prediction_evaluate_model]
113+
114+
# [START bigquery_dataframes_logistic_regression_prediction_predict_income_bracket]
115+
# Select model you'll use for predictions. `read_gbq_model` loads model
116+
# data from BigQuery, but you could also use the `census_model` object
117+
# from previous steps.
118+
census_model = bpd.read_gbq_model(
119+
your_model_id, # For example: "your-project.census.census_model"
120+
)
121+
122+
# input_data is defined in an earlier step.
123+
prediction_data = input_data[input_data["dataframe"] == "prediction"]
124+
125+
predictions = census_model.predict(prediction_data)
126+
predictions.peek()
127+
# Output:
128+
# predicted_income_bracket predicted_income_bracket_probs age workclass ... occupation hours_per_week income_bracket dataframe
129+
# 18004 <=50K [{'label': ' >50K', 'prob': 0.0763305999358786... 75 ? ... ? 6 <=50K prediction
130+
# 18886 <=50K [{'label': ' >50K', 'prob': 0.0448866871906495... 73 ? ... ? 22 >50K prediction
131+
# 31024 <=50K [{'label': ' >50K', 'prob': 0.0362982319421936... 69 ? ... ? 1 <=50K prediction
132+
# 31022 <=50K [{'label': ' >50K', 'prob': 0.0787836112058324... 75 ? ... ? 5 <=50K prediction
133+
# 23295 <=50K [{'label': ' >50K', 'prob': 0.3385373037905673... 78 ? ... ? 32 <=50K prediction
134+
# [END bigquery_dataframes_logistic_regression_prediction_predict_income_bracket]
135+
136+
# TODO(tswast): Implement ML.EXPLAIN_PREDICT() and corresponding sample.
137+
# TODO(tswast): Implement ML.GLOBAL_EXPLAIN() and corresponding sample.
There was a problem loading the remainder of the diff.

0 commit comments

Comments
 (0)
Failed to load comments.