diff --git a/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py b/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py index 19a130c50638..d2f5ee1acfce 100644 --- a/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/cassandra_to_gcs.py @@ -26,7 +26,7 @@ from datetime import datetime from decimal import Decimal from tempfile import NamedTemporaryFile -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Iterable, List, NewType, Optional, Sequence, Tuple, Union from uuid import UUID from cassandra.util import Date, OrderedMapSerializedKey, SortedSet, Time @@ -36,6 +36,9 @@ from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook from airflow.providers.google.cloud.hooks.gcs import GCSHook +NotSetType = NewType('NotSetType', object) +NOT_SET = NotSetType(object()) + class CassandraToGCSOperator(BaseOperator): """ @@ -84,6 +87,10 @@ class CassandraToGCSOperator(BaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). :type impersonation_chain: Union[str, Sequence[str]] + :param query_timeout: (Optional) The amount of time, in seconds, used to execute the Cassandra query. + If not set, the timeout value will be set in Session.execute() by Cassandra driver. + If set to None, there is no timeout. + :type query_timeout: float | None """ template_fields = ( @@ -110,6 +117,7 @@ def __init__( google_cloud_storage_conn_id: Optional[str] = None, delegate_to: Optional[str] = None, impersonation_chain: Optional[Union[str, Sequence[str]]] = None, + query_timeout: Union[float, None, NotSetType] = NOT_SET, **kwargs, ) -> None: super().__init__(**kwargs) @@ -133,6 +141,7 @@ def __init__( self.delegate_to = delegate_to self.gzip = gzip self.impersonation_chain = impersonation_chain + self.query_timeout = query_timeout # Default Cassandra to BigQuery type mapping CQL_TYPE_MAP = { @@ -162,7 +171,12 @@ def __init__( def execute(self, context: Dict[str, str]): hook = CassandraHook(cassandra_conn_id=self.cassandra_conn_id) - cursor = hook.get_conn().execute(self.cql) + + query_extra = {} + if self.query_timeout is not NOT_SET: + query_extra['timeout'] = self.query_timeout + + cursor = hook.get_conn().execute(self.cql, **query_extra) files_to_upload = self._write_local_data_files(cursor) diff --git a/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py b/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py index e71745d112a1..b53bbb4e66cb 100644 --- a/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_cassandra_to_gcs.py @@ -34,6 +34,7 @@ def test_execute(self, mock_hook, mock_upload, mock_tempfile): schema = "schema.json" filename = "data.json" gzip = True + query_timeout = 20 mock_tempfile.return_value.name = TMP_FILE_NAME operator = CassandraToGCSOperator( @@ -43,9 +44,14 @@ def test_execute(self, mock_hook, mock_upload, mock_tempfile): filename=filename, schema_filename=schema, gzip=gzip, + query_timeout=query_timeout, ) operator.execute(None) mock_hook.return_value.get_conn.assert_called_once_with() + mock_hook.return_value.get_conn.return_value.execute.assert_called_once_with( + "select * from keyspace1.table1", + timeout=20, + ) call_schema = call( bucket_name=test_bucket,