Skip to content

Commit

Permalink
Support partition_columns in BaseSQLToGCSOperator (#28677)
Browse files Browse the repository at this point in the history
* Support partition_columns in BaseSQLToGCSOperator

Co-authored-by: eladkal <[email protected]>
  • Loading branch information
vchiapaikeo and eladkal authored Jan 10, 2023
1 parent 07a17ba commit 35a8ffc
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 25 deletions.
116 changes: 91 additions & 25 deletions airflow/providers/google/cloud/transfers/sql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import abc
import json
import os
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Sequence

Expand Down Expand Up @@ -77,6 +78,10 @@ class BaseSQLToGCSOperator(BaseOperator):
account from the list granting this role to the originating account (templated).
:param upload_metadata: whether to upload the row count metadata as blob metadata
:param exclude_columns: set of columns to exclude from transmission
:param partition_columns: list of columns to use for file partitioning. In order to use
this parameter, you must sort your dataset by partition_columns. Do this by
passing an ORDER BY clause to the sql query. Files are uploaded to GCS as objects
with a hive style partitioning directory structure (templated).
"""

template_fields: Sequence[str] = (
Expand All @@ -87,6 +92,7 @@ class BaseSQLToGCSOperator(BaseOperator):
"schema",
"parameters",
"impersonation_chain",
"partition_columns",
)
template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {"sql": "sql"}
Expand All @@ -111,7 +117,8 @@ def __init__(
delegate_to: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
upload_metadata: bool = False,
exclude_columns=None,
exclude_columns: set | None = None,
partition_columns: list | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -135,8 +142,16 @@ def __init__(
self.impersonation_chain = impersonation_chain
self.upload_metadata = upload_metadata
self.exclude_columns = exclude_columns
self.partition_columns = partition_columns

def execute(self, context: Context):
if self.partition_columns:
self.log.info(
f"Found partition columns: {','.join(self.partition_columns)}. "
"Assuming the SQL statement is properly sorted by these columns in "
"ascending or descending order."
)

self.log.info("Executing query")
cursor = self.query()

Expand All @@ -158,6 +173,7 @@ def execute(self, context: Context):
total_files = 0
self.log.info("Writing local data files")
for file_to_upload in self._write_local_data_files(cursor):

# Flush file before uploading
file_to_upload["file_handle"].flush()

Expand Down Expand Up @@ -204,36 +220,56 @@ def _write_local_data_files(self, cursor):
names in GCS, and values are file handles to local files that
contain the data for the GCS objects.
"""
import os

org_schema = list(map(lambda schema_tuple: schema_tuple[0], cursor.description))
schema = [column for column in org_schema if column not in self.exclude_columns]

col_type_dict = self._get_col_type_dict()
file_no = 0

tmp_file_handle = NamedTemporaryFile(delete=True)
if self.export_format == "csv":
file_mime_type = "text/csv"
elif self.export_format == "parquet":
file_mime_type = "application/octet-stream"
else:
file_mime_type = "application/json"
file_to_upload = {
"file_name": self.filename.format(file_no),
"file_handle": tmp_file_handle,
"file_mime_type": file_mime_type,
"file_row_count": 0,
}
file_mime_type = self._get_file_mime_type()
file_to_upload, tmp_file_handle = self._get_file_to_upload(file_mime_type, file_no)

if self.export_format == "csv":
csv_writer = self._configure_csv_file(tmp_file_handle, schema)
if self.export_format == "parquet":
parquet_schema = self._convert_parquet_schema(cursor)
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)

prev_partition_values = None
curr_partition_values = None
for row in cursor:
if self.partition_columns:
row_dict = dict(zip(schema, row))
curr_partition_values = tuple(
[row_dict.get(partition_column, "") for partition_column in self.partition_columns]
)

if prev_partition_values is None:
# We haven't set prev_partition_values before. Set to current
prev_partition_values = curr_partition_values

elif prev_partition_values != curr_partition_values:
# If the partition values differ, write the current local file out
# Yield first before we write the current record
file_no += 1

if self.export_format == "parquet":
parquet_writer.close()

file_to_upload["partition_values"] = prev_partition_values
yield file_to_upload
file_to_upload, tmp_file_handle = self._get_file_to_upload(file_mime_type, file_no)
if self.export_format == "csv":
csv_writer = self._configure_csv_file(tmp_file_handle, schema)
if self.export_format == "parquet":
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)

# Reset previous to current after writing out the file
prev_partition_values = curr_partition_values

# Incrementing file_row_count after partition yield ensures all rows are written
file_to_upload["file_row_count"] += 1

# Proceed to write the row to the localfile
if self.export_format == "csv":
row = self.convert_types(schema, col_type_dict, row)
if self.null_marker is not None:
Expand Down Expand Up @@ -268,24 +304,44 @@ def _write_local_data_files(self, cursor):

if self.export_format == "parquet":
parquet_writer.close()

file_to_upload["partition_values"] = curr_partition_values
yield file_to_upload
tmp_file_handle = NamedTemporaryFile(delete=True)
file_to_upload = {
"file_name": self.filename.format(file_no),
"file_handle": tmp_file_handle,
"file_mime_type": file_mime_type,
"file_row_count": 0,
}
file_to_upload, tmp_file_handle = self._get_file_to_upload(file_mime_type, file_no)
if self.export_format == "csv":
csv_writer = self._configure_csv_file(tmp_file_handle, schema)
if self.export_format == "parquet":
parquet_writer = self._configure_parquet_file(tmp_file_handle, parquet_schema)

