Skip to content

Commit

Permalink
Refactor: BigQuery to GCS Operator (#22506)
Browse files Browse the repository at this point in the history
  • Loading branch information
shuhoy authored Mar 27, 2022
1 parent 02526b3 commit 02976be
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 50 deletions.
6 changes: 3 additions & 3 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1905,7 +1905,7 @@ def run_copy(
def run_extract(
self,
source_project_dataset_table: str,
destination_cloud_storage_uris: str,
destination_cloud_storage_uris: List[str],
compression: str = 'NONE',
export_format: str = 'CSV',
field_delimiter: str = ',',
Expand Down Expand Up @@ -1945,7 +1945,7 @@ def run_extract(
var_name='source_project_dataset_table',
)

configuration = {
configuration: Dict[str, Any] = {
'extract': {
'sourceTable': {
'projectId': source_project,
Expand All @@ -1956,7 +1956,7 @@ def run_extract(
'destinationUris': destination_cloud_storage_uris,
'destinationFormat': export_format,
}
} # type: Dict[str, Any]
}

if labels:
configuration['labels'] = labels
Expand Down
36 changes: 10 additions & 26 deletions airflow/providers/google/cloud/transfers/bigquery_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
# under the License.
"""This module contains Google BigQuery to Google Cloud Storage operator."""
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union

from google.cloud.bigquery.table import TableReference
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union

from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
Expand Down Expand Up @@ -128,26 +126,12 @@ def execute(self, context: 'Context'):
location=self.location,
impersonation_chain=self.impersonation_chain,
)

table_ref = TableReference.from_string(self.source_project_dataset_table, hook.project_id)

configuration: Dict[str, Any] = {
'extract': {
'sourceTable': table_ref.to_api_repr(),
'compression': self.compression,
'destinationUris': self.destination_cloud_storage_uris,
'destinationFormat': self.export_format,
}
}

if self.labels:
configuration['labels'] = self.labels

if self.export_format == 'CSV':
# Only set fieldDelimiter and printHeader fields if using CSV.
# Google does not like it if you set these fields for other export
# formats.
configuration['extract']['fieldDelimiter'] = self.field_delimiter
configuration['extract']['printHeader'] = self.print_header

hook.insert_job(configuration=configuration)
hook.run_extract(
source_project_dataset_table=self.source_project_dataset_table,
destination_cloud_storage_uris=self.destination_cloud_storage_uris,
compression=self.compression,
export_format=self.export_format,
field_delimiter=self.field_delimiter,
print_header=self.print_header,
labels=self.labels,
)
32 changes: 11 additions & 21 deletions tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,35 +27,17 @@
PROJECT_ID = 'test-project-id'


class TestBigQueryToCloudStorageOperator(unittest.TestCase):
class TestBigQueryToGCSOperator(unittest.TestCase):
@mock.patch('airflow.providers.google.cloud.transfers.bigquery_to_gcs.BigQueryHook')
def test_execute(self, mock_hook):
source_project_dataset_table = f'{TEST_DATASET}.{TEST_TABLE_ID}'
source_project_dataset_table = f'{PROJECT_ID}:{TEST_DATASET}.{TEST_TABLE_ID}'
destination_cloud_storage_uris = ['gs://some-bucket/some-file.txt']
compression = 'NONE'
export_format = 'CSV'
field_delimiter = ','
print_header = True
labels = {'k1': 'v1'}

mock_hook().project_id = PROJECT_ID

configuration = {
'extract': {
'sourceTable': {
'projectId': mock_hook().project_id,
'datasetId': TEST_DATASET,
'tableId': TEST_TABLE_ID,
},
'compression': compression,
'destinationUris': destination_cloud_storage_uris,
'destinationFormat': export_format,
'fieldDelimiter': field_delimiter,
'printHeader': print_header,
},
'labels': labels,
}

operator = BigQueryToGCSOperator(
task_id=TASK_ID,
source_project_dataset_table=source_project_dataset_table,
Expand All @@ -69,4 +51,12 @@ def test_execute(self, mock_hook):

operator.execute(None)

mock_hook.return_value.insert_job.assert_called_once_with(configuration=configuration)
mock_hook.return_value.run_extract.assert_called_once_with(
source_project_dataset_table=source_project_dataset_table,
destination_cloud_storage_uris=destination_cloud_storage_uris,
compression=compression,
export_format=export_format,
field_delimiter=field_delimiter,
print_header=print_header,
labels=labels,
)

0 comments on commit 02976be

Please sign in to comment.