Skip to content

Commit

Permalink
[AIRFLOW-7105] Unify Secrets Backend method interfaces (#7830)
Browse files Browse the repository at this point in the history
  • Loading branch information
xinbinhuang authored Mar 23, 2020
1 parent 3cc37b1 commit eef87b9
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 105 deletions.
28 changes: 2 additions & 26 deletions airflow/providers/amazon/aws/secrets/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
"""
Objects relating to sourcing connections from AWS SSM Parameter Store
"""
from typing import List, Optional
from typing import Optional

import boto3
from cached_property import cached_property

from airflow.models import Connection
from airflow.secrets import BaseSecretsBackend
from airflow.utils.log.logging_mixin import LoggingMixin

Expand Down Expand Up @@ -62,16 +61,6 @@ def client(self):
session = boto3.Session(profile_name=self.profile_name)
return session.client("ssm")

def build_ssm_path(self, conn_id: str):
"""
Given conn_id, build SSM path.
:param conn_id: connection id
:type conn_id: str
"""
param_path = self.connections_prefix + "/" + conn_id
return param_path

def get_conn_uri(self, conn_id: str) -> Optional[str]:
"""
Get param value
Expand All @@ -80,7 +69,7 @@ def get_conn_uri(self, conn_id: str) -> Optional[str]:
:type conn_id: str
"""

ssm_path = self.build_ssm_path(conn_id=conn_id)
ssm_path = self.build_path(connections_prefix=self.connections_prefix, conn_id=conn_id)
try:
response = self.client.get_parameter(
Name=ssm_path, WithDecryption=False
Expand All @@ -93,16 +82,3 @@ def get_conn_uri(self, conn_id: str) -> Optional[str]:
"Parameter %s not found.", ssm_path
)
return None

def get_connections(self, conn_id: str) -> List[Connection]:
"""
Create connection object.
:param conn_id: connection id
:type conn_id: str
"""
conn_uri = self.get_conn_uri(conn_id=conn_id)
if not conn_uri:
return []
conn = Connection(conn_id=conn_id, uri=conn_uri)
return [conn]
28 changes: 2 additions & 26 deletions airflow/providers/google/cloud/secrets/secrets_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@
"""
Objects relating to sourcing connections from GCP Secrets Manager
"""
from typing import List, Optional
from typing import Optional

from cached_property import cached_property
from google.api_core.exceptions import NotFound
from google.api_core.gapic_v1.client_info import ClientInfo
from google.cloud.secretmanager_v1 import SecretManagerServiceClient

from airflow import version
from airflow.models import Connection
from airflow.providers.google.cloud.utils.credentials_provider import (
_get_scopes, get_credentials_and_project_id,
)
Expand Down Expand Up @@ -87,24 +86,14 @@ def client(self) -> SecretManagerServiceClient:
)
return _client

def build_secret_id(self, conn_id: str) -> str:
"""
Given conn_id, build path for Secrets Manager
:param conn_id: connection id
:type conn_id: str
"""
secret_id = f"{self.connections_prefix}/{conn_id}"
return secret_id

def get_conn_uri(self, conn_id: str) -> Optional[str]:
"""
Get secret value from Secrets Manager.
:param conn_id: connection id
:type conn_id: str
"""
secret_id = self.build_secret_id(conn_id=conn_id)
secret_id = self.build_path(connections_prefix=self.connections_prefix, conn_id=conn_id)
# always return the latest version of the secret
secret_version = "latest"
name = self.client.secret_version_path(self.project_id, secret_id, secret_version)
Expand All @@ -117,16 +106,3 @@ def get_conn_uri(self, conn_id: str) -> Optional[str]:
"GCP API Call Error (NotFound): Secret ID %s not found.", secret_id
)
return None

def get_connections(self, conn_id: str) -> List[Connection]:
"""
Create connection object from GCP Secrets Manager
:param conn_id: connection id
:type conn_id: str
"""
conn_uri = self.get_conn_uri(conn_id=conn_id)
if not conn_uri:
return []
conn = Connection(conn_id=conn_id, uri=conn_uri)
return [conn]
27 changes: 2 additions & 25 deletions airflow/providers/hashicorp/secrets/vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@
"""
Objects relating to sourcing connections from Hashicorp Vault
"""
from typing import List, Optional
from typing import Optional

import hvac
from cached_property import cached_property
from hvac.exceptions import InvalidPath, VaultError

from airflow import AirflowException
from airflow.models import Connection
from airflow.secrets import BaseSecretsBackend
from airflow.utils.log.logging_mixin import LoggingMixin

Expand Down Expand Up @@ -144,23 +143,14 @@ def client(self) -> hvac.Client:
else:
raise VaultError("Vault Authentication Error!")

def build_path(self, conn_id: str):
"""
Given conn_id, build path for Vault Secret
:param conn_id: connection id
:type conn_id: str
"""
return self.connections_path + "/" + conn_id