if self.export_format == "parquet":
parquet_writer.close()
# Last file may have 0 rows, don't yield if empty
if file_to_upload["file_row_count"] > 0:
file_to_upload["partition_values"] = curr_partition_values
yield file_to_upload

def _get_file_to_upload(self, file_mime_type, file_no):
"""Returns a dictionary that represents the file to upload"""
tmp_file_handle = NamedTemporaryFile(delete=True)
return (
{
"file_name": self.filename.format(file_no),
"file_handle": tmp_file_handle,
"file_mime_type": file_mime_type,
"file_row_count": 0,
},
tmp_file_handle,
)

def _get_file_mime_type(self):
if self.export_format == "csv":
file_mime_type = "text/csv"
elif self.export_format == "parquet":
file_mime_type = "application/octet-stream"
else:
file_mime_type = "application/json"
return file_mime_type

def _configure_csv_file(self, file_handle, schema):
"""Configure a csv writer with the file_handle and write schema
as headers for the new file.
Expand Down Expand Up @@ -400,9 +456,19 @@ def _upload_to_gcs(self, file_to_upload):
if is_data_file and self.upload_metadata:
metadata = {"row_count": file_to_upload["file_row_count"]}

object_name = file_to_upload.get("file_name")
if is_data_file and self.partition_columns:
# Add partition column values to object_name
partition_values = file_to_upload.get("partition_values")
head_path, tail_path = os.path.split(object_name)
partition_subprefix = [
f"{col}={val}" for col, val in zip(self.partition_columns, partition_values)
]
object_name = os.path.join(head_path, *partition_subprefix, tail_path)

hook.upload(
self.bucket,
file_to_upload.get("file_name"),
object_name,
file_to_upload.get("file_handle").name,
mime_type=file_to_upload.get("file_mime_type"),
gzip=self.gzip if is_data_file else False,
Expand Down
97 changes: 97 additions & 0 deletions tests/providers/google/cloud/transfers/test_sql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
OUTPUT_DF = pd.DataFrame([["convert_type_return_value"] * 3] * 3, columns=COLUMNS)

EXCLUDE_COLUMNS = set("column_c")
PARTITION_COLUMNS = ["column_b", "column_c"]
NEW_COLUMNS = [c for c in COLUMNS if c not in EXCLUDE_COLUMNS]
OUTPUT_DF_WITH_EXCLUDE_COLUMNS = pd.DataFrame(
[["convert_type_return_value"] * len(NEW_COLUMNS)] * 3, columns=NEW_COLUMNS
Expand Down Expand Up @@ -305,6 +306,74 @@ def test_exec(self, mock_convert_type, mock_query, mock_upload, mock_writerow, m
)
mock_close.assert_called_once()

mock_query.reset_mock()
mock_flush.reset_mock()
mock_upload.reset_mock()
mock_close.reset_mock()
cursor_mock.reset_mock()

cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))

# Test partition columns
operator = DummySQLToGCSOperator(
sql=SQL,
bucket=BUCKET,
filename=FILENAME,
task_id=TASK_ID,
export_format="parquet",
schema=SCHEMA,
partition_columns=PARTITION_COLUMNS,
)
result = operator.execute(context=dict())

assert result == {
"bucket": "TEST-BUCKET-1",
"total_row_count": 3,
"total_files": 3,
"files": [
{
"file_name": "test_results_0.csv",
"file_mime_type": "application/octet-stream",
"file_row_count": 1,
},
{
"file_name": "test_results_1.csv",
"file_mime_type": "application/octet-stream",
"file_row_count": 1,
},
{
"file_name": "test_results_2.csv",
"file_mime_type": "application/octet-stream",
"file_row_count": 1,
},
],
}

mock_query.assert_called_once()
assert mock_flush.call_count == 3
assert mock_close.call_count == 3
mock_upload.assert_has_calls(
[
mock.call(
BUCKET,
f"column_b={row[1]}/column_c={row[2]}/test_results_{i}.csv",
TMP_FILE_NAME,
mime_type="application/octet-stream",
gzip=False,
metadata=None,
)
for i, row in enumerate(INPUT_DATA)
]
)

mock_query.reset_mock()
mock_flush.reset_mock()
mock_upload.reset_mock()
mock_close.reset_mock()
cursor_mock.reset_mock()

cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))

# Test null marker
cursor_mock.__iter__ = Mock(return_value=iter(INPUT_DATA))
mock_convert_type.return_value = None
Expand Down Expand Up @@ -423,3 +492,31 @@ def test__write_local_data_files_json_with_exclude_columns(self):
file.flush()
df = pd.read_json(file.name, orient="records", lines=True)
assert df.equals(OUTPUT_DF_WITH_EXCLUDE_COLUMNS)

def test__write_local_data_files_parquet_with_partition_columns(self):
op = DummySQLToGCSOperator(
sql=SQL,
bucket=BUCKET,
filename=FILENAME,
task_id=TASK_ID,
schema_filename=SCHEMA_FILE,
export_format="parquet",
gzip=False,
schema=SCHEMA,
gcp_conn_id="google_cloud_default",
partition_columns=PARTITION_COLUMNS,
)
cursor = MagicMock()
cursor.__iter__.return_value = INPUT_DATA
cursor.description = CURSOR_DESCRIPTION

local_data_files = op._write_local_data_files(cursor)
concat_dfs = []
for local_data_file in local_data_files:
file = local_data_file["file_handle"]
file.flush()
df = pd.read_parquet(file.name)
concat_dfs.append(df)

concat_df = pd.concat(concat_dfs, ignore_index=True)
assert concat_df.equals(OUTPUT_DF)

0 comments on commit 35a8ffc

Please sign in to comment.