Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 4a12e3c

Browse files
authoredMay 25, 2024
feat: support type annotations to supply input and output types to @remote_function decorator (#717)
* feat: support type annotations to supply input and output types to `@remote_function` decorator * make tests robust to cloud function listing failures too
1 parent 1fca588 commit 4a12e3c

File tree

7 files changed

+181
-96
lines changed

7 files changed

+181
-96
lines changed
 

‎bigframes/functions/remote_function.py

+69-25
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,17 @@
2424
import sys
2525
import tempfile
2626
import textwrap
27-
from typing import cast, List, NamedTuple, Optional, Sequence, TYPE_CHECKING, Union
27+
from typing import (
28+
Any,
29+
cast,
30+
List,
31+
Mapping,
32+
NamedTuple,
33+
Optional,
34+
Sequence,
35+
TYPE_CHECKING,
36+
Union,
37+
)
2838
import warnings
2939

3040
import ibis
@@ -736,8 +746,8 @@ def get_routine_reference(
736746
# which has moved as @js to the ibis package
737747
# https://github.com/ibis-project/ibis/blob/master/ibis/backends/bigquery/udf/__init__.py
738748
def remote_function(
739-
input_types: Union[type, Sequence[type]],
740-
output_type: type,
749+
input_types: Union[None, type, Sequence[type]] = None,
750+
output_type: Optional[type] = None,
741751
session: Optional[Session] = None,
742752
bigquery_client: Optional[bigquery.Client] = None,
743753
bigquery_connection_client: Optional[
@@ -801,11 +811,11 @@ def remote_function(
801811
`$ gcloud projects add-iam-policy-binding PROJECT_ID --member="serviceAccount:CONNECTION_SERVICE_ACCOUNT_ID" --role="roles/run.invoker"`.
802812
803813
Args:
804-
input_types (type or sequence(type)):
814+
input_types (None, type, or sequence(type)):
805815
For scalar user defined function it should be the input type or
806816
sequence of input types. For row processing user defined function,
807817
type `Series` should be specified.
808-
output_type (type):
818+
output_type (Optional[type]):
809819
Data type of the output in the user defined function.
810820
session (bigframes.Session, Optional):
811821
BigQuery DataFrames session to use for getting default project,
@@ -908,27 +918,10 @@ def remote_function(
908918
service(s) that are on a VPC network. See for more details
909919
https://cloud.google.com/functions/docs/networking/connecting-vpc.
910920
"""
911-
is_row_processor = False
912-
913-
import bigframes.series
914-
import bigframes.session
915-
916-
if input_types == bigframes.series.Series:
917-
warnings.warn(
918-
"input_types=Series scenario is in preview.",
919-
stacklevel=1,
920-
category=bigframes.exceptions.PreviewWarning,
921-
)
922-
923-
# we will model the row as a json serialized string containing the data
924-
# and the metadata representing the row
925-
input_types = [str]
926-
is_row_processor = True
927-
elif isinstance(input_types, type):
928-
input_types = [input_types]
929-
930921
# Some defaults may be used from the session if not provided otherwise
931922
import bigframes.pandas as bpd
923+
import bigframes.series
924+
import bigframes.session
932925

933926
session = cast(bigframes.session.Session, session or bpd.get_global_session())
934927

@@ -1021,10 +1014,61 @@ def remote_function(
10211014
bq_connection_manager = None if session is None else session.bqconnectionmanager
10221015

10231016
def wrapper(f):
1017+
nonlocal input_types, output_type
1018+
10241019
if not callable(f):
10251020
raise TypeError("f must be callable, got {}".format(f))
10261021

1027-
signature = inspect.signature(f)
1022+
if sys.version_info >= (3, 10):
1023+
# Add `eval_str = True` so that deferred annotations are turned into their
1024+
# corresponding type objects. Need Python 3.10 for eval_str parameter.
1025+
# https://docs.python.org/3/library/inspect.html#inspect.signature
1026+
signature_kwargs: Mapping[str, Any] = {"eval_str": True}
1027+
else:
1028+
signature_kwargs = {}
1029+
1030+
signature = inspect.signature(
1031+
f,
1032+
**signature_kwargs,
1033+
)
1034+
1035+
# Try to get input types via type annotations.
1036+
if input_types is None:
1037+
input_types = []
1038+
for parameter in signature.parameters.values():
1039+
if (param_type := parameter.annotation) is inspect.Signature.empty:
1040+
raise ValueError(
1041+
"'input_types' was not set and parameter "
1042+
f"'{parameter.name}' is missing a type annotation. "
1043+
"Types are required to use @remote_function."
1044+
)
1045+
input_types.append(param_type)
1046+
1047+
if output_type is None:
1048+
if (output_type := signature.return_annotation) is inspect.Signature.empty:
1049+
raise ValueError(
1050+
"'output_type' was not set and function is missing a "
1051+
"return type annotation. Types are required to use "
1052+
"@remote_function."
1053+
)
1054+
1055+
# The function will actually be receiving a pandas Series, but allow both
1056+
# BigQuery DataFrames and pandas object types for compatibility.
1057+
is_row_processor = False
1058+
if input_types == bigframes.series.Series or input_types == pandas.Series:
1059+
warnings.warn(
1060+
"input_types=Series scenario is in preview.",
1061+
stacklevel=1,
1062+
category=bigframes.exceptions.PreviewWarning,
1063+
)
1064+
1065+
# we will model the row as a json serialized string containing the data
1066+
# and the metadata representing the row
1067+
input_types = [str]
1068+
is_row_processor = True
1069+
elif isinstance(input_types, type):
1070+
input_types = [input_types]
1071+
10281072
# TODO(b/340898611): fix type error
10291073
ibis_signature = ibis_signature_from_python_signature(
10301074
signature, input_types, output_type # type: ignore

‎tests/system/conftest.py

+51-51
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
import math
1919
import pathlib
2020
import textwrap
21+
import traceback
2122
import typing
2223
from typing import Dict, Generator, Optional
23-
import warnings
2424

2525
import google.api_core.exceptions
2626
import google.cloud.bigquery as bigquery
@@ -1097,54 +1097,54 @@ def cleanup_cloud_functions(session, cloudfunctions_client, dataset_id_permanent
10971097
session.bqclient, dataset_id_permanent
10981098
)
10991099
delete_count = 0
1100-
for cloud_function in tests.system.utils.get_cloud_functions(
1101-
cloudfunctions_client,
1102-
session.bqclient.project,
1103-
session.bqclient.location,
1104-
name_prefix="bigframes-",
1105-
):
1106-
# Ignore bigframes cloud functions referred by the remote functions in
1107-
# the permanent dataset
1108-
if cloud_function.service_config.uri in permanent_endpoints:
1109-
continue
1110-
1111-
# Ignore the functions less than one day old
1112-
age = datetime.now() - datetime.fromtimestamp(
1113-
cloud_function.update_time.timestamp()
1114-
)
1115-
if age.days <= 0:
1116-
continue
1117-
1118-
# Go ahead and delete
1119-
try:
1120-
tests.system.utils.delete_cloud_function(
1121-
cloudfunctions_client, cloud_function.name
1100+
try:
1101+
for cloud_function in tests.system.utils.get_cloud_functions(
1102+
cloudfunctions_client,
1103+
session.bqclient.project,
1104+
session.bqclient.location,
1105+
name_prefix="bigframes-",
1106+
):
1107+
# Ignore bigframes cloud functions referred by the remote functions in
1108+
# the permanent dataset
1109+
if cloud_function.service_config.uri in permanent_endpoints:
1110+
continue
1111+
1112+
# Ignore the functions less than one day old
1113+
age = datetime.now() - datetime.fromtimestamp(
1114+
cloud_function.update_time.timestamp()
11221115
)
1123-
delete_count += 1
1124-
if delete_count >= MAX_NUM_FUNCTIONS_TO_DELETE_PER_SESSION:
1125-
break
1126-
except google.api_core.exceptions.NotFound:
1127-
# This can happen when multiple pytest sessions are running in
1128-
# parallel. Two or more sessions may discover the same cloud
1129-
# function, but only one of them would be able to delete it
1130-
# successfully, while the other instance will run into this
1131-
# exception. Ignore this exception.
1132-
pass
1133-
except Exception as exc:
1134-
# Don't fail the tests for unknown exceptions.
1135-
#
1136-
# This can happen if we are hitting GCP limits, e.g.
1137-
# google.api_core.exceptions.ResourceExhausted: 429 Quota exceeded
1138-
# for quota metric 'Per project mutation requests' and limit
1139-
# 'Per project mutation requests per minute per region' of service
1140-
# 'cloudfunctions.googleapis.com' for consumer
1141-
# 'project_number:1084210331973'.
1142-
# [reason: "RATE_LIMIT_EXCEEDED" domain: "googleapis.com" ...
1143-
#
1144-
# It can also happen occasionally with
1145-
# google.api_core.exceptions.ServiceUnavailable when there is some
1146-
# backend flakiness.
1147-
#
1148-
# Let's stop further clean up and leave it to later.
1149-
warnings.warn(f"Cloud functions cleanup failed: {str(exc)}")
1150-
break
1116+
if age.days <= 0:
1117+
continue
1118+
1119+
# Go ahead and delete
1120+
try:
1121+
tests.system.utils.delete_cloud_function(
1122+
cloudfunctions_client, cloud_function.name
1123+
)
1124+
delete_count += 1
1125+
if delete_count >= MAX_NUM_FUNCTIONS_TO_DELETE_PER_SESSION:
1126+
break
1127+
except google.api_core.exceptions.NotFound:
1128+
# This can happen when multiple pytest sessions are running in
1129+
# parallel. Two or more sessions may discover the same cloud
1130+
# function, but only one of them would be able to delete it
1131+
# successfully, while the other instance will run into this
1132+
# exception. Ignore this exception.
1133+
pass
1134+
except Exception as exc:
1135+
# Don't fail the tests for unknown exceptions.
1136+
#
1137+
# This can happen if we are hitting GCP limits, e.g.
1138+
# google.api_core.exceptions.ResourceExhausted: 429 Quota exceeded
1139+
# for quota metric 'Per project mutation requests' and limit
1140+
# 'Per project mutation requests per minute per region' of service
1141+
# 'cloudfunctions.googleapis.com' for consumer
1142+
# 'project_number:1084210331973'.
1143+
# [reason: "RATE_LIMIT_EXCEEDED" domain: "googleapis.com" ...
1144+
#
1145+
# It can also happen occasionally with
1146+
# google.api_core.exceptions.ServiceUnavailable when there is some
1147+
# backend flakiness.
1148+
#
1149+
# Let's stop further clean up and leave it to later.
1150+
traceback.print_exception(exc)

‎tests/unit/resources.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def create_bigquery_session(
3939
session_id: str = "abcxyz",
4040
table_schema: Sequence[google.cloud.bigquery.SchemaField] = TEST_SCHEMA,
4141
anonymous_dataset: Optional[google.cloud.bigquery.DatasetReference] = None,
42+
location: str = "test-region",
4243
) -> bigframes.Session:
4344
credentials = mock.create_autospec(
4445
google.auth.credentials.Credentials, instance=True
@@ -53,11 +54,12 @@ def create_bigquery_session(
5354
if bqclient is None:
5455
bqclient = mock.create_autospec(google.cloud.bigquery.Client, instance=True)
5556
bqclient.project = "test-project"
57+
bqclient.location = location
5658

5759
# Mock the location.
5860
table = mock.create_autospec(google.cloud.bigquery.Table, instance=True)
5961
table._properties = {}
60-
type(table).location = mock.PropertyMock(return_value="test-region")
62+
type(table).location = mock.PropertyMock(return_value=location)
6163
type(table).schema = mock.PropertyMock(return_value=table_schema)
6264
type(table).reference = mock.PropertyMock(
6365
return_value=anonymous_dataset.table("test_table")
@@ -93,9 +95,7 @@ def query_mock(query, *args, **kwargs):
9395
type(clients_provider).bqclient = mock.PropertyMock(return_value=bqclient)
9496
clients_provider._credentials = credentials
9597

96-
bqoptions = bigframes.BigQueryOptions(
97-
credentials=credentials, location="test-region"
98-
)
98+
bqoptions = bigframes.BigQueryOptions(credentials=credentials, location=location)
9999
session = bigframes.Session(context=bqoptions, clients_provider=clients_provider)
100100
return session
101101

‎tests/unit/test_pandas.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def all_session_methods():
5050
[(method_name,) for method_name in all_session_methods()],
5151
)
5252
def test_method_matches_session(method_name: str):
53-
if sys.version_info <= (3, 10):
53+
if sys.version_info < (3, 10):
5454
pytest.skip(
5555
"Need Python 3.10 to reconcile deferred annotations."
5656
) # pragma: no cover

‎tests/unit/test_remote_function.py

+39
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414

1515
import bigframes_vendored.ibis.backends.bigquery.datatypes as third_party_ibis_bqtypes
1616
from ibis.expr import datatypes as ibis_types
17+
import pytest
1718

1819
import bigframes.dtypes
20+
import bigframes.functions.remote_function
21+
from tests.unit import resources
1922

2023

2124
def test_supported_types_correspond():
@@ -29,3 +32,39 @@ def test_supported_types_correspond():
2932
}
3033

3134
assert ibis_types_from_python == ibis_types_from_bigquery
35+
36+
37+
def test_missing_input_types():
38+
session = resources.create_bigquery_session()
39+
remote_function_decorator = bigframes.functions.remote_function.remote_function(
40+
session=session
41+
)
42+
43+
def function_without_parameter_annotations(myparam) -> str:
44+
return str(myparam)
45+
46+
assert function_without_parameter_annotations(42) == "42"
47+
48+
with pytest.raises(
49+
ValueError,
50+
match="'input_types' was not set .* 'myparam' is missing a type annotation",
51+
):
52+
remote_function_decorator(function_without_parameter_annotations)
53+
54+
55+
def test_missing_output_type():
56+
session = resources.create_bigquery_session()
57+
remote_function_decorator = bigframes.functions.remote_function.remote_function(
58+
session=session
59+
)
60+
61+
def function_without_return_annotation(myparam: int):
62+
return str(myparam)
63+
64+
assert function_without_return_annotation(42) == "42"
65+
66+
with pytest.raises(
67+
ValueError,
68+
match="'output_type' was not set .* missing a return type annotation",
69+
):
70+
remote_function_decorator(function_without_return_annotation)

‎third_party/bigframes_vendored/pandas/core/frame.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -3916,8 +3916,8 @@ def map(self, func, na_action: Optional[str] = None) -> DataFrame:
39163916
to potentially reuse a previously deployed ``remote_function`` from
39173917
the same user defined function.
39183918
3919-
>>> @bpd.remote_function(int, float, reuse=False)
3920-
... def minutes_to_hours(x):
3919+
>>> @bpd.remote_function(reuse=False)
3920+
... def minutes_to_hours(x: int) -> float:
39213921
... return x/60
39223922
39233923
>>> df_minutes = bpd.DataFrame(
@@ -4238,6 +4238,7 @@ def apply(self, func, *, axis=0, args=(), **kwargs):
42384238
**Examples:**
42394239
42404240
>>> import bigframes.pandas as bpd
4241+
>>> import pandas as pd
42414242
>>> bpd.options.display.progress_bar = None
42424243
42434244
>>> df = bpd.DataFrame({'col1': [1, 2], 'col2': [3, 4]})
@@ -4259,16 +4260,19 @@ def apply(self, func, *, axis=0, args=(), **kwargs):
42594260
[2 rows x 2 columns]
42604261
42614262
You could apply a user defined function to every row of the DataFrame by
4262-
creating a remote function out of it, and using it with `axis=1`.
4263+
creating a remote function out of it, and using it with `axis=1`. Within
4264+
the function, each row is passed as a ``pandas.Series``. It is recommended
4265+
to select only the necessary columns before calling `apply()`. Note: This
4266+
feature is currently in **preview**.
42634267
4264-
>>> @bpd.remote_function(bpd.Series, int, reuse=False)
4265-
... def foo(row):
4268+
>>> @bpd.remote_function(reuse=False)
4269+
... def foo(row: pd.Series) -> int:
42664270
... result = 1
42674271
... result += row["col1"]
42684272
... result += row["col2"]*row["col2"]
42694273
... return result
42704274
4271-
>>> df.apply(foo, axis=1)
4275+
>>> df[["col1", "col2"]].apply(foo, axis=1)
42724276
0 11
42734277
1 19
42744278
dtype: Int64
There was a problem loading the remainder of the diff.

0 commit comments

Comments
 (0)
Failed to load comments.