Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 79a638e

Browse files
authoredOct 30, 2023
feat: Implement operator @ for DataFrame.dot (#139)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes b/297502513 🦕
1 parent ac44ccd commit 79a638e

File tree

3 files changed

+51
-0
lines changed

3 files changed

+51
-0
lines changed
 

‎bigframes/dataframe.py

+2
Original file line numberDiff line numberDiff line change
@@ -2707,3 +2707,5 @@ def get_right_id(id):
27072707
result = result[other.name].rename()
27082708

27092709
return result
2710+
2711+
__matmul__ = dot

‎tests/system/small/test_dataframe.py

+33
Original file line numberDiff line numberDiff line change
@@ -3264,6 +3264,23 @@ def test_df_dot(
32643264
)
32653265

32663266

3267+
def test_df_dot_operator(
3268+
matrix_2by3_df, matrix_2by3_pandas_df, matrix_3by4_df, matrix_3by4_pandas_df
3269+
):
3270+
bf_result = (matrix_2by3_df @ matrix_3by4_df).to_pandas()
3271+
pd_result = matrix_2by3_pandas_df @ matrix_3by4_pandas_df
3272+
3273+
# Patch pandas dtypes for testing parity
3274+
# Pandas result is object instead of Int64 (nullable) dtype.
3275+
for name in pd_result.columns:
3276+
pd_result[name] = pd_result[name].astype(pd.Int64Dtype())
3277+
3278+
pd.testing.assert_frame_equal(
3279+
bf_result,
3280+
pd_result,
3281+
)
3282+
3283+
32673284
def test_df_dot_series(
32683285
matrix_2by3_df, matrix_2by3_pandas_df, matrix_3by4_df, matrix_3by4_pandas_df
32693286
):
@@ -3278,3 +3295,19 @@ def test_df_dot_series(
32783295
bf_result,
32793296
pd_result,
32803297
)
3298+
3299+
3300+
def test_df_dot_operator_series(
3301+
matrix_2by3_df, matrix_2by3_pandas_df, matrix_3by4_df, matrix_3by4_pandas_df
3302+
):
3303+
bf_result = (matrix_2by3_df @ matrix_3by4_df["x"]).to_pandas()
3304+
pd_result = matrix_2by3_pandas_df @ matrix_3by4_pandas_df["x"]
3305+
3306+
# Patch pandas dtypes for testing parity
3307+
# Pandas result is object instead of Int64 (nullable) dtype.
3308+
pd_result = pd_result.astype(pd.Int64Dtype())
3309+
3310+
pd.testing.assert_series_equal(
3311+
bf_result,
3312+
pd_result,
3313+
)

‎tests/system/small/test_multiindex.py

+16
Original file line numberDiff line numberDiff line change
@@ -998,13 +998,19 @@ def test_df_multi_index_dot_not_supported():
998998
with pytest.raises(NotImplementedError, match="Multi-index input is not supported"):
999999
bf1.dot(bf2)
10001000

1001+
with pytest.raises(NotImplementedError, match="Multi-index input is not supported"):
1002+
bf1 @ bf2
1003+
10011004
# right multi-index
10021005
right_index = pandas.MultiIndex.from_tuples([("a", "aa"), ("a", "ab"), ("b", "bb")])
10031006
bf1 = bpd.DataFrame(left_matrix)
10041007
bf2 = bpd.DataFrame(right_matrix, index=right_index)
10051008
with pytest.raises(NotImplementedError, match="Multi-index input is not supported"):
10061009
bf1.dot(bf2)
10071010

1011+
with pytest.raises(NotImplementedError, match="Multi-index input is not supported"):
1012+
bf1 @ bf2
1013+
10081014

10091015
def test_column_multi_index_dot_not_supported():
10101016
left_matrix = [[1, 2, 3], [2, 5, 7]]
@@ -1022,10 +1028,20 @@ def test_column_multi_index_dot_not_supported():
10221028
):
10231029
bf1.dot(bf2)
10241030

1031+
with pytest.raises(
1032+
NotImplementedError, match="Multi-level column input is not supported"
1033+
):
1034+
bf1 @ bf2
1035+
10251036
# right multi-columns
10261037
bf1 = bpd.DataFrame(left_matrix)
10271038
bf2 = bpd.DataFrame(right_matrix, columns=multi_level_columns)
10281039
with pytest.raises(
10291040
NotImplementedError, match="Multi-level column input is not supported"
10301041
):
10311042
bf1.dot(bf2)
1043+
1044+
with pytest.raises(
1045+
NotImplementedError, match="Multi-level column input is not supported"
1046+
):
1047+
bf1 @ bf2

0 commit comments

Comments
 (0)
Failed to load comments.