Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit a5c94ec

Browse files
authoredApr 26, 2024
feat: Add .cache() method to persist intermediate dataframe (#626)
1 parent 1e7793c commit a5c94ec

File tree

5 files changed

+30
-14
lines changed

5 files changed

+30
-14
lines changed
 

‎bigframes/dataframe.py

+11
Original file line numberDiff line numberDiff line change
@@ -3397,6 +3397,17 @@ def _set_block(self, block: blocks.Block):
33973397
def _get_block(self) -> blocks.Block:
33983398
return self._block
33993399

3400+
def cache(self):
3401+
"""
3402+
Materializes the DataFrame to a temporary table.
3403+
3404+
Useful if the dataframe will be used multiple times, as this will avoid recomputating the shared intermediate value.
3405+
3406+
Returns:
3407+
DataFrame: Self
3408+
"""
3409+
return self._cached(force=True)
3410+
34003411
def _cached(self, *, force: bool = False) -> DataFrame:
34013412
"""Materialize dataframe to a temporary table.
34023413
No-op if the dataframe represents a trivial transformation of an existing materialization.

‎bigframes/ml/core.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def distance(
8383
"""
8484
assert len(x.columns) == 1 and len(y.columns) == 1
8585

86-
input_data = x._cached().join(y._cached(), how="outer")
86+
input_data = x.cache().join(y.cache(), how="outer")
8787
x_column_id, y_column_id = x._block.value_columns[0], y._block.value_columns[0]
8888

8989
return self._apply_sql(
@@ -310,11 +310,9 @@ def create_model(
310310
# Cache dataframes to make sure base table is not a snapshot
311311
# cached dataframe creates a full copy, never uses snapshot
312312
if y_train is None:
313-
input_data = X_train._cached(force=True)
313+
input_data = X_train.cache()
314314
else:
315-
input_data = X_train._cached(force=True).join(
316-
y_train._cached(force=True), how="outer"
317-
)
315+
input_data = X_train.cache().join(y_train.cache(), how="outer")
318316
options.update({"INPUT_LABEL_COLS": y_train.columns.tolist()})
319317

320318
session = X_train._session
@@ -354,9 +352,7 @@ def create_llm_remote_model(
354352
options = dict(options)
355353
# Cache dataframes to make sure base table is not a snapshot
356354
# cached dataframe creates a full copy, never uses snapshot
357-
input_data = X_train._cached(force=True).join(
358-
y_train._cached(force=True), how="outer"
359-
)
355+
input_data = X_train.cache().join(y_train.cache(), how="outer")
360356
options.update({"INPUT_LABEL_COLS": y_train.columns.tolist()})
361357

362358
session = X_train._session
@@ -389,9 +385,7 @@ def create_time_series_model(
389385
options = dict(options)
390386
# Cache dataframes to make sure base table is not a snapshot
391387
# cached dataframe creates a full copy, never uses snapshot
392-
input_data = X_train._cached(force=True).join(
393-
y_train._cached(force=True), how="outer"
394-
)
388+
input_data = X_train.cache().join(y_train.cache(), how="outer")
395389
options.update({"TIME_SERIES_TIMESTAMP_COL": X_train.columns.tolist()[0]})
396390
options.update({"TIME_SERIES_DATA_COL": y_train.columns.tolist()[0]})
397391

‎bigframes/series.py

+11
Original file line numberDiff line numberDiff line change
@@ -1682,6 +1682,17 @@ def _slice(
16821682
),
16831683
)
16841684

1685+
def cache(self):
1686+
"""
1687+
Materializes the Series to a temporary table.
1688+
1689+
Useful if the series will be used multiple times, as this will avoid recomputating the shared intermediate value.
1690+
1691+
Returns:
1692+
Series: Self
1693+
"""
1694+
return self._cached(force=True)
1695+
16851696
def _cached(self, *, force: bool = True) -> Series:
16861697
self._set_block(self._block.cached(force=force))
16871698
return self

‎tests/system/small/test_dataframe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4204,7 +4204,7 @@ def test_df_cached(scalars_df_index):
42044204
)
42054205
df = df[df["rowindex_2"] % 2 == 0]
42064206

4207-
df_cached_copy = df._cached()
4207+
df_cached_copy = df.cache()
42084208
pandas.testing.assert_frame_equal(df.to_pandas(), df_cached_copy.to_pandas())
42094209

42104210

‎tests/unit/ml/test_golden_sql.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def bqml_model_factory(mocker: pytest_mock.MockerFixture):
6363
def mock_y():
6464
mock_y = mock.create_autospec(spec=bpd.DataFrame)
6565
mock_y.columns = pd.Index(["input_column_label"])
66-
mock_y._cached.return_value = mock_y
66+
mock_y.cache.return_value = mock_y
6767

6868
return mock_y
6969

@@ -83,7 +83,7 @@ def mock_X(mock_y, mock_session):
8383
["index_column_id"],
8484
["index_column_label"],
8585
)
86-
mock_X._cached.return_value = mock_X
86+
mock_X.cache.return_value = mock_X
8787

8888
return mock_X
8989

0 commit comments

Comments
 (0)
Failed to load comments.