From 4212c4932433a50bda09f3e771a02f5ded4553a7 Mon Sep 17 00:00:00 2001 From: Josh Fell <48934154+josh-fell@users.noreply.github.com> Date: Sun, 14 Nov 2021 15:56:35 -0500 Subject: [PATCH] Update Azure modules to comply with AIP-21 (#19431) --- .../hooks/azure_container_instance_hook.py | 12 +- .../hooks/azure_container_registry_hook.py | 12 +- .../hooks/azure_container_volume_hook.py | 6 +- airflow/contrib/hooks/azure_cosmos_hook.py | 6 +- airflow/contrib/hooks/azure_data_lake_hook.py | 6 +- airflow/contrib/hooks/azure_fileshare_hook.py | 6 +- .../azure_container_instances_operator.py | 6 +- .../operators/azure_cosmos_operator.py | 6 +- airflow/contrib/secrets/azure_key_vault.py | 6 +- .../contrib/sensors/azure_cosmos_sensor.py | 6 +- .../google/cloud/transfers/adls_to_gcs.py | 2 +- .../cloud/transfers/azure_fileshare_to_gcs.py | 2 +- .../example_azure_container_instances.py | 4 +- .../example_dags/example_azure_cosmosdb.py | 4 +- .../azure/example_dags/example_fileshare.py | 2 +- .../microsoft/azure/hooks/azure_batch.py | 384 +---------------- .../azure/hooks/azure_container_instance.py | 144 +------ .../azure/hooks/azure_container_registry.py | 55 +-- .../azure/hooks/azure_container_volume.py | 95 +---- .../microsoft/azure/hooks/azure_cosmos.py | 353 +--------------- .../microsoft/azure/hooks/azure_data_lake.py | 234 +---------- .../microsoft/azure/hooks/azure_fileshare.py | 325 +------------- .../providers/microsoft/azure/hooks/batch.py | 395 ++++++++++++++++++ .../azure/hooks/container_instance.py | 157 +++++++ .../azure/hooks/container_registry.py | 66 +++ .../microsoft/azure/hooks/container_volume.py | 106 +++++ .../providers/microsoft/azure/hooks/cosmos.py | 353 ++++++++++++++++ .../microsoft/azure/hooks/data_lake.py | 245 +++++++++++ .../microsoft/azure/hooks/fileshare.py | 336 +++++++++++++++ .../microsoft/azure/operators/adls.py | 2 +- .../microsoft/azure/operators/azure_batch.py | 347 +-------------- .../operators/azure_container_instances.py | 381 +---------------- .../microsoft/azure/operators/azure_cosmos.py | 59 +-- .../microsoft/azure/operators/batch.py | 358 ++++++++++++++++ .../azure/operators/container_instances.py | 390 +++++++++++++++++ .../microsoft/azure/operators/cosmos.py | 70 ++++ .../providers/microsoft/azure/provider.yaml | 40 +- .../azure/secrets/azure_key_vault.py | 165 +------- .../microsoft/azure/secrets/key_vault.py | 176 ++++++++ .../microsoft/azure/sensors/azure_cosmos.py | 57 +-- .../microsoft/azure/sensors/cosmos.py | 68 +++ .../azure/transfers/local_to_adls.py | 2 +- .../transfers/oracle_to_azure_data_lake.py | 2 +- .../prepare_provider_packages.py | 13 + .../secrets-backends/azure-key-vault.rst | 6 +- tests/deprecated_classes.py | 15 +- .../microsoft/azure/hooks/test_azure_batch.py | 16 +- .../hooks/test_azure_container_instance.py | 2 +- .../hooks/test_azure_container_registry.py | 2 +- .../hooks/test_azure_container_volume.py | 2 +- .../azure/hooks/test_azure_cosmos.py | 28 +- .../azure/hooks/test_azure_data_lake.py | 52 +-- .../azure/hooks/test_azure_fileshare.py | 26 +- .../azure/operators/test_azure_batch.py | 8 +- .../test_azure_container_instances.py | 46 +- .../azure/operators/test_azure_cosmos.py | 4 +- .../azure/secrets/test_azure_key_vault.py | 24 +- .../azure/sensors/test_azure_cosmos.py | 6 +- tests/test_utils/azure_system_helpers.py | 2 +- 59 files changed, 3007 insertions(+), 2696 deletions(-) create mode 100644 airflow/providers/microsoft/azure/hooks/batch.py create mode 100644 airflow/providers/microsoft/azure/hooks/container_instance.py create mode 100644 airflow/providers/microsoft/azure/hooks/container_registry.py create mode 100644 airflow/providers/microsoft/azure/hooks/container_volume.py create mode 100644 airflow/providers/microsoft/azure/hooks/cosmos.py create mode 100644 airflow/providers/microsoft/azure/hooks/data_lake.py create mode 100644 airflow/providers/microsoft/azure/hooks/fileshare.py create mode 100644 airflow/providers/microsoft/azure/operators/batch.py create mode 100644 airflow/providers/microsoft/azure/operators/container_instances.py create mode 100644 airflow/providers/microsoft/azure/operators/cosmos.py create mode 100644 airflow/providers/microsoft/azure/secrets/key_vault.py create mode 100644 airflow/providers/microsoft/azure/sensors/cosmos.py diff --git a/airflow/contrib/hooks/azure_container_instance_hook.py b/airflow/contrib/hooks/azure_container_instance_hook.py index 5b40f9c9edaa1..9fefa5c679d38 100644 --- a/airflow/contrib/hooks/azure_container_instance_hook.py +++ b/airflow/contrib/hooks/azure_container_instance_hook.py @@ -15,20 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -This module is deprecated. -Please use :mod:`airflow.providers.microsoft.azure.hooks.azure_container_instance`. -""" +"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.container_instance`.""" import warnings -from airflow.providers.microsoft.azure.hooks.azure_container_instance import ( # noqa - AzureContainerInstanceHook, -) +from airflow.providers.microsoft.azure.hooks.container_instance import AzureContainerInstanceHook # noqa warnings.warn( - "This module is deprecated. " - "Please use `airflow.providers.microsoft.azure.hooks.azure_container_instance`.", + "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_instance`.", DeprecationWarning, stacklevel=2, ) diff --git a/airflow/contrib/hooks/azure_container_registry_hook.py b/airflow/contrib/hooks/azure_container_registry_hook.py index 840cf89cf943c..14e55ef820737 100644 --- a/airflow/contrib/hooks/azure_container_registry_hook.py +++ b/airflow/contrib/hooks/azure_container_registry_hook.py @@ -15,20 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -This module is deprecated. -Please use `airflow.providers.microsoft.azure.hooks.azure_container_registry`. -""" +"""This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_registry`.""" import warnings -from airflow.providers.microsoft.azure.hooks.azure_container_registry import ( # noqa - AzureContainerRegistryHook, -) +from airflow.providers.microsoft.azure.hooks.container_registry import AzureContainerRegistryHook # noqa warnings.warn( - "This module is deprecated. " - "Please use `airflow.providers.microsoft.azure.hooks.azure_container_registry`.", + "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_registry`.", DeprecationWarning, stacklevel=2, ) diff --git a/airflow/contrib/hooks/azure_container_volume_hook.py b/airflow/contrib/hooks/azure_container_volume_hook.py index 4b325ad266770..facfdaca4cc8a 100644 --- a/airflow/contrib/hooks/azure_container_volume_hook.py +++ b/airflow/contrib/hooks/azure_container_volume_hook.py @@ -17,15 +17,15 @@ # under the License. """ This module is deprecated. -Please use :mod:`airflow.providers.microsoft.azure.hooks.azure_container_volume`. +Please use :mod:`airflow.providers.microsoft.azure.hooks.container_volume`. """ import warnings -from airflow.providers.microsoft.azure.hooks.azure_container_volume import AzureContainerVolumeHook # noqa +from airflow.providers.microsoft.azure.hooks.container_volume import AzureContainerVolumeHook # noqa warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.azure_container_volume`.", + "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_volume`.", DeprecationWarning, stacklevel=2, ) diff --git a/airflow/contrib/hooks/azure_cosmos_hook.py b/airflow/contrib/hooks/azure_cosmos_hook.py index 26abe6194a641..4152f15e9d7b1 100644 --- a/airflow/contrib/hooks/azure_cosmos_hook.py +++ b/airflow/contrib/hooks/azure_cosmos_hook.py @@ -15,14 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.azure_cosmos`.""" +"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.cosmos`.""" import warnings -from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook # noqa +from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook # noqa warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.azure_cosmos`.", + "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.cosmos`.", DeprecationWarning, stacklevel=2, ) diff --git a/airflow/contrib/hooks/azure_data_lake_hook.py b/airflow/contrib/hooks/azure_data_lake_hook.py index a89961d73c6f5..3442d1345078c 100644 --- a/airflow/contrib/hooks/azure_data_lake_hook.py +++ b/airflow/contrib/hooks/azure_data_lake_hook.py @@ -15,14 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.azure_data_lake`.""" +"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.data_lake`.""" import warnings -from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook # noqa +from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook # noqa warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.azure_data_lake`.", + "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.data_lake`.", DeprecationWarning, stacklevel=2, ) diff --git a/airflow/contrib/hooks/azure_fileshare_hook.py b/airflow/contrib/hooks/azure_fileshare_hook.py index 2c49d41b825fa..f0a5b2ec4c9bb 100644 --- a/airflow/contrib/hooks/azure_fileshare_hook.py +++ b/airflow/contrib/hooks/azure_fileshare_hook.py @@ -15,14 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.azure_fileshare`.""" +"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.fileshare`.""" import warnings -from airflow.providers.microsoft.azure.hooks.azure_fileshare import AzureFileShareHook # noqa +from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook # noqa warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.azure_fileshare`.", + "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.fileshare`.", DeprecationWarning, stacklevel=2, ) diff --git a/airflow/contrib/operators/azure_container_instances_operator.py b/airflow/contrib/operators/azure_container_instances_operator.py index 8b6a32a8ae216..d084748ca06df 100644 --- a/airflow/contrib/operators/azure_container_instances_operator.py +++ b/airflow/contrib/operators/azure_container_instances_operator.py @@ -17,17 +17,17 @@ # under the License. """ This module is deprecated. Please use -`airflow.providers.microsoft.azure.operators.azure_container_instances`. +`airflow.providers.microsoft.azure.operators.container_instances`. """ import warnings -from airflow.providers.microsoft.azure.operators.azure_container_instances import ( # noqa +from airflow.providers.microsoft.azure.operators.container_instances import ( # noqa AzureContainerInstancesOperator, ) warnings.warn( "This module is deprecated. " - "Please use `airflow.providers.microsoft.azure.operators.azure_container_instances`.", + "Please use `airflow.providers.microsoft.azure.operators.container_instances`.", DeprecationWarning, stacklevel=2, ) diff --git a/airflow/contrib/operators/azure_cosmos_operator.py b/airflow/contrib/operators/azure_cosmos_operator.py index 6c087285ce8f7..269c8357c02d3 100644 --- a/airflow/contrib/operators/azure_cosmos_operator.py +++ b/airflow/contrib/operators/azure_cosmos_operator.py @@ -15,14 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.operators.azure_cosmos`.""" +"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.operators.cosmos`.""" import warnings -from airflow.providers.microsoft.azure.operators.azure_cosmos import AzureCosmosInsertDocumentOperator # noqa +from airflow.providers.microsoft.azure.operators.cosmos import AzureCosmosInsertDocumentOperator # noqa warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.operators.azure_cosmos`.", + "This module is deprecated. Please use `airflow.providers.microsoft.azure.operators.cosmos`.", DeprecationWarning, stacklevel=2, ) diff --git a/airflow/contrib/secrets/azure_key_vault.py b/airflow/contrib/secrets/azure_key_vault.py index f254dedf23c3f..000ae92b3ac28 100644 --- a/airflow/contrib/secrets/azure_key_vault.py +++ b/airflow/contrib/secrets/azure_key_vault.py @@ -16,14 +16,14 @@ # specific language governing permissions and limitations # under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.secrets.azure_key_vault`.""" +"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.secrets.key_vault`.""" import warnings -from airflow.providers.microsoft.azure.secrets.azure_key_vault import AzureKeyVaultBackend # noqa +from airflow.providers.microsoft.azure.secrets.key_vault import AzureKeyVaultBackend # noqa warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.secrets.azure_key_vault`.", + "This module is deprecated. Please use `airflow.providers.microsoft.azure.secrets.key_vault`.", DeprecationWarning, stacklevel=2, ) diff --git a/airflow/contrib/sensors/azure_cosmos_sensor.py b/airflow/contrib/sensors/azure_cosmos_sensor.py index b7c357d9658c8..fc3df4f26615e 100644 --- a/airflow/contrib/sensors/azure_cosmos_sensor.py +++ b/airflow/contrib/sensors/azure_cosmos_sensor.py @@ -15,14 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.sensors.azure_cosmos`.""" +"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.sensors.cosmos`.""" import warnings -from airflow.providers.microsoft.azure.sensors.azure_cosmos import AzureCosmosDocumentSensor # noqa +from airflow.providers.microsoft.azure.sensors.cosmos import AzureCosmosDocumentSensor # noqa warnings.warn( - "This module is deprecated. Please use `airflow.providers.microsoft.azure.sensors.azure_cosmos`.", + "This module is deprecated. Please use `airflow.providers.microsoft.azure.sensors.cosmos`.", DeprecationWarning, stacklevel=2, ) diff --git a/airflow/providers/google/cloud/transfers/adls_to_gcs.py b/airflow/providers/google/cloud/transfers/adls_to_gcs.py index 78a2d8e2d8646..763e779a02120 100644 --- a/airflow/providers/google/cloud/transfers/adls_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/adls_to_gcs.py @@ -25,7 +25,7 @@ from typing import Optional, Sequence, Union from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url -from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook +from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook from airflow.providers.microsoft.azure.operators.adls import ADLSListOperator diff --git a/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py b/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py index 8c3e4295dc478..949c3e232ba42 100644 --- a/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/azure_fileshare_to_gcs.py @@ -22,7 +22,7 @@ from airflow import AirflowException from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url, gcs_object_is_directory -from airflow.providers.microsoft.azure.hooks.azure_fileshare import AzureFileShareHook +from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook class AzureFileShareToGCSOperator(BaseOperator): diff --git a/airflow/providers/microsoft/azure/example_dags/example_azure_container_instances.py b/airflow/providers/microsoft/azure/example_dags/example_azure_container_instances.py index 481c727445ae8..42f258a54a415 100644 --- a/airflow/providers/microsoft/azure/example_dags/example_azure_container_instances.py +++ b/airflow/providers/microsoft/azure/example_dags/example_azure_container_instances.py @@ -21,9 +21,7 @@ from datetime import datetime, timedelta from airflow import DAG -from airflow.providers.microsoft.azure.operators.azure_container_instances import ( - AzureContainerInstancesOperator, -) +from airflow.providers.microsoft.azure.operators.container_instances import AzureContainerInstancesOperator with DAG( dag_id='aci_example', diff --git a/airflow/providers/microsoft/azure/example_dags/example_azure_cosmosdb.py b/airflow/providers/microsoft/azure/example_dags/example_azure_cosmosdb.py index 5736fa0d29fd7..249c2bed0a06b 100644 --- a/airflow/providers/microsoft/azure/example_dags/example_azure_cosmosdb.py +++ b/airflow/providers/microsoft/azure/example_dags/example_azure_cosmosdb.py @@ -29,8 +29,8 @@ from datetime import datetime from airflow import DAG -from airflow.providers.microsoft.azure.operators.azure_cosmos import AzureCosmosInsertDocumentOperator -from airflow.providers.microsoft.azure.sensors.azure_cosmos import AzureCosmosDocumentSensor +from airflow.providers.microsoft.azure.operators.cosmos import AzureCosmosInsertDocumentOperator +from airflow.providers.microsoft.azure.sensors.cosmos import AzureCosmosDocumentSensor with DAG( dag_id='example_azure_cosmosdb_sensor', diff --git a/airflow/providers/microsoft/azure/example_dags/example_fileshare.py b/airflow/providers/microsoft/azure/example_dags/example_fileshare.py index c6c702d1e78c3..d50db3cb04027 100644 --- a/airflow/providers/microsoft/azure/example_dags/example_fileshare.py +++ b/airflow/providers/microsoft/azure/example_dags/example_fileshare.py @@ -19,7 +19,7 @@ from airflow.decorators import task from airflow.models import DAG -from airflow.providers.microsoft.azure.hooks.azure_fileshare import AzureFileShareHook +from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook NAME = 'myfileshare' DIRECTORY = "mydirectory" diff --git a/airflow/providers/microsoft/azure/hooks/azure_batch.py b/airflow/providers/microsoft/azure/hooks/azure_batch.py index d60ab0579bb61..96e468c173460 100644 --- a/airflow/providers/microsoft/azure/hooks/azure_batch.py +++ b/airflow/providers/microsoft/azure/hooks/azure_batch.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,381 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -import time -from datetime import timedelta -from typing import Any, Dict, Optional, Set - -from azure.batch import BatchServiceClient, batch_auth, models as batch_models -from azure.batch.models import JobAddParameter, PoolAddParameter, TaskAddParameter - -from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook -from airflow.models import Connection -from airflow.utils import timezone - - -class AzureBatchHook(BaseHook): - """ - Hook for Azure Batch APIs - - :param azure_batch_conn_id: :ref:`Azure Batch connection id` - of a service principal which will be used to start the container instance. - :type azure_batch_conn_id: str - """ - - conn_name_attr = 'azure_batch_conn_id' - default_conn_name = 'azure_batch_default' - conn_type = 'azure_batch' - hook_name = 'Azure Batch Service' - - @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: - """Returns connection widgets to add to connection form""" - from flask_appbuilder.fieldwidgets import BS3TextFieldWidget - from flask_babel import lazy_gettext - from wtforms import StringField - - return { - "extra__azure_batch__account_url": StringField( - lazy_gettext('Batch Account URL'), widget=BS3TextFieldWidget() - ), - } - - @staticmethod - def get_ui_field_behaviour() -> Dict: - """Returns custom field behaviour""" - return { - "hidden_fields": ['schema', 'port', 'host', 'extra'], - "relabeling": { - 'login': 'Batch Account Name', - 'password': 'Batch Account Access Key', - }, - } - - def __init__(self, azure_batch_conn_id: str = default_conn_name) -> None: - super().__init__() - self.conn_id = azure_batch_conn_id - self.connection = self.get_conn() - - def _connection(self) -> Connection: - """Get connected to Azure Batch service""" - conn = self.get_connection(self.conn_id) - return conn - - def get_conn(self): - """ - Get the Batch client connection - - :return: Azure Batch client - """ - conn = self._connection() - - batch_account_url = conn.extra_dejson.get('extra__azure_batch__account_url') - if not batch_account_url: - raise AirflowException('Batch Account URL parameter is missing.') - - credentials = batch_auth.SharedKeyCredentials(conn.login, conn.password) - batch_client = BatchServiceClient(credentials, batch_url=batch_account_url) - return batch_client - - def configure_pool( - self, - pool_id: str, - vm_size: Optional[str] = None, - vm_publisher: Optional[str] = None, - vm_offer: Optional[str] = None, - sku_starts_with: Optional[str] = None, - vm_sku: Optional[str] = None, - vm_version: Optional[str] = None, - vm_node_agent_sku_id: Optional[str] = None, - os_family: Optional[str] = None, - os_version: Optional[str] = None, - display_name: Optional[str] = None, - target_dedicated_nodes: Optional[int] = None, - use_latest_image_and_sku: bool = False, - **kwargs, - ) -> PoolAddParameter: - """ - Configures a pool - - :param pool_id: A string that uniquely identifies the Pool within the Account - :type pool_id: str - - :param vm_size: The size of virtual machines in the Pool. - :type vm_size: str - - :param display_name: The display name for the Pool - :type display_name: str - - :param target_dedicated_nodes: The desired number of dedicated Compute Nodes in the Pool. - :type target_dedicated_nodes: Optional[int] - - :param use_latest_image_and_sku: Whether to use the latest verified vm image and sku - :type use_latest_image_and_sku: bool - - :param vm_publisher: The publisher of the Azure Virtual Machines Marketplace Image. - For example, Canonical or MicrosoftWindowsServer. - :type vm_publisher: Optional[str] - - :param vm_offer: The offer type of the Azure Virtual Machines Marketplace Image. - For example, UbuntuServer or WindowsServer. - :type vm_offer: Optional[str] - - :param sku_starts_with: The start name of the sku to search - :type sku_starts_with: Optional[str] - - :param vm_sku: The name of the virtual machine sku to use - :type vm_sku: Optional[str] - - :param vm_version: The version of the virtual machine - :param vm_version: str - - :param vm_node_agent_sku_id: The node agent sku id of the virtual machine - :type vm_node_agent_sku_id: Optional[str] - - :param os_family: The Azure Guest OS family to be installed on the virtual machines in the Pool. - :type os_family: Optional[str] - - :param os_version: The OS family version - :type os_version: Optional[str] - - """ - if use_latest_image_and_sku: - self.log.info('Using latest verified virtual machine image with node agent sku') - sku_to_use, image_ref_to_use = self._get_latest_verified_image_vm_and_sku( - publisher=vm_publisher, offer=vm_offer, sku_starts_with=sku_starts_with - ) - pool = batch_models.PoolAddParameter( - id=pool_id, - vm_size=vm_size, - display_name=display_name, - virtual_machine_configuration=batch_models.VirtualMachineConfiguration( - image_reference=image_ref_to_use, node_agent_sku_id=sku_to_use - ), - target_dedicated_nodes=target_dedicated_nodes, - **kwargs, - ) - - elif os_family: - self.log.info( - 'Using cloud service configuration to create pool, virtual machine configuration ignored' - ) - pool = batch_models.PoolAddParameter( - id=pool_id, - vm_size=vm_size, - display_name=display_name, - cloud_service_configuration=batch_models.CloudServiceConfiguration( - os_family=os_family, os_version=os_version - ), - target_dedicated_nodes=target_dedicated_nodes, - **kwargs, - ) - - else: - self.log.info('Using virtual machine configuration to create a pool') - pool = batch_models.PoolAddParameter( - id=pool_id, - vm_size=vm_size, - display_name=display_name, - virtual_machine_configuration=batch_models.VirtualMachineConfiguration( - image_reference=batch_models.ImageReference( - publisher=vm_publisher, - offer=vm_offer, - sku=vm_sku, - version=vm_version, - ), - node_agent_sku_id=vm_node_agent_sku_id, - ), - target_dedicated_nodes=target_dedicated_nodes, - **kwargs, - ) - return pool - - def create_pool(self, pool: PoolAddParameter) -> None: - """ - Creates a pool if not already existing - - :param pool: the pool object to create - :type pool: batch_models.PoolAddParameter - - """ - try: - self.log.info("Attempting to create a pool: %s", pool.id) - self.connection.pool.add(pool) - self.log.info("Created pool: %s", pool.id) - except batch_models.BatchErrorException as e: - if e.error.code != "PoolExists": - raise - else: - self.log.info("Pool %s already exists", pool.id) - - def _get_latest_verified_image_vm_and_sku( - self, - publisher: Optional[str] = None, - offer: Optional[str] = None, - sku_starts_with: Optional[str] = None, - ) -> tuple: - """ - Get latest verified image vm and sku - - :param publisher: The publisher of the Azure Virtual Machines Marketplace Image. - For example, Canonical or MicrosoftWindowsServer. - :type publisher: str - :param offer: The offer type of the Azure Virtual Machines Marketplace Image. - For example, UbuntuServer or WindowsServer. - :type offer: str - :param sku_starts_with: The start name of the sku to search - :type sku_starts_with: str - """ - options = batch_models.AccountListSupportedImagesOptions(filter="verificationType eq 'verified'") - images = self.connection.account.list_supported_images(account_list_supported_images_options=options) - # pick the latest supported sku - skus_to_use = [ - (image.node_agent_sku_id, image.image_reference) - for image in images - if image.image_reference.publisher.lower() == publisher - and image.image_reference.offer.lower() == offer - and image.image_reference.sku.startswith(sku_starts_with) - ] - - # pick first - agent_sku_id, image_ref_to_use = skus_to_use[0] - return agent_sku_id, image_ref_to_use - - def wait_for_all_node_state(self, pool_id: str, node_state: Set) -> list: - """ - Wait for all nodes in a pool to reach given states - - :param pool_id: A string that identifies the pool - :type pool_id: str - :param node_state: A set of batch_models.ComputeNodeState - :type node_state: set - """ - self.log.info('waiting for all nodes in pool %s to reach one of: %s', pool_id, node_state) - while True: - # refresh pool to ensure that there is no resize error - pool = self.connection.pool.get(pool_id) - if pool.resize_errors is not None: - resize_errors = "\n".join(repr(e) for e in pool.resize_errors) - raise RuntimeError(f'resize error encountered for pool {pool.id}:\n{resize_errors}') - nodes = list(self.connection.compute_node.list(pool.id)) - if len(nodes) >= pool.target_dedicated_nodes and all(node.state in node_state for node in nodes): - return nodes - # Allow the timeout to be controlled by the AzureBatchOperator - # specified timeout. This way we don't interrupt a startTask inside - # the pool - time.sleep(10) - - def configure_job( - self, - job_id: str, - pool_id: str, - display_name: Optional[str] = None, - **kwargs, - ) -> JobAddParameter: - """ - Configures a job for use in the pool - - :param job_id: A string that uniquely identifies the job within the account - :type job_id: str - :param pool_id: A string that identifies the pool - :type pool_id: str - :param display_name: The display name for the job - :type display_name: str - """ - job = batch_models.JobAddParameter( - id=job_id, - pool_info=batch_models.PoolInformation(pool_id=pool_id), - display_name=display_name, - **kwargs, - ) - return job - - def create_job(self, job: JobAddParameter) -> None: - """ - Creates a job in the pool - - :param job: The job object to create - :type job: batch_models.JobAddParameter - """ - try: - self.connection.job.add(job) - self.log.info("Job %s created", job.id) - except batch_models.BatchErrorException as err: - if err.error.code != "JobExists": - raise - else: - self.log.info("Job %s already exists", job.id) - - def configure_task( - self, - task_id: str, - command_line: str, - display_name: Optional[str] = None, - container_settings=None, - **kwargs, - ) -> TaskAddParameter: - """ - Creates a task - - :param task_id: A string that identifies the task to create - :type task_id: str - :param command_line: The command line of the Task. - :type command_line: str - :param display_name: A display name for the Task - :type display_name: str - :param container_settings: The settings for the container under which the Task runs. - If the Pool that will run this Task has containerConfiguration set, - this must be set as well. If the Pool that will run this Task doesn't have - containerConfiguration set, this must not be set. - :type container_settings: batch_models.TaskContainerSettings - """ - task = batch_models.TaskAddParameter( - id=task_id, - command_line=command_line, - display_name=display_name, - container_settings=container_settings, - **kwargs, - ) - self.log.info("Task created: %s", task_id) - return task - - def add_single_task_to_job(self, job_id: str, task: TaskAddParameter) -> None: - """ - Add a single task to given job if it doesn't exist - - :param job_id: A string that identifies the given job - :type job_id: str - :param task: The task to add - :type task: batch_models.TaskAddParameter - """ - try: - - self.connection.task.add(job_id=job_id, task=task) - except batch_models.BatchErrorException as err: - if err.error.code != "TaskExists": - raise - else: - self.log.info("Task %s already exists", task.id) +"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.batch`.""" - def wait_for_job_tasks_to_complete(self, job_id: str, timeout: int) -> None: - """ - Wait for tasks in a particular job to complete +import warnings - :param job_id: A string that identifies the job - :type job_id: str - :param timeout: The amount of time to wait before timing out in minutes - :type timeout: int - """ - timeout_time = timezone.utcnow() + timedelta(minutes=timeout) - while timezone.utcnow() < timeout_time: - tasks = self.connection.task.list(job_id) +from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook # noqa - incomplete_tasks = [task for task in tasks if task.state != batch_models.TaskState.completed] - if not incomplete_tasks: - return - for task in incomplete_tasks: - self.log.info("Waiting for %s to complete, currently on %s state", task.id, task.state) - time.sleep(15) - raise TimeoutError("Timed out waiting for tasks to complete") +warnings.warn( + "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.batch`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/providers/microsoft/azure/hooks/azure_container_instance.py b/airflow/providers/microsoft/azure/hooks/azure_container_instance.py index 9f4c0cc0f09c0..29ffe4e3ede91 100644 --- a/airflow/providers/microsoft/azure/hooks/azure_container_instance.py +++ b/airflow/providers/microsoft/azure/hooks/azure_container_instance.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,143 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# +"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.container_instance`.""" import warnings -from typing import Any - -from azure.mgmt.containerinstance import ContainerInstanceManagementClient -from azure.mgmt.containerinstance.models import ContainerGroup - -from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook - - -class AzureContainerInstanceHook(AzureBaseHook): - """ - A hook to communicate with Azure Container Instances. - - This hook requires a service principal in order to work. - After creating this service principal - (Azure Active Directory/App Registrations), you need to fill in the - client_id (Application ID) as login, the generated password as password, - and tenantId and subscriptionId in the extra's field as a json. - - :param conn_id: :ref:`Azure connection id` of - a service principal which will be used to start the container instance. - :type azure_conn_id: str - """ - - conn_name_attr = 'azure_conn_id' - default_conn_name = 'azure_default' - conn_type = 'azure_container_instance' - hook_name = 'Azure Container Instance' - - def __init__(self, *args, **kwargs) -> None: - super().__init__(sdk_client=ContainerInstanceManagementClient, *args, **kwargs) - self.connection = self.get_conn() - - def create_or_update(self, resource_group: str, name: str, container_group: ContainerGroup) -> None: - """ - Create a new container group - - :param resource_group: the name of the resource group - :type resource_group: str - :param name: the name of the container group - :type name: str - :param container_group: the properties of the container group - :type container_group: azure.mgmt.containerinstance.models.ContainerGroup - """ - self.connection.container_groups.create_or_update(resource_group, name, container_group) - - def get_state_exitcode_details(self, resource_group: str, name: str) -> tuple: - """ - Get the state and exitcode of a container group - - :param resource_group: the name of the resource group - :type resource_group: str - :param name: the name of the container group - :type name: str - :return: A tuple with the state, exitcode, and details. - If the exitcode is unknown 0 is returned. - :rtype: tuple(state,exitcode,details) - """ - warnings.warn( - "get_state_exitcode_details() is deprecated. Related method is get_state()", - DeprecationWarning, - stacklevel=2, - ) - cg_state = self.get_state(resource_group, name) - c_state = cg_state.containers[0].instance_view.current_state - return (c_state.state, c_state.exit_code, c_state.detail_status) - - def get_messages(self, resource_group: str, name: str) -> list: - """ - Get the messages of a container group - - :param resource_group: the name of the resource group - :type resource_group: str - :param name: the name of the container group - :type name: str - :return: A list of the event messages - :rtype: list[str] - """ - warnings.warn( - "get_messages() is deprecated. Related method is get_state()", DeprecationWarning, stacklevel=2 - ) - cg_state = self.get_state(resource_group, name) - instance_view = cg_state.containers[0].instance_view - return [event.message for event in instance_view.events] - - def get_state(self, resource_group: str, name: str) -> Any: - """ - Get the state of a container group - - :param resource_group: the name of the resource group - :type resource_group: str - :param name: the name of the container group - :type name: str - :return: ContainerGroup - :rtype: ~azure.mgmt.containerinstance.models.ContainerGroup - """ - return self.connection.container_groups.get(resource_group, name, raw=False) - - def get_logs(self, resource_group: str, name: str, tail: int = 1000) -> list: - """ - Get the tail from logs of a container group - - :param resource_group: the name of the resource group - :type resource_group: str - :param name: the name of the container group - :type name: str - :param tail: the size of the tail - :type tail: int - :return: A list of log messages - :rtype: list[str] - """ - logs = self.connection.container.list_logs(resource_group, name, name, tail=tail) - return logs.content.splitlines(True) - - def delete(self, resource_group: str, name: str) -> None: - """ - Delete a container group - - :param resource_group: the name of the resource group - :type resource_group: str - :param name: the name of the container group - :type name: str - """ - self.connection.container_groups.delete(resource_group, name) - def exists(self, resource_group: str, name: str) -> bool: - """ - Test if a container group exists +from airflow.providers.microsoft.azure.hooks.container_instance import AzureContainerInstanceHook # noqa - :param resource_group: the name of the resource group - :type resource_group: str - :param name: the name of the container group - :type name: str - """ - for container in self.connection.container_groups.list_by_resource_group(resource_group): - if container.name == name: - return True - return False +warnings.warn( + "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_instance`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/providers/microsoft/azure/hooks/azure_container_registry.py b/airflow/providers/microsoft/azure/hooks/azure_container_registry.py index f4c5d1adb40aa..50ef42b0bde54 100644 --- a/airflow/providers/microsoft/azure/hooks/azure_container_registry.py +++ b/airflow/providers/microsoft/azure/hooks/azure_container_registry.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,52 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Hook for Azure Container Registry""" - -from typing import Dict - -from azure.mgmt.containerinstance.models import ImageRegistryCredential - -from airflow.hooks.base import BaseHook - - -class AzureContainerRegistryHook(BaseHook): - """ - A hook to communicate with a Azure Container Registry. - - :param conn_id: :ref:`Azure Container Registry connection id` - of a service principal which will be used to start the container instance - - :type conn_id: str - """ - - conn_name_attr = 'azure_container_registry_conn_id' - default_conn_name = 'azure_container_registry_default' - conn_type = 'azure_container_registry' - hook_name = 'Azure Container Registry' +"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.container_registry`.""" - @staticmethod - def get_ui_field_behaviour() -> Dict: - """Returns custom field behaviour""" - return { - "hidden_fields": ['schema', 'port', 'extra'], - "relabeling": { - 'login': 'Registry Username', - 'password': 'Registry Password', - 'host': 'Registry Server', - }, - "placeholders": { - 'login': 'private registry username', - 'password': 'private registry password', - 'host': 'docker image registry server', - }, - } +import warnings - def __init__(self, conn_id: str = 'azure_registry') -> None: - super().__init__() - self.conn_id = conn_id - self.connection = self.get_conn() +from airflow.providers.microsoft.azure.hooks.container_registry import AzureContainerRegistryHook # noqa - def get_conn(self) -> ImageRegistryCredential: - conn = self.get_connection(self.conn_id) - return ImageRegistryCredential(server=conn.host, username=conn.login, password=conn.password) +warnings.warn( + "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_registry`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/providers/microsoft/azure/hooks/azure_container_volume.py b/airflow/providers/microsoft/azure/hooks/azure_container_volume.py index 8aae16b491a34..83a69e8a41cc8 100644 --- a/airflow/providers/microsoft/azure/hooks/azure_container_volume.py +++ b/airflow/providers/microsoft/azure/hooks/azure_container_volume.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,92 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict - -from azure.mgmt.containerinstance.models import AzureFileVolume, Volume - -from airflow.hooks.base import BaseHook - - -class AzureContainerVolumeHook(BaseHook): - """ - A hook which wraps an Azure Volume. - - :param azure_container_volume_conn_id: Reference to the - :ref:`Azure Container Volume connection id ` - of an Azure account of which container volumes should be used. - :type azure_container_volume_conn_id: str - """ - - conn_name_attr = "azure_container_volume_conn_id" - default_conn_name = 'azure_container_volume_default' - conn_type = 'azure_container_volume' - hook_name = 'Azure Container Volume' - - def __init__(self, azure_container_volume_conn_id: str = 'azure_container_volume_default') -> None: - super().__init__() - self.conn_id = azure_container_volume_conn_id - - @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: - """Returns connection widgets to add to connection form""" - from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget - from flask_babel import lazy_gettext - from wtforms import PasswordField - - return { - "extra__azure_container_volume__connection_string": PasswordField( - lazy_gettext('Blob Storage Connection String (optional)'), widget=BS3PasswordFieldWidget() - ), - } - - @staticmethod - def get_ui_field_behaviour() -> Dict: - """Returns custom field behaviour""" - import json - - return { - "hidden_fields": ['schema', 'port', 'host', "extra"], - "relabeling": { - 'login': 'Azure Client ID', - 'password': 'Azure Secret', - }, - "placeholders": { - 'extra': json.dumps( - { - "key_path": "path to json file for auth", - "key_json": "specifies json dict for auth", - }, - indent=1, - ), - 'login': 'client_id (token credentials auth)', - 'password': 'secret (token credentials auth)', - 'extra__azure_container_volume__connection_string': 'connection string auth', - }, - } +"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.container_volume`.""" - def get_storagekey(self) -> str: - """Get Azure File Volume storage key""" - conn = self.get_connection(self.conn_id) - service_options = conn.extra_dejson +import warnings - if 'extra__azure_container_volume__connection_string' in service_options: - for keyvalue in service_options['extra__azure_container_volume__connection_string'].split(";"): - key, value = keyvalue.split("=", 1) - if key == "AccountKey": - return value - return conn.password +from airflow.providers.microsoft.azure.hooks.container_volume import AzureContainerVolumeHook # noqa - def get_file_volume( - self, mount_name: str, share_name: str, storage_account_name: str, read_only: bool = False - ) -> Volume: - """Get Azure File Volume""" - return Volume( - name=mount_name, - azure_file=AzureFileVolume( - share_name=share_name, - storage_account_name=storage_account_name, - read_only=read_only, - storage_account_key=self.get_storagekey(), - ), - ) +warnings.warn( + "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_volume`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/providers/microsoft/azure/hooks/azure_cosmos.py b/airflow/providers/microsoft/azure/hooks/azure_cosmos.py index b75d75bfbfa06..9f1da045e4b89 100644 --- a/airflow/providers/microsoft/azure/hooks/azure_cosmos.py +++ b/airflow/providers/microsoft/azure/hooks/azure_cosmos.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,339 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -This module contains integration with Azure CosmosDB. - -AzureCosmosDBHook communicates via the Azure Cosmos library. Make sure that a -Airflow connection of type `azure_cosmos` exists. Authorization can be done by supplying a -login (=Endpoint uri), password (=secret key) and extra fields database_name and collection_name to specify -the default database and collection to use (see connection `azure_cosmos_default` for an example). -""" -import uuid -from typing import Any, Dict, Optional - -from azure.cosmos.cosmos_client import CosmosClient -from azure.cosmos.exceptions import CosmosHttpResponseError - -from airflow.exceptions import AirflowBadRequest -from airflow.hooks.base import BaseHook - - -class AzureCosmosDBHook(BaseHook): - """ - Interacts with Azure CosmosDB. - - login should be the endpoint uri, password should be the master key - optionally, you can use the following extras to default these values - {"database_name": "", "collection_name": "COLLECTION_NAME"}. - - :param azure_cosmos_conn_id: Reference to the - :ref:`Azure CosmosDB connection`. - :type azure_cosmos_conn_id: str - """ - - conn_name_attr = 'azure_cosmos_conn_id' - default_conn_name = 'azure_cosmos_default' - conn_type = 'azure_cosmos' - hook_name = 'Azure CosmosDB' - - @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: - """Returns connection widgets to add to connection form""" - from flask_appbuilder.fieldwidgets import BS3TextFieldWidget - from flask_babel import lazy_gettext - from wtforms import StringField - - return { - "extra__azure_cosmos__database_name": StringField( - lazy_gettext('Cosmos Database Name (optional)'), widget=BS3TextFieldWidget() - ), - "extra__azure_cosmos__collection_name": StringField( - lazy_gettext('Cosmos Collection Name (optional)'), widget=BS3TextFieldWidget() - ), - } - - @staticmethod - def get_ui_field_behaviour() -> Dict: - """Returns custom field behaviour""" - return { - "hidden_fields": ['schema', 'port', 'host', 'extra'], - "relabeling": { - 'login': 'Cosmos Endpoint URI', - 'password': 'Cosmos Master Key Token', - }, - "placeholders": { - 'login': 'endpoint uri', - 'password': 'master key', - 'extra__azure_cosmos__database_name': 'database name', - 'extra__azure_cosmos__collection_name': 'collection name', - }, - } - - def __init__(self, azure_cosmos_conn_id: str = default_conn_name) -> None: - super().__init__() - self.conn_id = azure_cosmos_conn_id - self._conn = None - - self.default_database_name = None - self.default_collection_name = None - - def get_conn(self) -> CosmosClient: - """Return a cosmos db client.""" - if not self._conn: - conn = self.get_connection(self.conn_id) - extras = conn.extra_dejson - endpoint_uri = conn.login - master_key = conn.password - - self.default_database_name = extras.get('database_name') or extras.get( - 'extra__azure_cosmos__database_name' - ) - self.default_collection_name = extras.get('collection_name') or extras.get( - 'extra__azure_cosmos__collection_name' - ) - - # Initialize the Python Azure Cosmos DB client - self._conn = CosmosClient(endpoint_uri, {'masterKey': master_key}) - return self._conn - - def __get_database_name(self, database_name: Optional[str] = None) -> str: - self.get_conn() - db_name = database_name - if db_name is None: - db_name = self.default_database_name - - if db_name is None: - raise AirflowBadRequest("Database name must be specified") - - return db_name - - def __get_collection_name(self, collection_name: Optional[str] = None) -> str: - self.get_conn() - coll_name = collection_name - if coll_name is None: - coll_name = self.default_collection_name - - if coll_name is None: - raise AirflowBadRequest("Collection name must be specified") - - return coll_name - - def does_collection_exist(self, collection_name: str, database_name: str) -> bool: - """Checks if a collection exists in CosmosDB.""" - if collection_name is None: - raise AirflowBadRequest("Collection name cannot be None.") - - existing_container = list( - self.get_conn().QueryContainers( - get_database_link(self.__get_database_name(database_name)), - { - "query": "SELECT * FROM r WHERE r.id=@id", - "parameters": [{"name": "@id", "value": collection_name}], - }, - ) - ) - if len(existing_container) == 0: - return False - - return True - - def create_collection(self, collection_name: str, database_name: Optional[str] = None) -> None: - """Creates a new collection in the CosmosDB database.""" - if collection_name is None: - raise AirflowBadRequest("Collection name cannot be None.") - - # We need to check to see if this container already exists so we don't try - # to create it twice - existing_container = list( - self.get_conn().QueryContainers( - get_database_link(self.__get_database_name(database_name)), - { - "query": "SELECT * FROM r WHERE r.id=@id", - "parameters": [{"name": "@id", "value": collection_name}], - }, - ) - ) - - # Only create if we did not find it already existing - if len(existing_container) == 0: - self.get_conn().CreateContainer( - get_database_link(self.__get_database_name(database_name)), {"id": collection_name} - ) - - def does_database_exist(self, database_name: str) -> bool: - """Checks if a database exists in CosmosDB.""" - if database_name is None: - raise AirflowBadRequest("Database name cannot be None.") - - existing_database = list( - self.get_conn().QueryDatabases( - { - "query": "SELECT * FROM r WHERE r.id=@id", - "parameters": [{"name": "@id", "value": database_name}], - } - ) - ) - if len(existing_database) == 0: - return False - - return True - - def create_database(self, database_name: str) -> None: - """Creates a new database in CosmosDB.""" - if database_name is None: - raise AirflowBadRequest("Database name cannot be None.") - - # We need to check to see if this database already exists so we don't try - # to create it twice - existing_database = list( - self.get_conn().QueryDatabases( - { - "query": "SELECT * FROM r WHERE r.id=@id", - "parameters": [{"name": "@id", "value": database_name}], - } - ) - ) - - # Only create if we did not find it already existing - if len(existing_database) == 0: - self.get_conn().CreateDatabase({"id": database_name}) - - def delete_database(self, database_name: str) -> None: - """Deletes an existing database in CosmosDB.""" - if database_name is None: - raise AirflowBadRequest("Database name cannot be None.") - - self.get_conn().DeleteDatabase(get_database_link(database_name)) - - def delete_collection(self, collection_name: str, database_name: Optional[str] = None) -> None: - """Deletes an existing collection in the CosmosDB database.""" - if collection_name is None: - raise AirflowBadRequest("Collection name cannot be None.") - - self.get_conn().DeleteContainer( - get_collection_link(self.__get_database_name(database_name), collection_name) - ) - - def upsert_document(self, document, database_name=None, collection_name=None, document_id=None): - """ - Inserts a new document (or updates an existing one) into an existing - collection in the CosmosDB database. - """ - # Assign unique ID if one isn't provided - if document_id is None: - document_id = str(uuid.uuid4()) - - if document is None: - raise AirflowBadRequest("You cannot insert a None document") - - # Add document id if isn't found - if 'id' in document: - if document['id'] is None: - document['id'] = document_id - else: - document['id'] = document_id - - created_document = self.get_conn().CreateItem( - get_collection_link( - self.__get_database_name(database_name), self.__get_collection_name(collection_name) - ), - document, - ) - - return created_document - - def insert_documents( - self, documents, database_name: Optional[str] = None, collection_name: Optional[str] = None - ) -> list: - """Insert a list of new documents into an existing collection in the CosmosDB database.""" - if documents is None: - raise AirflowBadRequest("You cannot insert empty documents") - - created_documents = [] - for single_document in documents: - created_documents.append( - self.get_conn().CreateItem( - get_collection_link( - self.__get_database_name(database_name), self.__get_collection_name(collection_name) - ), - single_document, - ) - ) - - return created_documents - - def delete_document( - self, document_id: str, database_name: Optional[str] = None, collection_name: Optional[str] = None - ) -> None: - """Delete an existing document out of a collection in the CosmosDB database.""" - if document_id is None: - raise AirflowBadRequest("Cannot delete a document without an id") - - self.get_conn().DeleteItem( - get_document_link( - self.__get_database_name(database_name), - self.__get_collection_name(collection_name), - document_id, - ) - ) - - def get_document( - self, document_id: str, database_name: Optional[str] = None, collection_name: Optional[str] = None - ): - """Get a document from an existing collection in the CosmosDB database.""" - if document_id is None: - raise AirflowBadRequest("Cannot get a document without an id") - - try: - return self.get_conn().ReadItem( - get_document_link( - self.__get_database_name(database_name), - self.__get_collection_name(collection_name), - document_id, - ) - ) - except CosmosHttpResponseError: - return None - - def get_documents( - self, - sql_string: str, - database_name: Optional[str] = None, - collection_name: Optional[str] = None, - partition_key: Optional[str] = None, - ) -> Optional[list]: - """Get a list of documents from an existing collection in the CosmosDB database via SQL query.""" - if sql_string is None: - raise AirflowBadRequest("SQL query string cannot be None") - - # Query them in SQL - query = {'query': sql_string} - - try: - result_iterable = self.get_conn().QueryItems( - get_collection_link( - self.__get_database_name(database_name), self.__get_collection_name(collection_name) - ), - query, - partition_key, - ) - - return list(result_iterable) - except CosmosHttpResponseError: - return None - - -def get_database_link(database_id: str) -> str: - """Get Azure CosmosDB database link""" - return "dbs/" + database_id - - -def get_collection_link(database_id: str, collection_id: str) -> str: - """Get Azure CosmosDB collection link""" - return get_database_link(database_id) + "/colls/" + collection_id - - -def get_document_link(database_id: str, collection_id: str, document_id: str) -> str: - """Get Azure CosmosDB document link""" - return get_collection_link(database_id, collection_id) + "/docs/" + document_id +"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.cosmos`.""" + +import warnings + +from airflow.providers.microsoft.azure.hooks.cosmos import ( # noqa + AzureCosmosDBHook, + get_collection_link, + get_database_link, + get_document_link, +) + +warnings.warn( + "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.cosmos`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/providers/microsoft/azure/hooks/azure_data_lake.py b/airflow/providers/microsoft/azure/hooks/azure_data_lake.py index 8cbd001686d1e..aae7eec8db141 100644 --- a/airflow/providers/microsoft/azure/hooks/azure_data_lake.py +++ b/airflow/providers/microsoft/azure/hooks/azure_data_lake.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,231 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -""" -This module contains integration with Azure Data Lake. - -AzureDataLakeHook communicates via a REST API compatible with WebHDFS. Make sure that a -Airflow connection of type `azure_data_lake` exists. Authorization can be done by supplying a -login (=Client ID), password (=Client Secret) and extra fields tenant (Tenant) and account_name (Account Name) -(see connection `azure_data_lake_default` for an example). -""" -from typing import Any, Dict, Optional - -from azure.datalake.store import core, lib, multithread - -from airflow.exceptions import AirflowException -from airflow.hooks.base import BaseHook - - -class AzureDataLakeHook(BaseHook): - """ - Interacts with Azure Data Lake. - - Client ID and client secret should be in user and password parameters. - Tenant and account name should be extra field as - {"tenant": "", "account_name": "ACCOUNT_NAME"}. - - :param azure_data_lake_conn_id: Reference to the :ref:`Azure Data Lake connection`. - :type azure_data_lake_conn_id: str - """ - - conn_name_attr = 'azure_data_lake_conn_id' - default_conn_name = 'azure_data_lake_default' - conn_type = 'azure_data_lake' - hook_name = 'Azure Data Lake' - - @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: - """Returns connection widgets to add to connection form""" - from flask_appbuilder.fieldwidgets import BS3TextFieldWidget - from flask_babel import lazy_gettext - from wtforms import StringField - - return { - "extra__azure_data_lake__tenant": StringField( - lazy_gettext('Azure Tenant ID'), widget=BS3TextFieldWidget() - ), - "extra__azure_data_lake__account_name": StringField( - lazy_gettext('Azure DataLake Store Name'), widget=BS3TextFieldWidget() - ), - } - - @staticmethod - def get_ui_field_behaviour() -> Dict: - """Returns custom field behaviour""" - return { - "hidden_fields": ['schema', 'port', 'host', 'extra'], - "relabeling": { - 'login': 'Azure Client ID', - 'password': 'Azure Client Secret', - }, - "placeholders": { - 'login': 'client id', - 'password': 'secret', - 'extra__azure_data_lake__tenant': 'tenant id', - 'extra__azure_data_lake__account_name': 'datalake store', - }, - } - - def __init__(self, azure_data_lake_conn_id: str = default_conn_name) -> None: - super().__init__() - self.conn_id = azure_data_lake_conn_id - self._conn: Optional[core.AzureDLFileSystem] = None - self.account_name: Optional[str] = None - - def get_conn(self) -> core.AzureDLFileSystem: - """Return a AzureDLFileSystem object.""" - if not self._conn: - conn = self.get_connection(self.conn_id) - service_options = conn.extra_dejson - self.account_name = service_options.get('account_name') or service_options.get( - 'extra__azure_data_lake__account_name' - ) - tenant = service_options.get('tenant') or service_options.get('extra__azure_data_lake__tenant') - - adl_creds = lib.auth(tenant_id=tenant, client_secret=conn.password, client_id=conn.login) - self._conn = core.AzureDLFileSystem(adl_creds, store_name=self.account_name) - self._conn.connect() - return self._conn - - def check_for_file(self, file_path: str) -> bool: - """ - Check if a file exists on Azure Data Lake. - - :param file_path: Path and name of the file. - :type file_path: str - :return: True if the file exists, False otherwise. - :rtype: bool - """ - try: - files = self.get_conn().glob(file_path, details=False, invalidate_cache=True) - return len(files) == 1 - except FileNotFoundError: - return False - - def upload_file( - self, - local_path: str, - remote_path: str, - nthreads: int = 64, - overwrite: bool = True, - buffersize: int = 4194304, - blocksize: int = 4194304, - **kwargs, - ) -> None: - """ - Upload a file to Azure Data Lake. - - :param local_path: local path. Can be single file, directory (in which case, - upload recursively) or glob pattern. Recursive glob patterns using `**` - are not supported. - :type local_path: str - :param remote_path: Remote path to upload to; if multiple files, this is the - directory root to write within. - :type remote_path: str - :param nthreads: Number of threads to use. If None, uses the number of cores. - :type nthreads: int - :param overwrite: Whether to forcibly overwrite existing files/directories. - If False and remote path is a directory, will quit regardless if any files - would be overwritten or not. If True, only matching filenames are actually - overwritten. - :type overwrite: bool - :param buffersize: int [2**22] - Number of bytes for internal buffer. This block cannot be bigger than - a chunk and cannot be smaller than a block. - :type buffersize: int - :param blocksize: int [2**22] - Number of bytes for a block. Within each chunk, we write a smaller - block for each API call. This block cannot be bigger than a chunk. - :type blocksize: int - """ - multithread.ADLUploader( - self.get_conn(), - lpath=local_path, - rpath=remote_path, - nthreads=nthreads, - overwrite=overwrite, - buffersize=buffersize, - blocksize=blocksize, - **kwargs, - ) - - def download_file( - self, - local_path: str, - remote_path: str, - nthreads: int = 64, - overwrite: bool = True, - buffersize: int = 4194304, - blocksize: int = 4194304, - **kwargs, - ) -> None: - """ - Download a file from Azure Blob Storage. - - :param local_path: local path. If downloading a single file, will write to this - specific file, unless it is an existing directory, in which case a file is - created within it. If downloading multiple files, this is the root - directory to write within. Will create directories as required. - :type local_path: str - :param remote_path: remote path/globstring to use to find remote files. - Recursive glob patterns using `**` are not supported. - :type remote_path: str - :param nthreads: Number of threads to use. If None, uses the number of cores. - :type nthreads: int - :param overwrite: Whether to forcibly overwrite existing files/directories. - If False and remote path is a directory, will quit regardless if any files - would be overwritten or not. If True, only matching filenames are actually - overwritten. - :type overwrite: bool - :param buffersize: int [2**22] - Number of bytes for internal buffer. This block cannot be bigger than - a chunk and cannot be smaller than a block. - :type buffersize: int - :param blocksize: int [2**22] - Number of bytes for a block. Within each chunk, we write a smaller - block for each API call. This block cannot be bigger than a chunk. - :type blocksize: int - """ - multithread.ADLDownloader( - self.get_conn(), - lpath=local_path, - rpath=remote_path, - nthreads=nthreads, - overwrite=overwrite, - buffersize=buffersize, - blocksize=blocksize, - **kwargs, - ) - - def list(self, path: str) -> list: - """ - List files in Azure Data Lake Storage +"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.data_lake`.""" - :param path: full path/globstring to use to list files in ADLS - :type path: str - """ - if "*" in path: - return self.get_conn().glob(path) - else: - return self.get_conn().walk(path) +import warnings - def remove(self, path: str, recursive: bool = False, ignore_not_found: bool = True) -> None: - """ - Remove files in Azure Data Lake Storage +from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook # noqa - :param path: A directory or file to remove in ADLS - :type path: str - :param recursive: Whether to loop into directories in the location and remove the files - :type recursive: bool - :param ignore_not_found: Whether to raise error if file to delete is not found - :type ignore_not_found: bool - """ - try: - self.get_conn().remove(path=path, recursive=recursive) - except FileNotFoundError: - if ignore_not_found: - self.log.info("File %s not found", path) - else: - raise AirflowException(f"File {path} not found") +warnings.warn( + "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.data_lake`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/providers/microsoft/azure/hooks/azure_fileshare.py b/airflow/providers/microsoft/azure/hooks/azure_fileshare.py index acfa881ddb266..ec5da4b3d117f 100644 --- a/airflow/providers/microsoft/azure/hooks/azure_fileshare.py +++ b/airflow/providers/microsoft/azure/hooks/azure_fileshare.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,322 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -import warnings -from typing import Any, Dict, List, Optional - -from azure.storage.file import File, FileService - -from airflow.hooks.base import BaseHook - - -class AzureFileShareHook(BaseHook): - """ - Interacts with Azure FileShare Storage. - - :param azure_fileshare_conn_id: Reference to the - :ref:`Azure Container Volume connection id` - of an Azure account of which container volumes should be used. - - """ - - conn_name_attr = "azure_fileshare_conn_id" - default_conn_name = 'azure_fileshare_default' - conn_type = 'azure_fileshare' - hook_name = 'Azure FileShare' - - def __init__(self, azure_fileshare_conn_id: str = 'azure_fileshare_default') -> None: - super().__init__() - self.conn_id = azure_fileshare_conn_id - self._conn = None - - @staticmethod - def get_connection_form_widgets() -> Dict[str, Any]: - """Returns connection widgets to add to connection form""" - from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget - from flask_babel import lazy_gettext - from wtforms import PasswordField, StringField - - return { - "extra__azure_fileshare__sas_token": PasswordField( - lazy_gettext('SAS Token (optional)'), widget=BS3PasswordFieldWidget() - ), - "extra__azure_fileshare__connection_string": StringField( - lazy_gettext('Connection String (optional)'), widget=BS3TextFieldWidget() - ), - "extra__azure_fileshare__protocol": StringField( - lazy_gettext('Account URL or token (optional)'), widget=BS3TextFieldWidget() - ), - } - - @staticmethod - def get_ui_field_behaviour() -> Dict: - """Returns custom field behaviour""" - return { - "hidden_fields": ['schema', 'port', 'host', 'extra'], - "relabeling": { - 'login': 'Blob Storage Login (optional)', - 'password': 'Blob Storage Key (optional)', - 'host': 'Account Name (Active Directory Auth)', - }, - "placeholders": { - 'login': 'account name', - 'password': 'secret', - 'host': 'account url', - 'extra__azure_fileshare__sas_token': 'account url or token (optional)', - 'extra__azure_fileshare__connection_string': 'account url or token (optional)', - 'extra__azure_fileshare__protocol': 'account url or token (optional)', - }, - } - - def get_conn(self) -> FileService: - """Return the FileService object.""" - prefix = "extra__azure_fileshare__" - if self._conn: - return self._conn - conn = self.get_connection(self.conn_id) - service_options_with_prefix = conn.extra_dejson - service_options = {} - for key, value in service_options_with_prefix.items(): - # in case dedicated FileShareHook is used, the connection will use the extras from UI. - # in case deprecated wasb hook is used, the old extras will work as well - if key.startswith(prefix): - if value != '': - service_options[key[len(prefix) :]] = value - else: - # warn if the deprecated wasb_connection is used - warnings.warn( - "You are using deprecated connection for AzureFileShareHook." - " Please change it to `Azure FileShare`.", - DeprecationWarning, - ) - else: - service_options[key] = value - # warn if the old non-prefixed value is used - warnings.warn( - "You are using deprecated connection for AzureFileShareHook." - " Please change it to `Azure FileShare`.", - DeprecationWarning, - ) - self._conn = FileService(account_name=conn.login, account_key=conn.password, **service_options) - return self._conn - - def check_for_directory(self, share_name: str, directory_name: str, **kwargs) -> bool: - """ - Check if a directory exists on Azure File Share. - - :param share_name: Name of the share. - :type share_name: str - :param directory_name: Name of the directory. - :type directory_name: str - :param kwargs: Optional keyword arguments that - `FileService.exists()` takes. - :type kwargs: object - :return: True if the file exists, False otherwise. - :rtype: bool - """ - return self.get_conn().exists(share_name, directory_name, **kwargs) - - def check_for_file(self, share_name: str, directory_name: str, file_name: str, **kwargs) -> bool: - """ - Check if a file exists on Azure File Share. - - :param share_name: Name of the share. - :type share_name: str - :param directory_name: Name of the directory. - :type directory_name: str - :param file_name: Name of the file. - :type file_name: str - :param kwargs: Optional keyword arguments that - `FileService.exists()` takes. - :type kwargs: object - :return: True if the file exists, False otherwise. - :rtype: bool - """ - return self.get_conn().exists(share_name, directory_name, file_name, **kwargs) +"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.fileshare`.""" - def list_directories_and_files( - self, share_name: str, directory_name: Optional[str] = None, **kwargs - ) -> list: - """ - Return the list of directories and files stored on a Azure File Share. - - :param share_name: Name of the share. - :type share_name: str - :param directory_name: Name of the directory. - :type directory_name: str - :param kwargs: Optional keyword arguments that - `FileService.list_directories_and_files()` takes. - :type kwargs: object - :return: A list of files and directories - :rtype: list - """ - return self.get_conn().list_directories_and_files(share_name, directory_name, **kwargs) - - def list_files(self, share_name: str, directory_name: Optional[str] = None, **kwargs) -> List[str]: - """ - Return the list of files stored on a Azure File Share. - - :param share_name: Name of the share. - :type share_name: str - :param directory_name: Name of the directory. - :type directory_name: str - :param kwargs: Optional keyword arguments that - `FileService.list_directories_and_files()` takes. - :type kwargs: object - :return: A list of files - :rtype: list - """ - return [ - obj.name - for obj in self.list_directories_and_files(share_name, directory_name, **kwargs) - if isinstance(obj, File) - ] - - def create_share(self, share_name: str, **kwargs) -> bool: - """ - Create new Azure File Share. - - :param share_name: Name of the share. - :type share_name: str - :param kwargs: Optional keyword arguments that - `FileService.create_share()` takes. - :type kwargs: object - :return: True if share is created, False if share already exists. - :rtype: bool - """ - return self.get_conn().create_share(share_name, **kwargs) - - def delete_share(self, share_name: str, **kwargs) -> bool: - """ - Delete existing Azure File Share. - - :param share_name: Name of the share. - :type share_name: str - :param kwargs: Optional keyword arguments that - `FileService.delete_share()` takes. - :type kwargs: object - :return: True if share is deleted, False if share does not exist. - :rtype: bool - """ - return self.get_conn().delete_share(share_name, **kwargs) - - def create_directory(self, share_name: str, directory_name: str, **kwargs) -> list: - """ - Create a new directory on a Azure File Share. - - :param share_name: Name of the share. - :type share_name: str - :param directory_name: Name of the directory. - :type directory_name: str - :param kwargs: Optional keyword arguments that - `FileService.create_directory()` takes. - :type kwargs: object - :return: A list of files and directories - :rtype: list - """ - return self.get_conn().create_directory(share_name, directory_name, **kwargs) - - def get_file( - self, file_path: str, share_name: str, directory_name: str, file_name: str, **kwargs - ) -> None: - """ - Download a file from Azure File Share. - - :param file_path: Where to store the file. - :type file_path: str - :param share_name: Name of the share. - :type share_name: str - :param directory_name: Name of the directory. - :type directory_name: str - :param file_name: Name of the file. - :type file_name: str - :param kwargs: Optional keyword arguments that - `FileService.get_file_to_path()` takes. - :type kwargs: object - """ - self.get_conn().get_file_to_path(share_name, directory_name, file_name, file_path, **kwargs) - - def get_file_to_stream( - self, stream: str, share_name: str, directory_name: str, file_name: str, **kwargs - ) -> None: - """ - Download a file from Azure File Share. - - :param stream: A filehandle to store the file to. - :type stream: file-like object - :param share_name: Name of the share. - :type share_name: str - :param directory_name: Name of the directory. - :type directory_name: str - :param file_name: Name of the file. - :type file_name: str - :param kwargs: Optional keyword arguments that - `FileService.get_file_to_stream()` takes. - :type kwargs: object - """ - self.get_conn().get_file_to_stream(share_name, directory_name, file_name, stream, **kwargs) - - def load_file( - self, file_path: str, share_name: str, directory_name: str, file_name: str, **kwargs - ) -> None: - """ - Upload a file to Azure File Share. - - :param file_path: Path to the file to load. - :type file_path: str - :param share_name: Name of the share. - :type share_name: str - :param directory_name: Name of the directory. - :type directory_name: str - :param file_name: Name of the file. - :type file_name: str - :param kwargs: Optional keyword arguments that - `FileService.create_file_from_path()` takes. - :type kwargs: object - """ - self.get_conn().create_file_from_path(share_name, directory_name, file_name, file_path, **kwargs) - - def load_string( - self, string_data: str, share_name: str, directory_name: str, file_name: str, **kwargs - ) -> None: - """ - Upload a string to Azure File Share. - - :param string_data: String to load. - :type string_data: str - :param share_name: Name of the share. - :type share_name: str - :param directory_name: Name of the directory. - :type directory_name: str - :param file_name: Name of the file. - :type file_name: str - :param kwargs: Optional keyword arguments that - `FileService.create_file_from_text()` takes. - :type kwargs: object - """ - self.get_conn().create_file_from_text(share_name, directory_name, file_name, string_data, **kwargs) +import warnings - def load_stream( - self, stream: str, share_name: str, directory_name: str, file_name: str, count: str, **kwargs - ) -> None: - """ - Upload a stream to Azure File Share. +from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook # noqa - :param stream: Opened file/stream to upload as the file content. - :type stream: file-like - :param share_name: Name of the share. - :type share_name: str - :param directory_name: Name of the directory. - :type directory_name: str - :param file_name: Name of the file. - :type file_name: str - :param count: Size of the stream in bytes - :type count: int - :param kwargs: Optional keyword arguments that - `FileService.create_file_from_stream()` takes. - :type kwargs: object - """ - self.get_conn().create_file_from_stream( - share_name, directory_name, file_name, stream, count, **kwargs - ) +warnings.warn( + "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.fileshare`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/providers/microsoft/azure/hooks/batch.py b/airflow/providers/microsoft/azure/hooks/batch.py new file mode 100644 index 0000000000000..d60ab0579bb61 --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/batch.py @@ -0,0 +1,395 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import time +from datetime import timedelta +from typing import Any, Dict, Optional, Set + +from azure.batch import BatchServiceClient, batch_auth, models as batch_models +from azure.batch.models import JobAddParameter, PoolAddParameter, TaskAddParameter + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.models import Connection +from airflow.utils import timezone + + +class AzureBatchHook(BaseHook): + """ + Hook for Azure Batch APIs + + :param azure_batch_conn_id: :ref:`Azure Batch connection id` + of a service principal which will be used to start the container instance. + :type azure_batch_conn_id: str + """ + + conn_name_attr = 'azure_batch_conn_id' + default_conn_name = 'azure_batch_default' + conn_type = 'azure_batch' + hook_name = 'Azure Batch Service' + + @staticmethod + def get_connection_form_widgets() -> Dict[str, Any]: + """Returns connection widgets to add to connection form""" + from flask_appbuilder.fieldwidgets import BS3TextFieldWidget + from flask_babel import lazy_gettext + from wtforms import StringField + + return { + "extra__azure_batch__account_url": StringField( + lazy_gettext('Batch Account URL'), widget=BS3TextFieldWidget() + ), + } + + @staticmethod + def get_ui_field_behaviour() -> Dict: + """Returns custom field behaviour""" + return { + "hidden_fields": ['schema', 'port', 'host', 'extra'], + "relabeling": { + 'login': 'Batch Account Name', + 'password': 'Batch Account Access Key', + }, + } + + def __init__(self, azure_batch_conn_id: str = default_conn_name) -> None: + super().__init__() + self.conn_id = azure_batch_conn_id + self.connection = self.get_conn() + + def _connection(self) -> Connection: + """Get connected to Azure Batch service""" + conn = self.get_connection(self.conn_id) + return conn + + def get_conn(self): + """ + Get the Batch client connection + + :return: Azure Batch client + """ + conn = self._connection() + + batch_account_url = conn.extra_dejson.get('extra__azure_batch__account_url') + if not batch_account_url: + raise AirflowException('Batch Account URL parameter is missing.') + + credentials = batch_auth.SharedKeyCredentials(conn.login, conn.password) + batch_client = BatchServiceClient(credentials, batch_url=batch_account_url) + return batch_client + + def configure_pool( + self, + pool_id: str, + vm_size: Optional[str] = None, + vm_publisher: Optional[str] = None, + vm_offer: Optional[str] = None, + sku_starts_with: Optional[str] = None, + vm_sku: Optional[str] = None, + vm_version: Optional[str] = None, + vm_node_agent_sku_id: Optional[str] = None, + os_family: Optional[str] = None, + os_version: Optional[str] = None, + display_name: Optional[str] = None, + target_dedicated_nodes: Optional[int] = None, + use_latest_image_and_sku: bool = False, + **kwargs, + ) -> PoolAddParameter: + """ + Configures a pool + + :param pool_id: A string that uniquely identifies the Pool within the Account + :type pool_id: str + + :param vm_size: The size of virtual machines in the Pool. + :type vm_size: str + + :param display_name: The display name for the Pool + :type display_name: str + + :param target_dedicated_nodes: The desired number of dedicated Compute Nodes in the Pool. + :type target_dedicated_nodes: Optional[int] + + :param use_latest_image_and_sku: Whether to use the latest verified vm image and sku + :type use_latest_image_and_sku: bool + + :param vm_publisher: The publisher of the Azure Virtual Machines Marketplace Image. + For example, Canonical or MicrosoftWindowsServer. + :type vm_publisher: Optional[str] + + :param vm_offer: The offer type of the Azure Virtual Machines Marketplace Image. + For example, UbuntuServer or WindowsServer. + :type vm_offer: Optional[str] + + :param sku_starts_with: The start name of the sku to search + :type sku_starts_with: Optional[str] + + :param vm_sku: The name of the virtual machine sku to use + :type vm_sku: Optional[str] + + :param vm_version: The version of the virtual machine + :param vm_version: str + + :param vm_node_agent_sku_id: The node agent sku id of the virtual machine + :type vm_node_agent_sku_id: Optional[str] + + :param os_family: The Azure Guest OS family to be installed on the virtual machines in the Pool. + :type os_family: Optional[str] + + :param os_version: The OS family version + :type os_version: Optional[str] + + """ + if use_latest_image_and_sku: + self.log.info('Using latest verified virtual machine image with node agent sku') + sku_to_use, image_ref_to_use = self._get_latest_verified_image_vm_and_sku( + publisher=vm_publisher, offer=vm_offer, sku_starts_with=sku_starts_with + ) + pool = batch_models.PoolAddParameter( + id=pool_id, + vm_size=vm_size, + display_name=display_name, + virtual_machine_configuration=batch_models.VirtualMachineConfiguration( + image_reference=image_ref_to_use, node_agent_sku_id=sku_to_use + ), + target_dedicated_nodes=target_dedicated_nodes, + **kwargs, + ) + + elif os_family: + self.log.info( + 'Using cloud service configuration to create pool, virtual machine configuration ignored' + ) + pool = batch_models.PoolAddParameter( + id=pool_id, + vm_size=vm_size, + display_name=display_name, + cloud_service_configuration=batch_models.CloudServiceConfiguration( + os_family=os_family, os_version=os_version + ), + target_dedicated_nodes=target_dedicated_nodes, + **kwargs, + ) + + else: + self.log.info('Using virtual machine configuration to create a pool') + pool = batch_models.PoolAddParameter( + id=pool_id, + vm_size=vm_size, + display_name=display_name, + virtual_machine_configuration=batch_models.VirtualMachineConfiguration( + image_reference=batch_models.ImageReference( + publisher=vm_publisher, + offer=vm_offer, + sku=vm_sku, + version=vm_version, + ), + node_agent_sku_id=vm_node_agent_sku_id, + ), + target_dedicated_nodes=target_dedicated_nodes, + **kwargs, + ) + return pool + + def create_pool(self, pool: PoolAddParameter) -> None: + """ + Creates a pool if not already existing + + :param pool: the pool object to create + :type pool: batch_models.PoolAddParameter + + """ + try: + self.log.info("Attempting to create a pool: %s", pool.id) + self.connection.pool.add(pool) + self.log.info("Created pool: %s", pool.id) + except batch_models.BatchErrorException as e: + if e.error.code != "PoolExists": + raise + else: + self.log.info("Pool %s already exists", pool.id) + + def _get_latest_verified_image_vm_and_sku( + self, + publisher: Optional[str] = None, + offer: Optional[str] = None, + sku_starts_with: Optional[str] = None, + ) -> tuple: + """ + Get latest verified image vm and sku + + :param publisher: The publisher of the Azure Virtual Machines Marketplace Image. + For example, Canonical or MicrosoftWindowsServer. + :type publisher: str + :param offer: The offer type of the Azure Virtual Machines Marketplace Image. + For example, UbuntuServer or WindowsServer. + :type offer: str + :param sku_starts_with: The start name of the sku to search + :type sku_starts_with: str + """ + options = batch_models.AccountListSupportedImagesOptions(filter="verificationType eq 'verified'") + images = self.connection.account.list_supported_images(account_list_supported_images_options=options) + # pick the latest supported sku + skus_to_use = [ + (image.node_agent_sku_id, image.image_reference) + for image in images + if image.image_reference.publisher.lower() == publisher + and image.image_reference.offer.lower() == offer + and image.image_reference.sku.startswith(sku_starts_with) + ] + + # pick first + agent_sku_id, image_ref_to_use = skus_to_use[0] + return agent_sku_id, image_ref_to_use + + def wait_for_all_node_state(self, pool_id: str, node_state: Set) -> list: + """ + Wait for all nodes in a pool to reach given states + + :param pool_id: A string that identifies the pool + :type pool_id: str + :param node_state: A set of batch_models.ComputeNodeState + :type node_state: set + """ + self.log.info('waiting for all nodes in pool %s to reach one of: %s', pool_id, node_state) + while True: + # refresh pool to ensure that there is no resize error + pool = self.connection.pool.get(pool_id) + if pool.resize_errors is not None: + resize_errors = "\n".join(repr(e) for e in pool.resize_errors) + raise RuntimeError(f'resize error encountered for pool {pool.id}:\n{resize_errors}') + nodes = list(self.connection.compute_node.list(pool.id)) + if len(nodes) >= pool.target_dedicated_nodes and all(node.state in node_state for node in nodes): + return nodes + # Allow the timeout to be controlled by the AzureBatchOperator + # specified timeout. This way we don't interrupt a startTask inside + # the pool + time.sleep(10) + + def configure_job( + self, + job_id: str, + pool_id: str, + display_name: Optional[str] = None, + **kwargs, + ) -> JobAddParameter: + """ + Configures a job for use in the pool + + :param job_id: A string that uniquely identifies the job within the account + :type job_id: str + :param pool_id: A string that identifies the pool + :type pool_id: str + :param display_name: The display name for the job + :type display_name: str + """ + job = batch_models.JobAddParameter( + id=job_id, + pool_info=batch_models.PoolInformation(pool_id=pool_id), + display_name=display_name, + **kwargs, + ) + return job + + def create_job(self, job: JobAddParameter) -> None: + """ + Creates a job in the pool + + :param job: The job object to create + :type job: batch_models.JobAddParameter + """ + try: + self.connection.job.add(job) + self.log.info("Job %s created", job.id) + except batch_models.BatchErrorException as err: + if err.error.code != "JobExists": + raise + else: + self.log.info("Job %s already exists", job.id) + + def configure_task( + self, + task_id: str, + command_line: str, + display_name: Optional[str] = None, + container_settings=None, + **kwargs, + ) -> TaskAddParameter: + """ + Creates a task + + :param task_id: A string that identifies the task to create + :type task_id: str + :param command_line: The command line of the Task. + :type command_line: str + :param display_name: A display name for the Task + :type display_name: str + :param container_settings: The settings for the container under which the Task runs. + If the Pool that will run this Task has containerConfiguration set, + this must be set as well. If the Pool that will run this Task doesn't have + containerConfiguration set, this must not be set. + :type container_settings: batch_models.TaskContainerSettings + """ + task = batch_models.TaskAddParameter( + id=task_id, + command_line=command_line, + display_name=display_name, + container_settings=container_settings, + **kwargs, + ) + self.log.info("Task created: %s", task_id) + return task + + def add_single_task_to_job(self, job_id: str, task: TaskAddParameter) -> None: + """ + Add a single task to given job if it doesn't exist + + :param job_id: A string that identifies the given job + :type job_id: str + :param task: The task to add + :type task: batch_models.TaskAddParameter + """ + try: + + self.connection.task.add(job_id=job_id, task=task) + except batch_models.BatchErrorException as err: + if err.error.code != "TaskExists": + raise + else: + self.log.info("Task %s already exists", task.id) + + def wait_for_job_tasks_to_complete(self, job_id: str, timeout: int) -> None: + """ + Wait for tasks in a particular job to complete + + :param job_id: A string that identifies the job + :type job_id: str + :param timeout: The amount of time to wait before timing out in minutes + :type timeout: int + """ + timeout_time = timezone.utcnow() + timedelta(minutes=timeout) + while timezone.utcnow() < timeout_time: + tasks = self.connection.task.list(job_id) + + incomplete_tasks = [task for task in tasks if task.state != batch_models.TaskState.completed] + if not incomplete_tasks: + return + for task in incomplete_tasks: + self.log.info("Waiting for %s to complete, currently on %s state", task.id, task.state) + time.sleep(15) + raise TimeoutError("Timed out waiting for tasks to complete") diff --git a/airflow/providers/microsoft/azure/hooks/container_instance.py b/airflow/providers/microsoft/azure/hooks/container_instance.py new file mode 100644 index 0000000000000..9f4c0cc0f09c0 --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/container_instance.py @@ -0,0 +1,157 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import warnings +from typing import Any + +from azure.mgmt.containerinstance import ContainerInstanceManagementClient +from azure.mgmt.containerinstance.models import ContainerGroup + +from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook + + +class AzureContainerInstanceHook(AzureBaseHook): + """ + A hook to communicate with Azure Container Instances. + + This hook requires a service principal in order to work. + After creating this service principal + (Azure Active Directory/App Registrations), you need to fill in the + client_id (Application ID) as login, the generated password as password, + and tenantId and subscriptionId in the extra's field as a json. + + :param conn_id: :ref:`Azure connection id` of + a service principal which will be used to start the container instance. + :type azure_conn_id: str + """ + + conn_name_attr = 'azure_conn_id' + default_conn_name = 'azure_default' + conn_type = 'azure_container_instance' + hook_name = 'Azure Container Instance' + + def __init__(self, *args, **kwargs) -> None: + super().__init__(sdk_client=ContainerInstanceManagementClient, *args, **kwargs) + self.connection = self.get_conn() + + def create_or_update(self, resource_group: str, name: str, container_group: ContainerGroup) -> None: + """ + Create a new container group + + :param resource_group: the name of the resource group + :type resource_group: str + :param name: the name of the container group + :type name: str + :param container_group: the properties of the container group + :type container_group: azure.mgmt.containerinstance.models.ContainerGroup + """ + self.connection.container_groups.create_or_update(resource_group, name, container_group) + + def get_state_exitcode_details(self, resource_group: str, name: str) -> tuple: + """ + Get the state and exitcode of a container group + + :param resource_group: the name of the resource group + :type resource_group: str + :param name: the name of the container group + :type name: str + :return: A tuple with the state, exitcode, and details. + If the exitcode is unknown 0 is returned. + :rtype: tuple(state,exitcode,details) + """ + warnings.warn( + "get_state_exitcode_details() is deprecated. Related method is get_state()", + DeprecationWarning, + stacklevel=2, + ) + cg_state = self.get_state(resource_group, name) + c_state = cg_state.containers[0].instance_view.current_state + return (c_state.state, c_state.exit_code, c_state.detail_status) + + def get_messages(self, resource_group: str, name: str) -> list: + """ + Get the messages of a container group + + :param resource_group: the name of the resource group + :type resource_group: str + :param name: the name of the container group + :type name: str + :return: A list of the event messages + :rtype: list[str] + """ + warnings.warn( + "get_messages() is deprecated. Related method is get_state()", DeprecationWarning, stacklevel=2 + ) + cg_state = self.get_state(resource_group, name) + instance_view = cg_state.containers[0].instance_view + return [event.message for event in instance_view.events] + + def get_state(self, resource_group: str, name: str) -> Any: + """ + Get the state of a container group + + :param resource_group: the name of the resource group + :type resource_group: str + :param name: the name of the container group + :type name: str + :return: ContainerGroup + :rtype: ~azure.mgmt.containerinstance.models.ContainerGroup + """ + return self.connection.container_groups.get(resource_group, name, raw=False) + + def get_logs(self, resource_group: str, name: str, tail: int = 1000) -> list: + """ + Get the tail from logs of a container group + + :param resource_group: the name of the resource group + :type resource_group: str + :param name: the name of the container group + :type name: str + :param tail: the size of the tail + :type tail: int + :return: A list of log messages + :rtype: list[str] + """ + logs = self.connection.container.list_logs(resource_group, name, name, tail=tail) + return logs.content.splitlines(True) + + def delete(self, resource_group: str, name: str) -> None: + """ + Delete a container group + + :param resource_group: the name of the resource group + :type resource_group: str + :param name: the name of the container group + :type name: str + """ + self.connection.container_groups.delete(resource_group, name) + + def exists(self, resource_group: str, name: str) -> bool: + """ + Test if a container group exists + + :param resource_group: the name of the resource group + :type resource_group: str + :param name: the name of the container group + :type name: str + """ + for container in self.connection.container_groups.list_by_resource_group(resource_group): + if container.name == name: + return True + return False diff --git a/airflow/providers/microsoft/azure/hooks/container_registry.py b/airflow/providers/microsoft/azure/hooks/container_registry.py new file mode 100644 index 0000000000000..f4c5d1adb40aa --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/container_registry.py @@ -0,0 +1,66 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for Azure Container Registry""" + +from typing import Dict + +from azure.mgmt.containerinstance.models import ImageRegistryCredential + +from airflow.hooks.base import BaseHook + + +class AzureContainerRegistryHook(BaseHook): + """ + A hook to communicate with a Azure Container Registry. + + :param conn_id: :ref:`Azure Container Registry connection id` + of a service principal which will be used to start the container instance + + :type conn_id: str + """ + + conn_name_attr = 'azure_container_registry_conn_id' + default_conn_name = 'azure_container_registry_default' + conn_type = 'azure_container_registry' + hook_name = 'Azure Container Registry' + + @staticmethod + def get_ui_field_behaviour() -> Dict: + """Returns custom field behaviour""" + return { + "hidden_fields": ['schema', 'port', 'extra'], + "relabeling": { + 'login': 'Registry Username', + 'password': 'Registry Password', + 'host': 'Registry Server', + }, + "placeholders": { + 'login': 'private registry username', + 'password': 'private registry password', + 'host': 'docker image registry server', + }, + } + + def __init__(self, conn_id: str = 'azure_registry') -> None: + super().__init__() + self.conn_id = conn_id + self.connection = self.get_conn() + + def get_conn(self) -> ImageRegistryCredential: + conn = self.get_connection(self.conn_id) + return ImageRegistryCredential(server=conn.host, username=conn.login, password=conn.password) diff --git a/airflow/providers/microsoft/azure/hooks/container_volume.py b/airflow/providers/microsoft/azure/hooks/container_volume.py new file mode 100644 index 0000000000000..8aae16b491a34 --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/container_volume.py @@ -0,0 +1,106 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Any, Dict + +from azure.mgmt.containerinstance.models import AzureFileVolume, Volume + +from airflow.hooks.base import BaseHook + + +class AzureContainerVolumeHook(BaseHook): + """ + A hook which wraps an Azure Volume. + + :param azure_container_volume_conn_id: Reference to the + :ref:`Azure Container Volume connection id ` + of an Azure account of which container volumes should be used. + :type azure_container_volume_conn_id: str + """ + + conn_name_attr = "azure_container_volume_conn_id" + default_conn_name = 'azure_container_volume_default' + conn_type = 'azure_container_volume' + hook_name = 'Azure Container Volume' + + def __init__(self, azure_container_volume_conn_id: str = 'azure_container_volume_default') -> None: + super().__init__() + self.conn_id = azure_container_volume_conn_id + + @staticmethod + def get_connection_form_widgets() -> Dict[str, Any]: + """Returns connection widgets to add to connection form""" + from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget + from flask_babel import lazy_gettext + from wtforms import PasswordField + + return { + "extra__azure_container_volume__connection_string": PasswordField( + lazy_gettext('Blob Storage Connection String (optional)'), widget=BS3PasswordFieldWidget() + ), + } + + @staticmethod + def get_ui_field_behaviour() -> Dict: + """Returns custom field behaviour""" + import json + + return { + "hidden_fields": ['schema', 'port', 'host', "extra"], + "relabeling": { + 'login': 'Azure Client ID', + 'password': 'Azure Secret', + }, + "placeholders": { + 'extra': json.dumps( + { + "key_path": "path to json file for auth", + "key_json": "specifies json dict for auth", + }, + indent=1, + ), + 'login': 'client_id (token credentials auth)', + 'password': 'secret (token credentials auth)', + 'extra__azure_container_volume__connection_string': 'connection string auth', + }, + } + + def get_storagekey(self) -> str: + """Get Azure File Volume storage key""" + conn = self.get_connection(self.conn_id) + service_options = conn.extra_dejson + + if 'extra__azure_container_volume__connection_string' in service_options: + for keyvalue in service_options['extra__azure_container_volume__connection_string'].split(";"): + key, value = keyvalue.split("=", 1) + if key == "AccountKey": + return value + return conn.password + + def get_file_volume( + self, mount_name: str, share_name: str, storage_account_name: str, read_only: bool = False + ) -> Volume: + """Get Azure File Volume""" + return Volume( + name=mount_name, + azure_file=AzureFileVolume( + share_name=share_name, + storage_account_name=storage_account_name, + read_only=read_only, + storage_account_key=self.get_storagekey(), + ), + ) diff --git a/airflow/providers/microsoft/azure/hooks/cosmos.py b/airflow/providers/microsoft/azure/hooks/cosmos.py new file mode 100644 index 0000000000000..b75d75bfbfa06 --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/cosmos.py @@ -0,0 +1,353 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains integration with Azure CosmosDB. + +AzureCosmosDBHook communicates via the Azure Cosmos library. Make sure that a +Airflow connection of type `azure_cosmos` exists. Authorization can be done by supplying a +login (=Endpoint uri), password (=secret key) and extra fields database_name and collection_name to specify +the default database and collection to use (see connection `azure_cosmos_default` for an example). +""" +import uuid +from typing import Any, Dict, Optional + +from azure.cosmos.cosmos_client import CosmosClient +from azure.cosmos.exceptions import CosmosHttpResponseError + +from airflow.exceptions import AirflowBadRequest +from airflow.hooks.base import BaseHook + + +class AzureCosmosDBHook(BaseHook): + """ + Interacts with Azure CosmosDB. + + login should be the endpoint uri, password should be the master key + optionally, you can use the following extras to default these values + {"database_name": "", "collection_name": "COLLECTION_NAME"}. + + :param azure_cosmos_conn_id: Reference to the + :ref:`Azure CosmosDB connection`. + :type azure_cosmos_conn_id: str + """ + + conn_name_attr = 'azure_cosmos_conn_id' + default_conn_name = 'azure_cosmos_default' + conn_type = 'azure_cosmos' + hook_name = 'Azure CosmosDB' + + @staticmethod + def get_connection_form_widgets() -> Dict[str, Any]: + """Returns connection widgets to add to connection form""" + from flask_appbuilder.fieldwidgets import BS3TextFieldWidget + from flask_babel import lazy_gettext + from wtforms import StringField + + return { + "extra__azure_cosmos__database_name": StringField( + lazy_gettext('Cosmos Database Name (optional)'), widget=BS3TextFieldWidget() + ), + "extra__azure_cosmos__collection_name": StringField( + lazy_gettext('Cosmos Collection Name (optional)'), widget=BS3TextFieldWidget() + ), + } + + @staticmethod + def get_ui_field_behaviour() -> Dict: + """Returns custom field behaviour""" + return { + "hidden_fields": ['schema', 'port', 'host', 'extra'], + "relabeling": { + 'login': 'Cosmos Endpoint URI', + 'password': 'Cosmos Master Key Token', + }, + "placeholders": { + 'login': 'endpoint uri', + 'password': 'master key', + 'extra__azure_cosmos__database_name': 'database name', + 'extra__azure_cosmos__collection_name': 'collection name', + }, + } + + def __init__(self, azure_cosmos_conn_id: str = default_conn_name) -> None: + super().__init__() + self.conn_id = azure_cosmos_conn_id + self._conn = None + + self.default_database_name = None + self.default_collection_name = None + + def get_conn(self) -> CosmosClient: + """Return a cosmos db client.""" + if not self._conn: + conn = self.get_connection(self.conn_id) + extras = conn.extra_dejson + endpoint_uri = conn.login + master_key = conn.password + + self.default_database_name = extras.get('database_name') or extras.get( + 'extra__azure_cosmos__database_name' + ) + self.default_collection_name = extras.get('collection_name') or extras.get( + 'extra__azure_cosmos__collection_name' + ) + + # Initialize the Python Azure Cosmos DB client + self._conn = CosmosClient(endpoint_uri, {'masterKey': master_key}) + return self._conn + + def __get_database_name(self, database_name: Optional[str] = None) -> str: + self.get_conn() + db_name = database_name + if db_name is None: + db_name = self.default_database_name + + if db_name is None: + raise AirflowBadRequest("Database name must be specified") + + return db_name + + def __get_collection_name(self, collection_name: Optional[str] = None) -> str: + self.get_conn() + coll_name = collection_name + if coll_name is None: + coll_name = self.default_collection_name + + if coll_name is None: + raise AirflowBadRequest("Collection name must be specified") + + return coll_name + + def does_collection_exist(self, collection_name: str, database_name: str) -> bool: + """Checks if a collection exists in CosmosDB.""" + if collection_name is None: + raise AirflowBadRequest("Collection name cannot be None.") + + existing_container = list( + self.get_conn().QueryContainers( + get_database_link(self.__get_database_name(database_name)), + { + "query": "SELECT * FROM r WHERE r.id=@id", + "parameters": [{"name": "@id", "value": collection_name}], + }, + ) + ) + if len(existing_container) == 0: + return False + + return True + + def create_collection(self, collection_name: str, database_name: Optional[str] = None) -> None: + """Creates a new collection in the CosmosDB database.""" + if collection_name is None: + raise AirflowBadRequest("Collection name cannot be None.") + + # We need to check to see if this container already exists so we don't try + # to create it twice + existing_container = list( + self.get_conn().QueryContainers( + get_database_link(self.__get_database_name(database_name)), + { + "query": "SELECT * FROM r WHERE r.id=@id", + "parameters": [{"name": "@id", "value": collection_name}], + }, + ) + ) + + # Only create if we did not find it already existing + if len(existing_container) == 0: + self.get_conn().CreateContainer( + get_database_link(self.__get_database_name(database_name)), {"id": collection_name} + ) + + def does_database_exist(self, database_name: str) -> bool: + """Checks if a database exists in CosmosDB.""" + if database_name is None: + raise AirflowBadRequest("Database name cannot be None.") + + existing_database = list( + self.get_conn().QueryDatabases( + { + "query": "SELECT * FROM r WHERE r.id=@id", + "parameters": [{"name": "@id", "value": database_name}], + } + ) + ) + if len(existing_database) == 0: + return False + + return True + + def create_database(self, database_name: str) -> None: + """Creates a new database in CosmosDB.""" + if database_name is None: + raise AirflowBadRequest("Database name cannot be None.") + + # We need to check to see if this database already exists so we don't try + # to create it twice + existing_database = list( + self.get_conn().QueryDatabases( + { + "query": "SELECT * FROM r WHERE r.id=@id", + "parameters": [{"name": "@id", "value": database_name}], + } + ) + ) + + # Only create if we did not find it already existing + if len(existing_database) == 0: + self.get_conn().CreateDatabase({"id": database_name}) + + def delete_database(self, database_name: str) -> None: + """Deletes an existing database in CosmosDB.""" + if database_name is None: + raise AirflowBadRequest("Database name cannot be None.") + + self.get_conn().DeleteDatabase(get_database_link(database_name)) + + def delete_collection(self, collection_name: str, database_name: Optional[str] = None) -> None: + """Deletes an existing collection in the CosmosDB database.""" + if collection_name is None: + raise AirflowBadRequest("Collection name cannot be None.") + + self.get_conn().DeleteContainer( + get_collection_link(self.__get_database_name(database_name), collection_name) + ) + + def upsert_document(self, document, database_name=None, collection_name=None, document_id=None): + """ + Inserts a new document (or updates an existing one) into an existing + collection in the CosmosDB database. + """ + # Assign unique ID if one isn't provided + if document_id is None: + document_id = str(uuid.uuid4()) + + if document is None: + raise AirflowBadRequest("You cannot insert a None document") + + # Add document id if isn't found + if 'id' in document: + if document['id'] is None: + document['id'] = document_id + else: + document['id'] = document_id + + created_document = self.get_conn().CreateItem( + get_collection_link( + self.__get_database_name(database_name), self.__get_collection_name(collection_name) + ), + document, + ) + + return created_document + + def insert_documents( + self, documents, database_name: Optional[str] = None, collection_name: Optional[str] = None + ) -> list: + """Insert a list of new documents into an existing collection in the CosmosDB database.""" + if documents is None: + raise AirflowBadRequest("You cannot insert empty documents") + + created_documents = [] + for single_document in documents: + created_documents.append( + self.get_conn().CreateItem( + get_collection_link( + self.__get_database_name(database_name), self.__get_collection_name(collection_name) + ), + single_document, + ) + ) + + return created_documents + + def delete_document( + self, document_id: str, database_name: Optional[str] = None, collection_name: Optional[str] = None + ) -> None: + """Delete an existing document out of a collection in the CosmosDB database.""" + if document_id is None: + raise AirflowBadRequest("Cannot delete a document without an id") + + self.get_conn().DeleteItem( + get_document_link( + self.__get_database_name(database_name), + self.__get_collection_name(collection_name), + document_id, + ) + ) + + def get_document( + self, document_id: str, database_name: Optional[str] = None, collection_name: Optional[str] = None + ): + """Get a document from an existing collection in the CosmosDB database.""" + if document_id is None: + raise AirflowBadRequest("Cannot get a document without an id") + + try: + return self.get_conn().ReadItem( + get_document_link( + self.__get_database_name(database_name), + self.__get_collection_name(collection_name), + document_id, + ) + ) + except CosmosHttpResponseError: + return None + + def get_documents( + self, + sql_string: str, + database_name: Optional[str] = None, + collection_name: Optional[str] = None, + partition_key: Optional[str] = None, + ) -> Optional[list]: + """Get a list of documents from an existing collection in the CosmosDB database via SQL query.""" + if sql_string is None: + raise AirflowBadRequest("SQL query string cannot be None") + + # Query them in SQL + query = {'query': sql_string} + + try: + result_iterable = self.get_conn().QueryItems( + get_collection_link( + self.__get_database_name(database_name), self.__get_collection_name(collection_name) + ), + query, + partition_key, + ) + + return list(result_iterable) + except CosmosHttpResponseError: + return None + + +def get_database_link(database_id: str) -> str: + """Get Azure CosmosDB database link""" + return "dbs/" + database_id + + +def get_collection_link(database_id: str, collection_id: str) -> str: + """Get Azure CosmosDB collection link""" + return get_database_link(database_id) + "/colls/" + collection_id + + +def get_document_link(database_id: str, collection_id: str, document_id: str) -> str: + """Get Azure CosmosDB document link""" + return get_collection_link(database_id, collection_id) + "/docs/" + document_id diff --git a/airflow/providers/microsoft/azure/hooks/data_lake.py b/airflow/providers/microsoft/azure/hooks/data_lake.py new file mode 100644 index 0000000000000..8cbd001686d1e --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/data_lake.py @@ -0,0 +1,245 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +""" +This module contains integration with Azure Data Lake. + +AzureDataLakeHook communicates via a REST API compatible with WebHDFS. Make sure that a +Airflow connection of type `azure_data_lake` exists. Authorization can be done by supplying a +login (=Client ID), password (=Client Secret) and extra fields tenant (Tenant) and account_name (Account Name) +(see connection `azure_data_lake_default` for an example). +""" +from typing import Any, Dict, Optional + +from azure.datalake.store import core, lib, multithread + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook + + +class AzureDataLakeHook(BaseHook): + """ + Interacts with Azure Data Lake. + + Client ID and client secret should be in user and password parameters. + Tenant and account name should be extra field as + {"tenant": "", "account_name": "ACCOUNT_NAME"}. + + :param azure_data_lake_conn_id: Reference to the :ref:`Azure Data Lake connection`. + :type azure_data_lake_conn_id: str + """ + + conn_name_attr = 'azure_data_lake_conn_id' + default_conn_name = 'azure_data_lake_default' + conn_type = 'azure_data_lake' + hook_name = 'Azure Data Lake' + + @staticmethod + def get_connection_form_widgets() -> Dict[str, Any]: + """Returns connection widgets to add to connection form""" + from flask_appbuilder.fieldwidgets import BS3TextFieldWidget + from flask_babel import lazy_gettext + from wtforms import StringField + + return { + "extra__azure_data_lake__tenant": StringField( + lazy_gettext('Azure Tenant ID'), widget=BS3TextFieldWidget() + ), + "extra__azure_data_lake__account_name": StringField( + lazy_gettext('Azure DataLake Store Name'), widget=BS3TextFieldWidget() + ), + } + + @staticmethod + def get_ui_field_behaviour() -> Dict: + """Returns custom field behaviour""" + return { + "hidden_fields": ['schema', 'port', 'host', 'extra'], + "relabeling": { + 'login': 'Azure Client ID', + 'password': 'Azure Client Secret', + }, + "placeholders": { + 'login': 'client id', + 'password': 'secret', + 'extra__azure_data_lake__tenant': 'tenant id', + 'extra__azure_data_lake__account_name': 'datalake store', + }, + } + + def __init__(self, azure_data_lake_conn_id: str = default_conn_name) -> None: + super().__init__() + self.conn_id = azure_data_lake_conn_id + self._conn: Optional[core.AzureDLFileSystem] = None + self.account_name: Optional[str] = None + + def get_conn(self) -> core.AzureDLFileSystem: + """Return a AzureDLFileSystem object.""" + if not self._conn: + conn = self.get_connection(self.conn_id) + service_options = conn.extra_dejson + self.account_name = service_options.get('account_name') or service_options.get( + 'extra__azure_data_lake__account_name' + ) + tenant = service_options.get('tenant') or service_options.get('extra__azure_data_lake__tenant') + + adl_creds = lib.auth(tenant_id=tenant, client_secret=conn.password, client_id=conn.login) + self._conn = core.AzureDLFileSystem(adl_creds, store_name=self.account_name) + self._conn.connect() + return self._conn + + def check_for_file(self, file_path: str) -> bool: + """ + Check if a file exists on Azure Data Lake. + + :param file_path: Path and name of the file. + :type file_path: str + :return: True if the file exists, False otherwise. + :rtype: bool + """ + try: + files = self.get_conn().glob(file_path, details=False, invalidate_cache=True) + return len(files) == 1 + except FileNotFoundError: + return False + + def upload_file( + self, + local_path: str, + remote_path: str, + nthreads: int = 64, + overwrite: bool = True, + buffersize: int = 4194304, + blocksize: int = 4194304, + **kwargs, + ) -> None: + """ + Upload a file to Azure Data Lake. + + :param local_path: local path. Can be single file, directory (in which case, + upload recursively) or glob pattern. Recursive glob patterns using `**` + are not supported. + :type local_path: str + :param remote_path: Remote path to upload to; if multiple files, this is the + directory root to write within. + :type remote_path: str + :param nthreads: Number of threads to use. If None, uses the number of cores. + :type nthreads: int + :param overwrite: Whether to forcibly overwrite existing files/directories. + If False and remote path is a directory, will quit regardless if any files + would be overwritten or not. If True, only matching filenames are actually + overwritten. + :type overwrite: bool + :param buffersize: int [2**22] + Number of bytes for internal buffer. This block cannot be bigger than + a chunk and cannot be smaller than a block. + :type buffersize: int + :param blocksize: int [2**22] + Number of bytes for a block. Within each chunk, we write a smaller + block for each API call. This block cannot be bigger than a chunk. + :type blocksize: int + """ + multithread.ADLUploader( + self.get_conn(), + lpath=local_path, + rpath=remote_path, + nthreads=nthreads, + overwrite=overwrite, + buffersize=buffersize, + blocksize=blocksize, + **kwargs, + ) + + def download_file( + self, + local_path: str, + remote_path: str, + nthreads: int = 64, + overwrite: bool = True, + buffersize: int = 4194304, + blocksize: int = 4194304, + **kwargs, + ) -> None: + """ + Download a file from Azure Blob Storage. + + :param local_path: local path. If downloading a single file, will write to this + specific file, unless it is an existing directory, in which case a file is + created within it. If downloading multiple files, this is the root + directory to write within. Will create directories as required. + :type local_path: str + :param remote_path: remote path/globstring to use to find remote files. + Recursive glob patterns using `**` are not supported. + :type remote_path: str + :param nthreads: Number of threads to use. If None, uses the number of cores. + :type nthreads: int + :param overwrite: Whether to forcibly overwrite existing files/directories. + If False and remote path is a directory, will quit regardless if any files + would be overwritten or not. If True, only matching filenames are actually + overwritten. + :type overwrite: bool + :param buffersize: int [2**22] + Number of bytes for internal buffer. This block cannot be bigger than + a chunk and cannot be smaller than a block. + :type buffersize: int + :param blocksize: int [2**22] + Number of bytes for a block. Within each chunk, we write a smaller + block for each API call. This block cannot be bigger than a chunk. + :type blocksize: int + """ + multithread.ADLDownloader( + self.get_conn(), + lpath=local_path, + rpath=remote_path, + nthreads=nthreads, + overwrite=overwrite, + buffersize=buffersize, + blocksize=blocksize, + **kwargs, + ) + + def list(self, path: str) -> list: + """ + List files in Azure Data Lake Storage + + :param path: full path/globstring to use to list files in ADLS + :type path: str + """ + if "*" in path: + return self.get_conn().glob(path) + else: + return self.get_conn().walk(path) + + def remove(self, path: str, recursive: bool = False, ignore_not_found: bool = True) -> None: + """ + Remove files in Azure Data Lake Storage + + :param path: A directory or file to remove in ADLS + :type path: str + :param recursive: Whether to loop into directories in the location and remove the files + :type recursive: bool + :param ignore_not_found: Whether to raise error if file to delete is not found + :type ignore_not_found: bool + """ + try: + self.get_conn().remove(path=path, recursive=recursive) + except FileNotFoundError: + if ignore_not_found: + self.log.info("File %s not found", path) + else: + raise AirflowException(f"File {path} not found") diff --git a/airflow/providers/microsoft/azure/hooks/fileshare.py b/airflow/providers/microsoft/azure/hooks/fileshare.py new file mode 100644 index 0000000000000..acfa881ddb266 --- /dev/null +++ b/airflow/providers/microsoft/azure/hooks/fileshare.py @@ -0,0 +1,336 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import warnings +from typing import Any, Dict, List, Optional + +from azure.storage.file import File, FileService + +from airflow.hooks.base import BaseHook + + +class AzureFileShareHook(BaseHook): + """ + Interacts with Azure FileShare Storage. + + :param azure_fileshare_conn_id: Reference to the + :ref:`Azure Container Volume connection id` + of an Azure account of which container volumes should be used. + + """ + + conn_name_attr = "azure_fileshare_conn_id" + default_conn_name = 'azure_fileshare_default' + conn_type = 'azure_fileshare' + hook_name = 'Azure FileShare' + + def __init__(self, azure_fileshare_conn_id: str = 'azure_fileshare_default') -> None: + super().__init__() + self.conn_id = azure_fileshare_conn_id + self._conn = None + + @staticmethod + def get_connection_form_widgets() -> Dict[str, Any]: + """Returns connection widgets to add to connection form""" + from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget + from flask_babel import lazy_gettext + from wtforms import PasswordField, StringField + + return { + "extra__azure_fileshare__sas_token": PasswordField( + lazy_gettext('SAS Token (optional)'), widget=BS3PasswordFieldWidget() + ), + "extra__azure_fileshare__connection_string": StringField( + lazy_gettext('Connection String (optional)'), widget=BS3TextFieldWidget() + ), + "extra__azure_fileshare__protocol": StringField( + lazy_gettext('Account URL or token (optional)'), widget=BS3TextFieldWidget() + ), + } + + @staticmethod + def get_ui_field_behaviour() -> Dict: + """Returns custom field behaviour""" + return { + "hidden_fields": ['schema', 'port', 'host', 'extra'], + "relabeling": { + 'login': 'Blob Storage Login (optional)', + 'password': 'Blob Storage Key (optional)', + 'host': 'Account Name (Active Directory Auth)', + }, + "placeholders": { + 'login': 'account name', + 'password': 'secret', + 'host': 'account url', + 'extra__azure_fileshare__sas_token': 'account url or token (optional)', + 'extra__azure_fileshare__connection_string': 'account url or token (optional)', + 'extra__azure_fileshare__protocol': 'account url or token (optional)', + }, + } + + def get_conn(self) -> FileService: + """Return the FileService object.""" + prefix = "extra__azure_fileshare__" + if self._conn: + return self._conn + conn = self.get_connection(self.conn_id) + service_options_with_prefix = conn.extra_dejson + service_options = {} + for key, value in service_options_with_prefix.items(): + # in case dedicated FileShareHook is used, the connection will use the extras from UI. + # in case deprecated wasb hook is used, the old extras will work as well + if key.startswith(prefix): + if value != '': + service_options[key[len(prefix) :]] = value + else: + # warn if the deprecated wasb_connection is used + warnings.warn( + "You are using deprecated connection for AzureFileShareHook." + " Please change it to `Azure FileShare`.", + DeprecationWarning, + ) + else: + service_options[key] = value + # warn if the old non-prefixed value is used + warnings.warn( + "You are using deprecated connection for AzureFileShareHook." + " Please change it to `Azure FileShare`.", + DeprecationWarning, + ) + self._conn = FileService(account_name=conn.login, account_key=conn.password, **service_options) + return self._conn + + def check_for_directory(self, share_name: str, directory_name: str, **kwargs) -> bool: + """ + Check if a directory exists on Azure File Share. + + :param share_name: Name of the share. + :type share_name: str + :param directory_name: Name of the directory. + :type directory_name: str + :param kwargs: Optional keyword arguments that + `FileService.exists()` takes. + :type kwargs: object + :return: True if the file exists, False otherwise. + :rtype: bool + """ + return self.get_conn().exists(share_name, directory_name, **kwargs) + + def check_for_file(self, share_name: str, directory_name: str, file_name: str, **kwargs) -> bool: + """ + Check if a file exists on Azure File Share. + + :param share_name: Name of the share. + :type share_name: str + :param directory_name: Name of the directory. + :type directory_name: str + :param file_name: Name of the file. + :type file_name: str + :param kwargs: Optional keyword arguments that + `FileService.exists()` takes. + :type kwargs: object + :return: True if the file exists, False otherwise. + :rtype: bool + """ + return self.get_conn().exists(share_name, directory_name, file_name, **kwargs) + + def list_directories_and_files( + self, share_name: str, directory_name: Optional[str] = None, **kwargs + ) -> list: + """ + Return the list of directories and files stored on a Azure File Share. + + :param share_name: Name of the share. + :type share_name: str + :param directory_name: Name of the directory. + :type directory_name: str + :param kwargs: Optional keyword arguments that + `FileService.list_directories_and_files()` takes. + :type kwargs: object + :return: A list of files and directories + :rtype: list + """ + return self.get_conn().list_directories_and_files(share_name, directory_name, **kwargs) + + def list_files(self, share_name: str, directory_name: Optional[str] = None, **kwargs) -> List[str]: + """ + Return the list of files stored on a Azure File Share. + + :param share_name: Name of the share. + :type share_name: str + :param directory_name: Name of the directory. + :type directory_name: str + :param kwargs: Optional keyword arguments that + `FileService.list_directories_and_files()` takes. + :type kwargs: object + :return: A list of files + :rtype: list + """ + return [ + obj.name + for obj in self.list_directories_and_files(share_name, directory_name, **kwargs) + if isinstance(obj, File) + ] + + def create_share(self, share_name: str, **kwargs) -> bool: + """ + Create new Azure File Share. + + :param share_name: Name of the share. + :type share_name: str + :param kwargs: Optional keyword arguments that + `FileService.create_share()` takes. + :type kwargs: object + :return: True if share is created, False if share already exists. + :rtype: bool + """ + return self.get_conn().create_share(share_name, **kwargs) + + def delete_share(self, share_name: str, **kwargs) -> bool: + """ + Delete existing Azure File Share. + + :param share_name: Name of the share. + :type share_name: str + :param kwargs: Optional keyword arguments that + `FileService.delete_share()` takes. + :type kwargs: object + :return: True if share is deleted, False if share does not exist. + :rtype: bool + """ + return self.get_conn().delete_share(share_name, **kwargs) + + def create_directory(self, share_name: str, directory_name: str, **kwargs) -> list: + """ + Create a new directory on a Azure File Share. + + :param share_name: Name of the share. + :type share_name: str + :param directory_name: Name of the directory. + :type directory_name: str + :param kwargs: Optional keyword arguments that + `FileService.create_directory()` takes. + :type kwargs: object + :return: A list of files and directories + :rtype: list + """ + return self.get_conn().create_directory(share_name, directory_name, **kwargs) + + def get_file( + self, file_path: str, share_name: str, directory_name: str, file_name: str, **kwargs + ) -> None: + """ + Download a file from Azure File Share. + + :param file_path: Where to store the file. + :type file_path: str + :param share_name: Name of the share. + :type share_name: str + :param directory_name: Name of the directory. + :type directory_name: str + :param file_name: Name of the file. + :type file_name: str + :param kwargs: Optional keyword arguments that + `FileService.get_file_to_path()` takes. + :type kwargs: object + """ + self.get_conn().get_file_to_path(share_name, directory_name, file_name, file_path, **kwargs) + + def get_file_to_stream( + self, stream: str, share_name: str, directory_name: str, file_name: str, **kwargs + ) -> None: + """ + Download a file from Azure File Share. + + :param stream: A filehandle to store the file to. + :type stream: file-like object + :param share_name: Name of the share. + :type share_name: str + :param directory_name: Name of the directory. + :type directory_name: str + :param file_name: Name of the file. + :type file_name: str + :param kwargs: Optional keyword arguments that + `FileService.get_file_to_stream()` takes. + :type kwargs: object + """ + self.get_conn().get_file_to_stream(share_name, directory_name, file_name, stream, **kwargs) + + def load_file( + self, file_path: str, share_name: str, directory_name: str, file_name: str, **kwargs + ) -> None: + """ + Upload a file to Azure File Share. + + :param file_path: Path to the file to load. + :type file_path: str + :param share_name: Name of the share. + :type share_name: str + :param directory_name: Name of the directory. + :type directory_name: str + :param file_name: Name of the file. + :type file_name: str + :param kwargs: Optional keyword arguments that + `FileService.create_file_from_path()` takes. + :type kwargs: object + """ + self.get_conn().create_file_from_path(share_name, directory_name, file_name, file_path, **kwargs) + + def load_string( + self, string_data: str, share_name: str, directory_name: str, file_name: str, **kwargs + ) -> None: + """ + Upload a string to Azure File Share. + + :param string_data: String to load. + :type string_data: str + :param share_name: Name of the share. + :type share_name: str + :param directory_name: Name of the directory. + :type directory_name: str + :param file_name: Name of the file. + :type file_name: str + :param kwargs: Optional keyword arguments that + `FileService.create_file_from_text()` takes. + :type kwargs: object + """ + self.get_conn().create_file_from_text(share_name, directory_name, file_name, string_data, **kwargs) + + def load_stream( + self, stream: str, share_name: str, directory_name: str, file_name: str, count: str, **kwargs + ) -> None: + """ + Upload a stream to Azure File Share. + + :param stream: Opened file/stream to upload as the file content. + :type stream: file-like + :param share_name: Name of the share. + :type share_name: str + :param directory_name: Name of the directory. + :type directory_name: str + :param file_name: Name of the file. + :type file_name: str + :param count: Size of the stream in bytes + :type count: int + :param kwargs: Optional keyword arguments that + `FileService.create_file_from_stream()` takes. + :type kwargs: object + """ + self.get_conn().create_file_from_stream( + share_name, directory_name, file_name, stream, count, **kwargs + ) diff --git a/airflow/providers/microsoft/azure/operators/adls.py b/airflow/providers/microsoft/azure/operators/adls.py index 4672726ff598c..33f895ea283d2 100644 --- a/airflow/providers/microsoft/azure/operators/adls.py +++ b/airflow/providers/microsoft/azure/operators/adls.py @@ -18,7 +18,7 @@ from typing import Any, Sequence from airflow.models import BaseOperator -from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook +from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook class ADLSDeleteOperator(BaseOperator): diff --git a/airflow/providers/microsoft/azure/operators/azure_batch.py b/airflow/providers/microsoft/azure/operators/azure_batch.py index d58513b444cbe..baa931e6c76e3 100644 --- a/airflow/providers/microsoft/azure/operators/azure_batch.py +++ b/airflow/providers/microsoft/azure/operators/azure_batch.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,344 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -from typing import Any, List, Optional - -from azure.batch import models as batch_models - -from airflow.exceptions import AirflowException -from airflow.models import BaseOperator -from airflow.providers.microsoft.azure.hooks.azure_batch import AzureBatchHook - - -class AzureBatchOperator(BaseOperator): - """ - Executes a job on Azure Batch Service - - :param batch_pool_id: A string that uniquely identifies the Pool within the Account. - :type batch_pool_id: str - :param batch_pool_vm_size: The size of virtual machines in the Pool - :type batch_pool_vm_size: str - :param batch_job_id: A string that uniquely identifies the Job within the Account. - :type batch_job_id: str - :param batch_task_command_line: The command line of the Task - :type batch_task_command_line: str - :param batch_task_id: A string that uniquely identifies the task within the Job. - :type batch_task_id: str - :param batch_pool_display_name: The display name for the Pool. - The display name need not be unique - :type batch_pool_display_name: Optional[str] - :param batch_job_display_name: The display name for the Job. - The display name need not be unique - :type batch_job_display_name: Optional[str] - :param batch_job_manager_task: Details of a Job Manager Task to be launched when the Job is started. - :type batch_job_manager_task: Optional[batch_models.JobManagerTask] - :param batch_job_preparation_task: The Job Preparation Task. If set, the Batch service will - run the Job Preparation Task on a Node before starting any Tasks of that - Job on that Compute Node. Required if batch_job_release_task is set. - :type batch_job_preparation_task: Optional[batch_models.JobPreparationTask] - :param batch_job_release_task: The Job Release Task. Use to undo changes to Compute Nodes - made by the Job Preparation Task - :type batch_job_release_task: Optional[batch_models.JobReleaseTask] - :param batch_task_display_name: The display name for the task. - The display name need not be unique - :type batch_task_display_name: Optional[str] - :param batch_task_container_settings: The settings for the container under which the Task runs - :type batch_task_container_settings: Optional[batch_models.TaskContainerSettings] - :param batch_start_task: A Task specified to run on each Compute Node as it joins the Pool. - The Task runs when the Compute Node is added to the Pool or - when the Compute Node is restarted. - :type batch_start_task: Optional[batch_models.StartTask] - :param batch_max_retries: The number of times to retry this batch operation before it's - considered a failed operation. Default is 3 - :type batch_max_retries: int - :param batch_task_resource_files: A list of files that the Batch service will - download to the Compute Node before running the command line. - :type batch_task_resource_files: Optional[List[batch_models.ResourceFile]] - :param batch_task_output_files: A list of files that the Batch service will upload - from the Compute Node after running the command line. - :type batch_task_output_files: Optional[List[batch_models.OutputFile]] - :param batch_task_user_identity: The user identity under which the Task runs. - If omitted, the Task runs as a non-administrative user unique to the Task. - :type batch_task_user_identity: Optional[batch_models.UserIdentity] - :param target_low_priority_nodes: The desired number of low-priority Compute Nodes in the Pool. - This property must not be specified if enable_auto_scale is set to true. - :type target_low_priority_nodes: Optional[int] - :param target_dedicated_nodes: The desired number of dedicated Compute Nodes in the Pool. - This property must not be specified if enable_auto_scale is set to true. - :type target_dedicated_nodes: Optional[int] - :param enable_auto_scale: Whether the Pool size should automatically adjust over time. Default is false - :type enable_auto_scale: bool - :param auto_scale_formula: A formula for the desired number of Compute Nodes in the Pool. - This property must not be specified if enableAutoScale is set to false. - It is required if enableAutoScale is set to true. - :type auto_scale_formula: Optional[str] - :param azure_batch_conn_id: The :ref:`Azure Batch connection id` - :type azure_batch_conn_id: str - :param use_latest_verified_vm_image_and_sku: Whether to use the latest verified virtual - machine image and sku in the batch account. Default is false. - :type use_latest_verified_vm_image_and_sku: bool - :param vm_publisher: The publisher of the Azure Virtual Machines Marketplace Image. - For example, Canonical or MicrosoftWindowsServer. Required if - use_latest_image_and_sku is set to True - :type vm_publisher: Optional[str] - :param vm_offer: The offer type of the Azure Virtual Machines Marketplace Image. - For example, UbuntuServer or WindowsServer. Required if - use_latest_image_and_sku is set to True - :type vm_offer: Optional[str] - :param sku_starts_with: The starting string of the Virtual Machine SKU. Required if - use_latest_image_and_sku is set to True - :type sku_starts_with: Optional[str] - :param vm_sku: The name of the virtual machine sku to use - :type vm_sku: Optional[str] - :param vm_version: The version of the virtual machine - :param vm_version: Optional[str] - :param vm_node_agent_sku_id: The node agent sku id of the virtual machine - :type vm_node_agent_sku_id: Optional[str] - :param os_family: The Azure Guest OS family to be installed on the virtual machines in the Pool. - :type os_family: Optional[str] - :param os_version: The OS family version - :type os_version: Optional[str] - :param timeout: The amount of time to wait for the job to complete in minutes. Default is 25 - :type timeout: int - :param should_delete_job: Whether to delete job after execution. Default is False - :type should_delete_job: bool - :param should_delete_pool: Whether to delete pool after execution of jobs. Default is False - :type should_delete_pool: bool - """ - - template_fields = ( - 'batch_pool_id', - 'batch_pool_vm_size', - 'batch_job_id', - 'batch_task_id', - 'batch_task_command_line', - ) - ui_color = '#f0f0e4' - - def __init__( - self, - *, - batch_pool_id: str, - batch_pool_vm_size: str, - batch_job_id: str, - batch_task_command_line: str, - batch_task_id: str, - vm_publisher: Optional[str] = None, - vm_offer: Optional[str] = None, - sku_starts_with: Optional[str] = None, - vm_sku: Optional[str] = None, - vm_version: Optional[str] = None, - vm_node_agent_sku_id: Optional[str] = None, - os_family: Optional[str] = None, - os_version: Optional[str] = None, - batch_pool_display_name: Optional[str] = None, - batch_job_display_name: Optional[str] = None, - batch_job_manager_task: Optional[batch_models.JobManagerTask] = None, - batch_job_preparation_task: Optional[batch_models.JobPreparationTask] = None, - batch_job_release_task: Optional[batch_models.JobReleaseTask] = None, - batch_task_display_name: Optional[str] = None, - batch_task_container_settings: Optional[batch_models.TaskContainerSettings] = None, - batch_start_task: Optional[batch_models.StartTask] = None, - batch_max_retries: int = 3, - batch_task_resource_files: Optional[List[batch_models.ResourceFile]] = None, - batch_task_output_files: Optional[List[batch_models.OutputFile]] = None, - batch_task_user_identity: Optional[batch_models.UserIdentity] = None, - target_low_priority_nodes: Optional[int] = None, - target_dedicated_nodes: Optional[int] = None, - enable_auto_scale: bool = False, - auto_scale_formula: Optional[str] = None, - azure_batch_conn_id='azure_batch_default', - use_latest_verified_vm_image_and_sku: bool = False, - timeout: int = 25, - should_delete_job: bool = False, - should_delete_pool: bool = False, - **kwargs, - ) -> None: - - super().__init__(**kwargs) - self.batch_pool_id = batch_pool_id - self.batch_pool_vm_size = batch_pool_vm_size - self.batch_job_id = batch_job_id - self.batch_task_id = batch_task_id - self.batch_task_command_line = batch_task_command_line - self.batch_pool_display_name = batch_pool_display_name - self.batch_job_display_name = batch_job_display_name - self.batch_job_manager_task = batch_job_manager_task - self.batch_job_preparation_task = batch_job_preparation_task - self.batch_job_release_task = batch_job_release_task - self.batch_task_display_name = batch_task_display_name - self.batch_task_container_settings = batch_task_container_settings - self.batch_start_task = batch_start_task - self.batch_max_retries = batch_max_retries - self.batch_task_resource_files = batch_task_resource_files - self.batch_task_output_files = batch_task_output_files - self.batch_task_user_identity = batch_task_user_identity - self.target_low_priority_nodes = target_low_priority_nodes - self.target_dedicated_nodes = target_dedicated_nodes - self.enable_auto_scale = enable_auto_scale - self.auto_scale_formula = auto_scale_formula - self.azure_batch_conn_id = azure_batch_conn_id - self.use_latest_image = use_latest_verified_vm_image_and_sku - self.vm_publisher = vm_publisher - self.vm_offer = vm_offer - self.sku_starts_with = sku_starts_with - self.vm_sku = vm_sku - self.vm_version = vm_version - self.vm_node_agent_sku_id = vm_node_agent_sku_id - self.os_family = os_family - self.os_version = os_version - self.timeout = timeout - self.should_delete_job = should_delete_job - self.should_delete_pool = should_delete_pool - self.hook = self.get_hook() - - def _check_inputs(self) -> Any: - if not self.os_family and not self.vm_publisher: - raise AirflowException("You must specify either vm_publisher or os_family") - if self.os_family and self.vm_publisher: - raise AirflowException( - "Cloud service configuration and virtual machine configuration " - "are mutually exclusive. You must specify either of os_family and" - " vm_publisher" - ) - - if self.use_latest_image: - if not all(elem for elem in [self.vm_publisher, self.vm_offer]): - raise AirflowException( - f"If use_latest_image_and_sku is set to True then the parameters vm_publisher, " - f"vm_offer, must all be set. " - f"Found vm_publisher={self.vm_publisher}, vm_offer={self.vm_offer}" - ) - if self.vm_publisher: - if not all([self.vm_sku, self.vm_offer, self.vm_node_agent_sku_id]): - raise AirflowException( - "If vm_publisher is set, then the parameters vm_sku, vm_offer," - "vm_node_agent_sku_id must be set. Found " - f"vm_publisher={self.vm_publisher}, vm_offer={self.vm_offer} " - f"vm_node_agent_sku_id={self.vm_node_agent_sku_id}, " - f"vm_version={self.vm_version}" - ) - - if not self.target_dedicated_nodes and not self.enable_auto_scale: - raise AirflowException( - "Either target_dedicated_nodes or enable_auto_scale must be set. None was set" - ) - if self.enable_auto_scale: - if self.target_dedicated_nodes or self.target_low_priority_nodes: - raise AirflowException( - f"If enable_auto_scale is set, then the parameters target_dedicated_nodes and " - f"target_low_priority_nodes must not be set. Found " - f"target_dedicated_nodes={self.target_dedicated_nodes}, " - f"target_low_priority_nodes={self.target_low_priority_nodes}" - ) - if not self.auto_scale_formula: - raise AirflowException("The auto_scale_formula is required when enable_auto_scale is set") - if self.batch_job_release_task and not self.batch_job_preparation_task: - raise AirflowException( - "A batch_job_release_task cannot be specified without also " - " specifying a batch_job_preparation_task for the Job." - ) - if not all( - [ - self.batch_pool_id, - self.batch_job_id, - self.batch_pool_vm_size, - self.batch_task_id, - self.batch_task_command_line, - ] - ): - raise AirflowException( - "Some required parameters are missing.Please you must set all the required parameters. " - ) - - def execute(self, context: dict) -> None: - self._check_inputs() - self.hook.connection.config.retry_policy = self.batch_max_retries - - pool = self.hook.configure_pool( - pool_id=self.batch_pool_id, - vm_size=self.batch_pool_vm_size, - display_name=self.batch_pool_display_name, - target_dedicated_nodes=self.target_dedicated_nodes, - use_latest_image_and_sku=self.use_latest_image, - vm_publisher=self.vm_publisher, - vm_offer=self.vm_offer, - sku_starts_with=self.sku_starts_with, - vm_sku=self.vm_sku, - vm_version=self.vm_version, - vm_node_agent_sku_id=self.vm_node_agent_sku_id, - os_family=self.os_family, - os_version=self.os_version, - target_low_priority_nodes=self.target_low_priority_nodes, - enable_auto_scale=self.enable_auto_scale, - auto_scale_formula=self.auto_scale_formula, - start_task=self.batch_start_task, - ) - self.hook.create_pool(pool) - # Wait for nodes to reach complete state - self.hook.wait_for_all_node_state( - self.batch_pool_id, - { - batch_models.ComputeNodeState.start_task_failed, - batch_models.ComputeNodeState.unusable, - batch_models.ComputeNodeState.idle, - }, - ) - # Create job if not already exist - job = self.hook.configure_job( - job_id=self.batch_job_id, - pool_id=self.batch_pool_id, - display_name=self.batch_job_display_name, - job_manager_task=self.batch_job_manager_task, - job_preparation_task=self.batch_job_preparation_task, - job_release_task=self.batch_job_release_task, - ) - self.hook.create_job(job) - # Create task - task = self.hook.configure_task( - task_id=self.batch_task_id, - command_line=self.batch_task_command_line, - display_name=self.batch_task_display_name, - container_settings=self.batch_task_container_settings, - resource_files=self.batch_task_resource_files, - output_files=self.batch_task_output_files, - user_identity=self.batch_task_user_identity, - ) - # Add task to job - self.hook.add_single_task_to_job(job_id=self.batch_job_id, task=task) - # Wait for tasks to complete - self.hook.wait_for_job_tasks_to_complete(job_id=self.batch_job_id, timeout=self.timeout) - # Clean up - if self.should_delete_job: - # delete job first - self.clean_up(job_id=self.batch_job_id) - if self.should_delete_pool: - self.clean_up(self.batch_pool_id) - - def on_kill(self) -> None: - response = self.hook.connection.job.terminate( - job_id=self.batch_job_id, terminate_reason='Job killed by user' - ) - self.log.info("Azure Batch job (%s) terminated: %s", self.batch_job_id, response) - - def get_hook(self) -> AzureBatchHook: - """Create and return an AzureBatchHook.""" - return AzureBatchHook(azure_batch_conn_id=self.azure_batch_conn_id) +"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.operators.batch`.""" - def clean_up(self, pool_id: Optional[str] = None, job_id: Optional[str] = None) -> None: - """ - Delete the given pool and job in the batch account +import warnings - :param pool_id: The id of the pool to delete - :type pool_id: str - :param job_id: The id of the job to delete - :type job_id: str +from airflow.providers.microsoft.azure.operators.batch import AzureBatchOperator # noqa - """ - if job_id: - self.log.info("Deleting job: %s", job_id) - self.hook.connection.job.delete(job_id) - if pool_id: - self.log.info("Deleting pool: %s", pool_id) - self.hook.connection.pool.delete(pool_id) +warnings.warn( + "This module is deprecated. Please use `airflow.providers.microsoft.azure.operators.batch`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/providers/microsoft/azure/operators/azure_container_instances.py b/airflow/providers/microsoft/azure/operators/azure_container_instances.py index cd8cd449f9002..bb9bd05f24831 100644 --- a/airflow/providers/microsoft/azure/operators/azure_container_instances.py +++ b/airflow/providers/microsoft/azure/operators/azure_container_instances.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,376 +14,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +""" +This module is deprecated. +Please use :mod:`airflow.providers.microsoft.azure.operators.container_instances`. +""" -import re -from collections import namedtuple -from time import sleep -from typing import Any, Dict, List, Optional, Sequence, Union +import warnings -from azure.mgmt.containerinstance.models import ( - Container, - ContainerGroup, - ContainerPort, - EnvironmentVariable, - IpAddress, - ResourceRequests, - ResourceRequirements, - VolumeMount, +from airflow.providers.microsoft.azure.operators.container_instances import ( # noqa + AzureContainerInstancesOperator, ) -from msrestazure.azure_exceptions import CloudError - -from airflow.exceptions import AirflowException, AirflowTaskTimeout -from airflow.models import BaseOperator -from airflow.providers.microsoft.azure.hooks.azure_container_instance import AzureContainerInstanceHook -from airflow.providers.microsoft.azure.hooks.azure_container_registry import AzureContainerRegistryHook -from airflow.providers.microsoft.azure.hooks.azure_container_volume import AzureContainerVolumeHook -Volume = namedtuple( - 'Volume', - ['conn_id', 'account_name', 'share_name', 'mount_path', 'read_only'], +warnings.warn( + "This module is deprecated. " + "Please use `airflow.providers.microsoft.azure.operators.container_instances`.", + DeprecationWarning, + stacklevel=2, ) - - -DEFAULT_ENVIRONMENT_VARIABLES: Dict[str, str] = {} -DEFAULT_SECURED_VARIABLES: Sequence[str] = [] -DEFAULT_VOLUMES: Sequence[Volume] = [] -DEFAULT_MEMORY_IN_GB = 2.0 -DEFAULT_CPU = 1.0 - - -class AzureContainerInstancesOperator(BaseOperator): - """ - Start a container on Azure Container Instances - - :param ci_conn_id: connection id of a service principal which will be used - to start the container instance - :type ci_conn_id: str - :param registry_conn_id: connection id of a user which can login to a - private docker registry. For Azure use :ref:`Azure connection id` - :type azure_conn_id: str If None, we assume a public registry - :type registry_conn_id: Optional[str] - :param resource_group: name of the resource group wherein this container - instance should be started - :type resource_group: str - :param name: name of this container instance. Please note this name has - to be unique in order to run containers in parallel. - :type name: str - :param image: the docker image to be used - :type image: str - :param region: the region wherein this container instance should be started - :type region: str - :param environment_variables: key,value pairs containing environment - variables which will be passed to the running container - :type environment_variables: Optional[dict] - :param secured_variables: names of environmental variables that should not - be exposed outside the container (typically passwords). - :type secured_variables: Optional[str] - :param volumes: list of ``Volume`` tuples to be mounted to the container. - Currently only Azure Fileshares are supported. - :type volumes: list[] - :param memory_in_gb: the amount of memory to allocate to this container - :type memory_in_gb: double - :param cpu: the number of cpus to allocate to this container - :type cpu: double - :param gpu: GPU Resource for the container. - :type gpu: azure.mgmt.containerinstance.models.GpuResource - :param command: the command to run inside the container - :type command: Optional[List[str]] - :param container_timeout: max time allowed for the execution of - the container instance. - :type container_timeout: datetime.timedelta - :param tags: azure tags as dict of str:str - :type tags: Optional[dict[str, str]] - :param os_type: The operating system type required by the containers - in the container group. Possible values include: 'Windows', 'Linux' - :type os_type: str - :param restart_policy: Restart policy for all containers within the container group. - Possible values include: 'Always', 'OnFailure', 'Never' - :type restart_policy: str - :param ip_address: The IP address type of the container group. - :type ip_address: IpAddress - - **Example**:: - - AzureContainerInstancesOperator( - ci_conn_id = "azure_service_principal", - registry_conn_id = "azure_registry_user", - resource_group = "my-resource-group", - name = "my-container-name-{{ ds }}", - image = "myprivateregistry.azurecr.io/my_container:latest", - region = "westeurope", - environment_variables = {"MODEL_PATH": "my_value", - "POSTGRES_LOGIN": "{{ macros.connection('postgres_default').login }}", - "POSTGRES_PASSWORD": "{{ macros.connection('postgres_default').password }}", - "JOB_GUID": "{{ ti.xcom_pull(task_ids='task1', key='guid') }}" }, - secured_variables = ['POSTGRES_PASSWORD'], - volumes = [("azure_container_instance_conn_id", - "my_storage_container", - "my_fileshare", - "/input-data", - True),], - memory_in_gb=14.0, - cpu=4.0, - gpu=GpuResource(count=1, sku='K80'), - command=["/bin/echo", "world"], - task_id="start_container" - ) - """ - - template_fields = ('name', 'image', 'command', 'environment_variables') - template_fields_renderers = {"command": "bash", "environment_variables": "json"} - - def __init__( - self, - *, - ci_conn_id: str, - registry_conn_id: Optional[str], - resource_group: str, - name: str, - image: str, - region: str, - environment_variables: Optional[dict] = None, - secured_variables: Optional[str] = None, - volumes: Optional[list] = None, - memory_in_gb: Optional[Any] = None, - cpu: Optional[Any] = None, - gpu: Optional[Any] = None, - command: Optional[List[str]] = None, - remove_on_error: bool = True, - fail_if_exists: bool = True, - tags: Optional[Dict[str, str]] = None, - os_type: str = 'Linux', - restart_policy: str = 'Never', - ip_address: Optional[IpAddress] = None, - ports: Optional[List[ContainerPort]] = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - self.ci_conn_id = ci_conn_id - self.resource_group = resource_group - self.name = self._check_name(name) - self.image = image - self.region = region - self.registry_conn_id = registry_conn_id - self.environment_variables = environment_variables or DEFAULT_ENVIRONMENT_VARIABLES - self.secured_variables = secured_variables or DEFAULT_SECURED_VARIABLES - self.volumes = volumes or DEFAULT_VOLUMES - self.memory_in_gb = memory_in_gb or DEFAULT_MEMORY_IN_GB - self.cpu = cpu or DEFAULT_CPU - self.gpu = gpu - self.command = command - self.remove_on_error = remove_on_error - self.fail_if_exists = fail_if_exists - self._ci_hook: Any = None - self.tags = tags - self.os_type = os_type - if self.os_type not in ['Linux', 'Windows']: - raise AirflowException( - "Invalid value for the os_type argument. " - "Please set 'Linux' or 'Windows' as the os_type. " - f"Found `{self.os_type}`." - ) - self.restart_policy = restart_policy - if self.restart_policy not in ['Always', 'OnFailure', 'Never']: - raise AirflowException( - "Invalid value for the restart_policy argument. " - "Please set one of 'Always', 'OnFailure','Never' as the restart_policy. " - f"Found `{self.restart_policy}`" - ) - self.ip_address = ip_address - self.ports = ports - - def execute(self, context: dict) -> int: - # Check name again in case it was templated. - self._check_name(self.name) - - self._ci_hook = AzureContainerInstanceHook(self.ci_conn_id) - - if self.fail_if_exists: - self.log.info("Testing if container group already exists") - if self._ci_hook.exists(self.resource_group, self.name): - raise AirflowException("Container group exists") - - if self.registry_conn_id: - registry_hook = AzureContainerRegistryHook(self.registry_conn_id) - image_registry_credentials: Optional[list] = [ - registry_hook.connection, - ] - else: - image_registry_credentials = None - - environment_variables = [] - for key, value in self.environment_variables.items(): - if key in self.secured_variables: - e = EnvironmentVariable(name=key, secure_value=value) - else: - e = EnvironmentVariable(name=key, value=value) - environment_variables.append(e) - - volumes: List[Union[Volume, Volume]] = [] - volume_mounts: List[Union[VolumeMount, VolumeMount]] = [] - for conn_id, account_name, share_name, mount_path, read_only in self.volumes: - hook = AzureContainerVolumeHook(conn_id) - - mount_name = "mount-%d" % len(volumes) - volumes.append(hook.get_file_volume(mount_name, share_name, account_name, read_only)) - volume_mounts.append(VolumeMount(name=mount_name, mount_path=mount_path, read_only=read_only)) - - exit_code = 1 - try: - self.log.info("Starting container group with %.1f cpu %.1f mem", self.cpu, self.memory_in_gb) - if self.gpu: - self.log.info("GPU count: %.1f, GPU SKU: %s", self.gpu.count, self.gpu.sku) - - resources = ResourceRequirements( - requests=ResourceRequests(memory_in_gb=self.memory_in_gb, cpu=self.cpu, gpu=self.gpu) - ) - - if self.ip_address and not self.ports: - self.ports = [ContainerPort(port=80)] - self.log.info("Default port set. Container will listen on port 80") - - container = Container( - name=self.name, - image=self.image, - resources=resources, - command=self.command, - environment_variables=environment_variables, - volume_mounts=volume_mounts, - ports=self.ports, - ) - - container_group = ContainerGroup( - location=self.region, - containers=[ - container, - ], - image_registry_credentials=image_registry_credentials, - volumes=volumes, - restart_policy=self.restart_policy, - os_type=self.os_type, - tags=self.tags, - ip_address=self.ip_address, - ) - - self._ci_hook.create_or_update(self.resource_group, self.name, container_group) - - self.log.info("Container group started %s/%s", self.resource_group, self.name) - - exit_code = self._monitor_logging(self.resource_group, self.name) - - self.log.info("Container had exit code: %s", exit_code) - if exit_code != 0: - raise AirflowException(f"Container had a non-zero exit code, {exit_code}") - return exit_code - - except CloudError: - self.log.exception("Could not start container group") - raise AirflowException("Could not start container group") - - finally: - if exit_code == 0 or self.remove_on_error: - self.on_kill() - - def on_kill(self) -> None: - if self.remove_on_error: - self.log.info("Deleting container group") - try: - self._ci_hook.delete(self.resource_group, self.name) - except Exception: - self.log.exception("Could not delete container group") - - def _monitor_logging(self, resource_group: str, name: str) -> int: - last_state = None - last_message_logged = None - last_line_logged = None - - while True: - try: - cg_state = self._ci_hook.get_state(resource_group, name) - instance_view = cg_state.containers[0].instance_view - - # If there is no instance view, we show the provisioning state - if instance_view is not None: - c_state = instance_view.current_state - state, exit_code, detail_status = ( - c_state.state, - c_state.exit_code, - c_state.detail_status, - ) - else: - state = cg_state.provisioning_state - exit_code = 0 - detail_status = "Provisioning" - - if instance_view is not None and instance_view.events is not None: - messages = [event.message for event in instance_view.events] - last_message_logged = self._log_last(messages, last_message_logged) - - if state != last_state: - self.log.info("Container group state changed to %s", state) - last_state = state - - if state in ["Running", "Terminated", "Succeeded"]: - try: - logs = self._ci_hook.get_logs(resource_group, name) - last_line_logged = self._log_last(logs, last_line_logged) - except CloudError: - self.log.exception( - "Exception while getting logs from container instance, retrying..." - ) - - if state == "Terminated": - self.log.info("Container exited with detail_status %s", detail_status) - return exit_code - - if state == "Failed": - self.log.error("Azure provision failure") - return 1 - - except AirflowTaskTimeout: - raise - except CloudError as err: - if 'ResourceNotFound' in str(err): - self.log.warning( - "ResourceNotFound, container is probably removed " - "by another process " - "(make sure that the name is unique)." - ) - return 1 - else: - self.log.exception("Exception while getting container groups") - except Exception: - self.log.exception("Exception while getting container groups") - - sleep(1) - - def _log_last(self, logs: Optional[list], last_line_logged: Any) -> Optional[Any]: - if logs: - # determine the last line which was logged before - last_line_index = 0 - for i in range(len(logs) - 1, -1, -1): - if logs[i] == last_line_logged: - # this line is the same, hence print from i+1 - last_line_index = i + 1 - break - - # log all new ones - for line in logs[last_line_index:]: - self.log.info(line.rstrip()) - - return logs[-1] - return None - - @staticmethod - def _check_name(name: str) -> str: - if '{{' in name: - # Let macros pass as they cannot be checked at construction time - return name - regex_check = re.match("[a-z0-9]([-a-z0-9]*[a-z0-9])?", name) - if regex_check is None or regex_check.group() != name: - raise AirflowException('ACI name must match regex [a-z0-9]([-a-z0-9]*[a-z0-9])? (like "my-name")') - if len(name) > 63: - raise AirflowException('ACI name cannot be longer than 63 characters') - return name diff --git a/airflow/providers/microsoft/azure/operators/azure_cosmos.py b/airflow/providers/microsoft/azure/operators/azure_cosmos.py index b096bcc8c8af4..8ef095350ea7b 100644 --- a/airflow/providers/microsoft/azure/operators/azure_cosmos.py +++ b/airflow/providers/microsoft/azure/operators/azure_cosmos.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,56 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.operators.cosmos`.""" -from airflow.models import BaseOperator -from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook - - -class AzureCosmosInsertDocumentOperator(BaseOperator): - """ - Inserts a new document into the specified Cosmos database and collection - It will create both the database and collection if they do not already exist - - :param database_name: The name of the database. (templated) - :type database_name: str - :param collection_name: The name of the collection. (templated) - :type collection_name: str - :param document: The document to insert - :type document: dict - :param azure_cosmos_conn_id: Reference to the - :ref:`Azure CosmosDB connection`. - :type azure_cosmos_conn_id: str - """ - - template_fields = ('database_name', 'collection_name') - ui_color = '#e4f0e8' - - def __init__( - self, - *, - database_name: str, - collection_name: str, - document: dict, - azure_cosmos_conn_id: str = 'azure_cosmos_default', - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.database_name = database_name - self.collection_name = collection_name - self.document = document - self.azure_cosmos_conn_id = azure_cosmos_conn_id - - def execute(self, context: dict) -> None: - # Create the hook - hook = AzureCosmosDBHook(azure_cosmos_conn_id=self.azure_cosmos_conn_id) - - # Create the DB if it doesn't already exist - if not hook.does_database_exist(self.database_name): - hook.create_database(self.database_name) +import warnings - # Create the collection as well - if not hook.does_collection_exist(self.collection_name, self.database_name): - hook.create_collection(self.collection_name, self.database_name) +from airflow.providers.microsoft.azure.operators.cosmos import AzureCosmosInsertDocumentOperator # noqa - # finally insert the document - hook.upsert_document(self.document, self.database_name, self.collection_name) +warnings.warn( + "This module is deprecated. Please use `airflow.providers.microsoft.azure.operators.cosmos`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/providers/microsoft/azure/operators/batch.py b/airflow/providers/microsoft/azure/operators/batch.py new file mode 100644 index 0000000000000..81bbe62b66d3d --- /dev/null +++ b/airflow/providers/microsoft/azure/operators/batch.py @@ -0,0 +1,358 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from typing import Any, List, Optional + +from azure.batch import models as batch_models + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook + + +class AzureBatchOperator(BaseOperator): + """ + Executes a job on Azure Batch Service + + :param batch_pool_id: A string that uniquely identifies the Pool within the Account. + :type batch_pool_id: str + :param batch_pool_vm_size: The size of virtual machines in the Pool + :type batch_pool_vm_size: str + :param batch_job_id: A string that uniquely identifies the Job within the Account. + :type batch_job_id: str + :param batch_task_command_line: The command line of the Task + :type batch_task_command_line: str + :param batch_task_id: A string that uniquely identifies the task within the Job. + :type batch_task_id: str + :param batch_pool_display_name: The display name for the Pool. + The display name need not be unique + :type batch_pool_display_name: Optional[str] + :param batch_job_display_name: The display name for the Job. + The display name need not be unique + :type batch_job_display_name: Optional[str] + :param batch_job_manager_task: Details of a Job Manager Task to be launched when the Job is started. + :type batch_job_manager_task: Optional[batch_models.JobManagerTask] + :param batch_job_preparation_task: The Job Preparation Task. If set, the Batch service will + run the Job Preparation Task on a Node before starting any Tasks of that + Job on that Compute Node. Required if batch_job_release_task is set. + :type batch_job_preparation_task: Optional[batch_models.JobPreparationTask] + :param batch_job_release_task: The Job Release Task. Use to undo changes to Compute Nodes + made by the Job Preparation Task + :type batch_job_release_task: Optional[batch_models.JobReleaseTask] + :param batch_task_display_name: The display name for the task. + The display name need not be unique + :type batch_task_display_name: Optional[str] + :param batch_task_container_settings: The settings for the container under which the Task runs + :type batch_task_container_settings: Optional[batch_models.TaskContainerSettings] + :param batch_start_task: A Task specified to run on each Compute Node as it joins the Pool. + The Task runs when the Compute Node is added to the Pool or + when the Compute Node is restarted. + :type batch_start_task: Optional[batch_models.StartTask] + :param batch_max_retries: The number of times to retry this batch operation before it's + considered a failed operation. Default is 3 + :type batch_max_retries: int + :param batch_task_resource_files: A list of files that the Batch service will + download to the Compute Node before running the command line. + :type batch_task_resource_files: Optional[List[batch_models.ResourceFile]] + :param batch_task_output_files: A list of files that the Batch service will upload + from the Compute Node after running the command line. + :type batch_task_output_files: Optional[List[batch_models.OutputFile]] + :param batch_task_user_identity: The user identity under which the Task runs. + If omitted, the Task runs as a non-administrative user unique to the Task. + :type batch_task_user_identity: Optional[batch_models.UserIdentity] + :param target_low_priority_nodes: The desired number of low-priority Compute Nodes in the Pool. + This property must not be specified if enable_auto_scale is set to true. + :type target_low_priority_nodes: Optional[int] + :param target_dedicated_nodes: The desired number of dedicated Compute Nodes in the Pool. + This property must not be specified if enable_auto_scale is set to true. + :type target_dedicated_nodes: Optional[int] + :param enable_auto_scale: Whether the Pool size should automatically adjust over time. Default is false + :type enable_auto_scale: bool + :param auto_scale_formula: A formula for the desired number of Compute Nodes in the Pool. + This property must not be specified if enableAutoScale is set to false. + It is required if enableAutoScale is set to true. + :type auto_scale_formula: Optional[str] + :param azure_batch_conn_id: The :ref:`Azure Batch connection id` + :type azure_batch_conn_id: str + :param use_latest_verified_vm_image_and_sku: Whether to use the latest verified virtual + machine image and sku in the batch account. Default is false. + :type use_latest_verified_vm_image_and_sku: bool + :param vm_publisher: The publisher of the Azure Virtual Machines Marketplace Image. + For example, Canonical or MicrosoftWindowsServer. Required if + use_latest_image_and_sku is set to True + :type vm_publisher: Optional[str] + :param vm_offer: The offer type of the Azure Virtual Machines Marketplace Image. + For example, UbuntuServer or WindowsServer. Required if + use_latest_image_and_sku is set to True + :type vm_offer: Optional[str] + :param sku_starts_with: The starting string of the Virtual Machine SKU. Required if + use_latest_image_and_sku is set to True + :type sku_starts_with: Optional[str] + :param vm_sku: The name of the virtual machine sku to use + :type vm_sku: Optional[str] + :param vm_version: The version of the virtual machine + :param vm_version: Optional[str] + :param vm_node_agent_sku_id: The node agent sku id of the virtual machine + :type vm_node_agent_sku_id: Optional[str] + :param os_family: The Azure Guest OS family to be installed on the virtual machines in the Pool. + :type os_family: Optional[str] + :param os_version: The OS family version + :type os_version: Optional[str] + :param timeout: The amount of time to wait for the job to complete in minutes. Default is 25 + :type timeout: int + :param should_delete_job: Whether to delete job after execution. Default is False + :type should_delete_job: bool + :param should_delete_pool: Whether to delete pool after execution of jobs. Default is False + :type should_delete_pool: bool + """ + + template_fields = ( + 'batch_pool_id', + 'batch_pool_vm_size', + 'batch_job_id', + 'batch_task_id', + 'batch_task_command_line', + ) + ui_color = '#f0f0e4' + + def __init__( + self, + *, + batch_pool_id: str, + batch_pool_vm_size: str, + batch_job_id: str, + batch_task_command_line: str, + batch_task_id: str, + vm_publisher: Optional[str] = None, + vm_offer: Optional[str] = None, + sku_starts_with: Optional[str] = None, + vm_sku: Optional[str] = None, + vm_version: Optional[str] = None, + vm_node_agent_sku_id: Optional[str] = None, + os_family: Optional[str] = None, + os_version: Optional[str] = None, + batch_pool_display_name: Optional[str] = None, + batch_job_display_name: Optional[str] = None, + batch_job_manager_task: Optional[batch_models.JobManagerTask] = None, + batch_job_preparation_task: Optional[batch_models.JobPreparationTask] = None, + batch_job_release_task: Optional[batch_models.JobReleaseTask] = None, + batch_task_display_name: Optional[str] = None, + batch_task_container_settings: Optional[batch_models.TaskContainerSettings] = None, + batch_start_task: Optional[batch_models.StartTask] = None, + batch_max_retries: int = 3, + batch_task_resource_files: Optional[List[batch_models.ResourceFile]] = None, + batch_task_output_files: Optional[List[batch_models.OutputFile]] = None, + batch_task_user_identity: Optional[batch_models.UserIdentity] = None, + target_low_priority_nodes: Optional[int] = None, + target_dedicated_nodes: Optional[int] = None, + enable_auto_scale: bool = False, + auto_scale_formula: Optional[str] = None, + azure_batch_conn_id='azure_batch_default', + use_latest_verified_vm_image_and_sku: bool = False, + timeout: int = 25, + should_delete_job: bool = False, + should_delete_pool: bool = False, + **kwargs, + ) -> None: + + super().__init__(**kwargs) + self.batch_pool_id = batch_pool_id + self.batch_pool_vm_size = batch_pool_vm_size + self.batch_job_id = batch_job_id + self.batch_task_id = batch_task_id + self.batch_task_command_line = batch_task_command_line + self.batch_pool_display_name = batch_pool_display_name + self.batch_job_display_name = batch_job_display_name + self.batch_job_manager_task = batch_job_manager_task + self.batch_job_preparation_task = batch_job_preparation_task + self.batch_job_release_task = batch_job_release_task + self.batch_task_display_name = batch_task_display_name + self.batch_task_container_settings = batch_task_container_settings + self.batch_start_task = batch_start_task + self.batch_max_retries = batch_max_retries + self.batch_task_resource_files = batch_task_resource_files + self.batch_task_output_files = batch_task_output_files + self.batch_task_user_identity = batch_task_user_identity + self.target_low_priority_nodes = target_low_priority_nodes + self.target_dedicated_nodes = target_dedicated_nodes + self.enable_auto_scale = enable_auto_scale + self.auto_scale_formula = auto_scale_formula + self.azure_batch_conn_id = azure_batch_conn_id + self.use_latest_image = use_latest_verified_vm_image_and_sku + self.vm_publisher = vm_publisher + self.vm_offer = vm_offer + self.sku_starts_with = sku_starts_with + self.vm_sku = vm_sku + self.vm_version = vm_version + self.vm_node_agent_sku_id = vm_node_agent_sku_id + self.os_family = os_family + self.os_version = os_version + self.timeout = timeout + self.should_delete_job = should_delete_job + self.should_delete_pool = should_delete_pool + self.hook = self.get_hook() + + def _check_inputs(self) -> Any: + if not self.os_family and not self.vm_publisher: + raise AirflowException("You must specify either vm_publisher or os_family") + if self.os_family and self.vm_publisher: + raise AirflowException( + "Cloud service configuration and virtual machine configuration " + "are mutually exclusive. You must specify either of os_family and" + " vm_publisher" + ) + + if self.use_latest_image: + if not all(elem for elem in [self.vm_publisher, self.vm_offer]): + raise AirflowException( + f"If use_latest_image_and_sku is set to True then the parameters vm_publisher, " + f"vm_offer, must all be set. " + f"Found vm_publisher={self.vm_publisher}, vm_offer={self.vm_offer}" + ) + if self.vm_publisher: + if not all([self.vm_sku, self.vm_offer, self.vm_node_agent_sku_id]): + raise AirflowException( + "If vm_publisher is set, then the parameters vm_sku, vm_offer," + "vm_node_agent_sku_id must be set. Found " + f"vm_publisher={self.vm_publisher}, vm_offer={self.vm_offer} " + f"vm_node_agent_sku_id={self.vm_node_agent_sku_id}, " + f"vm_version={self.vm_version}" + ) + + if not self.target_dedicated_nodes and not self.enable_auto_scale: + raise AirflowException( + "Either target_dedicated_nodes or enable_auto_scale must be set. None was set" + ) + if self.enable_auto_scale: + if self.target_dedicated_nodes or self.target_low_priority_nodes: + raise AirflowException( + f"If enable_auto_scale is set, then the parameters target_dedicated_nodes and " + f"target_low_priority_nodes must not be set. Found " + f"target_dedicated_nodes={self.target_dedicated_nodes}, " + f"target_low_priority_nodes={self.target_low_priority_nodes}" + ) + if not self.auto_scale_formula: + raise AirflowException("The auto_scale_formula is required when enable_auto_scale is set") + if self.batch_job_release_task and not self.batch_job_preparation_task: + raise AirflowException( + "A batch_job_release_task cannot be specified without also " + " specifying a batch_job_preparation_task for the Job." + ) + if not all( + [ + self.batch_pool_id, + self.batch_job_id, + self.batch_pool_vm_size, + self.batch_task_id, + self.batch_task_command_line, + ] + ): + raise AirflowException( + "Some required parameters are missing.Please you must set all the required parameters. " + ) + + def execute(self, context: dict) -> None: + self._check_inputs() + self.hook.connection.config.retry_policy = self.batch_max_retries + + pool = self.hook.configure_pool( + pool_id=self.batch_pool_id, + vm_size=self.batch_pool_vm_size, + display_name=self.batch_pool_display_name, + target_dedicated_nodes=self.target_dedicated_nodes, + use_latest_image_and_sku=self.use_latest_image, + vm_publisher=self.vm_publisher, + vm_offer=self.vm_offer, + sku_starts_with=self.sku_starts_with, + vm_sku=self.vm_sku, + vm_version=self.vm_version, + vm_node_agent_sku_id=self.vm_node_agent_sku_id, + os_family=self.os_family, + os_version=self.os_version, + target_low_priority_nodes=self.target_low_priority_nodes, + enable_auto_scale=self.enable_auto_scale, + auto_scale_formula=self.auto_scale_formula, + start_task=self.batch_start_task, + ) + self.hook.create_pool(pool) + # Wait for nodes to reach complete state + self.hook.wait_for_all_node_state( + self.batch_pool_id, + { + batch_models.ComputeNodeState.start_task_failed, + batch_models.ComputeNodeState.unusable, + batch_models.ComputeNodeState.idle, + }, + ) + # Create job if not already exist + job = self.hook.configure_job( + job_id=self.batch_job_id, + pool_id=self.batch_pool_id, + display_name=self.batch_job_display_name, + job_manager_task=self.batch_job_manager_task, + job_preparation_task=self.batch_job_preparation_task, + job_release_task=self.batch_job_release_task, + ) + self.hook.create_job(job) + # Create task + task = self.hook.configure_task( + task_id=self.batch_task_id, + command_line=self.batch_task_command_line, + display_name=self.batch_task_display_name, + container_settings=self.batch_task_container_settings, + resource_files=self.batch_task_resource_files, + output_files=self.batch_task_output_files, + user_identity=self.batch_task_user_identity, + ) + # Add task to job + self.hook.add_single_task_to_job(job_id=self.batch_job_id, task=task) + # Wait for tasks to complete + self.hook.wait_for_job_tasks_to_complete(job_id=self.batch_job_id, timeout=self.timeout) + # Clean up + if self.should_delete_job: + # delete job first + self.clean_up(job_id=self.batch_job_id) + if self.should_delete_pool: + self.clean_up(self.batch_pool_id) + + def on_kill(self) -> None: + response = self.hook.connection.job.terminate( + job_id=self.batch_job_id, terminate_reason='Job killed by user' + ) + self.log.info("Azure Batch job (%s) terminated: %s", self.batch_job_id, response) + + def get_hook(self) -> AzureBatchHook: + """Create and return an AzureBatchHook.""" + return AzureBatchHook(azure_batch_conn_id=self.azure_batch_conn_id) + + def clean_up(self, pool_id: Optional[str] = None, job_id: Optional[str] = None) -> None: + """ + Delete the given pool and job in the batch account + + :param pool_id: The id of the pool to delete + :type pool_id: str + :param job_id: The id of the job to delete + :type job_id: str + + """ + if job_id: + self.log.info("Deleting job: %s", job_id) + self.hook.connection.job.delete(job_id) + if pool_id: + self.log.info("Deleting pool: %s", pool_id) + self.hook.connection.pool.delete(pool_id) diff --git a/airflow/providers/microsoft/azure/operators/container_instances.py b/airflow/providers/microsoft/azure/operators/container_instances.py new file mode 100644 index 0000000000000..da297c0a65462 --- /dev/null +++ b/airflow/providers/microsoft/azure/operators/container_instances.py @@ -0,0 +1,390 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re +from collections import namedtuple +from time import sleep +from typing import Any, Dict, List, Optional, Sequence, Union + +from azure.mgmt.containerinstance.models import ( + Container, + ContainerGroup, + ContainerPort, + EnvironmentVariable, + IpAddress, + ResourceRequests, + ResourceRequirements, + VolumeMount, +) +from msrestazure.azure_exceptions import CloudError + +from airflow.exceptions import AirflowException, AirflowTaskTimeout +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.container_instance import AzureContainerInstanceHook +from airflow.providers.microsoft.azure.hooks.container_registry import AzureContainerRegistryHook +from airflow.providers.microsoft.azure.hooks.container_volume import AzureContainerVolumeHook + +Volume = namedtuple( + 'Volume', + ['conn_id', 'account_name', 'share_name', 'mount_path', 'read_only'], +) + + +DEFAULT_ENVIRONMENT_VARIABLES: Dict[str, str] = {} +DEFAULT_SECURED_VARIABLES: Sequence[str] = [] +DEFAULT_VOLUMES: Sequence[Volume] = [] +DEFAULT_MEMORY_IN_GB = 2.0 +DEFAULT_CPU = 1.0 + + +class AzureContainerInstancesOperator(BaseOperator): + """ + Start a container on Azure Container Instances + + :param ci_conn_id: connection id of a service principal which will be used + to start the container instance + :type ci_conn_id: str + :param registry_conn_id: connection id of a user which can login to a + private docker registry. For Azure use :ref:`Azure connection id` + :type azure_conn_id: str If None, we assume a public registry + :type registry_conn_id: Optional[str] + :param resource_group: name of the resource group wherein this container + instance should be started + :type resource_group: str + :param name: name of this container instance. Please note this name has + to be unique in order to run containers in parallel. + :type name: str + :param image: the docker image to be used + :type image: str + :param region: the region wherein this container instance should be started + :type region: str + :param environment_variables: key,value pairs containing environment + variables which will be passed to the running container + :type environment_variables: Optional[dict] + :param secured_variables: names of environmental variables that should not + be exposed outside the container (typically passwords). + :type secured_variables: Optional[str] + :param volumes: list of ``Volume`` tuples to be mounted to the container. + Currently only Azure Fileshares are supported. + :type volumes: list[] + :param memory_in_gb: the amount of memory to allocate to this container + :type memory_in_gb: double + :param cpu: the number of cpus to allocate to this container + :type cpu: double + :param gpu: GPU Resource for the container. + :type gpu: azure.mgmt.containerinstance.models.GpuResource + :param command: the command to run inside the container + :type command: Optional[List[str]] + :param container_timeout: max time allowed for the execution of + the container instance. + :type container_timeout: datetime.timedelta + :param tags: azure tags as dict of str:str + :type tags: Optional[dict[str, str]] + :param os_type: The operating system type required by the containers + in the container group. Possible values include: 'Windows', 'Linux' + :type os_type: str + :param restart_policy: Restart policy for all containers within the container group. + Possible values include: 'Always', 'OnFailure', 'Never' + :type restart_policy: str + :param ip_address: The IP address type of the container group. + :type ip_address: IpAddress + + **Example**:: + + AzureContainerInstancesOperator( + ci_conn_id = "azure_service_principal", + registry_conn_id = "azure_registry_user", + resource_group = "my-resource-group", + name = "my-container-name-{{ ds }}", + image = "myprivateregistry.azurecr.io/my_container:latest", + region = "westeurope", + environment_variables = {"MODEL_PATH": "my_value", + "POSTGRES_LOGIN": "{{ macros.connection('postgres_default').login }}", + "POSTGRES_PASSWORD": "{{ macros.connection('postgres_default').password }}", + "JOB_GUID": "{{ ti.xcom_pull(task_ids='task1', key='guid') }}" }, + secured_variables = ['POSTGRES_PASSWORD'], + volumes = [("azure_container_instance_conn_id", + "my_storage_container", + "my_fileshare", + "/input-data", + True),], + memory_in_gb=14.0, + cpu=4.0, + gpu=GpuResource(count=1, sku='K80'), + command=["/bin/echo", "world"], + task_id="start_container" + ) + """ + + template_fields = ('name', 'image', 'command', 'environment_variables') + template_fields_renderers = {"command": "bash", "environment_variables": "json"} + + def __init__( + self, + *, + ci_conn_id: str, + registry_conn_id: Optional[str], + resource_group: str, + name: str, + image: str, + region: str, + environment_variables: Optional[dict] = None, + secured_variables: Optional[str] = None, + volumes: Optional[list] = None, + memory_in_gb: Optional[Any] = None, + cpu: Optional[Any] = None, + gpu: Optional[Any] = None, + command: Optional[List[str]] = None, + remove_on_error: bool = True, + fail_if_exists: bool = True, + tags: Optional[Dict[str, str]] = None, + os_type: str = 'Linux', + restart_policy: str = 'Never', + ip_address: Optional[IpAddress] = None, + ports: Optional[List[ContainerPort]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.ci_conn_id = ci_conn_id + self.resource_group = resource_group + self.name = self._check_name(name) + self.image = image + self.region = region + self.registry_conn_id = registry_conn_id + self.environment_variables = environment_variables or DEFAULT_ENVIRONMENT_VARIABLES + self.secured_variables = secured_variables or DEFAULT_SECURED_VARIABLES + self.volumes = volumes or DEFAULT_VOLUMES + self.memory_in_gb = memory_in_gb or DEFAULT_MEMORY_IN_GB + self.cpu = cpu or DEFAULT_CPU + self.gpu = gpu + self.command = command + self.remove_on_error = remove_on_error + self.fail_if_exists = fail_if_exists + self._ci_hook: Any = None + self.tags = tags + self.os_type = os_type + if self.os_type not in ['Linux', 'Windows']: + raise AirflowException( + "Invalid value for the os_type argument. " + "Please set 'Linux' or 'Windows' as the os_type. " + f"Found `{self.os_type}`." + ) + self.restart_policy = restart_policy + if self.restart_policy not in ['Always', 'OnFailure', 'Never']: + raise AirflowException( + "Invalid value for the restart_policy argument. " + "Please set one of 'Always', 'OnFailure','Never' as the restart_policy. " + f"Found `{self.restart_policy}`" + ) + self.ip_address = ip_address + self.ports = ports + + def execute(self, context: dict) -> int: + # Check name again in case it was templated. + self._check_name(self.name) + + self._ci_hook = AzureContainerInstanceHook(self.ci_conn_id) + + if self.fail_if_exists: + self.log.info("Testing if container group already exists") + if self._ci_hook.exists(self.resource_group, self.name): + raise AirflowException("Container group exists") + + if self.registry_conn_id: + registry_hook = AzureContainerRegistryHook(self.registry_conn_id) + image_registry_credentials: Optional[list] = [ + registry_hook.connection, + ] + else: + image_registry_credentials = None + + environment_variables = [] + for key, value in self.environment_variables.items(): + if key in self.secured_variables: + e = EnvironmentVariable(name=key, secure_value=value) + else: + e = EnvironmentVariable(name=key, value=value) + environment_variables.append(e) + + volumes: List[Union[Volume, Volume]] = [] + volume_mounts: List[Union[VolumeMount, VolumeMount]] = [] + for conn_id, account_name, share_name, mount_path, read_only in self.volumes: + hook = AzureContainerVolumeHook(conn_id) + + mount_name = "mount-%d" % len(volumes) + volumes.append(hook.get_file_volume(mount_name, share_name, account_name, read_only)) + volume_mounts.append(VolumeMount(name=mount_name, mount_path=mount_path, read_only=read_only)) + + exit_code = 1 + try: + self.log.info("Starting container group with %.1f cpu %.1f mem", self.cpu, self.memory_in_gb) + if self.gpu: + self.log.info("GPU count: %.1f, GPU SKU: %s", self.gpu.count, self.gpu.sku) + + resources = ResourceRequirements( + requests=ResourceRequests(memory_in_gb=self.memory_in_gb, cpu=self.cpu, gpu=self.gpu) + ) + + if self.ip_address and not self.ports: + self.ports = [ContainerPort(port=80)] + self.log.info("Default port set. Container will listen on port 80") + + container = Container( + name=self.name, + image=self.image, + resources=resources, + command=self.command, + environment_variables=environment_variables, + volume_mounts=volume_mounts, + ports=self.ports, + ) + + container_group = ContainerGroup( + location=self.region, + containers=[ + container, + ], + image_registry_credentials=image_registry_credentials, + volumes=volumes, + restart_policy=self.restart_policy, + os_type=self.os_type, + tags=self.tags, + ip_address=self.ip_address, + ) + + self._ci_hook.create_or_update(self.resource_group, self.name, container_group) + + self.log.info("Container group started %s/%s", self.resource_group, self.name) + + exit_code = self._monitor_logging(self.resource_group, self.name) + + self.log.info("Container had exit code: %s", exit_code) + if exit_code != 0: + raise AirflowException(f"Container had a non-zero exit code, {exit_code}") + return exit_code + + except CloudError: + self.log.exception("Could not start container group") + raise AirflowException("Could not start container group") + + finally: + if exit_code == 0 or self.remove_on_error: + self.on_kill() + + def on_kill(self) -> None: + if self.remove_on_error: + self.log.info("Deleting container group") + try: + self._ci_hook.delete(self.resource_group, self.name) + except Exception: + self.log.exception("Could not delete container group") + + def _monitor_logging(self, resource_group: str, name: str) -> int: + last_state = None + last_message_logged = None + last_line_logged = None + + while True: + try: + cg_state = self._ci_hook.get_state(resource_group, name) + instance_view = cg_state.containers[0].instance_view + + # If there is no instance view, we show the provisioning state + if instance_view is not None: + c_state = instance_view.current_state + state, exit_code, detail_status = ( + c_state.state, + c_state.exit_code, + c_state.detail_status, + ) + else: + state = cg_state.provisioning_state + exit_code = 0 + detail_status = "Provisioning" + + if instance_view is not None and instance_view.events is not None: + messages = [event.message for event in instance_view.events] + last_message_logged = self._log_last(messages, last_message_logged) + + if state != last_state: + self.log.info("Container group state changed to %s", state) + last_state = state + + if state in ["Running", "Terminated", "Succeeded"]: + try: + logs = self._ci_hook.get_logs(resource_group, name) + last_line_logged = self._log_last(logs, last_line_logged) + except CloudError: + self.log.exception( + "Exception while getting logs from container instance, retrying..." + ) + + if state == "Terminated": + self.log.info("Container exited with detail_status %s", detail_status) + return exit_code + + if state == "Failed": + self.log.error("Azure provision failure") + return 1 + + except AirflowTaskTimeout: + raise + except CloudError as err: + if 'ResourceNotFound' in str(err): + self.log.warning( + "ResourceNotFound, container is probably removed " + "by another process " + "(make sure that the name is unique)." + ) + return 1 + else: + self.log.exception("Exception while getting container groups") + except Exception: + self.log.exception("Exception while getting container groups") + + sleep(1) + + def _log_last(self, logs: Optional[list], last_line_logged: Any) -> Optional[Any]: + if logs: + # determine the last line which was logged before + last_line_index = 0 + for i in range(len(logs) - 1, -1, -1): + if logs[i] == last_line_logged: + # this line is the same, hence print from i+1 + last_line_index = i + 1 + break + + # log all new ones + for line in logs[last_line_index:]: + self.log.info(line.rstrip()) + + return logs[-1] + return None + + @staticmethod + def _check_name(name: str) -> str: + if '{{' in name: + # Let macros pass as they cannot be checked at construction time + return name + regex_check = re.match("[a-z0-9]([-a-z0-9]*[a-z0-9])?", name) + if regex_check is None or regex_check.group() != name: + raise AirflowException('ACI name must match regex [a-z0-9]([-a-z0-9]*[a-z0-9])? (like "my-name")') + if len(name) > 63: + raise AirflowException('ACI name cannot be longer than 63 characters') + return name diff --git a/airflow/providers/microsoft/azure/operators/cosmos.py b/airflow/providers/microsoft/azure/operators/cosmos.py new file mode 100644 index 0000000000000..f4d50ef4b7558 --- /dev/null +++ b/airflow/providers/microsoft/azure/operators/cosmos.py @@ -0,0 +1,70 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.models import BaseOperator +from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook + + +class AzureCosmosInsertDocumentOperator(BaseOperator): + """ + Inserts a new document into the specified Cosmos database and collection + It will create both the database and collection if they do not already exist + + :param database_name: The name of the database. (templated) + :type database_name: str + :param collection_name: The name of the collection. (templated) + :type collection_name: str + :param document: The document to insert + :type document: dict + :param azure_cosmos_conn_id: Reference to the + :ref:`Azure CosmosDB connection`. + :type azure_cosmos_conn_id: str + """ + + template_fields = ('database_name', 'collection_name') + ui_color = '#e4f0e8' + + def __init__( + self, + *, + database_name: str, + collection_name: str, + document: dict, + azure_cosmos_conn_id: str = 'azure_cosmos_default', + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.database_name = database_name + self.collection_name = collection_name + self.document = document + self.azure_cosmos_conn_id = azure_cosmos_conn_id + + def execute(self, context: dict) -> None: + # Create the hook + hook = AzureCosmosDBHook(azure_cosmos_conn_id=self.azure_cosmos_conn_id) + + # Create the DB if it doesn't already exist + if not hook.does_database_exist(self.database_name): + hook.create_database(self.database_name) + + # Create the collection as well + if not hook.does_collection_exist(self.collection_name, self.database_name): + hook.create_collection(self.collection_name, self.database_name) + + # finally insert the document + hook.upsert_document(self.document, self.database_name, self.collection_name) diff --git a/airflow/providers/microsoft/azure/provider.yaml b/airflow/providers/microsoft/azure/provider.yaml index d014d44c1d725..84e73c1be31e1 100644 --- a/airflow/providers/microsoft/azure/provider.yaml +++ b/airflow/providers/microsoft/azure/provider.yaml @@ -92,12 +92,15 @@ operators: - airflow.providers.microsoft.azure.operators.adx - integration-name: Microsoft Azure Batch python-modules: + - airflow.providers.microsoft.azure.operators.batch - airflow.providers.microsoft.azure.operators.azure_batch - integration-name: Microsoft Azure Container Instances python-modules: + - airflow.providers.microsoft.azure.operators.container_instances - airflow.providers.microsoft.azure.operators.azure_container_instances - integration-name: Microsoft Azure Cosmos DB python-modules: + - airflow.providers.microsoft.azure.operators.cosmos - airflow.providers.microsoft.azure.operators.azure_cosmos - integration-name: Microsoft Azure Blob Storage python-modules: @@ -109,6 +112,7 @@ operators: sensors: - integration-name: Microsoft Azure Cosmos DB python-modules: + - airflow.providers.microsoft.azure.sensors.cosmos - airflow.providers.microsoft.azure.sensors.azure_cosmos - integration-name: Microsoft Azure Blob Storage python-modules: @@ -120,6 +124,9 @@ sensors: hooks: - integration-name: Microsoft Azure Container Instances python-modules: + - airflow.providers.microsoft.azure.hooks.container_volume + - airflow.providers.microsoft.azure.hooks.container_registry + - airflow.providers.microsoft.azure.hooks.container_instance - airflow.providers.microsoft.azure.hooks.azure_container_volume - airflow.providers.microsoft.azure.hooks.azure_container_registry - airflow.providers.microsoft.azure.hooks.azure_container_instance @@ -128,18 +135,22 @@ hooks: - airflow.providers.microsoft.azure.hooks.adx - integration-name: Microsoft Azure FileShare python-modules: + - airflow.providers.microsoft.azure.hooks.fileshare - airflow.providers.microsoft.azure.hooks.azure_fileshare - integration-name: Microsoft Azure python-modules: - airflow.providers.microsoft.azure.hooks.base_azure - integration-name: Microsoft Azure Batch python-modules: + - airflow.providers.microsoft.azure.hooks.batch - airflow.providers.microsoft.azure.hooks.azure_batch - integration-name: Microsoft Azure Data Lake Storage python-modules: + - airflow.providers.microsoft.azure.hooks.data_lake - airflow.providers.microsoft.azure.hooks.azure_data_lake - integration-name: Microsoft Azure Cosmos DB python-modules: + - airflow.providers.microsoft.azure.hooks.cosmos - airflow.providers.microsoft.azure.hooks.azure_cosmos - integration-name: Microsoft Azure Blob Storage python-modules: @@ -171,43 +182,44 @@ transfers: hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+ - airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook - airflow.providers.microsoft.azure.hooks.adx.AzureDataExplorerHook - - airflow.providers.microsoft.azure.hooks.azure_batch.AzureBatchHook - - airflow.providers.microsoft.azure.hooks.azure_cosmos.AzureCosmosDBHook - - airflow.providers.microsoft.azure.hooks.azure_data_lake.AzureDataLakeHook - - airflow.providers.microsoft.azure.hooks.azure_fileshare.AzureFileShareHook - - airflow.providers.microsoft.azure.hooks.azure_container_volume.AzureContainerVolumeHook - - airflow.providers.microsoft.azure.hooks.azure_container_instance.AzureContainerInstanceHook + - airflow.providers.microsoft.azure.hooks.batch.AzureBatchHook + - airflow.providers.microsoft.azure.hooks.cosmos.AzureCosmosDBHook + - airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeHook + - airflow.providers.microsoft.azure.hooks.fileshare.AzureFileShareHook + - airflow.providers.microsoft.azure.hooks.container_volume.AzureContainerVolumeHook + - airflow.providers.microsoft.azure.hooks.container_instance.AzureContainerInstanceHook - airflow.providers.microsoft.azure.hooks.wasb.WasbHook - airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook - - airflow.providers.microsoft.azure.hooks.azure_container_registry.AzureContainerRegistryHook + - airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook connection-types: - hook-class-name: airflow.providers.microsoft.azure.hooks.base_azure.AzureBaseHook connection-type: azure - hook-class-name: airflow.providers.microsoft.azure.hooks.adx.AzureDataExplorerHook connection-type: azure_data_explorer - - hook-class-name: airflow.providers.microsoft.azure.hooks.azure_batch.AzureBatchHook + - hook-class-name: airflow.providers.microsoft.azure.hooks.batch.AzureBatchHook connection-type: azure_batch - - hook-class-name: airflow.providers.microsoft.azure.hooks.azure_cosmos.AzureCosmosDBHook + - hook-class-name: airflow.providers.microsoft.azure.hooks.cosmos.AzureCosmosDBHook connection-type: azure_cosmos - - hook-class-name: airflow.providers.microsoft.azure.hooks.azure_data_lake.AzureDataLakeHook + - hook-class-name: airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeHook connection-type: azure_data_lake - - hook-class-name: airflow.providers.microsoft.azure.hooks.azure_fileshare.AzureFileShareHook + - hook-class-name: airflow.providers.microsoft.azure.hooks.fileshare.AzureFileShareHook connection-type: azure_fileshare - - hook-class-name: airflow.providers.microsoft.azure.hooks.azure_container_volume.AzureContainerVolumeHook + - hook-class-name: airflow.providers.microsoft.azure.hooks.container_volume.AzureContainerVolumeHook connection-type: azure_container_volume - hook-class-name: >- - airflow.providers.microsoft.azure.hooks.azure_container_instance.AzureContainerInstanceHook + airflow.providers.microsoft.azure.hooks.container_instance.AzureContainerInstanceHook connection-type: azure_container_instance - hook-class-name: airflow.providers.microsoft.azure.hooks.wasb.WasbHook connection-type: wasb - hook-class-name: airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook connection-type: azure_data_factory - hook-class-name: >- - airflow.providers.microsoft.azure.hooks.azure_container_registry.AzureContainerRegistryHook + airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook connection-type: azure_container_registry secrets-backends: + - airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend - airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend logging: diff --git a/airflow/providers/microsoft/azure/secrets/azure_key_vault.py b/airflow/providers/microsoft/azure/secrets/azure_key_vault.py index e7d3813a572ee..f15ed17fde4c2 100644 --- a/airflow/providers/microsoft/azure/secrets/azure_key_vault.py +++ b/airflow/providers/microsoft/azure/secrets/azure_key_vault.py @@ -14,163 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Optional +"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.secrets.key_vault`.""" -from azure.core.exceptions import ResourceNotFoundError -from azure.identity import DefaultAzureCredential -from azure.keyvault.secrets import SecretClient +import warnings -try: - from functools import cached_property -except ImportError: - from cached_property import cached_property +from airflow.providers.microsoft.azure.secrets.key_vault import AzureKeyVaultBackend # noqa -from airflow.secrets import BaseSecretsBackend -from airflow.utils.log.logging_mixin import LoggingMixin - - -class AzureKeyVaultBackend(BaseSecretsBackend, LoggingMixin): - """ - Retrieves Airflow Connections or Variables from Azure Key Vault secrets. - - The Azure Key Vault can be configured as a secrets backend in the ``airflow.cfg``: - - .. code-block:: ini - - [secrets] - backend = airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend - backend_kwargs = {"connections_prefix": "airflow-connections", "vault_url": ""} - - For example, if the secrets prefix is ``airflow-connections-smtp-default``, this would be accessible - if you provide ``{"connections_prefix": "airflow-connections"}`` and request conn_id ``smtp-default``. - And if variables prefix is ``airflow-variables-hello``, this would be accessible - if you provide ``{"variables_prefix": "airflow-variables"}`` and request variable key ``hello``. - - For client authentication, the ``DefaultAzureCredential`` from the Azure Python SDK is used as - credential provider, which supports service principal, managed identity and user credentials - - For example, to specify a service principal with secret you can set the environment variables - ``AZURE_TENANT_ID``, ``AZURE_CLIENT_ID`` and ``AZURE_CLIENT_SECRET``. - - .. seealso:: - For more details on client authentication refer to the ``DefaultAzureCredential`` Class reference: - https://docs.microsoft.com/en-us/python/api/azure-identity/azure.identity.defaultazurecredential?view=azure-python - - :param connections_prefix: Specifies the prefix of the secret to read to get Connections - If set to None (null), requests for connections will not be sent to Azure Key Vault - :type connections_prefix: str - :param variables_prefix: Specifies the prefix of the secret to read to get Variables - If set to None (null), requests for variables will not be sent to Azure Key Vault - :type variables_prefix: str - :param config_prefix: Specifies the prefix of the secret to read to get Variables. - If set to None (null), requests for configurations will not be sent to Azure Key Vault - :type config_prefix: str - :param vault_url: The URL of an Azure Key Vault to use - :type vault_url: str - :param sep: separator used to concatenate secret_prefix and secret_id. Default: "-" - :type sep: str - """ - - def __init__( - self, - connections_prefix: str = 'airflow-connections', - variables_prefix: str = 'airflow-variables', - config_prefix: str = 'airflow-config', - vault_url: str = '', - sep: str = '-', - **kwargs, - ) -> None: - super().__init__() - self.vault_url = vault_url - if connections_prefix is not None: - self.connections_prefix = connections_prefix.rstrip(sep) - else: - self.connections_prefix = connections_prefix - if variables_prefix is not None: - self.variables_prefix = variables_prefix.rstrip(sep) - else: - self.variables_prefix = variables_prefix - if config_prefix is not None: - self.config_prefix = config_prefix.rstrip(sep) - else: - self.config_prefix = config_prefix - self.sep = sep - self.kwargs = kwargs - - @cached_property - def client(self) -> SecretClient: - """Create a Azure Key Vault client.""" - credential = DefaultAzureCredential() - client = SecretClient(vault_url=self.vault_url, credential=credential, **self.kwargs) - return client - - def get_conn_uri(self, conn_id: str) -> Optional[str]: - """ - Get an Airflow Connection URI from an Azure Key Vault secret - - :param conn_id: The Airflow connection id to retrieve - :type conn_id: str - """ - if self.connections_prefix is None: - return None - - return self._get_secret(self.connections_prefix, conn_id) - - def get_variable(self, key: str) -> Optional[str]: - """ - Get an Airflow Variable from an Azure Key Vault secret. - - :param key: Variable Key - :type key: str - :return: Variable Value - """ - if self.variables_prefix is None: - return None - - return self._get_secret(self.variables_prefix, key) - - def get_config(self, key: str) -> Optional[str]: - """ - Get Airflow Configuration - - :param key: Configuration Option Key - :return: Configuration Option Value - """ - if self.config_prefix is None: - return None - - return self._get_secret(self.config_prefix, key) - - @staticmethod - def build_path(path_prefix: str, secret_id: str, sep: str = '-') -> str: - """ - Given a path_prefix and secret_id, build a valid secret name for the Azure Key Vault Backend. - Also replaces underscore in the path with dashes to support easy switching between - environment variables, so ``connection_default`` becomes ``connection-default``. - - :param path_prefix: The path prefix of the secret to retrieve - :type path_prefix: str - :param secret_id: Name of the secret - :type secret_id: str - :param sep: Separator used to concatenate path_prefix and secret_id - :type sep: str - """ - path = f'{path_prefix}{sep}{secret_id}' - return path.replace('_', sep) - - def _get_secret(self, path_prefix: str, secret_id: str) -> Optional[str]: - """ - Get an Azure Key Vault secret value - - :param path_prefix: Prefix for the Path to get Secret - :type path_prefix: str - :param secret_id: Secret Key - :type secret_id: str - """ - name = self.build_path(path_prefix, secret_id, self.sep) - try: - secret = self.client.get_secret(name=name) - return secret.value - except ResourceNotFoundError as ex: - self.log.debug('Secret %s not found: %s', name, ex) - return None +warnings.warn( + "This module is deprecated. Please use `airflow.providers.microsoft.azure.secrets.key_vault`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/providers/microsoft/azure/secrets/key_vault.py b/airflow/providers/microsoft/azure/secrets/key_vault.py new file mode 100644 index 0000000000000..354c3dcd089d0 --- /dev/null +++ b/airflow/providers/microsoft/azure/secrets/key_vault.py @@ -0,0 +1,176 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Optional + +from azure.core.exceptions import ResourceNotFoundError +from azure.identity import DefaultAzureCredential +from azure.keyvault.secrets import SecretClient + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from airflow.secrets import BaseSecretsBackend +from airflow.utils.log.logging_mixin import LoggingMixin + + +class AzureKeyVaultBackend(BaseSecretsBackend, LoggingMixin): + """ + Retrieves Airflow Connections or Variables from Azure Key Vault secrets. + + The Azure Key Vault can be configured as a secrets backend in the ``airflow.cfg``: + + .. code-block:: ini + + [secrets] + backend = airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend + backend_kwargs = {"connections_prefix": "airflow-connections", "vault_url": ""} + + For example, if the secrets prefix is ``airflow-connections-smtp-default``, this would be accessible + if you provide ``{"connections_prefix": "airflow-connections"}`` and request conn_id ``smtp-default``. + And if variables prefix is ``airflow-variables-hello``, this would be accessible + if you provide ``{"variables_prefix": "airflow-variables"}`` and request variable key ``hello``. + + For client authentication, the ``DefaultAzureCredential`` from the Azure Python SDK is used as + credential provider, which supports service principal, managed identity and user credentials + + For example, to specify a service principal with secret you can set the environment variables + ``AZURE_TENANT_ID``, ``AZURE_CLIENT_ID`` and ``AZURE_CLIENT_SECRET``. + + .. seealso:: + For more details on client authentication refer to the ``DefaultAzureCredential`` Class reference: + https://docs.microsoft.com/en-us/python/api/azure-identity/azure.identity.defaultazurecredential?view=azure-python + + :param connections_prefix: Specifies the prefix of the secret to read to get Connections + If set to None (null), requests for connections will not be sent to Azure Key Vault + :type connections_prefix: str + :param variables_prefix: Specifies the prefix of the secret to read to get Variables + If set to None (null), requests for variables will not be sent to Azure Key Vault + :type variables_prefix: str + :param config_prefix: Specifies the prefix of the secret to read to get Variables. + If set to None (null), requests for configurations will not be sent to Azure Key Vault + :type config_prefix: str + :param vault_url: The URL of an Azure Key Vault to use + :type vault_url: str + :param sep: separator used to concatenate secret_prefix and secret_id. Default: "-" + :type sep: str + """ + + def __init__( + self, + connections_prefix: str = 'airflow-connections', + variables_prefix: str = 'airflow-variables', + config_prefix: str = 'airflow-config', + vault_url: str = '', + sep: str = '-', + **kwargs, + ) -> None: + super().__init__() + self.vault_url = vault_url + if connections_prefix is not None: + self.connections_prefix = connections_prefix.rstrip(sep) + else: + self.connections_prefix = connections_prefix + if variables_prefix is not None: + self.variables_prefix = variables_prefix.rstrip(sep) + else: + self.variables_prefix = variables_prefix + if config_prefix is not None: + self.config_prefix = config_prefix.rstrip(sep) + else: + self.config_prefix = config_prefix + self.sep = sep + self.kwargs = kwargs + + @cached_property + def client(self) -> SecretClient: + """Create a Azure Key Vault client.""" + credential = DefaultAzureCredential() + client = SecretClient(vault_url=self.vault_url, credential=credential, **self.kwargs) + return client + + def get_conn_uri(self, conn_id: str) -> Optional[str]: + """ + Get an Airflow Connection URI from an Azure Key Vault secret + + :param conn_id: The Airflow connection id to retrieve + :type conn_id: str + """ + if self.connections_prefix is None: + return None + + return self._get_secret(self.connections_prefix, conn_id) + + def get_variable(self, key: str) -> Optional[str]: + """ + Get an Airflow Variable from an Azure Key Vault secret. + + :param key: Variable Key + :type key: str + :return: Variable Value + """ + if self.variables_prefix is None: + return None + + return self._get_secret(self.variables_prefix, key) + + def get_config(self, key: str) -> Optional[str]: + """ + Get Airflow Configuration + + :param key: Configuration Option Key + :return: Configuration Option Value + """ + if self.config_prefix is None: + return None + + return self._get_secret(self.config_prefix, key) + + @staticmethod + def build_path(path_prefix: str, secret_id: str, sep: str = '-') -> str: + """ + Given a path_prefix and secret_id, build a valid secret name for the Azure Key Vault Backend. + Also replaces underscore in the path with dashes to support easy switching between + environment variables, so ``connection_default`` becomes ``connection-default``. + + :param path_prefix: The path prefix of the secret to retrieve + :type path_prefix: str + :param secret_id: Name of the secret + :type secret_id: str + :param sep: Separator used to concatenate path_prefix and secret_id + :type sep: str + """ + path = f'{path_prefix}{sep}{secret_id}' + return path.replace('_', sep) + + def _get_secret(self, path_prefix: str, secret_id: str) -> Optional[str]: + """ + Get an Azure Key Vault secret value + + :param path_prefix: Prefix for the Path to get Secret + :type path_prefix: str + :param secret_id: Secret Key + :type secret_id: str + """ + name = self.build_path(path_prefix, secret_id, self.sep) + try: + secret = self.client.get_secret(name=name) + return secret.value + except ResourceNotFoundError as ex: + self.log.debug('Secret %s not found: %s', name, ex) + return None diff --git a/airflow/providers/microsoft/azure/sensors/azure_cosmos.py b/airflow/providers/microsoft/azure/sensors/azure_cosmos.py index 277e193fff6af..0adeddac60a26 100644 --- a/airflow/providers/microsoft/azure/sensors/azure_cosmos.py +++ b/airflow/providers/microsoft/azure/sensors/azure_cosmos.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,54 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.sensors.cosmos`.""" -from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook -from airflow.sensors.base import BaseSensorOperator - - -class AzureCosmosDocumentSensor(BaseSensorOperator): - """ - Checks for the existence of a document which - matches the given query in CosmosDB. Example: - - .. code-block:: - - azure_cosmos_sensor = AzureCosmosDocumentSensor( - database_name="somedatabase_name", - collection_name="somecollection_name", - document_id="unique-doc-id", - azure_cosmos_conn_id="azure_cosmos_default", - task_id="azure_cosmos_sensor") - - :param database_name: Target CosmosDB database_name. - :type database_name: str - :param collection_name: Target CosmosDB collection_name. - :type collection_name: str - :param document_id: The ID of the target document. - :type document_id: str - :param azure_cosmos_conn_id: Reference to the - :ref:`Azure CosmosDB connection`. - :type azure_cosmos_conn_id: str - """ - - template_fields = ('database_name', 'collection_name', 'document_id') +import warnings - def __init__( - self, - *, - database_name: str, - collection_name: str, - document_id: str, - azure_cosmos_conn_id: str = "azure_cosmos_default", - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.azure_cosmos_conn_id = azure_cosmos_conn_id - self.database_name = database_name - self.collection_name = collection_name - self.document_id = document_id +from airflow.providers.microsoft.azure.sensors.cosmos import AzureCosmosDocumentSensor # noqa - def poke(self, context: dict) -> bool: - self.log.info("*** Intering poke") - hook = AzureCosmosDBHook(self.azure_cosmos_conn_id) - return hook.get_document(self.document_id, self.database_name, self.collection_name) is not None +warnings.warn( + "This module is deprecated. Please use `airflow.providers.microsoft.azure.sensors.cosmos`.", + DeprecationWarning, + stacklevel=2, +) diff --git a/airflow/providers/microsoft/azure/sensors/cosmos.py b/airflow/providers/microsoft/azure/sensors/cosmos.py new file mode 100644 index 0000000000000..46673f01ed1eb --- /dev/null +++ b/airflow/providers/microsoft/azure/sensors/cosmos.py @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook +from airflow.sensors.base import BaseSensorOperator + + +class AzureCosmosDocumentSensor(BaseSensorOperator): + """ + Checks for the existence of a document which + matches the given query in CosmosDB. Example: + + .. code-block:: + + azure_cosmos_sensor = AzureCosmosDocumentSensor( + database_name="somedatabase_name", + collection_name="somecollection_name", + document_id="unique-doc-id", + azure_cosmos_conn_id="azure_cosmos_default", + task_id="azure_cosmos_sensor") + + :param database_name: Target CosmosDB database_name. + :type database_name: str + :param collection_name: Target CosmosDB collection_name. + :type collection_name: str + :param document_id: The ID of the target document. + :type document_id: str + :param azure_cosmos_conn_id: Reference to the + :ref:`Azure CosmosDB connection`. + :type azure_cosmos_conn_id: str + """ + + template_fields = ('database_name', 'collection_name', 'document_id') + + def __init__( + self, + *, + database_name: str, + collection_name: str, + document_id: str, + azure_cosmos_conn_id: str = "azure_cosmos_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.azure_cosmos_conn_id = azure_cosmos_conn_id + self.database_name = database_name + self.collection_name = collection_name + self.document_id = document_id + + def poke(self, context: dict) -> bool: + self.log.info("*** Intering poke") + hook = AzureCosmosDBHook(self.azure_cosmos_conn_id) + return hook.get_document(self.document_id, self.database_name, self.collection_name) is not None diff --git a/airflow/providers/microsoft/azure/transfers/local_to_adls.py b/airflow/providers/microsoft/azure/transfers/local_to_adls.py index 2305cbfbbaf1d..e0c5c968a727b 100644 --- a/airflow/providers/microsoft/azure/transfers/local_to_adls.py +++ b/airflow/providers/microsoft/azure/transfers/local_to_adls.py @@ -19,7 +19,7 @@ from airflow.exceptions import AirflowException from airflow.models import BaseOperator -from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook +from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook class LocalFilesystemToADLSOperator(BaseOperator): diff --git a/airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py b/airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py index 6a533b4cb25fe..b1bc147263ae6 100644 --- a/airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py +++ b/airflow/providers/microsoft/azure/transfers/oracle_to_azure_data_lake.py @@ -23,7 +23,7 @@ import unicodecsv as csv from airflow.models import BaseOperator -from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook +from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook from airflow.providers.oracle.hooks.oracle import OracleHook diff --git a/dev/provider_packages/prepare_provider_packages.py b/dev/provider_packages/prepare_provider_packages.py index 57be3ddf99bd6..007f677bedb22 100755 --- a/dev/provider_packages/prepare_provider_packages.py +++ b/dev/provider_packages/prepare_provider_packages.py @@ -2114,7 +2114,20 @@ def summarise_total_vs_bad_and_warnings(total: int, bad: int, warns: List[warnin # ignore those messages when the warnings are generated directly by importlib - which means that # we imported it directly during module walk by the importlib library KNOWN_DEPRECATED_DIRECT_IMPORTS: Set[str] = { + "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.batch`.", + "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_instance`.", + "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_registry`.", + "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_volume`.", + "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.cosmos`.", "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.data_factory`.", + "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.data_lake`.", + "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.fileshare`.", + "This module is deprecated. Please use `airflow.providers.microsoft.azure.operators.batch`.", + "This module is deprecated. " + "Please use `airflow.providers.microsoft.azure.operators.container_instances`.", + "This module is deprecated. Please use `airflow.providers.microsoft.azure.operators.cosmos`.", + "This module is deprecated. Please use `airflow.providers.microsoft.azure.secrets.key_vault`.", + "This module is deprecated. Please use `airflow.providers.microsoft.azure.sensors.cosmos`.", "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.dynamodb`.", "This module is deprecated. Please use `airflow.providers.microsoft.azure.transfers.local_to_wasb`.", "This module is deprecated. Please use `airflow.providers.tableau.operators.tableau_refresh_workbook`.", diff --git a/docs/apache-airflow-providers-microsoft-azure/secrets-backends/azure-key-vault.rst b/docs/apache-airflow-providers-microsoft-azure/secrets-backends/azure-key-vault.rst index e0f89f6a7f866..b7430d32697fc 100644 --- a/docs/apache-airflow-providers-microsoft-azure/secrets-backends/azure-key-vault.rst +++ b/docs/apache-airflow-providers-microsoft-azure/secrets-backends/azure-key-vault.rst @@ -20,7 +20,7 @@ Azure Key Vault Backend ^^^^^^^^^^^^^^^^^^^^^^^ To enable the Azure Key Vault as secrets backend, specify -:py:class:`~airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend` +:py:class:`~airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend` as the ``backend`` in ``[secrets]`` section of ``airflow.cfg``. Here is a sample configuration: @@ -28,7 +28,7 @@ Here is a sample configuration: .. code-block:: ini [secrets] - backend = airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend + backend = airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend backend_kwargs = {"connections_prefix": "airflow-connections", "variables_prefix": "airflow-variables", "vault_url": "https://example-akv-resource-name.vault.azure.net/"} For client authentication, the ``DefaultAzureCredential`` from the Azure Python SDK is used as credential provider, @@ -49,7 +49,7 @@ For example, if you want to set parameter ``connections_prefix`` to ``"airflow-c .. code-block:: ini [secrets] - backend = airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend + backend = airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend backend_kwargs = {"connections_prefix": "airflow-connections", "variables_prefix": null, "vault_url": "https://example-akv-resource-name.vault.azure.net/"} Storing and Retrieving Connections diff --git a/tests/deprecated_classes.py b/tests/deprecated_classes.py index fc8c3a82e6dc2..2cb927947f716 100644 --- a/tests/deprecated_classes.py +++ b/tests/deprecated_classes.py @@ -220,23 +220,23 @@ 'airflow.contrib.hooks.fs_hook.FSHook', ), ( - 'airflow.providers.microsoft.azure.hooks.azure_container_instance.AzureContainerInstanceHook', + 'airflow.providers.microsoft.azure.hooks.container_instance.AzureContainerInstanceHook', 'airflow.contrib.hooks.azure_container_instance_hook.AzureContainerInstanceHook', ), ( - 'airflow.providers.microsoft.azure.hooks.azure_container_registry.AzureContainerRegistryHook', + 'airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook', 'airflow.contrib.hooks.azure_container_registry_hook.AzureContainerRegistryHook', ), ( - 'airflow.providers.microsoft.azure.hooks.azure_container_volume.AzureContainerVolumeHook', + 'airflow.providers.microsoft.azure.hooks.container_volume.AzureContainerVolumeHook', 'airflow.contrib.hooks.azure_container_volume_hook.AzureContainerVolumeHook', ), ( - 'airflow.providers.microsoft.azure.hooks.azure_cosmos.AzureCosmosDBHook', + 'airflow.providers.microsoft.azure.hooks.cosmos.AzureCosmosDBHook', 'airflow.contrib.hooks.azure_cosmos_hook.AzureCosmosDBHook', ), ( - 'airflow.providers.microsoft.azure.hooks.azure_fileshare.AzureFileShareHook', + 'airflow.providers.microsoft.azure.hooks.fileshare.AzureFileShareHook', 'airflow.contrib.hooks.azure_fileshare_hook.AzureFileShareHook', ), ( @@ -1028,12 +1028,11 @@ 'airflow.contrib.operators.adls_list_operator.AzureDataLakeStorageListOperator', ), ( - 'airflow.providers.microsoft.azure.operators' - '.azure_container_instances.AzureContainerInstancesOperator', + 'airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstancesOperator', 'airflow.contrib.operators.azure_container_instances_operator.AzureContainerInstancesOperator', ), ( - 'airflow.providers.microsoft.azure.operators.azure_cosmos.AzureCosmosInsertDocumentOperator', + 'airflow.providers.microsoft.azure.operators.cosmos.AzureCosmosInsertDocumentOperator', 'airflow.contrib.operators.azure_cosmos_operator.AzureCosmosInsertDocumentOperator', ), ( diff --git a/tests/providers/microsoft/azure/hooks/test_azure_batch.py b/tests/providers/microsoft/azure/hooks/test_azure_batch.py index 40e8b1f0c6386..daed87c30ea9e 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_batch.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_batch.py @@ -23,7 +23,7 @@ from azure.batch import BatchServiceClient, models as batch_models from airflow.models import Connection -from airflow.providers.microsoft.azure.hooks.azure_batch import AzureBatchHook +from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook from airflow.utils import db @@ -93,7 +93,7 @@ def test_configure_pool_with_cloud_config(self): def test_configure_pool_with_latest_vm(self): with mock.patch( "airflow.providers.microsoft.azure.hooks." - "azure_batch.AzureBatchHook._get_latest_verified_image_vm_and_sku" + "batch.AzureBatchHook._get_latest_verified_image_vm_and_sku" ) as mock_getvm: hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id) getvm_instance = mock_getvm @@ -108,7 +108,7 @@ def test_configure_pool_with_latest_vm(self): ) assert isinstance(pool, batch_models.PoolAddParameter) - @mock.patch("airflow.providers.microsoft.azure.hooks.azure_batch.BatchServiceClient") + @mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient") def test_create_pool_with_vm_config(self, mock_batch): hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) mock_instance = mock_batch.return_value.pool.add @@ -123,7 +123,7 @@ def test_create_pool_with_vm_config(self, mock_batch): hook.create_pool(pool=pool) mock_instance.assert_called_once_with(pool) - @mock.patch("airflow.providers.microsoft.azure.hooks.azure_batch.BatchServiceClient") + @mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient") def test_create_pool_with_cloud_config(self, mock_batch): hook = AzureBatchHook(azure_batch_conn_id=self.test_cloud_conn_id) mock_instance = mock_batch.return_value.pool.add @@ -138,12 +138,12 @@ def test_create_pool_with_cloud_config(self, mock_batch): hook.create_pool(pool=pool) mock_instance.assert_called_once_with(pool) - @mock.patch("airflow.providers.microsoft.azure.hooks.azure_batch.BatchServiceClient") + @mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient") def test_wait_for_all_nodes(self, mock_batch): # TODO: Add test pass - @mock.patch("airflow.providers.microsoft.azure.hooks.azure_batch.BatchServiceClient") + @mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient") def test_job_configuration_and_create_job(self, mock_batch): hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) mock_instance = mock_batch.return_value.job.add @@ -152,7 +152,7 @@ def test_job_configuration_and_create_job(self, mock_batch): assert isinstance(job, batch_models.JobAddParameter) mock_instance.assert_called_once_with(job) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_batch.BatchServiceClient') + @mock.patch('airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient') def test_add_single_task_to_job(self, mock_batch): hook = AzureBatchHook(azure_batch_conn_id=self.test_vm_conn_id) mock_instance = mock_batch.return_value.task.add @@ -161,7 +161,7 @@ def test_add_single_task_to_job(self, mock_batch): assert isinstance(task, batch_models.TaskAddParameter) mock_instance.assert_called_once_with(job_id="myjob", task=task) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_batch.BatchServiceClient') + @mock.patch('airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient') def test_wait_for_all_task_to_complete(self, mock_batch): # TODO: Add test pass diff --git a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py index b744ed2d04918..e29d9c2881e15 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_container_instance.py @@ -29,7 +29,7 @@ ) from airflow.models import Connection -from airflow.providers.microsoft.azure.hooks.azure_container_instance import AzureContainerInstanceHook +from airflow.providers.microsoft.azure.hooks.container_instance import AzureContainerInstanceHook from airflow.utils import db diff --git a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py index 7fc42113983d5..34a876d9447be 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_container_registry.py @@ -19,7 +19,7 @@ import unittest from airflow.models import Connection -from airflow.providers.microsoft.azure.hooks.azure_container_registry import AzureContainerRegistryHook +from airflow.providers.microsoft.azure.hooks.container_registry import AzureContainerRegistryHook from airflow.utils import db diff --git a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py index ccd3cb62df5a1..f6361df6a28b3 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_container_volume.py @@ -19,7 +19,7 @@ import unittest from airflow.models import Connection -from airflow.providers.microsoft.azure.hooks.azure_container_volume import AzureContainerVolumeHook +from airflow.providers.microsoft.azure.hooks.container_volume import AzureContainerVolumeHook from airflow.utils import db diff --git a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py index 1ff086d759f88..6f4b1801c9646 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py @@ -29,7 +29,7 @@ from airflow.exceptions import AirflowException from airflow.models import Connection -from airflow.providers.microsoft.azure.hooks.azure_cosmos import AzureCosmosDBHook +from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook from airflow.utils import db @@ -59,13 +59,13 @@ def setUp(self): ) ) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient', autospec=True) def test_client(self, mock_cosmos): hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') assert hook._conn is None assert isinstance(hook.get_conn(), CosmosClient) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient') + @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient') def test_create_database(self, mock_cosmos): hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') hook.create_database(self.test_database_name) @@ -73,19 +73,19 @@ def test_create_database(self, mock_cosmos): mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) mock_cosmos.assert_has_calls(expected_calls) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient') + @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient') def test_create_database_exception(self, mock_cosmos): hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') with pytest.raises(AirflowException): hook.create_database(None) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient') + @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient') def test_create_container_exception(self, mock_cosmos): hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') with pytest.raises(AirflowException): hook.create_collection(None) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient') + @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient') def test_create_container(self, mock_cosmos): hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') hook.create_collection(self.test_collection_name, self.test_database_name) @@ -95,7 +95,7 @@ def test_create_container(self, mock_cosmos): mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) mock_cosmos.assert_has_calls(expected_calls) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient') + @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient') def test_create_container_default(self, mock_cosmos): hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') hook.create_collection(self.test_collection_name) @@ -105,7 +105,7 @@ def test_create_container_default(self, mock_cosmos): mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) mock_cosmos.assert_has_calls(expected_calls) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient') + @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient') def test_upsert_document_default(self, mock_cosmos): test_id = str(uuid.uuid4()) mock_cosmos.return_value.CreateItem.return_value = {'id': test_id} @@ -122,7 +122,7 @@ def test_upsert_document_default(self, mock_cosmos): logging.getLogger().info(returned_item) assert returned_item['id'] == test_id - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient') + @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient') def test_upsert_document(self, mock_cosmos): test_id = str(uuid.uuid4()) mock_cosmos.return_value.CreateItem.return_value = {'id': test_id} @@ -146,7 +146,7 @@ def test_upsert_document(self, mock_cosmos): logging.getLogger().info(returned_item) assert returned_item['id'] == test_id - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient') + @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient') def test_insert_documents(self, mock_cosmos): test_id1 = str(uuid.uuid4()) test_id2 = str(uuid.uuid4()) @@ -177,7 +177,7 @@ def test_insert_documents(self, mock_cosmos): mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) mock_cosmos.assert_has_calls(expected_calls, any_order=True) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient') + @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient') def test_delete_database(self, mock_cosmos): hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') hook.delete_database(self.test_database_name) @@ -185,7 +185,7 @@ def test_delete_database(self, mock_cosmos): mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) mock_cosmos.assert_has_calls(expected_calls) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient') + @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient') def test_delete_database_exception(self, mock_cosmos): hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') with pytest.raises(AirflowException): @@ -197,7 +197,7 @@ def test_delete_container_exception(self, mock_cosmos): with pytest.raises(AirflowException): hook.delete_collection(None) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient') + @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient') def test_delete_container(self, mock_cosmos): hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') hook.delete_collection(self.test_collection_name, self.test_database_name) @@ -205,7 +205,7 @@ def test_delete_container(self, mock_cosmos): mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key}) mock_cosmos.assert_has_calls(expected_calls) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient') + @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient') def test_delete_container_default(self, mock_cosmos): hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id') hook.delete_collection(self.test_collection_name) diff --git a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py index 046f5565ae059..330f3ccba3cd9 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_data_lake.py @@ -36,11 +36,11 @@ def setUp(self): ) ) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_data_lake.lib', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.lib', autospec=True) def test_conn(self, mock_lib): from azure.datalake.store import core - from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook + from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key') assert hook._conn is None @@ -48,23 +48,19 @@ def test_conn(self, mock_lib): assert isinstance(hook.get_conn(), core.AzureDLFileSystem) assert mock_lib.auth.called - @mock.patch( - 'airflow.providers.microsoft.azure.hooks.azure_data_lake.core.AzureDLFileSystem', autospec=True - ) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_data_lake.lib', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.core.AzureDLFileSystem', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.lib', autospec=True) def test_check_for_blob(self, mock_lib, mock_filesystem): - from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook + from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key') hook.check_for_file('file_path') mock_filesystem.glob.called - @mock.patch( - 'airflow.providers.microsoft.azure.hooks.azure_data_lake.multithread.ADLUploader', autospec=True - ) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_data_lake.lib', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.multithread.ADLUploader', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.lib', autospec=True) def test_upload_file(self, mock_lib, mock_uploader): - from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook + from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key') hook.upload_file( @@ -85,12 +81,10 @@ def test_upload_file(self, mock_lib, mock_uploader): blocksize=4194304, ) - @mock.patch( - 'airflow.providers.microsoft.azure.hooks.azure_data_lake.multithread.ADLDownloader', autospec=True - ) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_data_lake.lib', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.multithread.ADLDownloader', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.lib', autospec=True) def test_download_file(self, mock_lib, mock_downloader): - from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook + from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key') hook.download_file( @@ -111,34 +105,28 @@ def test_download_file(self, mock_lib, mock_downloader): blocksize=4194304, ) - @mock.patch( - 'airflow.providers.microsoft.azure.hooks.azure_data_lake.core.AzureDLFileSystem', autospec=True - ) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_data_lake.lib', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.core.AzureDLFileSystem', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.lib', autospec=True) def test_list_glob(self, mock_lib, mock_fs): - from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook + from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key') hook.list('file_path/*') mock_fs.return_value.glob.assert_called_once_with('file_path/*') - @mock.patch( - 'airflow.providers.microsoft.azure.hooks.azure_data_lake.core.AzureDLFileSystem', autospec=True - ) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_data_lake.lib', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.core.AzureDLFileSystem', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.lib', autospec=True) def test_list_walk(self, mock_lib, mock_fs): - from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook + from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key') hook.list('file_path/some_folder/') mock_fs.return_value.walk.assert_called_once_with('file_path/some_folder/') - @mock.patch( - 'airflow.providers.microsoft.azure.hooks.azure_data_lake.core.AzureDLFileSystem', autospec=True - ) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_data_lake.lib', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.core.AzureDLFileSystem', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.data_lake.lib', autospec=True) def test_remove(self, mock_lib, mock_fs): - from airflow.providers.microsoft.azure.hooks.azure_data_lake import AzureDataLakeHook + from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook hook = AzureDataLakeHook(azure_data_lake_conn_id='adl_test_key') hook.remove('filepath', True) diff --git a/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py b/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py index ac39db9170567..57c6faf82404f 100644 --- a/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py +++ b/tests/providers/microsoft/azure/hooks/test_azure_fileshare.py @@ -33,7 +33,7 @@ from azure.storage.file import Directory, File from airflow.models import Connection -from airflow.providers.microsoft.azure.hooks.azure_fileshare import AzureFileShareHook +from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook from airflow.utils import db @@ -153,7 +153,7 @@ def test_missing_credentials(self): with pytest.raises(ValueError, match=".*account_key or sas_token.*"): hook.get_conn() - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True) def test_check_for_file(self, mock_service): mock_instance = mock_service.return_value mock_instance.exists.return_value = True @@ -161,7 +161,7 @@ def test_check_for_file(self, mock_service): assert hook.check_for_file('share', 'directory', 'file', timeout=3) mock_instance.exists.assert_called_once_with('share', 'directory', 'file', timeout=3) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True) def test_check_for_directory(self, mock_service): mock_instance = mock_service.return_value mock_instance.exists.return_value = True @@ -169,7 +169,7 @@ def test_check_for_directory(self, mock_service): assert hook.check_for_directory('share', 'directory', timeout=3) mock_instance.exists.assert_called_once_with('share', 'directory', timeout=3) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True) def test_load_file(self, mock_service): mock_instance = mock_service.return_value hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras') @@ -178,7 +178,7 @@ def test_load_file(self, mock_service): 'share', 'directory', 'file', 'path', max_connections=1 ) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True) def test_load_string(self, mock_service): mock_instance = mock_service.return_value hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras') @@ -187,7 +187,7 @@ def test_load_string(self, mock_service): 'share', 'directory', 'file', 'big string', timeout=1 ) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True) def test_load_stream(self, mock_service): mock_instance = mock_service.return_value hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras') @@ -196,14 +196,14 @@ def test_load_stream(self, mock_service): 'share', 'directory', 'file', 'stream', 42, timeout=1 ) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True) def test_list_directories_and_files(self, mock_service): mock_instance = mock_service.return_value hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras') hook.list_directories_and_files('share', 'directory', timeout=1) mock_instance.list_directories_and_files.assert_called_once_with('share', 'directory', timeout=1) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True) def test_list_files(self, mock_service): mock_instance = mock_service.return_value mock_instance.list_directories_and_files.return_value = [ @@ -217,14 +217,14 @@ def test_list_files(self, mock_service): assert files == ["file1", 'file2'] mock_instance.list_directories_and_files.assert_called_once_with('share', 'directory', timeout=1) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True) def test_create_directory(self, mock_service): mock_instance = mock_service.return_value hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras') hook.create_directory('share', 'directory', timeout=1) mock_instance.create_directory.assert_called_once_with('share', 'directory', timeout=1) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True) def test_get_file(self, mock_service): mock_instance = mock_service.return_value hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras') @@ -233,7 +233,7 @@ def test_get_file(self, mock_service): 'share', 'directory', 'file', 'path', max_connections=1 ) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True) def test_get_file_to_stream(self, mock_service): mock_instance = mock_service.return_value hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras') @@ -242,14 +242,14 @@ def test_get_file_to_stream(self, mock_service): 'share', 'directory', 'file', 'stream', max_connections=1 ) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True) def test_create_share(self, mock_service): mock_instance = mock_service.return_value hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras') hook.create_share('my_share') mock_instance.create_share.assert_called_once_with('my_share') - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_fileshare.FileService', autospec=True) + @mock.patch('airflow.providers.microsoft.azure.hooks.fileshare.FileService', autospec=True) def test_delete_share(self, mock_service): mock_instance = mock_service.return_value hook = AzureFileShareHook(azure_fileshare_conn_id='azure_fileshare_extras') diff --git a/tests/providers/microsoft/azure/operators/test_azure_batch.py b/tests/providers/microsoft/azure/operators/test_azure_batch.py index 7a699943cf615..9f7afcdd61d91 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_batch.py +++ b/tests/providers/microsoft/azure/operators/test_azure_batch.py @@ -24,8 +24,8 @@ from airflow.exceptions import AirflowException from airflow.models import Connection -from airflow.providers.microsoft.azure.hooks.azure_batch import AzureBatchHook -from airflow.providers.microsoft.azure.operators.azure_batch import AzureBatchOperator +from airflow.providers.microsoft.azure.hooks.batch import AzureBatchHook +from airflow.providers.microsoft.azure.operators.batch import AzureBatchOperator from airflow.utils import db TASK_ID = "MyDag" @@ -42,8 +42,8 @@ class TestAzureBatchOperator(unittest.TestCase): # set up the test environment - @mock.patch("airflow.providers.microsoft.azure.hooks.azure_batch.AzureBatchHook") - @mock.patch("airflow.providers.microsoft.azure.hooks.azure_batch.BatchServiceClient") + @mock.patch("airflow.providers.microsoft.azure.hooks.batch.AzureBatchHook") + @mock.patch("airflow.providers.microsoft.azure.hooks.batch.BatchServiceClient") def setUp(self, mock_batch, mock_hook): # set up the test variable self.test_vm_conn_id = "test_azure_batch_vm2" diff --git a/tests/providers/microsoft/azure/operators/test_azure_container_instances.py b/tests/providers/microsoft/azure/operators/test_azure_container_instances.py index 0367d3d15142b..ce46ee4876fd4 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_container_instances.py +++ b/tests/providers/microsoft/azure/operators/test_azure_container_instances.py @@ -26,9 +26,7 @@ from azure.mgmt.containerinstance.models import ContainerState, Event from airflow.exceptions import AirflowException -from airflow.providers.microsoft.azure.operators.azure_container_instances import ( - AzureContainerInstancesOperator, -) +from airflow.providers.microsoft.azure.operators.container_instances import AzureContainerInstancesOperator def make_mock_cg(container_state, events=None): @@ -67,9 +65,7 @@ def make_mock_cg_with_missing_events(container_state): class TestACIOperator(unittest.TestCase): - @mock.patch( - "airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook" - ) + @mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook") def test_execute(self, aci_mock): expected_c_state = ContainerState(state='Terminated', exit_code=0, detail_status='test') expected_cg = make_mock_cg(expected_c_state) @@ -105,9 +101,7 @@ def test_execute(self, aci_mock): assert aci_mock.return_value.delete.call_count == 1 - @mock.patch( - "airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook" - ) + @mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook") def test_execute_with_failures(self, aci_mock): expected_c_state = ContainerState(state='Terminated', exit_code=1, detail_status='test') expected_cg = make_mock_cg(expected_c_state) @@ -129,9 +123,7 @@ def test_execute_with_failures(self, aci_mock): assert aci_mock.return_value.delete.call_count == 1 - @mock.patch( - "airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook" - ) + @mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook") def test_execute_with_tags(self, aci_mock): expected_c_state = ContainerState(state='Terminated', exit_code=0, detail_status='test') expected_cg = make_mock_cg(expected_c_state) @@ -170,9 +162,7 @@ def test_execute_with_tags(self, aci_mock): assert aci_mock.return_value.delete.call_count == 1 - @mock.patch( - "airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook" - ) + @mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook") def test_execute_with_messages_logs(self, aci_mock): events = [Event(message="test"), Event(message="messages")] expected_c_state1 = ContainerState(state='Succeeded', exit_code=0, detail_status='test') @@ -220,9 +210,7 @@ def test_name_checker(self): checked_name = AzureContainerInstancesOperator._check_name(name) assert checked_name == name - @mock.patch( - "airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook" - ) + @mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook") def test_execute_with_ipaddress(self, aci_mock): expected_c_state = ContainerState(state='Terminated', exit_code=0, detail_status='test') expected_cg = make_mock_cg(expected_c_state) @@ -247,9 +235,7 @@ def test_execute_with_ipaddress(self, aci_mock): assert called_cg.ip_address == ipaddress - @mock.patch( - "airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook" - ) + @mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook") def test_execute_with_windows_os_and_diff_restart_policy(self, aci_mock): expected_c_state = ContainerState(state='Terminated', exit_code=0, detail_status='test') expected_cg = make_mock_cg(expected_c_state) @@ -275,9 +261,7 @@ def test_execute_with_windows_os_and_diff_restart_policy(self, aci_mock): assert called_cg.restart_policy == 'Always' assert called_cg.os_type == 'Windows' - @mock.patch( - "airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook" - ) + @mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook") def test_execute_fails_with_incorrect_os_type(self, aci_mock): expected_c_state = ContainerState(state='Terminated', exit_code=0, detail_status='test') expected_cg = make_mock_cg(expected_c_state) @@ -303,9 +287,7 @@ def test_execute_fails_with_incorrect_os_type(self, aci_mock): "Found `MacOs`." ) - @mock.patch( - "airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook" - ) + @mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook") def test_execute_fails_with_incorrect_restart_policy(self, aci_mock): expected_c_state = ContainerState(state='Terminated', exit_code=0, detail_status='test') expected_cg = make_mock_cg(expected_c_state) @@ -331,10 +313,8 @@ def test_execute_fails_with_incorrect_restart_policy(self, aci_mock): "Found `Everyday`" ) - @mock.patch( - "airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook" - ) - @mock.patch('airflow.providers.microsoft.azure.operators.azure_container_instances.sleep') + @mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook") + @mock.patch('airflow.providers.microsoft.azure.operators.container_instances.sleep') def test_execute_correct_sleep_cycle(self, sleep_mock, aci_mock): expected_c_state1 = ContainerState(state='Running', exit_code=0, detail_status='test') expected_cg1 = make_mock_cg(expected_c_state1) @@ -358,9 +338,7 @@ def test_execute_correct_sleep_cycle(self, sleep_mock, aci_mock): # sleep is called at the end of cycles. Thus, the Terminated call does not trigger sleep assert sleep_mock.call_count == 2 - @mock.patch( - "airflow.providers.microsoft.azure.operators.azure_container_instances.AzureContainerInstanceHook" - ) + @mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook") @mock.patch("logging.Logger.exception") def test_execute_with_missing_events(self, log_mock, aci_mock): expected_c_state1 = ContainerState(state='Running', exit_code=0, detail_status='test') diff --git a/tests/providers/microsoft/azure/operators/test_azure_cosmos.py b/tests/providers/microsoft/azure/operators/test_azure_cosmos.py index 26144406cb892..b7755b1ca9746 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_cosmos.py +++ b/tests/providers/microsoft/azure/operators/test_azure_cosmos.py @@ -24,7 +24,7 @@ from unittest import mock from airflow.models import Connection -from airflow.providers.microsoft.azure.operators.azure_cosmos import AzureCosmosInsertDocumentOperator +from airflow.providers.microsoft.azure.operators.cosmos import AzureCosmosInsertDocumentOperator from airflow.utils import db @@ -49,7 +49,7 @@ def setUp(self): ) ) - @mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient') + @mock.patch('airflow.providers.microsoft.azure.hooks.cosmos.CosmosClient') def test_insert_document(self, cosmos_mock): test_id = str(uuid.uuid4()) cosmos_mock.return_value.CreateItem.return_value = {'id': test_id} diff --git a/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py b/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py index 783dcf62f8b03..e037fbe632fd8 100644 --- a/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py +++ b/tests/providers/microsoft/azure/secrets/test_azure_key_vault.py @@ -20,19 +20,19 @@ from azure.core.exceptions import ResourceNotFoundError -from airflow.providers.microsoft.azure.secrets.azure_key_vault import AzureKeyVaultBackend +from airflow.providers.microsoft.azure.secrets.key_vault import AzureKeyVaultBackend class TestAzureKeyVaultBackend(TestCase): - @mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend.get_conn_uri') + @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.get_conn_uri') def test_get_connections(self, mock_get_uri): mock_get_uri.return_value = 'scheme://user:pass@host:100' conn_list = AzureKeyVaultBackend().get_connections('fake_conn') conn = conn_list[0] assert conn.host == 'host' - @mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.DefaultAzureCredential') - @mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.SecretClient') + @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.DefaultAzureCredential') + @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.SecretClient') def test_get_conn_uri(self, mock_secret_client, mock_azure_cred): mock_cred = mock.Mock() mock_sec_client = mock.Mock() @@ -50,7 +50,7 @@ def test_get_conn_uri(self, mock_secret_client, mock_azure_cred): ) assert returned_uri == 'postgresql://airflow:airflow@host:5432/airflow' - @mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend.client') + @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client') def test_get_conn_uri_non_existent_key(self, mock_client): """ Test that if the key with connection ID is not present, @@ -63,7 +63,7 @@ def test_get_conn_uri_non_existent_key(self, mock_client): assert backend.get_conn_uri(conn_id=conn_id) is None assert [] == backend.get_connections(conn_id=conn_id) - @mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend.client') + @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client') def test_get_variable(self, mock_client): mock_client.get_secret.return_value = mock.Mock(value='world') backend = AzureKeyVaultBackend() @@ -71,7 +71,7 @@ def test_get_variable(self, mock_client): mock_client.get_secret.assert_called_with(name='airflow-variables-hello') assert 'world' == returned_uri - @mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend.client') + @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client') def test_get_variable_non_existent_key(self, mock_client): """ Test that if Variable key is not present, @@ -81,7 +81,7 @@ def test_get_variable_non_existent_key(self, mock_client): backend = AzureKeyVaultBackend() assert backend.get_variable('test_mysql') is None - @mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend.client') + @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client') def test_get_secret_value_not_found(self, mock_client): """ Test that if a non-existent secret returns None @@ -92,7 +92,7 @@ def test_get_secret_value_not_found(self, mock_client): backend._get_secret(path_prefix=backend.connections_prefix, secret_id='test_non_existent') is None ) - @mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend.client') + @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend.client') def test_get_secret_value(self, mock_client): """ Test that get_secret returns the secret value @@ -103,7 +103,7 @@ def test_get_secret_value(self, mock_client): mock_client.get_secret.assert_called_with(name='af-secrets-test-mysql-password') assert secret_val == 'super-secret' - @mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend._get_secret') + @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend._get_secret') def test_connection_prefix_none_value(self, mock_get_secret): """ Test that if Connections prefix is None, @@ -116,7 +116,7 @@ def test_connection_prefix_none_value(self, mock_get_secret): assert backend.get_conn_uri('test_mysql') is None mock_get_secret.assert_not_called() - @mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend._get_secret') + @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend._get_secret') def test_variable_prefix_none_value(self, mock_get_secret): """ Test that if Variables prefix is None, @@ -129,7 +129,7 @@ def test_variable_prefix_none_value(self, mock_get_secret): assert backend.get_variable('hello') is None mock_get_secret.assert_not_called() - @mock.patch('airflow.providers.microsoft.azure.secrets.azure_key_vault.AzureKeyVaultBackend._get_secret') + @mock.patch('airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend._get_secret') def test_config_prefix_none_value(self, mock_get_secret): """ Test that if Config prefix is None, diff --git a/tests/providers/microsoft/azure/sensors/test_azure_cosmos.py b/tests/providers/microsoft/azure/sensors/test_azure_cosmos.py index a2eafe7ad638f..07988004e6ad0 100644 --- a/tests/providers/microsoft/azure/sensors/test_azure_cosmos.py +++ b/tests/providers/microsoft/azure/sensors/test_azure_cosmos.py @@ -18,7 +18,7 @@ import unittest from unittest import mock -from airflow.providers.microsoft.azure.sensors.azure_cosmos import AzureCosmosDocumentSensor +from airflow.providers.microsoft.azure.sensors.cosmos import AzureCosmosDocumentSensor DB_NAME = 'test-db-name' COLLECTION_NAME = 'test-db-collection-name' @@ -26,7 +26,7 @@ class TestAzureCosmosSensor(unittest.TestCase): - @mock.patch('airflow.providers.microsoft.azure.sensors.azure_cosmos.AzureCosmosDBHook') + @mock.patch('airflow.providers.microsoft.azure.sensors.cosmos.AzureCosmosDBHook') def test_should_call_hook_with_args(self, mock_hook): mock_instance = mock_hook.return_value mock_instance.get_document.return_value = True # Indicate document returned @@ -40,7 +40,7 @@ def test_should_call_hook_with_args(self, mock_hook): mock_instance.get_document.assert_called_once_with(DOCUMENT_ID, DB_NAME, COLLECTION_NAME) assert result is True - @mock.patch('airflow.providers.microsoft.azure.sensors.azure_cosmos.AzureCosmosDBHook') + @mock.patch('airflow.providers.microsoft.azure.sensors.cosmos.AzureCosmosDBHook') def test_should_return_false_on_no_document(self, mock_hook): mock_instance = mock_hook.return_value mock_instance.get_document.return_value = None # Indicate document not returned diff --git a/tests/test_utils/azure_system_helpers.py b/tests/test_utils/azure_system_helpers.py index 79e7e4fed6930..e6599b5b39639 100644 --- a/tests/test_utils/azure_system_helpers.py +++ b/tests/test_utils/azure_system_helpers.py @@ -25,7 +25,7 @@ from airflow.exceptions import AirflowException from airflow.models import Connection -from airflow.providers.microsoft.azure.hooks.azure_fileshare import AzureFileShareHook +from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook from airflow.utils.process_utils import patch_environ from tests.test_utils import AIRFLOW_MAIN_FOLDER from tests.test_utils.system_tests_class import SystemTest