Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit abe48d6

Browse files
sycaigcf-owl-bot[bot]
andauthoredFeb 19, 2025
feat: (Preview) Support diff aggregation for timestamp series. (#1405)
* [WIP] support time series diff * add tests * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --------- Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 4df61b4 commit abe48d6

File tree

4 files changed

+82
-0
lines changed

4 files changed

+82
-0
lines changed
 

‎bigframes/core/compile/aggregate_compiler.py

+18
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,24 @@ def _(
551551
raise TypeError(f"Cannot perform diff on type{column.type()}")
552552

553553

554+
@compile_unary_agg.register
555+
def _(
556+
op: agg_ops.TimeSeriesDiffOp,
557+
column: ibis_types.Column,
558+
window=None,
559+
) -> ibis_types.Value:
560+
if not column.type().is_timestamp():
561+
raise TypeError(f"Cannot perform time series diff on type{column.type()}")
562+
563+
original_column = cast(ibis_types.TimestampColumn, column)
564+
shifted_column = cast(
565+
ibis_types.TimestampColumn,
566+
compile_unary_agg(agg_ops.ShiftOp(op.periods), column, window),
567+
)
568+
569+
return original_column.delta(shifted_column, part="microsecond")
570+
571+
554572
@compile_unary_agg.register
555573
def _(
556574
op: agg_ops.AllOp,

‎bigframes/core/rewrite/timedeltas.py

+33
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from bigframes import operations as ops
2323
from bigframes.core import expression as ex
2424
from bigframes.core import nodes, schema, utils
25+
from bigframes.operations import aggregations as aggs
2526

2627

2728
@dataclasses.dataclass
@@ -59,6 +60,16 @@ def rewrite_timedelta_expressions(root: nodes.BigFrameNode) -> nodes.BigFrameNod
5960
by = tuple(_rewrite_ordering_expr(x, root.schema) for x in root.by)
6061
return nodes.OrderByNode(root.child, by)
6162

63+
if isinstance(root, nodes.WindowOpNode):
64+
return nodes.WindowOpNode(
65+
root.child,
66+
_rewrite_aggregation(root.expression, root.schema),
67+
root.window_spec,
68+
root.output_name,
69+
root.never_skip_nulls,
70+
root.skip_reproject_unsafe,
71+
)
72+
6273
return root
6374

6475

@@ -166,3 +177,25 @@ def _rewrite_floordiv_op(left: _TypedExpr, right: _TypedExpr) -> _TypedExpr:
166177
return _TypedExpr.create_op_expr(ops.ToTimedeltaOp("us"), result)
167178

168179
return result
180+
181+
182+
@functools.cache
183+
def _rewrite_aggregation(
184+
aggregation: ex.Aggregation, schema: schema.ArraySchema
185+
) -> ex.Aggregation:
186+
if not isinstance(aggregation, ex.UnaryAggregation):
187+
return aggregation
188+
if not isinstance(aggregation.op, aggs.DiffOp):
189+
return aggregation
190+
191+
if isinstance(aggregation.arg, ex.DerefOp):
192+
input_type = schema.get_type(aggregation.arg.id.sql)
193+
else:
194+
input_type = aggregation.arg.dtype
195+
196+
if dtypes.is_datetime_like(input_type):
197+
return ex.UnaryAggregation(
198+
aggs.TimeSeriesDiffOp(aggregation.op.periods), aggregation.arg
199+
)
200+
201+
return aggregation

‎bigframes/operations/aggregations.py

+19
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,25 @@ class DiffOp(UnaryWindowOp):
484484
def skips_nulls(self):
485485
return False
486486

487+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
488+
if dtypes.is_datetime_like(input_types[0]):
489+
return dtypes.TIMEDELTA_DTYPE
490+
return super().output_type(*input_types)
491+
492+
493+
@dataclasses.dataclass(frozen=True)
494+
class TimeSeriesDiffOp(UnaryWindowOp):
495+
periods: int
496+
497+
@property
498+
def skips_nulls(self):
499+
return False
500+
501+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
502+
if dtypes.is_datetime_like(input_types[0]):
503+
return dtypes.TIMEDELTA_DTYPE
504+
raise TypeError(f"expect datetime-like types, but got {input_types[0]}")
505+
487506

488507
@dataclasses.dataclass(frozen=True)
489508
class AllOp(UnaryAggregateOp):

‎tests/system/small/operations/test_datetimes.py

+12
Original file line numberDiff line numberDiff line change
@@ -448,3 +448,15 @@ def test_timestamp_diff_literal_sub_series(scalars_dfs, column, value):
448448

449449
expected_result = value - pd_series
450450
assert_series_equal(actual_result, expected_result)
451+
452+
453+
@pytest.mark.parametrize("column", ["timestamp_col", "datetime_col"])
454+
def test_timestamp_series_diff_agg(scalars_dfs, column):
455+
bf_df, pd_df = scalars_dfs
456+
bf_series = bf_df[column]
457+
pd_series = pd_df[column]
458+
459+
actual_result = bf_series.diff().to_pandas()
460+
461+
expected_result = pd_series.diff()
462+
assert_series_equal(actual_result, expected_result)

0 commit comments

Comments
 (0)
Failed to load comments.