Skip to content

Commit

Permalink
Add support for external IdP OIDC token retrieval for Google Cloud Op…
Browse files Browse the repository at this point in the history
…erators. (#39873)

* Add support for external IdP OIDC token retrieval
using OAuth2.0 Crient Credentials Grant for
Google Cloud Operators.

This feature enables OIDC token retrieval from
any generic Identity Provider (IdP) that uses the OAuth 2.0
Credentials Grant Flow. Additionally, it lays the groundwork
for integrating other custom OIDC token retrieval methods.

related: #35899

Co-authored-by: Gonçalo Azevedo <[email protected]>

---------

Co-authored-by: Gonçalo Azevedo <[email protected]>
  • Loading branch information
dybolo and gazev committed Jun 11, 2024
1 parent e7d036a commit a586ea8
Show file tree
Hide file tree
Showing 6 changed files with 543 additions and 0 deletions.
68 changes: 68 additions & 0 deletions airflow/providers/google/cloud/utils/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud._internal_client.secret_manager_client import _SecretManagerClient
from airflow.providers.google.cloud.utils.external_token_supplier import (
ClientCredentialsGrantFlowTokenSupplier,
)
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.process_utils import patch_environ

Expand Down Expand Up @@ -210,6 +213,10 @@ def __init__(
target_principal: str | None = None,
delegates: Sequence[str] | None = None,
is_anonymous: bool | None = None,
idp_issuer_url: str | None = None,
client_id: str | None = None,
client_secret: str | None = None,
idp_extra_params_dict: dict[str, str] | None = None,
) -> None:
super().__init__()
key_options = [key_path, keyfile_dict, credential_config_file, key_secret_name, is_anonymous]
Expand All @@ -229,6 +236,10 @@ def __init__(
self.target_principal = target_principal
self.delegates = delegates
self.is_anonymous = is_anonymous
self.idp_issuer_url = idp_issuer_url
self.client_id = client_id
self.client_secret = client_secret
self.idp_extra_params_dict = idp_extra_params_dict

def get_credentials_and_project(self) -> tuple[Credentials, str]:
"""
Expand All @@ -248,6 +259,10 @@ def get_credentials_and_project(self) -> tuple[Credentials, str]:
credentials, project_id = self._get_credentials_using_key_secret_name()
elif self.keyfile_dict:
credentials, project_id = self._get_credentials_using_keyfile_dict()
elif self.idp_issuer_url:
credentials, project_id = (
self._get_credentials_using_credential_config_file_and_token_supplier()
)
elif self.credential_config_file:
credentials, project_id = self._get_credentials_using_credential_config_file()
else:
Expand Down Expand Up @@ -357,6 +372,24 @@ def _get_credentials_using_credential_config_file(self) -> tuple[Credentials, st

return credentials, project_id

def _get_credentials_using_credential_config_file_and_token_supplier(self):
self._log_info(
"Getting connection using credential configuration file and external Identity Provider."
)

if not self.credential_config_file:
raise AirflowException(
"Credential Configuration File is needed to use authentication by External Identity Provider."
)

info = _get_info_from_credential_configuration_file(self.credential_config_file)
info["subject_token_supplier"] = ClientCredentialsGrantFlowTokenSupplier(
oidc_issuer_url=self.idp_issuer_url, client_id=self.client_id, client_secret=self.client_secret
)

credentials, project_id = google.auth.load_credentials_from_dict(info=info, scopes=self.scopes)
return credentials, project_id

def _get_credentials_using_adc(self) -> tuple[Credentials, str]:
self._log_info(
"Getting connection using `google.auth.default()` since no explicit credentials are provided."
Expand Down Expand Up @@ -426,3 +459,38 @@ def _get_project_id_from_service_account_email(service_account_email: str) -> st
raise AirflowException(
f"Could not extract project_id from service account's email: {service_account_email}."
)


def _get_info_from_credential_configuration_file(
credential_configuration_file: str | dict[str, str],
) -> dict[str, str]:
"""
Extract the Credential Configuration File information, either from a json file, json string or dictionary.
:param credential_configuration_file: File path or content (as json string or dictionary) of a GCP credential configuration file.
:return: Returns a dictionary containing the Credential Configuration File information.
"""
# if it's already a dict, just return it
if isinstance(credential_configuration_file, dict):
return credential_configuration_file

if not isinstance(credential_configuration_file, str):
raise AirflowException(
f"Invalid argument type, expected str or dict, got {type(credential_configuration_file)}."
)

if os.path.exists(credential_configuration_file): # attempts to load from json file
with open(credential_configuration_file) as file_obj:
try:
return json.load(file_obj)
except ValueError:
raise AirflowException(
f"Credential Configuration File '{credential_configuration_file}' is not a valid json file."
)

# if not a file, attempt to load it from a json string
try:
return json.loads(credential_configuration_file)
except ValueError:
raise AirflowException("Credential Configuration File is not a valid json string.")
175 changes: 175 additions & 0 deletions airflow/providers/google/cloud/utils/external_token_supplier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import abc
import time
from functools import wraps
from typing import TYPE_CHECKING, Any

import requests
from google.auth.exceptions import RefreshError
from google.auth.identity_pool import SubjectTokenSupplier

if TYPE_CHECKING:
from google.auth.external_account import SupplierContext
from google.auth.transport import Request

from airflow.utils.log.logging_mixin import LoggingMixin


def cache_token_decorator(get_subject_token_method):
"""Cache calls to ``SubjectTokenSupplier`` instances' ``get_token_supplier`` methods.
Different instances of a same SubjectTokenSupplier class with the same attributes
share the OIDC token cache.
:param get_subject_token_method: A method that returns both a token and an integer specifying
the time in seconds until the token expires
See also:
https://googleapis.dev/python/google-auth/latest/reference/google.auth.identity_pool.html#google.auth.identity_pool.SubjectTokenSupplier.get_subject_token
"""
cache = {}

@wraps(get_subject_token_method)
def wrapper(supplier_instance: CacheTokenSupplier, *args, **kwargs) -> str:
"""Obeys the interface set by ``SubjectTokenSupplier`` for ``get_subject_token`` methods.
:param supplier_instance: the SubjectTokenSupplier instance whose get_subject_token method is being decorated
:return: The token string
"""
nonlocal cache

cache_key = supplier_instance.get_subject_key()
token: dict[str, str | float] = {}

if cache_key not in cache or cache[cache_key]["expiration_time"] < time.monotonic():
supplier_instance.log.info("OIDC token missing or expired")
try:
access_token, expires_in = get_subject_token_method(supplier_instance, *args, **kwargs)
if not isinstance(expires_in, int) or not isinstance(access_token, str):
raise RefreshError # assume error if strange values are provided

except RefreshError:
supplier_instance.log.error("Failed retrieving new OIDC Token from IdP")
raise

expiration_time = time.monotonic() + float(expires_in)
token["access_token"] = access_token
token["expiration_time"] = expiration_time
cache[cache_key] = token

supplier_instance.log.info("New OIDC token retrieved, expires in %s seconds.", expires_in)

return cache[cache_key]["access_token"]

return wrapper


class CacheTokenSupplier(LoggingMixin, SubjectTokenSupplier):
"""
A superclass for all Subject Token Supplier classes that wish to implement a caching mechanism.
Child classes must implement the ``get_subject_key`` method to generate a string that serves as the cache key,
ensuring that tokens are shared appropriately among instances.
Methods:
get_subject_key: Abstract method to be implemented by child classes. It should return a string that serves as the cache key.
"""

def __init__(self):
super().__init__()

@abc.abstractmethod
def get_subject_key(self) -> str:
raise NotImplementedError("")


class ClientCredentialsGrantFlowTokenSupplier(CacheTokenSupplier):
"""
Class that retrieves an OIDC token from an external IdP using OAuth2.0 Client Credentials Grant flow.
This class implements the ``SubjectTokenSupplier`` interface class used by ``google.auth.identity_pool.Credentials``
:params oidc_issuer_url: URL of the IdP that performs OAuth2.0 Client Credentials Grant flow and returns an OIDC token.
:params client_id: Client ID of the application requesting the token
:params client_secret: Client secret of the application requesting the token
:params extra_params_kwargs: Extra parameters to be passed in the payload of the POST request to the `oidc_issuer_url`
See also:
https://googleapis.dev/python/google-auth/latest/reference/google.auth.identity_pool.html#google.auth.identity_pool.SubjectTokenSupplier
"""

def __init__(
self,
oidc_issuer_url: str,
client_id: str,
client_secret: str,
**extra_params_kwargs: Any,
) -> None:
super().__init__()
self.oidc_issuer_url = oidc_issuer_url
self.client_id = client_id
self.client_secret = client_secret
self.extra_params_kwargs = extra_params_kwargs

@cache_token_decorator
def get_subject_token(self, context: SupplierContext, request: Request) -> tuple[str, int]:
"""Perform Client Credentials Grant flow with IdP and retrieves an OIDC token and expiration time."""
self.log.info("Requesting new OIDC token from external IdP.")
try:
response = requests.post(
self.oidc_issuer_url,
data={
"grant_type": "client_credentials",
"client_id": self.client_id,
"client_secret": self.client_secret,
**self.extra_params_kwargs,
},
)
response.raise_for_status()
except requests.HTTPError as e:
raise RefreshError(str(e))
except requests.ConnectionError as e:
raise RefreshError(str(e))

try:
response_dict = response.json()
except requests.JSONDecodeError:
raise RefreshError(f"Didn't get a json response from {self.oidc_issuer_url}")

# These fields are required
if {"access_token", "expires_in"} - set(response_dict.keys()):
# TODO more information about the error can be provided in the exception by inspecting the response
raise RefreshError(f"No access token returned from {self.oidc_issuer_url}")

return response_dict["access_token"], response_dict["expires_in"]

def get_subject_key(self) -> str:
"""
Create a cache key using the OIDC issuer URL, client ID, client secret and additional parameters.
Instances with the same credentials will share tokens.
"""
cache_key = (
self.oidc_issuer_url
+ self.client_id
+ self.client_secret
+ ",".join(sorted(self.extra_params_kwargs))
)
return cache_key
30 changes: 30 additions & 0 deletions airflow/providers/google/common/hooks/base_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,20 @@ def get_connection_form_widgets(cls) -> dict[str, Any]:
"impersonation_chain": StringField(
lazy_gettext("Impersonation Chain"), widget=BS3TextFieldWidget()
),
"idp_issuer_url": StringField(
lazy_gettext("IdP Token Issue URL (Client Credentials Grant Flow)"),
widget=BS3TextFieldWidget(),
),
"client_id": StringField(
lazy_gettext("Client ID (Client Credentials Grant Flow)"), widget=BS3TextFieldWidget()
),
"client_secret": StringField(
lazy_gettext("Client Secret (Client Credentials Grant Flow)"),
widget=BS3PasswordFieldWidget(),
),
"idp_extra_parameters": StringField(
lazy_gettext("IdP Extra Request Parameters"), widget=BS3TextFieldWidget()
),
"is_anonymous": BooleanField(
lazy_gettext("Anonymous credentials (ignores all other settings)"), default=False
),
Expand Down Expand Up @@ -305,6 +319,18 @@ def get_credentials_and_project_id(self) -> tuple[Credentials, str | None]:
target_principal, delegates = _get_target_principal_and_delegates(self.impersonation_chain)
is_anonymous = self._get_field("is_anonymous")

idp_issuer_url: str | None = self._get_field("idp_issuer_url", None)
client_id: str | None = self._get_field("client_id", None)
client_secret: str | None = self._get_field("client_secret", None)
idp_extra_params: str | None = self._get_field("idp_extra_params", None)

idp_extra_params_dict: dict[str, str] | None = None
if idp_extra_params:
try:
idp_extra_params_dict = json.loads(idp_extra_params)
except json.decoder.JSONDecodeError:
raise AirflowException("Invalid JSON.")

credentials, project_id = get_credentials_and_project_id(
key_path=key_path,
keyfile_dict=keyfile_dict_json,
Expand All @@ -316,6 +342,10 @@ def get_credentials_and_project_id(self) -> tuple[Credentials, str | None]:
target_principal=target_principal,
delegates=delegates,
is_anonymous=is_anonymous,
idp_issuer_url=idp_issuer_url,
client_id=client_id,
client_secret=client_secret,
idp_extra_params_dict=idp_extra_params_dict,
)

overridden_project_id = self._get_field("project")
Expand Down
Loading

0 comments on commit a586ea8

Please sign in to comment.