Skip to content

Commit

Permalink
Retry requests in case of error in Google ML Engine Hook (#11712)
Browse files Browse the repository at this point in the history
  • Loading branch information
kosteev authored Oct 21, 2020
1 parent a182291 commit 950c16d
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 40 deletions.
47 changes: 34 additions & 13 deletions airflow/providers/google/cloud/hooks/mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,30 @@
_AIRFLOW_VERSION = 'v' + airflow_version.replace('.', '-').replace('+', '-')


def _poll_with_exponential_delay(request, max_n, is_done_func, is_error_func):

def _poll_with_exponential_delay(request, execute_num_retries, max_n, is_done_func, is_error_func):
"""
Execute request with exponential delay.
This method is intended to handle and retry in case of api-specific errors,
such as 429 "Too Many Requests", unlike the `request.execute` which handles
lower level errors like `ConnectionError`/`socket.timeout`/`ssl.SSLError`.
:param request: request to be executed.
:type request: googleapiclient.http.HttpRequest
:param execute_num_retries: num_retries for `request.execute` method.
:type execute_num_retries: int
:param max_n: number of times to retry request in this method.
:type max_n: int
:param is_done_func: callable to determine if operation is done.
:type is_done_func: callable
:param is_error_func: callable to determine if operation is failed.
:type is_error_func: callable
:return: response
:rtype: httplib2.Response
"""
for i in range(0, max_n):
try:
response = request.execute()
response = request.execute(num_retries=execute_num_retries)
if is_error_func(response):
raise ValueError('The response contained an error: {}'.format(response))
if is_done_func(response):
Expand Down Expand Up @@ -113,7 +132,7 @@ def create_job(self, job: dict, project_id: str, use_existing_job_fn: Optional[C
job_id = job['jobId']

try:
request.execute()
request.execute(num_retries=self.num_retries)
except HttpError as e:
# 409 means there is an existing job with the same job ID.
if e.resp.status == 409:
Expand Down Expand Up @@ -158,7 +177,7 @@ def cancel_job(
request = hook.projects().jobs().cancel(name=f'projects/{project_id}/jobs/{job_id}')

try:
return request.execute()
return request.execute(num_retries=self.num_retries)
except HttpError as e:
if e.resp.status == 404:
self.log.error('Job with job_id %s does not exist. ', job_id)
Expand Down Expand Up @@ -188,7 +207,7 @@ def _get_job(self, project_id: str, job_id: str) -> dict:
request = hook.projects().jobs().get(name=job_name) # pylint: disable=no-member
while True:
try:
return request.execute()
return request.execute(num_retries=self.num_retries)
except HttpError as e:
if e.resp.status == 429:
# polling after 30 seconds when quota failure occurs
Expand Down Expand Up @@ -253,11 +272,12 @@ def create_version(

# pylint: disable=no-member
create_request = hook.projects().models().versions().create(parent=parent_name, body=version_spec)
response = create_request.execute()
response = create_request.execute(num_retries=self.num_retries)
get_request = hook.projects().operations().get(name=response['name']) # pylint: disable=no-member

return _poll_with_exponential_delay(
request=get_request,
execute_num_retries=self.num_retries,
max_n=9,
is_done_func=lambda resp: resp.get('done', False),
is_error_func=lambda resp: resp.get('error', None) is not None,
Expand Down Expand Up @@ -292,7 +312,7 @@ def set_default_version(
request = hook.projects().models().versions().setDefault(name=full_version_name, body={})

try:
response = request.execute()
response = request.execute(num_retries=self.num_retries)
self.log.info('Successfully set version: %s to default', response)
return response
except HttpError as e:
Expand Down Expand Up @@ -325,7 +345,7 @@ def list_versions(
request = hook.projects().models().versions().list(parent=full_parent_name, pageSize=100)

while request is not None:
response = request.execute()
response = request.execute(num_retries=self.num_retries)
result.extend(response.get('versions', []))
# pylint: disable=no-member
request = (
Expand Down Expand Up @@ -362,11 +382,12 @@ def delete_version(
delete_request = (
hook.projects().models().versions().delete(name=full_name) # pylint: disable=no-member
)
response = delete_request.execute()
response = delete_request.execute(num_retries=self.num_retries)
get_request = hook.projects().operations().get(name=response['name']) # pylint: disable=no-member

return _poll_with_exponential_delay(
request=get_request,
execute_num_retries=self.num_retries,
max_n=9,
is_done_func=lambda resp: resp.get('done', False),
is_error_func=lambda resp: resp.get('error', None) is not None,
Expand Down Expand Up @@ -399,7 +420,7 @@ def create_model(
self._append_label(model)
try:
request = hook.projects().models().create(parent=project, body=model) # pylint: disable=no-member
respone = request.execute()
respone = request.execute(num_retries=self.num_retries)
except HttpError as e:
if e.resp.status != 409:
raise e
Expand Down Expand Up @@ -449,7 +470,7 @@ def get_model(
full_model_name = 'projects/{}/models/{}'.format(project_id, model_name)
request = hook.projects().models().get(name=full_model_name) # pylint: disable=no-member
try:
return request.execute()
return request.execute(num_retries=self.num_retries)
except HttpError as e:
if e.resp.status == 404:
self.log.error('Model was not found: %s', e)
Expand Down Expand Up @@ -486,7 +507,7 @@ def delete_model(
self._delete_all_versions(model_name, project_id)
request = hook.projects().models().delete(name=model_path) # pylint: disable=no-member
try:
request.execute()
request.execute(num_retries=self.num_retries)
except HttpError as e:
if e.resp.status == 404:
self.log.error('Model was not found: %s', e)
Expand Down
58 changes: 31 additions & 27 deletions tests/providers/google/cloud/hooks/test_mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_create_version(self, mock_get_conn):
.models()
.versions()
.create(body=version_with_airflow_version, parent=model_path),
mock.call().projects().models().versions().create().execute(),
mock.call().projects().models().versions().create().execute(num_retries=5),
mock.call().projects().operations().get(name=version_name),
],
any_order=True,
Expand Down Expand Up @@ -138,7 +138,7 @@ def test_create_version_with_labels(self, mock_get_conn):
.models()
.versions()
.create(body=version_with_airflow_version, parent=model_path),
mock.call().projects().models().versions().create().execute(),
mock.call().projects().models().versions().create().execute(num_retries=5),
mock.call().projects().operations().get(name=version_name),
],
any_order=True,
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_set_default_version(self, mock_get_conn):
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().versions().setDefault(body={}, name=version_path),
mock.call().projects().models().versions().setDefault().execute(),
mock.call().projects().models().versions().setDefault().execute(num_retries=5),
],
any_order=True,
)
Expand Down Expand Up @@ -208,7 +208,7 @@ def test_list_versions(self, mock_get_conn, mock_sleep):
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().versions().list(pageSize=100, parent=model_path),
mock.call().projects().models().versions().list().execute(),
mock.call().projects().models().versions().list().execute(num_retries=5),
]
+ [
mock.call()
Expand Down Expand Up @@ -257,9 +257,9 @@ def test_delete_version(self, mock_get_conn):
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().versions().delete(name=version_path),
mock.call().projects().models().versions().delete().execute(),
mock.call().projects().models().versions().delete().execute(num_retries=5),
mock.call().projects().operations().get(name=operation_path),
mock.call().projects().operations().get().execute(),
mock.call().projects().operations().get().execute(num_retries=5),
],
any_order=True,
)
Expand Down Expand Up @@ -291,7 +291,7 @@ def test_create_model(self, mock_get_conn):
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().create(body=model_with_airflow_version, parent=project_path),
mock.call().projects().models().create().execute(),
mock.call().projects().models().create().execute(num_retries=5),
]
)

Expand Down Expand Up @@ -354,13 +354,13 @@ def test_create_model_idempotency(self, mock_get_conn):
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().create(body=model_with_airflow_version, parent=project_path),
mock.call().projects().models().create().execute(),
mock.call().projects().models().create().execute(num_retries=5),
]
)
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().get(name='projects/test-project/models/test-model'),
mock.call().projects().models().get().execute(),
mock.call().projects().models().get().execute(num_retries=5),
]
)

Expand Down Expand Up @@ -391,7 +391,7 @@ def test_create_model_with_labels(self, mock_get_conn):
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().create(body=model_with_airflow_version, parent=project_path),
mock.call().projects().models().create().execute(),
mock.call().projects().models().create().execute(num_retries=5),
]
)

