Skip to content

Commit

Permalink
feat(spanner): add support for txn changstream exclusion (#1152)
Browse files Browse the repository at this point in the history
* feat(spanner): add support for txn changstream exclusion

* feat(spanner): add tests for txn change streams exclusion

* chore(spanner): lint fix

* feat(spanner): add docs

* feat(spanner): add test for ILB with change stream exclusion

* feat(spanner): update default value and add optional
  • Loading branch information
harshachinta authored Jun 20, 2024
1 parent c670ebc commit 00ccb7a
Show file tree
Hide file tree
Showing 8 changed files with 346 additions and 17 deletions.
21 changes: 18 additions & 3 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ def _check_state(self):
raise ValueError("Batch already committed")

def commit(
self, return_commit_stats=False, request_options=None, max_commit_delay=None
self,
return_commit_stats=False,
request_options=None,
max_commit_delay=None,
exclude_txn_from_change_streams=False,
):
"""Commit mutations to the database.
Expand Down Expand Up @@ -178,7 +182,10 @@ def commit(
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
txn_options = TransactionOptions(
read_write=TransactionOptions.ReadWrite(),
exclude_txn_from_change_streams=exclude_txn_from_change_streams,
)
trace_attributes = {"num_mutations": len(self._mutations)}

if request_options is None:
Expand Down Expand Up @@ -270,7 +277,7 @@ def group(self):
self._mutation_groups.append(mutation_group)
return MutationGroup(self._session, mutation_group.mutations)

def batch_write(self, request_options=None):
def batch_write(self, request_options=None, exclude_txn_from_change_streams=False):
"""Executes batch_write.
:type request_options:
Expand All @@ -280,6 +287,13 @@ def batch_write(self, request_options=None):
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.RequestOptions`.
:type exclude_txn_from_change_streams: bool
:param exclude_txn_from_change_streams:
(Optional) If true, instructs the transaction to be excluded from being recorded in change streams
with the DDL option `allow_txn_exclusion=true`. This does not exclude the transaction from
being recorded in the change streams with the DDL option `allow_txn_exclusion` being false or
unset.
:rtype: :class:`Iterable[google.cloud.spanner_v1.types.BatchWriteResponse]`
:returns: a sequence of responses for each batch.
"""
Expand All @@ -302,6 +316,7 @@ def batch_write(self, request_options=None):
session=self._session.name,
mutation_groups=self._mutation_groups,
request_options=request_options,
exclude_txn_from_change_streams=exclude_txn_from_change_streams,
)
with trace_call("CloudSpanner.BatchWrite", self._session, trace_attributes):
method = functools.partial(
Expand Down
43 changes: 39 additions & 4 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ def execute_partitioned_dml(
param_types=None,
query_options=None,
request_options=None,
exclude_txn_from_change_streams=False,
):
"""Execute a partitionable DML statement.
Expand Down Expand Up @@ -651,6 +652,13 @@ def execute_partitioned_dml(
Please note, the `transactionTag` setting will be ignored as it is
not supported for partitioned DML.
:type exclude_txn_from_change_streams: bool
:param exclude_txn_from_change_streams:
(Optional) If true, instructs the transaction to be excluded from being recorded in change streams
with the DDL option `allow_txn_exclusion=true`. This does not exclude the transaction from
being recorded in the change streams with the DDL option `allow_txn_exclusion` being false or
unset.
:rtype: int
:returns: Count of rows affected by the DML statement.
"""
Expand All @@ -673,7 +681,8 @@ def execute_partitioned_dml(
api = self.spanner_api

txn_options = TransactionOptions(
partitioned_dml=TransactionOptions.PartitionedDml()
partitioned_dml=TransactionOptions.PartitionedDml(),
exclude_txn_from_change_streams=exclude_txn_from_change_streams,
)

metadata = _metadata_with_prefix(self.name)
Expand Down Expand Up @@ -752,7 +761,12 @@ def snapshot(self, **kw):
"""
return SnapshotCheckout(self, **kw)

def batch(self, request_options=None, max_commit_delay=None):
def batch(
self,
request_options=None,
max_commit_delay=None,
exclude_txn_from_change_streams=False,
):
"""Return an object which wraps a batch.
The wrapper *must* be used as a context manager, with the batch
Expand All @@ -771,10 +785,19 @@ def batch(self, request_options=None, max_commit_delay=None):
in order to improve throughput. Value must be between 0ms and
500ms.
:type exclude_txn_from_change_streams: bool
:param exclude_txn_from_change_streams:
(Optional) If true, instructs the transaction to be excluded from being recorded in change streams
with the DDL option `allow_txn_exclusion=true`. This does not exclude the transaction from
being recorded in the change streams with the DDL option `allow_txn_exclusion` being false or
unset.
:rtype: :class:`~google.cloud.spanner_v1.database.BatchCheckout`
:returns: new wrapper
"""
return BatchCheckout(self, request_options, max_commit_delay)
return BatchCheckout(
self, request_options, max_commit_delay, exclude_txn_from_change_streams
)

def mutation_groups(self):
"""Return an object which wraps a mutation_group.
Expand Down Expand Up @@ -840,6 +863,10 @@ def run_in_transaction(self, func, *args, **kw):
"max_commit_delay" will be removed and used to set the
max_commit_delay for the request. Value must be between
0ms and 500ms.
"exclude_txn_from_change_streams" if true, instructs the transaction to be excluded
from being recorded in change streams with the DDL option `allow_txn_exclusion=true`.
This does not exclude the transaction from being recorded in the change streams with
the DDL option `allow_txn_exclusion` being false or unset.
:rtype: Any
:returns: The return value of ``func``.
Expand Down Expand Up @@ -1103,7 +1130,13 @@ class BatchCheckout(object):
in order to improve throughput.
"""

def __init__(self, database, request_options=None, max_commit_delay=None):
def __init__(
self,
database,
request_options=None,
max_commit_delay=None,
exclude_txn_from_change_streams=False,
):
self._database = database
self._session = self._batch = None
if request_options is None:
Expand All @@ -1113,6 +1146,7 @@ def __init__(self, database, request_options=None, max_commit_delay=None):
else:
self._request_options = request_options
self._max_commit_delay = max_commit_delay
self._exclude_txn_from_change_streams = exclude_txn_from_change_streams

def __enter__(self):
"""Begin ``with`` block."""
Expand All @@ -1130,6 +1164,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return_commit_stats=self._database.log_commit_stats,
request_options=self._request_options,
max_commit_delay=self._max_commit_delay,
exclude_txn_from_change_streams=self._exclude_txn_from_change_streams,
)
finally:
if self._database.log_commit_stats and self._batch.commit_stats:
Expand Down
8 changes: 8 additions & 0 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,10 @@ def run_in_transaction(self, func, *args, **kw):
request options for the commit request.
"max_commit_delay" will be removed and used to set the max commit delay for the request.
"transaction_tag" will be removed and used to set the transaction tag for the request.
"exclude_txn_from_change_streams" if true, instructs the transaction to be excluded
from being recorded in change streams with the DDL option `allow_txn_exclusion=true`.
This does not exclude the transaction from being recorded in the change streams with
the DDL option `allow_txn_exclusion` being false or unset.
:rtype: Any
:returns: The return value of ``func``.
Expand All @@ -398,12 +402,16 @@ def run_in_transaction(self, func, *args, **kw):
commit_request_options = kw.pop("commit_request_options", None)
max_commit_delay = kw.pop("max_commit_delay", None)
transaction_tag = kw.pop("transaction_tag", None)
exclude_txn_from_change_streams = kw.pop(
"exclude_txn_from_change_streams", None
)
attempts = 0

while True:
if self._transaction is None:
txn = self.transaction()
txn.transaction_tag = transaction_tag
txn.exclude_txn_from_change_streams = exclude_txn_from_change_streams
else:
txn = self._transaction

Expand Down
11 changes: 9 additions & 2 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class Transaction(_SnapshotBase, _BatchBase):
_execute_sql_count = 0
_lock = threading.Lock()
_read_only = False
exclude_txn_from_change_streams = False

def __init__(self, session):
if session._transaction is not None:
Expand Down Expand Up @@ -86,7 +87,10 @@ def _make_txn_selector(self):

if self._transaction_id is None:
return TransactionSelector(
begin=TransactionOptions(read_write=TransactionOptions.ReadWrite())
begin=TransactionOptions(
read_write=TransactionOptions.ReadWrite(),
exclude_txn_from_change_streams=self.exclude_txn_from_change_streams,
)
)
else:
return TransactionSelector(id=self._transaction_id)
Expand Down Expand Up @@ -137,7 +141,10 @@ def begin(self):
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
txn_options = TransactionOptions(
read_write=TransactionOptions.ReadWrite(),
exclude_txn_from_change_streams=self.exclude_txn_from_change_streams,
)
with trace_call("CloudSpanner.BeginTransaction", self._session):
method = functools.partial(
api.begin_transaction,
Expand Down
42 changes: 38 additions & 4 deletions tests/unit/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,12 @@ def test_commit_ok(self):
"CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1)
)

def _test_commit_with_options(self, request_options=None, max_commit_delay_in=None):
def _test_commit_with_options(
self,
request_options=None,
max_commit_delay_in=None,
exclude_txn_from_change_streams=False,
):
import datetime
from google.cloud.spanner_v1 import CommitResponse
from google.cloud.spanner_v1 import TransactionOptions
Expand All @@ -276,7 +281,9 @@ def _test_commit_with_options(self, request_options=None, max_commit_delay_in=No
batch.transaction_tag = self.TRANSACTION_TAG
batch.insert(TABLE_NAME, COLUMNS, VALUES)
committed = batch.commit(
request_options=request_options, max_commit_delay=max_commit_delay_in
request_options=request_options,
max_commit_delay=max_commit_delay_in,
exclude_txn_from_change_streams=exclude_txn_from_change_streams,
)

self.assertEqual(committed, now)
Expand All @@ -301,6 +308,10 @@ def _test_commit_with_options(self, request_options=None, max_commit_delay_in=No
self.assertEqual(mutations, batch._mutations)
self.assertIsInstance(single_use_txn, TransactionOptions)
self.assertTrue(type(single_use_txn).pb(single_use_txn).HasField("read_write"))
self.assertEqual(
single_use_txn.exclude_txn_from_change_streams,
exclude_txn_from_change_streams,
)
self.assertEqual(
metadata,
[
Expand Down Expand Up @@ -355,6 +366,14 @@ def test_commit_w_max_commit_delay(self):
max_commit_delay_in=datetime.timedelta(milliseconds=100),
)

def test_commit_w_exclude_txn_from_change_streams(self):
request_options = RequestOptions(
request_tag="tag-1",
)
self._test_commit_with_options(
request_options=request_options, exclude_txn_from_change_streams=True
)

def test_context_mgr_already_committed(self):
import datetime
from google.cloud._helpers import UTC
Expand Down Expand Up @@ -499,7 +518,9 @@ def test_batch_write_grpc_error(self):
attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1),
)

def _test_batch_write_with_request_options(self, request_options=None):
def _test_batch_write_with_request_options(
self, request_options=None, exclude_txn_from_change_streams=False
):
import datetime
from google.cloud.spanner_v1 import BatchWriteResponse
from google.cloud._helpers import UTC
Expand All @@ -519,7 +540,10 @@ def _test_batch_write_with_request_options(self, request_options=None):
group = groups.group()
group.insert(TABLE_NAME, COLUMNS, VALUES)

response_iter = groups.batch_write(request_options)
response_iter = groups.batch_write(
request_options,
exclude_txn_from_change_streams=exclude_txn_from_change_streams,
)
self.assertEqual(len(response_iter), 1)
self.assertEqual(response_iter[0], response)

Expand All @@ -528,6 +552,7 @@ def _test_batch_write_with_request_options(self, request_options=None):
mutation_groups,
actual_request_options,
metadata,
request_exclude_txn_from_change_streams,
) = api._batch_request
self.assertEqual(session, self.SESSION_NAME)
self.assertEqual(mutation_groups, groups._mutation_groups)
Expand All @@ -545,6 +570,9 @@ def _test_batch_write_with_request_options(self, request_options=None):
else:
expected_request_options = request_options
self.assertEqual(actual_request_options, expected_request_options)
self.assertEqual(
request_exclude_txn_from_change_streams, exclude_txn_from_change_streams
)

self.assertSpanAttributes(
"CloudSpanner.BatchWrite",
Expand All @@ -567,6 +595,11 @@ def test_batch_write_w_incorrect_tag_dictionary_error(self):
with self.assertRaises(ValueError):
self._test_batch_write_with_request_options({"incorrect_tag": "tag-1-1"})

def test_batch_write_w_exclude_txn_from_change_streams(self):
self._test_batch_write_with_request_options(
exclude_txn_from_change_streams=True
)


class _Session(object):
def __init__(self, database=None, name=TestBatch.SESSION_NAME):
Expand Down Expand Up @@ -625,6 +658,7 @@ def batch_write(
request.mutation_groups,
request.request_options,
metadata,
request.exclude_txn_from_change_streams,
)
if self._rpc_error:
raise Unknown("error")
Expand Down
16 changes: 14 additions & 2 deletions tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,7 @@ def _execute_partitioned_dml_helper(
query_options=None,
request_options=None,
retried=False,
exclude_txn_from_change_streams=False,
):
from google.api_core.exceptions import Aborted
from google.api_core.retry import Retry
Expand Down Expand Up @@ -1129,13 +1130,19 @@ def _execute_partitioned_dml_helper(
api.execute_streaming_sql.return_value = iterator

row_count = database.execute_partitioned_dml(
dml, params, param_types, query_options, request_options
dml,
params,
param_types,
query_options,
request_options,
exclude_txn_from_change_streams,
)

self.assertEqual(row_count, 2)

txn_options = TransactionOptions(
partitioned_dml=TransactionOptions.PartitionedDml()
partitioned_dml=TransactionOptions.PartitionedDml(),
exclude_txn_from_change_streams=exclude_txn_from_change_streams,
)

api.begin_transaction.assert_called_with(
Expand Down Expand Up @@ -1250,6 +1257,11 @@ def test_execute_partitioned_dml_w_req_tag_used(self):
def test_execute_partitioned_dml_wo_params_retry_aborted(self):
self._execute_partitioned_dml_helper(dml=DML_WO_PARAM, retried=True)

def test_execute_partitioned_dml_w_exclude_txn_from_change_streams(self):
self._execute_partitioned_dml_helper(
dml=DML_WO_PARAM, exclude_txn_from_change_streams=True
)

def test_session_factory_defaults(self):
from google.cloud.spanner_v1.session import Session

Expand Down
Loading

0 comments on commit 00ccb7a

Please sign in to comment.