Skip to content

Commit

Permalink
Fix skipping non-GCS located jars (#22302)
Browse files Browse the repository at this point in the history
* Fix #21989 indentation. A test is added to confirm job is executed on DataFlow with local jar file.

Co-authored-by: Kyaw <[email protected]>
  • Loading branch information
yathit and okisan committed Mar 20, 2022
1 parent 43dfec3 commit a3ffbee
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 22 deletions.
44 changes: 22 additions & 22 deletions airflow/providers/google/cloud/operators/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,33 +410,33 @@ def set_current_job_id(job_id):
tmp_gcs_file = exit_stack.enter_context(gcs_hook.provide_file(object_url=self.jar))
self.jar = tmp_gcs_file.name

is_running = False
if self.check_if_running != CheckJobRunning.IgnoreJob:
is_running = False
if self.check_if_running != CheckJobRunning.IgnoreJob:
is_running = self.dataflow_hook.is_job_dataflow_running(
name=self.job_name,
variables=pipeline_options,
)
while is_running and self.check_if_running == CheckJobRunning.WaitForRun:

is_running = self.dataflow_hook.is_job_dataflow_running(
name=self.job_name,
variables=pipeline_options,
)
while is_running and self.check_if_running == CheckJobRunning.WaitForRun:

is_running = self.dataflow_hook.is_job_dataflow_running(
name=self.job_name,
variables=pipeline_options,
)
if not is_running:
pipeline_options["jobName"] = job_name
with self.dataflow_hook.provide_authorized_gcloud():
self.beam_hook.start_java_pipeline(
variables=pipeline_options,
jar=self.jar,
job_class=self.job_class,
process_line_callback=process_line_callback,
)
self.dataflow_hook.wait_for_done(
job_name=job_name,
location=self.location,
job_id=self.job_id,
multiple_jobs=self.multiple_jobs,
if not is_running:
pipeline_options["jobName"] = job_name
with self.dataflow_hook.provide_authorized_gcloud():
self.beam_hook.start_java_pipeline(
variables=pipeline_options,
jar=self.jar,
job_class=self.job_class,
process_line_callback=process_line_callback,
)
self.dataflow_hook.wait_for_done(
job_name=job_name,
location=self.location,
job_id=self.job_id,
multiple_jobs=self.multiple_jobs,
)

return {"job_id": self.job_id}

Expand Down
63 changes: 63 additions & 0 deletions tests/providers/google/cloud/operators/test_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
PY_FILE = 'gs://my-bucket/my-object.py'
PY_INTERPRETER = 'python3'
JAR_FILE = 'gs://my-bucket/example/test.jar'
LOCAL_JAR_FILE = '/mnt/dev/example/test.jar'
JOB_CLASS = 'com.test.NotMain'
PY_OPTIONS = ['-m']
DEFAULT_OPTIONS_PYTHON = DEFAULT_OPTIONS_JAVA = {
Expand Down Expand Up @@ -380,6 +381,68 @@ def set_is_job_dataflow_running_variables(*args, **kwargs):
)


class TestDataflowJavaOperatorWithLocal(unittest.TestCase):
def setUp(self):
self.dataflow = DataflowCreateJavaJobOperator(
task_id=TASK_ID,
jar=LOCAL_JAR_FILE,
job_name=JOB_NAME,
job_class=JOB_CLASS,
dataflow_default_options=DEFAULT_OPTIONS_JAVA,
options=ADDITIONAL_OPTIONS,
poll_sleep=POLL_SLEEP,
location=TEST_LOCATION,
)
self.expected_airflow_version = 'v' + airflow.version.version.replace(".", "-").replace("+", "-")

def test_init(self):
"""Test DataflowTemplateOperator instance is properly initialized."""
assert self.dataflow.jar == LOCAL_JAR_FILE

@mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook')
@mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook')
def test_check_job_not_running_exec(self, dataflow_hook_mock, beam_hook_mock):
"""Test DataflowHook is created and the right args are passed to
start_java_workflow with option to check if job is running
"""
is_job_dataflow_running_variables = None

def set_is_job_dataflow_running_variables(*args, **kwargs):
nonlocal is_job_dataflow_running_variables
is_job_dataflow_running_variables = copy.deepcopy(kwargs.get("variables"))

dataflow_running = dataflow_hook_mock.return_value.is_job_dataflow_running
dataflow_running.side_effect = set_is_job_dataflow_running_variables
dataflow_running.return_value = False
start_java_mock = beam_hook_mock.return_value.start_java_pipeline
self.dataflow.check_if_running = True

self.dataflow.execute(None)
expected_variables = {
'project': dataflow_hook_mock.return_value.project_id,
'stagingLocation': 'gs://test/staging',
'jobName': JOB_NAME,
'region': TEST_LOCATION,
'output': 'gs://test/output',
'labels': {'foo': 'bar', 'airflow-version': self.expected_airflow_version},
}
self.assertEqual(expected_variables, is_job_dataflow_running_variables)
job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value
expected_variables["jobName"] = job_name
start_java_mock.assert_called_once_with(
variables=expected_variables,
jar=LOCAL_JAR_FILE,
job_class=JOB_CLASS,
process_line_callback=mock.ANY,
)
dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with(
job_id=mock.ANY,
job_name=job_name,
location=TEST_LOCATION,
multiple_jobs=False,
)


class TestDataflowTemplateOperator(unittest.TestCase):
def setUp(self):
self.dataflow = DataflowTemplatedJobStartOperator(
Expand Down

0 comments on commit a3ffbee

Please sign in to comment.