def get_conn_uri(self, conn_id: str) -> Optional[str]:
"""
Get secret value from Vault. Store the secret in the form of URI
:param conn_id: connection id
:type conn_id: str
"""
secret_path = self.build_path(conn_id=conn_id)
secret_path = self.build_path(connections_prefix=self.connections_path, conn_id=conn_id)

try:
if self.kv_engine_version == 1:
Expand All @@ -176,16 +166,3 @@ def get_conn_uri(self, conn_id: str) -> Optional[str]:

return_data = response["data"] if self.kv_engine_version == 1 else response["data"]["data"]
return return_data.get("conn_uri")

def get_connections(self, conn_id: str) -> List[Connection]:
"""
Get connections with a specific ID
:param conn_id: connection id
:type conn_id: str
"""
conn_uri = self.get_conn_uri(conn_id=conn_id)
if not conn_uri:
return []
conn = Connection(conn_id=conn_id, uri=conn_uri)
return [conn]
20 changes: 1 addition & 19 deletions airflow/secrets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
__all__ = ['BaseSecretsBackend', 'get_connections']

import json
from abc import ABC, abstractmethod
from json import JSONDecodeError
from typing import List

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.secrets.base_secrets import BaseSecretsBackend
from airflow.utils.module_loading import import_string

CONFIG_SECTION = "secrets"
Expand All @@ -41,24 +41,6 @@
]


class BaseSecretsBackend(ABC):
"""
Abstract base class to retrieve secrets given a conn_id and construct a Connection object
"""

def __init__(self, **kwargs):
pass

@abstractmethod
def get_connections(self, conn_id) -> List[Connection]:
"""
Return list of connection objects matching a given ``conn_id``.
:param conn_id: connection id to search for
:return:
"""


def get_connections(conn_id: str) -> List[Connection]:
"""
Get all connections as an iterable.
Expand Down
64 changes: 64 additions & 0 deletions airflow/secrets/base_secrets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from abc import ABC
from typing import List, Optional

from airflow.models import Connection


class BaseSecretsBackend(ABC):
"""
Abstract base class to retrieve secrets given a conn_id and construct a Connection object
"""

def __init__(self, **kwargs):
pass

@staticmethod
def build_path(connections_prefix: str, conn_id: str) -> str:
"""
Given conn_id, build path for Secrets Backend
:param connections_prefix: prefix of the secret to read to get Connections
:type connections_prefix: str
:param conn_id: connection id
:type conn_id: str
"""
return f"{connections_prefix}/{conn_id}"

def get_conn_uri(self, conn_id: str) -> Optional[str]:
"""
Get conn_uri from Secrets Backend
:param conn_id: connection id
:type conn_id: str
"""
raise NotImplementedError()

def get_connections(self, conn_id: str) -> List[Connection]:
"""
Return connection object with a given ``conn_id``.
:param conn_id: connection id
:type conn_id: str
"""
conn_uri = self.get_conn_uri(conn_id=conn_id)
if not conn_uri:
return []
conn = Connection(conn_id=conn_id, uri=conn_uri)
return [conn]
11 changes: 3 additions & 8 deletions airflow/secrets/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
"""

import os
from typing import List
from typing import Optional

from airflow.models import Connection
from airflow.secrets import BaseSecretsBackend

CONN_ENV_PREFIX = "AIRFLOW_CONN_"
Expand All @@ -34,10 +33,6 @@ class EnvironmentVariablesSecretsBackend(BaseSecretsBackend):
"""

# pylint: disable=missing-docstring
def get_connections(self, conn_id) -> List[Connection]:
def get_conn_uri(self, conn_id: str) -> Optional[str]:
environment_uri = os.environ.get(CONN_ENV_PREFIX + conn_id.upper())
if environment_uri:
conn = Connection(conn_id=conn_id, uri=environment_uri)
return [conn]
else:
return []
return environment_uri
6 changes: 6 additions & 0 deletions docs/howto/use-alternative-secrets-backend.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ Roll your own secrets backend
A secrets backend is a subclass of :py:class:`airflow.secrets.BaseSecretsBackend`, and just has to implement the
:py:meth:`~airflow.secrets.BaseSecretsBackend.get_connections` method.

There are two options:

* Option 1: a base implmentation of the :py:meth:`~airflow.secrets.BaseSecretsBackend.get_connections` is provided, you just need to implement the
:py:meth:`~airflow.secrets.BaseSecretsBackend.get_conn_uri` method to make it functional.
* Option 2: simply override the :py:meth:`~airflow.secrets.BaseSecretsBackend.get_connections` method.

Just create your class, and put the fully qualified class name in ``backend`` key in the ``[secrets]``
section of ``airflow.cfg``. You can you can also pass kwargs to ``__init__`` by supplying json to the
``backend_kwargs`` config param. See :ref:`Configuration <secrets_backend_configuration>` for more details,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
MODULE_NAME = "airflow.providers.google.cloud.secrets.secrets_manager"


class TestGcpSecretsManagerBackend(TestCase):
class TestCloudSecretsManagerBackend(TestCase):
@parameterized.expand([
"airflow/connections",
"connections",
Expand Down

0 comments on commit eef87b9

Please sign in to comment.