Expand All @@ -416,7 +416,7 @@ def test_get_model(self, mock_get_conn):
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().get(name=model_path),
mock.call().projects().models().get().execute(),
mock.call().projects().models().get().execute(num_retries=5),
]
)

Expand All @@ -440,7 +440,7 @@ def test_delete_model(self, mock_get_conn):
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().delete(name=model_path),
mock.call().projects().models().delete().execute(),
mock.call().projects().models().delete().execute(num_retries=5),
]
)

Expand All @@ -467,7 +467,7 @@ def test_delete_model_when_not_exists(self, mock_get_conn, mock_log):
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().delete(name=model_path),
mock.call().projects().models().delete().execute(),
mock.call().projects().models().delete().execute(num_retries=5),
]
)
mock_log.error.assert_called_once_with('Model was not found: %s', http_error)
Expand Down Expand Up @@ -517,7 +517,7 @@ def test_delete_model_with_contents(self, mock_get_conn, mock_sleep):
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().delete(name=model_path),
mock.call().projects().models().delete().execute(),
mock.call().projects().models().delete().execute(num_retries=5),
]
+ [
mock.call()
Expand Down Expand Up @@ -580,7 +580,7 @@ def test_create_mlengine_job(self, mock_get_conn, mock_sleep):
[
mock.call().projects().jobs().create(body=new_job_with_airflow_version, parent=project_path),
mock.call().projects().jobs().get(name=job_path),
mock.call().projects().jobs().get().execute(),
mock.call().projects().jobs().get().execute(num_retries=5),
],
any_order=True,
)
Expand Down Expand Up @@ -632,7 +632,7 @@ def test_create_mlengine_job_with_labels(self, mock_get_conn, mock_sleep):
[
mock.call().projects().jobs().create(body=new_job_with_airflow_version, parent=project_path),
mock.call().projects().jobs().get(name=job_path),
mock.call().projects().jobs().get().execute(),
mock.call().projects().jobs().get().execute(num_retries=5),
],
any_order=True,
)
Expand Down Expand Up @@ -671,9 +671,9 @@ def test_create_mlengine_job_reuse_existing_job_by_default(self, mock_get_conn):
mock_get_conn.assert_has_calls(
[
mock.call().projects().jobs().create(body=job_succeeded, parent=project_path),
mock.call().projects().jobs().create().execute(),
mock.call().projects().jobs().create().execute(num_retries=5),
mock.call().projects().jobs().get(name=job_path),
mock.call().projects().jobs().get().execute(),
mock.call().projects().jobs().get().execute(num_retries=5),
],
any_order=True,
)
Expand Down Expand Up @@ -778,6 +778,7 @@ def test_cancel_mlengine_job(self, mock_get_conn):
mock_get_conn.assert_has_calls(
[
mock.call().projects().jobs().cancel(name=job_path),
mock.call().projects().jobs().cancel().execute(num_retries=5),
],
any_order=True,
)
Expand Down Expand Up @@ -840,6 +841,7 @@ def test_cancel_mlengine_job_completed_job(self, mock_get_conn):
mock_get_conn.assert_has_calls(
[
mock.call().projects().jobs().cancel(name=job_path),
mock.call().projects().jobs().cancel().execute(num_retries=5),
],
any_order=True,
)
Expand Down Expand Up @@ -892,8 +894,9 @@ def test_create_version(self, mock_get_conn, mock_project_id):
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().versions().create(body=version, parent=model_path),
mock.call().projects().models().versions().create().execute(),
mock.call().projects().models().versions().create().execute(num_retries=5),
mock.call().projects().operations().get(name=version_name),
mock.call().projects().operations().get().execute(num_retries=5),
],
any_order=True,
)
Expand Down Expand Up @@ -932,7 +935,7 @@ def test_set_default_version(self, mock_get_conn, mock_project_id):
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().versions().setDefault(body={}, name=version_path),
mock.call().projects().models().versions().setDefault().execute(),
mock.call().projects().models().versions().setDefault().execute(num_retries=5),
],
any_order=True,
)
Expand Down Expand Up @@ -974,7 +977,7 @@ def test_list_versions(self, mock_get_conn, mock_sleep, mock_project_id):
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().versions().list(pageSize=100, parent=model_path),
mock.call().projects().models().versions().list().execute(),
mock.call().projects().models().versions().list().execute(num_retries=5),
]
+ [
mock.call()
Expand Down Expand Up @@ -1032,9 +1035,9 @@ def test_delete_version(self, mock_get_conn, mock_project_id):
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().versions().delete(name=version_path),
mock.call().projects().models().versions().delete().execute(),
mock.call().projects().models().versions().delete().execute(num_retries=5),
mock.call().projects().operations().get(name=operation_path),
mock.call().projects().operations().get().execute(),
mock.call().projects().operations().get().execute(num_retries=5),
],
any_order=True,
)
Expand Down Expand Up @@ -1066,7 +1069,7 @@ def test_create_model(self, mock_get_conn, mock_project_id):
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().create(body=model, parent=project_path),
mock.call().projects().models().create().execute(),
mock.call().projects().models().create().execute(num_retries=5),
]
)

