From 763919d4152ffa13433e2489fec85ed286b7b196 Mon Sep 17 00:00:00 2001 From: josh-fell <48934154+josh-fell@users.noreply.github.com> Date: Sun, 25 Jul 2021 06:05:59 -0400 Subject: [PATCH] Adding custom Salesforce connection type + SalesforceToS3Operator updates (#17162) --- .../example_dags/example_salesforce_to_s3.py | 7 +-- .../amazon/aws/transfers/salesforce_to_s3.py | 6 +- .../cloud/transfers/salesforce_to_gcs.py | 2 +- .../providers/salesforce/hooks/salesforce.py | 58 ++++++++++++++----- airflow/providers/salesforce/provider.yaml | 1 + .../connections/salesforce.rst | 49 +++++----------- .../aws/transfers/test_salesforce_to_s3.py | 8 +-- .../salesforce/hooks/test_salesforce.py | 11 ++-- 8 files changed, 71 insertions(+), 71 deletions(-) diff --git a/airflow/providers/amazon/aws/example_dags/example_salesforce_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_salesforce_to_s3.py index ab8edcb8a023..f8fd0dba917a 100644 --- a/airflow/providers/amazon/aws/example_dags/example_salesforce_to_s3.py +++ b/airflow/providers/amazon/aws/example_dags/example_salesforce_to_s3.py @@ -54,19 +54,18 @@ store_to_s3_data_lake = S3CopyObjectOperator( task_id="store_to_s3_data_lake", - source_bucket_key=upload_salesforce_data_to_s3_landing.output["s3_uri"], + source_bucket_key=upload_salesforce_data_to_s3_landing.output, dest_bucket_name="data_lake", dest_bucket_key=f"{BASE_PATH}/{date_prefixes}/{FILE_NAME}", ) delete_data_from_s3_landing = S3DeleteObjectsOperator( task_id="delete_data_from_s3_landing", - bucket=upload_salesforce_data_to_s3_landing.output["s3_bucket_name"], - keys=upload_salesforce_data_to_s3_landing.output["s3_key"], + bucket=upload_salesforce_data_to_s3_landing.s3_bucket_name, + keys=upload_salesforce_data_to_s3_landing.s3_key, ) store_to_s3_data_lake >> delete_data_from_s3_landing # Task dependencies created via `XComArgs`: # upload_salesforce_data_to_s3_landing >> store_to_s3_data_lake - # upload_salesforce_data_to_s3_landing >> delete_data_from_s3_landing diff --git a/airflow/providers/amazon/aws/transfers/salesforce_to_s3.py b/airflow/providers/amazon/aws/transfers/salesforce_to_s3.py index 25df2fbbba91..09b49c1971ea 100644 --- a/airflow/providers/amazon/aws/transfers/salesforce_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/salesforce_to_s3.py @@ -106,8 +106,8 @@ def __init__( self.gzip = gzip self.acl_policy = acl_policy - def execute(self, context: Dict) -> Dict: - salesforce_hook = SalesforceHook(conn_id=self.salesforce_conn_id) + def execute(self, context: Dict) -> str: + salesforce_hook = SalesforceHook(salesforce_conn_id=self.salesforce_conn_id) response = salesforce_hook.make_query( query=self.salesforce_query, include_deleted=self.include_deleted, @@ -138,4 +138,4 @@ def execute(self, context: Dict) -> Dict: s3_uri = f"s3://{self.s3_bucket_name}/{self.s3_key}" self.log.info(f"Salesforce data uploaded to S3 at {s3_uri}.") - return {"s3_uri": s3_uri, "s3_bucket_name": self.s3_bucket_name, "s3_key": self.s3_key} + return s3_uri diff --git a/airflow/providers/google/cloud/transfers/salesforce_to_gcs.py b/airflow/providers/google/cloud/transfers/salesforce_to_gcs.py index d8179e743a41..5564b410d206 100644 --- a/airflow/providers/google/cloud/transfers/salesforce_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/salesforce_to_gcs.py @@ -97,7 +97,7 @@ def __init__( self.query_params = query_params def execute(self, context: Dict): - salesforce = SalesforceHook(conn_id=self.salesforce_conn_id) + salesforce = SalesforceHook(salesforce_conn_id=self.salesforce_conn_id) response = salesforce.make_query( query=self.query, include_deleted=self.include_deleted, query_params=self.query_params ) diff --git a/airflow/providers/salesforce/hooks/salesforce.py b/airflow/providers/salesforce/hooks/salesforce.py index d76baac2c727..1b32cc972a0d 100644 --- a/airflow/providers/salesforce/hooks/salesforce.py +++ b/airflow/providers/salesforce/hooks/salesforce.py @@ -25,7 +25,7 @@ """ import logging import time -from typing import Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional import pandas as pd from simple_salesforce import Salesforce, api @@ -37,29 +37,58 @@ class SalesforceHook(BaseHook): """ - Create new connection to Salesforce and allows you to pull data out of SFDC and save it to a file. + Creates new connection to Salesforce and allows you to pull data out of SFDC and save it to a file. You can then use that file with other Airflow operators to move the data into another data source. - :param conn_id: the name of the connection that has the parameters we need to connect to Salesforce. - The connection should be type `http` and include a user's security token in the `Extras` field. + :param conn_id: The name of the connection that has the parameters needed to connect to Salesforce. + The connection should be of type `Salesforce`. :type conn_id: str .. note:: - For the HTTP connection type, you can include a - JSON structure in the `Extras` field. - We need a user's security token to connect to Salesforce. - So we define it in the `Extras` field as `{"security_token":"YOUR_SECURITY_TOKEN"}` - - For sandbox mode, add `{"domain":"test"}` in the `Extras` field + To connect to Salesforce make sure the connection includes a Username, Password, and Security Token. + If in sandbox, enter a Domain value of 'test'. Login methods such as IP filtering and JWT are not + supported currently. """ - def __init__(self, conn_id: str) -> None: + conn_name_attr = "salesforce_conn_id" + default_conn_name = "salesforce_default" + conn_type = "salesforce" + hook_name = "Salesforce" + + def __init__(self, salesforce_conn_id: str = default_conn_name) -> None: super().__init__() - self.conn_id = conn_id + self.conn_id = salesforce_conn_id self.conn = None + @staticmethod + def get_connection_form_widgets() -> Dict[str, Any]: + """Returns connection widgets to add to connection form""" + from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget + from flask_babel import lazy_gettext + from wtforms import PasswordField, StringField + + return { + "extra__salesforce__security_token": PasswordField( + lazy_gettext("Security Token"), widget=BS3PasswordFieldWidget() + ), + "extra__salesforce__domain": StringField(lazy_gettext("Domain"), widget=BS3TextFieldWidget()), + } + + @staticmethod + def get_ui_field_behaviour() -> Dict: + """Returns custom field behaviour""" + return { + "hidden_fields": ["schema", "port", "extra", "host"], + "relabeling": { + "login": "Username", + }, + "placeholders": { + "extra__salesforce__domain": "(Optional) Set to 'test' if working in sandbox mode.", + }, + } + def get_conn(self) -> api.Salesforce: """Sign into Salesforce, only if we are not already signed in.""" if not self.conn: @@ -68,9 +97,8 @@ def get_conn(self) -> api.Salesforce: self.conn = Salesforce( username=connection.login, password=connection.password, - security_token=extras['security_token'], - instance_url=connection.host, - domain=extras.get('domain'), + security_token=extras["extra__salesforce__security_token"], + domain=extras["extra__salesforce__domain"] or "login", ) return self.conn diff --git a/airflow/providers/salesforce/provider.yaml b/airflow/providers/salesforce/provider.yaml index edfc925cf06f..9eadab7fd0c0 100644 --- a/airflow/providers/salesforce/provider.yaml +++ b/airflow/providers/salesforce/provider.yaml @@ -55,4 +55,5 @@ hooks: - airflow.providers.salesforce.hooks.salesforce hook-class-names: + - airflow.providers.salesforce.hooks.salesforce.SalesforceHook - airflow.providers.salesforce.hooks.tableau.TableauHook diff --git a/docs/apache-airflow-providers-salesforce/connections/salesforce.rst b/docs/apache-airflow-providers-salesforce/connections/salesforce.rst index 0508e5ce8c49..baab53f64332 100644 --- a/docs/apache-airflow-providers-salesforce/connections/salesforce.rst +++ b/docs/apache-airflow-providers-salesforce/connections/salesforce.rst @@ -19,55 +19,32 @@ Salesforce Connection ===================== -The HTTP connection type provides connection to Salesforce. +The Salesforce connection type provides connection to Salesforce. Configuring the Connection -------------------------- -Host (required) - specify the host address to connect: ``https://your_host.lightning.force.com`` - -Login (required) +Username (required) Specify the email address used to login to your account. Password (required) Specify the password associated with the account. -Extra (required) - Specify the extra parameters (as json dictionary) that can be used in Salesforce - connection. - The following parameter is required: - - * ``security_token``: Salesforce token. - - The following parameter is optional: - - * ``domain``: set to ``test`` if working in sandbox mode. - - For security reason we suggest you to use one of the secrets Backend to create this - connection (Using ENVIRONMENT VARIABLE or Hashicorp Vault, GCP Secrets Manager etc). - - - When specifying the connection as URI (in :envvar:`AIRFLOW_CONN_{CONN_ID}` variable) you should specify it - following the standard syntax of DB connections - where extras are passed as parameters - of the URI. - - For example: - - .. code-block:: bash +Security Token (required) + Specify the Salesforce security token for the username. - export AIRFLOW_CONN_SALESFORCE_DEFAULT='http://your_username:your_password@https%3A%2F%2Fyour_host.lightning.force.com?security_token=your_token' +Domain (optional) + The domain to using for connecting to Salesforce. Use common domains, such as 'login' + or 'test', or Salesforce My domain. If not used, will default to 'login'. +For security reason we suggest you to use one of the secrets Backend to create this +connection (Using ENVIRONMENT VARIABLE or Hashicorp Vault, GCP Secrets Manager etc). -Examples for the **Extra** field --------------------------------- -Setting up sandbox mode: +When specifying the connection as URI (in :envvar:`AIRFLOW_CONN_{CONN_ID}` variable) you should specify it +following the standard syntax of DB connections - where extras are passed as parameters of the URI. For example: -.. code-block:: json + .. code-block:: bash - { - "security_token": "your_token", - "domain":"test" - } + export AIRFLOW_CONN_SALESFORCE_DEFAULT='http://your_username:your_password@https%3A%2F%2Fyour_host.lightning.force.com?security_token=your_token' .. note:: Airflow currently does not support other login methods such as IP filtering and JWT. diff --git a/tests/providers/amazon/aws/transfers/test_salesforce_to_s3.py b/tests/providers/amazon/aws/transfers/test_salesforce_to_s3.py index 9c9fe82eafee..8a0d15a82f7b 100644 --- a/tests/providers/amazon/aws/transfers/test_salesforce_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_salesforce_to_s3.py @@ -95,13 +95,7 @@ def test_execute(self, mock_make_query, mock_write_object_to_file, mock_load_fil assert operator.gzip == GZIP assert operator.acl_policy == ACL_POLICY - expected_op_output = { - "s3_uri": f"s3://{S3_BUCKET}/{S3_KEY}", - "s3_bucket_name": S3_BUCKET, - "s3_key": S3_KEY, - } - - assert expected_op_output == operator.execute({}) + assert f"s3://{S3_BUCKET}/{S3_KEY}" == operator.execute({}) mock_make_query.assert_called_once_with( query=QUERY, include_deleted=INCLUDE_DELETED, query_params=QUERY_PARAMS diff --git a/tests/providers/salesforce/hooks/test_salesforce.py b/tests/providers/salesforce/hooks/test_salesforce.py index bf6041bea5ca..6eb9591f0934 100644 --- a/tests/providers/salesforce/hooks/test_salesforce.py +++ b/tests/providers/salesforce/hooks/test_salesforce.py @@ -31,7 +31,7 @@ class TestSalesforceHook(unittest.TestCase): def setUp(self): - self.salesforce_hook = SalesforceHook(conn_id="conn_id") + self.salesforce_hook = SalesforceHook(salesforce_conn_id="conn_id") def test_get_conn_exists(self): self.salesforce_hook.conn = Mock(spec=Salesforce) @@ -43,7 +43,9 @@ def test_get_conn_exists(self): @patch( "airflow.providers.salesforce.hooks.salesforce.SalesforceHook.get_connection", return_value=Connection( - login="username", password="password", extra='{"security_token": "token", "domain": "test"}' + login="username", + password="password", + extra='{"extra__salesforce__security_token": "token", "extra__salesforce__domain": "login"}', ), ) @patch("airflow.providers.salesforce.hooks.salesforce.Salesforce") @@ -54,9 +56,8 @@ def test_get_conn(self, mock_salesforce, mock_get_connection): mock_salesforce.assert_called_once_with( username=mock_get_connection.return_value.login, password=mock_get_connection.return_value.password, - security_token=mock_get_connection.return_value.extra_dejson["security_token"], - instance_url=mock_get_connection.return_value.host, - domain=mock_get_connection.return_value.extra_dejson.get("domain"), + security_token=mock_get_connection.return_value.extra_dejson["extra__salesforce__security_token"], + domain=mock_get_connection.return_value.extra_dejson.get("extra__salesforce__domain"), ) @patch("airflow.providers.salesforce.hooks.salesforce.Salesforce")