15
15
from __future__ import annotations
16
16
17
17
import typing
18
- from typing import List , Sequence
18
+ from typing import List , Sequence , Union
19
19
20
20
import bigframes_vendored .constants as constants
21
21
import bigframes_vendored .pandas .pandas ._typing as vendored_pandas_typing
@@ -180,9 +180,10 @@ def _apply_binary_op(
180
180
(self_col , other_col , block ) = self ._align (other_series , how = alignment )
181
181
182
182
name = self ._name
183
+ # Drop name if both objects have name attr, but they don't match
183
184
if (
184
185
hasattr (other , "name" )
185
- and other .name != self ._name
186
+ and other_series .name != self ._name
186
187
and alignment == "outer"
187
188
):
188
189
name = None
@@ -208,41 +209,78 @@ def _apply_nary_op(
208
209
ignore_self = False ,
209
210
):
210
211
"""Applies an n-ary operator to the series and others."""
211
- values , block = self ._align_n (others , ignore_self = ignore_self )
212
- block , result_id = block .apply_nary_op (
213
- values ,
214
- op ,
215
- self ._name ,
212
+ values , block = self ._align_n (
213
+ others , ignore_self = ignore_self , cast_scalars = False
216
214
)
215
+ block , result_id = block .project_expr (op .as_expr (* values ))
217
216
return series .Series (block .select_column (result_id ))
218
217
219
218
def _apply_binary_aggregation (
220
219
self , other : series .Series , stat : agg_ops .BinaryAggregateOp
221
220
) -> float :
222
221
(left , right , block ) = self ._align (other , how = "outer" )
222
+ assert isinstance (left , ex .UnboundVariableExpression )
223
+ assert isinstance (right , ex .UnboundVariableExpression )
224
+ return block .get_binary_stat (left .id , right .id , stat )
225
+
226
+ AlignedExprT = Union [ex .ScalarConstantExpression , ex .UnboundVariableExpression ]
223
227
224
- return block .get_binary_stat (left , right , stat )
228
+ @typing .overload
229
+ def _align (
230
+ self , other : series .Series , how = "outer"
231
+ ) -> tuple [
232
+ ex .UnboundVariableExpression ,
233
+ ex .UnboundVariableExpression ,
234
+ blocks .Block ,
235
+ ]:
236
+ ...
225
237
226
- def _align (self , other : series .Series , how = "outer" ) -> tuple [str , str , blocks .Block ]: # type: ignore
238
+ @typing .overload
239
+ def _align (
240
+ self , other : typing .Union [series .Series , scalars .Scalar ], how = "outer"
241
+ ) -> tuple [ex .UnboundVariableExpression , AlignedExprT , blocks .Block ,]:
242
+ ...
243
+
244
+ def _align (
245
+ self , other : typing .Union [series .Series , scalars .Scalar ], how = "outer"
246
+ ) -> tuple [ex .UnboundVariableExpression , AlignedExprT , blocks .Block ,]:
227
247
"""Aligns the series value with another scalar or series object. Returns new left column id, right column id and joined tabled expression."""
228
248
values , block = self ._align_n (
229
249
[
230
250
other ,
231
251
],
232
252
how ,
233
253
)
234
- return (values [0 ], values [1 ], block )
254
+ return (typing .cast (ex .UnboundVariableExpression , values [0 ]), values [1 ], block )
255
+
256
+ def _align3 (self , other1 : series .Series | scalars .Scalar , other2 : series .Series | scalars .Scalar , how = "left" ) -> tuple [ex .UnboundVariableExpression , AlignedExprT , AlignedExprT , blocks .Block ]: # type: ignore
257
+ """Aligns the series value with 2 other scalars or series objects. Returns new values and joined tabled expression."""
258
+ values , index = self ._align_n ([other1 , other2 ], how )
259
+ return (
260
+ typing .cast (ex .UnboundVariableExpression , values [0 ]),
261
+ values [1 ],
262
+ values [2 ],
263
+ index ,
264
+ )
235
265
236
266
def _align_n (
237
267
self ,
238
268
others : typing .Sequence [typing .Union [series .Series , scalars .Scalar ]],
239
269
how = "outer" ,
240
270
ignore_self = False ,
241
- ) -> tuple [typing .Sequence [str ], blocks .Block ]:
271
+ cast_scalars : bool = True ,
272
+ ) -> tuple [
273
+ typing .Sequence [
274
+ Union [ex .ScalarConstantExpression , ex .UnboundVariableExpression ]
275
+ ],
276
+ blocks .Block ,
277
+ ]:
242
278
if ignore_self :
243
- value_ids : List [str ] = []
279
+ value_ids : List [
280
+ Union [ex .ScalarConstantExpression , ex .UnboundVariableExpression ]
281
+ ] = []
244
282
else :
245
- value_ids = [self ._value_column ]
283
+ value_ids = [ex . free_var ( self ._value_column ) ]
246
284
247
285
block = self ._block
248
286
for other in others :
@@ -252,14 +290,16 @@ def _align_n(
252
290
get_column_right ,
253
291
) = block .join (other ._block , how = how )
254
292
value_ids = [
255
- * [get_column_left [ value ] for value in value_ids ],
256
- get_column_right [other ._value_column ],
293
+ * [value . rename ( get_column_left ) for value in value_ids ],
294
+ ex . free_var ( get_column_right [other ._value_column ]) ,
257
295
]
258
296
else :
259
297
# Will throw if can't interpret as scalar.
260
298
dtype = typing .cast (bigframes .dtypes .Dtype , self ._dtype )
261
- block , constant_col_id = block .create_constant (other , dtype = dtype )
262
- value_ids = [* value_ids , constant_col_id ]
299
+ value_ids = [
300
+ * value_ids ,
301
+ ex .const (other , dtype = dtype if cast_scalars else None ),
302
+ ]
263
303
return (value_ids , block )
264
304
265
305
def _throw_if_null_index (self , opname : str ):
0 commit comments