Skip to content

Commit

Permalink
Remove deprecated method call (blob.download_as_string) (#20091)
Browse files Browse the repository at this point in the history
  • Loading branch information
kazanzhy authored Dec 7, 2021
1 parent 3a0f554 commit 50bf536
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def download(
self.log.info('File downloaded to %s', filename)
return filename
else:
return blob.download_as_string()
return blob.download_as_bytes()

except GoogleCloudError:
if num_file_attempts == num_max_attempts:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ def execute(self, context) -> None:
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
schema_fields = json.loads(gcs_hook.download(gcs_bucket, gcs_object))
schema_fields = json.loads(gcs_hook.download(gcs_bucket, gcs_object).decode("utf-8"))
else:
schema_fields = self.schema_fields

Expand Down Expand Up @@ -1203,7 +1203,7 @@ def execute(self, context) -> None:
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
schema_fields = json.loads(gcs_hook.download(self.bucket, self.schema_object))
schema_fields = json.loads(gcs_hook.download(self.bucket, self.schema_object).decode("utf-8"))
else:
schema_fields = self.schema_fields

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def apply_validate_fn(*args, templates_dict, **kwargs):
raise ValueError(f"Wrong format prediction_path: {prediction_path}")
summary = os.path.join(obj.strip("/"), "prediction.summary.json")
gcs_hook = GCSHook()
summary = json.loads(gcs_hook.download(bucket, summary))
summary = json.loads(gcs_hook.download(bucket, summary).decode("utf-8"))
return validate_fn(summary)

evaluate_validation = PythonOperator(
Expand Down
16 changes: 8 additions & 8 deletions tests/providers/google/cloud/hooks/test_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,12 +641,12 @@ def test_compose_without_destination_object(self, mock_service):
assert str(ctx.value) == 'bucket_name and destination_object cannot be empty.'

@mock.patch(GCS_STRING.format('GCSHook.get_conn'))
def test_download_as_string(self, mock_service):
def test_download_as_bytes(self, mock_service):
test_bucket = 'test_bucket'
test_object = 'test_object'
test_object_bytes = io.BytesIO(b"input")

download_method = mock_service.return_value.bucket.return_value.blob.return_value.download_as_string
download_method = mock_service.return_value.bucket.return_value.blob.return_value.download_as_bytes
download_method.return_value = test_object_bytes

response = self.gcs_hook.download(bucket_name=test_bucket, object_name=test_object, filename=None)
Expand All @@ -666,10 +666,10 @@ def test_download_to_file(self, mock_service):
)
download_filename_method.return_value = None

download_as_a_string_method = (
mock_service.return_value.bucket.return_value.blob.return_value.download_as_string
download_as_a_bytes_method = (
mock_service.return_value.bucket.return_value.blob.return_value.download_as_bytes
)
download_as_a_string_method.return_value = test_object_bytes
download_as_a_bytes_method.return_value = test_object_bytes
response = self.gcs_hook.download(
bucket_name=test_bucket, object_name=test_object, filename=test_file
)
Expand All @@ -690,10 +690,10 @@ def test_provide_file(self, mock_service, mock_temp_file):
)
download_filename_method.return_value = None

download_as_a_string_method = (
mock_service.return_value.bucket.return_value.blob.return_value.download_as_string
download_as_a_bytes_method = (
mock_service.return_value.bucket.return_value.blob.return_value.download_as_bytes
)
download_as_a_string_method.return_value = test_object_bytes
download_as_a_bytes_method.return_value = test_object_bytes
mock_temp_file.return_value.__enter__.return_value = mock.MagicMock()
mock_temp_file.return_value.__enter__.return_value.name = test_file

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_successful_run(self):

with patch('airflow.providers.google.cloud.utils.mlengine_operator_utils.GCSHook') as mock_gcs_hook:
hook_instance = mock_gcs_hook.return_value
hook_instance.download.return_value = '{"err": 0.9, "count": 9}'
hook_instance.download.return_value = b'{"err": 0.9, "count": 9}'
result = validate.execute({})
hook_instance.download.assert_called_once_with(
'legal-bucket', 'fake-output-path/prediction.summary.json'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_apply_validate_fn(self, mock_beam_pipeline, mock_python, mock_download)

_, _, evaluate_validation = result

mock_download.return_value = json.dumps({"err": 0.3, "mse": 0.04, "count": 1100})
mock_download.return_value = json.dumps({"err": 0.3, "mse": 0.04, "count": 1100}).encode("utf-8")
templates_dict = {"prediction_path": PREDICTION_PATH}
with pytest.raises(ValueError) as ctx:
evaluate_validation.python_callable(templates_dict=templates_dict)
Expand Down

0 comments on commit 50bf536

Please sign in to comment.