Expand Down Expand Up @@ -1097,7 +1100,7 @@ def test_get_model(self, mock_get_conn, mock_project_id):
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().get(name=model_path),
mock.call().projects().models().get().execute(),
mock.call().projects().models().get().execute(num_retries=5),
]
)

Expand Down Expand Up @@ -1125,7 +1128,7 @@ def test_delete_model(self, mock_get_conn, mock_project_id):
mock_get_conn.assert_has_calls(
[
mock.call().projects().models().delete(name=model_path),
mock.call().projects().models().delete().execute(),
mock.call().projects().models().delete().execute(num_retries=5),
]
)

Expand Down Expand Up @@ -1175,7 +1178,7 @@ def test_create_mlengine_job(self, mock_get_conn, mock_sleep, mock_project_id):
[
mock.call().projects().jobs().create(body=new_job, parent=project_path),
mock.call().projects().jobs().get(name=job_path),
mock.call().projects().jobs().get().execute(),
mock.call().projects().jobs().get().execute(num_retries=5),
],
any_order=True,
)
Expand Down Expand Up @@ -1206,6 +1209,7 @@ def test_cancel_mlengine_job(self, mock_get_conn, mock_project_id):
mock_get_conn.assert_has_calls(
[
mock.call().projects().jobs().cancel(name=job_path),
mock.call().projects().jobs().cancel().execute(num_retries=5),
],
any_order=True,
)

0 comments on commit 950c16d

Please sign in to comment.