From a3ffbee7c9b5cd8cc5b7b246116f0254f1daa505 Mon Sep 17 00:00:00 2001 From: Kyaw Tun Date: Sun, 20 Mar 2022 19:12:25 +0800 Subject: [PATCH] Fix skipping non-GCS located jars (#22302) * Fix #21989 indentation. A test is added to confirm job is executed on DataFlow with local jar file. Co-authored-by: Kyaw --- .../google/cloud/operators/dataflow.py | 44 ++++++------- .../google/cloud/operators/test_dataflow.py | 63 +++++++++++++++++++ 2 files changed, 85 insertions(+), 22 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataflow.py b/airflow/providers/google/cloud/operators/dataflow.py index 24fe2cd42a202..8cd146965f494 100644 --- a/airflow/providers/google/cloud/operators/dataflow.py +++ b/airflow/providers/google/cloud/operators/dataflow.py @@ -410,33 +410,33 @@ def set_current_job_id(job_id): tmp_gcs_file = exit_stack.enter_context(gcs_hook.provide_file(object_url=self.jar)) self.jar = tmp_gcs_file.name - is_running = False - if self.check_if_running != CheckJobRunning.IgnoreJob: + is_running = False + if self.check_if_running != CheckJobRunning.IgnoreJob: + is_running = self.dataflow_hook.is_job_dataflow_running( + name=self.job_name, + variables=pipeline_options, + ) + while is_running and self.check_if_running == CheckJobRunning.WaitForRun: + is_running = self.dataflow_hook.is_job_dataflow_running( name=self.job_name, variables=pipeline_options, ) - while is_running and self.check_if_running == CheckJobRunning.WaitForRun: - - is_running = self.dataflow_hook.is_job_dataflow_running( - name=self.job_name, - variables=pipeline_options, - ) - if not is_running: - pipeline_options["jobName"] = job_name - with self.dataflow_hook.provide_authorized_gcloud(): - self.beam_hook.start_java_pipeline( - variables=pipeline_options, - jar=self.jar, - job_class=self.job_class, - process_line_callback=process_line_callback, - ) - self.dataflow_hook.wait_for_done( - job_name=job_name, - location=self.location, - job_id=self.job_id, - multiple_jobs=self.multiple_jobs, + if not is_running: + pipeline_options["jobName"] = job_name + with self.dataflow_hook.provide_authorized_gcloud(): + self.beam_hook.start_java_pipeline( + variables=pipeline_options, + jar=self.jar, + job_class=self.job_class, + process_line_callback=process_line_callback, ) + self.dataflow_hook.wait_for_done( + job_name=job_name, + location=self.location, + job_id=self.job_id, + multiple_jobs=self.multiple_jobs, + ) return {"job_id": self.job_id} diff --git a/tests/providers/google/cloud/operators/test_dataflow.py b/tests/providers/google/cloud/operators/test_dataflow.py index 6501466540f50..333bf7c2725cf 100644 --- a/tests/providers/google/cloud/operators/test_dataflow.py +++ b/tests/providers/google/cloud/operators/test_dataflow.py @@ -43,6 +43,7 @@ PY_FILE = 'gs://my-bucket/my-object.py' PY_INTERPRETER = 'python3' JAR_FILE = 'gs://my-bucket/example/test.jar' +LOCAL_JAR_FILE = '/mnt/dev/example/test.jar' JOB_CLASS = 'com.test.NotMain' PY_OPTIONS = ['-m'] DEFAULT_OPTIONS_PYTHON = DEFAULT_OPTIONS_JAVA = { @@ -380,6 +381,68 @@ def set_is_job_dataflow_running_variables(*args, **kwargs): ) +class TestDataflowJavaOperatorWithLocal(unittest.TestCase): + def setUp(self): + self.dataflow = DataflowCreateJavaJobOperator( + task_id=TASK_ID, + jar=LOCAL_JAR_FILE, + job_name=JOB_NAME, + job_class=JOB_CLASS, + dataflow_default_options=DEFAULT_OPTIONS_JAVA, + options=ADDITIONAL_OPTIONS, + poll_sleep=POLL_SLEEP, + location=TEST_LOCATION, + ) + self.expected_airflow_version = 'v' + airflow.version.version.replace(".", "-").replace("+", "-") + + def test_init(self): + """Test DataflowTemplateOperator instance is properly initialized.""" + assert self.dataflow.jar == LOCAL_JAR_FILE + + @mock.patch('airflow.providers.google.cloud.operators.dataflow.BeamHook') + @mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook') + def test_check_job_not_running_exec(self, dataflow_hook_mock, beam_hook_mock): + """Test DataflowHook is created and the right args are passed to + start_java_workflow with option to check if job is running + """ + is_job_dataflow_running_variables = None + + def set_is_job_dataflow_running_variables(*args, **kwargs): + nonlocal is_job_dataflow_running_variables + is_job_dataflow_running_variables = copy.deepcopy(kwargs.get("variables")) + + dataflow_running = dataflow_hook_mock.return_value.is_job_dataflow_running + dataflow_running.side_effect = set_is_job_dataflow_running_variables + dataflow_running.return_value = False + start_java_mock = beam_hook_mock.return_value.start_java_pipeline + self.dataflow.check_if_running = True + + self.dataflow.execute(None) + expected_variables = { + 'project': dataflow_hook_mock.return_value.project_id, + 'stagingLocation': 'gs://test/staging', + 'jobName': JOB_NAME, + 'region': TEST_LOCATION, + 'output': 'gs://test/output', + 'labels': {'foo': 'bar', 'airflow-version': self.expected_airflow_version}, + } + self.assertEqual(expected_variables, is_job_dataflow_running_variables) + job_name = dataflow_hook_mock.return_value.build_dataflow_job_name.return_value + expected_variables["jobName"] = job_name + start_java_mock.assert_called_once_with( + variables=expected_variables, + jar=LOCAL_JAR_FILE, + job_class=JOB_CLASS, + process_line_callback=mock.ANY, + ) + dataflow_hook_mock.return_value.wait_for_done.assert_called_once_with( + job_id=mock.ANY, + job_name=job_name, + location=TEST_LOCATION, + multiple_jobs=False, + ) + + class TestDataflowTemplateOperator(unittest.TestCase): def setUp(self): self.dataflow = DataflowTemplatedJobStartOperator(