Skip to content

Commit

Permalink
Enable specifying dictionary paths in template_fields_renderers (#1…
Browse files Browse the repository at this point in the history
…7321)

Added the handling of paths in `template_fields_renderers` which enables information contained in dictionaries to be unpacked and rendered appropriately.
  • Loading branch information
nathadfield committed Aug 2, 2021
1 parent 97428ef commit 67cbb0f
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 5 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2196,7 +2196,7 @@ class BigQueryInsertJobOperator(BaseOperator):
"impersonation_chain",
)
template_ext = (".json",)
template_fields_renderers = {"configuration": "json"}
template_fields_renderers = {"configuration": "json", "configuration.query.query": "sql"}
ui_color = BigQueryUIColors.QUERY.value

def __init__(
Expand Down
47 changes: 45 additions & 2 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,29 @@ def task_group_to_dict(task_group):
}


def get_key_paths(input_dict):
"""Return a list of dot-separated dictionary paths"""
for key, value in input_dict.items():
if isinstance(value, dict):
for sub_key in get_key_paths(value):
yield '.'.join((key, sub_key))
else:
yield key


def get_value_from_path(key_path, content):
"""Return the value from a dictionary based on dot-separated path of keys"""
elem = content
for x in key_path.strip(".").split("."):
try:
x = int(x)
elem = elem[x]
except ValueError:
elem = elem.get(x)

return elem


def dag_edges(dag):
"""
Create the list of edges needed to construct the Graph view.
Expand Down Expand Up @@ -995,11 +1018,31 @@ def rendered_templates(self):
renderer = task.template_fields_renderers.get(template_field, template_field)
if renderer in renderers:
if isinstance(content, (dict, list)):
content = json.dumps(content, sort_keys=True, indent=4)
html_dict[template_field] = renderers[renderer](content)
json_content = json.dumps(content, sort_keys=True, indent=4)
html_dict[template_field] = renderers[renderer](json_content)
else:
html_dict[template_field] = renderers[renderer](content)
else:
html_dict[template_field] = Markup("<pre><code>{}</pre></code>").format(pformat(content))

if isinstance(content, dict):
if template_field == 'op_kwargs':
for key, value in content.items():
renderer = task.template_fields_renderers.get(key, key)
if renderer in renderers:
html_dict['.'.join([template_field, key])] = renderers[renderer](value)
else:
html_dict['.'.join([template_field, key])] = Markup(
"<pre><code>{}</pre></code>"
).format(pformat(value))
else:
for dict_keys in get_key_paths(content):
template_path = '.'.join((template_field, dict_keys))
renderer = task.template_fields_renderers.get(template_path, template_path)
if renderer in renderers:
content_value = get_value_from_path(dict_keys, content)
html_dict[template_path] = renderers[renderer](content_value)

return self.render_template(
'airflow/ti_code.html',
html_dict=html_dict,
Expand Down
38 changes: 37 additions & 1 deletion docs/apache-airflow/howto/custom-operator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ with actual value. Note that Jinja substitutes the operator attributes and not t
In the example, the ``template_fields`` should be ``['guest_name']`` and not ``['name']``

Additionally you may provide ``template_fields_renderers`` dictionary which defines in what style the value
Additionally you may provide ``template_fields_renderers`` a dictionary which defines in what style the value
from template field renders in Web UI. For example:

.. code-block:: python
Expand All @@ -208,12 +208,48 @@ from template field renders in Web UI. For example:
super().__init__(**kwargs)
self.request_body = request_body
In the situation where ``template_field`` is itself a dictionary, it is also possible to specify a
dot-separated key path to extract and render individual elements appropriately. For example:

.. code-block:: python
class MyConfigOperator(BaseOperator):
template_fields = ["configuration"]
template_fields_renderers = {
"configuration": "json",
"configuration.query.sql": "sql",
}
def __init__(self, configuration: dict, **kwargs) -> None:
super().__init__(**kwargs)
self.configuration = configuration
Then using this template as follows:

.. code-block:: python
with dag:
config_task = MyConfigOperator(
task_id="task_id_1",
configuration={"query": {"job_id": "123", "sql": "select * from my_table"}},
dag=dag,
)
This will result in the UI rendering ``configuration`` as json in addition to the value contained in the
configuration at ``query.sql`` to be rendered with the SQL lexer.

.. image:: ../img/template_field_renderer_path.png

Currently available lexers:

- bash
- doc
- hql
- html
- jinja
- json
- md
- powershell
- py
- rst
- sql
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
25 changes: 24 additions & 1 deletion tests/www/views/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from airflow.configuration import initialize_config
from airflow.plugins_manager import AirflowPlugin, EntryPointSource
from airflow.www.views import get_safe_url, truncate_task_duration
from airflow.www.views import get_key_paths, get_safe_url, get_value_from_path, truncate_task_duration
from tests.test_utils.config import conf_vars
from tests.test_utils.mock_plugins import mock_plugin_manager
from tests.test_utils.www import check_content_in_response, check_content_not_in_response
Expand Down Expand Up @@ -243,3 +243,26 @@ def get_task_instance(session, task):
dagrun.refresh_from_db(session=session)
# dagrun should be set to QUEUED
assert dagrun.get_state() == State.QUEUED


TEST_CONTENT_DICT = {"key1": {"key2": "val2", "key3": "val3", "key4": {"key5": "val5"}}}


@pytest.mark.parametrize(
"test_content_dict, expected_paths", [(TEST_CONTENT_DICT, ("key1.key2", "key1.key3", "key1.key4.key5"))]
)
def test_generate_key_paths(test_content_dict, expected_paths):
for key_path in get_key_paths(test_content_dict):
assert key_path in expected_paths


@pytest.mark.parametrize(
"test_content_dict, test_key_path, expected_value",
[
(TEST_CONTENT_DICT, "key1.key2", "val2"),
(TEST_CONTENT_DICT, "key1.key3", "val3"),
(TEST_CONTENT_DICT, "key1.key4.key5", "val5"),
],
)
def test_get_value_from_path(test_content_dict, test_key_path, expected_value):
assert expected_value == get_value_from_path(test_key_path, test_content_dict)

0 comments on commit 67cbb0f

Please sign in to comment.