Skip to content

Commit

Permalink
Add support for write_on_empty in BaseSQLToGCSOperator (#28959)
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiapaikeo authored Jan 19, 2023
1 parent 7f2b065 commit 5350be2
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
8 changes: 7 additions & 1 deletion airflow/providers/google/cloud/transfers/sql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class BaseSQLToGCSOperator(BaseOperator):
this parameter, you must sort your dataset by partition_columns. Do this by
passing an ORDER BY clause to the sql query. Files are uploaded to GCS as objects
with a hive style partitioning directory structure (templated).
:param write_on_empty: Optional parameter to specify whether to write a file if the
export does not return any rows. Default is False so we will not write a file
if the export returns no rows.
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -119,6 +122,7 @@ def __init__(
upload_metadata: bool = False,
exclude_columns: set | None = None,
partition_columns: list | None = None,
write_on_empty: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -143,6 +147,7 @@ def __init__(
self.upload_metadata = upload_metadata
self.exclude_columns = exclude_columns
self.partition_columns = partition_columns
self.write_on_empty = write_on_empty

def execute(self, context: Context):
if self.partition_columns:
Expand Down Expand Up @@ -316,7 +321,8 @@ def _write_local_data_files(self, cursor):
if self.export_format == "parquet":
parquet_writer.close()
# Last file may have 0 rows, don't yield if empty
if file_to_upload["file_row_count"] > 0:
# However, if it is the first file and self.write_on_empty is True, then yield to write an empty file
if file_to_upload["file_row_count"] > 0 or (file_no == 0 and self.write_on_empty):
file_to_upload["partition_values"] = curr_partition_values
yield file_to_upload

Expand Down
49 changes: 49 additions & 0 deletions tests/providers/google/cloud/transfers/test_sql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from unittest.mock import MagicMock, Mock

import pandas as pd
import pytest
import unicodecsv as csv

from airflow.providers.google.cloud.hooks.gcs import GCSHook
Expand All @@ -44,6 +45,7 @@
("column_c", "10", 0, 0, 0, 0, False),
]
TMP_FILE_NAME = "temp-file"
EMPTY_INPUT_DATA = []
INPUT_DATA = [
["101", "school", "2015-01-01"],
["102", "business", "2017-05-24"],
Expand Down Expand Up @@ -520,3 +522,50 @@ def test__write_local_data_files_parquet_with_partition_columns(self):

concat_df = pd.concat(concat_dfs, ignore_index=True)
assert concat_df.equals(OUTPUT_DF)

def test__write_local_data_files_csv_does_not_write_on_empty_rows(self):
op = DummySQLToGCSOperator(
sql=SQL,
bucket=BUCKET,
filename=FILENAME,
task_id=TASK_ID,
schema_filename=SCHEMA_FILE,
export_format="csv",
gzip=False,
schema=SCHEMA,
gcp_conn_id="google_cloud_default",
)
cursor = MagicMock()
cursor.__iter__.return_value = EMPTY_INPUT_DATA
cursor.description = CURSOR_DESCRIPTION

files = op._write_local_data_files(cursor)
# Raises StopIteration when next is called because generator returns no files
with pytest.raises(StopIteration):
next(files)["file_handle"]

assert len([f for f in files]) == 0

def test__write_local_data_files_csv_writes_empty_file_with_write_on_empty(self):
op = DummySQLToGCSOperator(
sql=SQL,
bucket=BUCKET,
filename=FILENAME,
task_id=TASK_ID,
schema_filename=SCHEMA_FILE,
export_format="csv",
gzip=False,
schema=SCHEMA,
gcp_conn_id="google_cloud_default",
write_on_empty=True,
)
cursor = MagicMock()
cursor.__iter__.return_value = EMPTY_INPUT_DATA
cursor.description = CURSOR_DESCRIPTION

files = op._write_local_data_files(cursor)
file = next(files)["file_handle"]
file.flush()

df = pd.read_csv(file.name)
assert len(df.index) == 0

0 comments on commit 5350be2

Please sign in to comment.