From b06300222b88c2345d4f00ddbe834697f890708c Mon Sep 17 00:00:00 2001 From: Linus Date: Thu, 17 Oct 2024 14:15:33 +0200 Subject: [PATCH 1/6] adding cert auth --- src/droid/__main__.py | 23 +---- src/droid/platforms/ms_xdr.py | 168 ++++++++++++++++++++++++++-------- 2 files changed, 133 insertions(+), 58 deletions(-) diff --git a/src/droid/__main__.py b/src/droid/__main__.py index 22db6d3..f6fe4e7 100644 --- a/src/droid/__main__.py +++ b/src/droid/__main__.py @@ -199,7 +199,7 @@ def droid_platform_config(args, config_path): if "export_auth" in config and config["export_auth"] not in auth_methods: raise ValueError(f"Invalid export_auth: {config['export_auth']}") - if config["search_auth"] == "app" and not "credential_file" in config: + if (config["search_auth"] == "app" and not "credential_file" in config) or (config["export_auth"] == "app" and args.export and not "credential_file" in config): if environ.get("DROID_AZURE_TENANT_ID"): tenant_id = environ.get("DROID_AZURE_TENANT_ID") @@ -219,25 +219,10 @@ def droid_platform_config(args, config_path): else: raise Exception("Please use: export DROID_AZURE_CLIENT_SECRET=") - elif config["export_auth"] == "app" and args.export and not "credential_file" in config: - - if environ.get("DROID_AZURE_TENANT_ID"): - tenant_id = environ.get("DROID_AZURE_TENANT_ID") - config["tenant_id"] = tenant_id + if environ.get("DROID_AZURE_CERT_PASS"): + config["cert_pass"] = environ.get("DROID_AZURE_CERT_PASS") else: - raise Exception("Please use: export DROID_AZURE_TENANT_ID=") - - if environ.get("DROID_AZURE_CLIENT_ID"): - client_id = environ.get("DROID_AZURE_CLIENT_ID") - config["client_id"] = client_id - else: - raise Exception("Please use: export DROID_AZURE_CLIENT_ID=") - - if environ.get("DROID_AZURE_CLIENT_SECRET"): - client_secret = environ.get("DROID_AZURE_CLIENT_SECRET") - config["client_secret"] = client_secret - else: - raise Exception("Please use: export DROID_AZURE_CLIENT_SECRET=") + config["cert_pass"] = None return config diff --git a/src/droid/platforms/ms_xdr.py b/src/droid/platforms/ms_xdr.py index a2501eb..ec0d0d2 100644 --- a/src/droid/platforms/ms_xdr.py +++ b/src/droid/platforms/ms_xdr.py @@ -13,10 +13,17 @@ from droid.platforms.common import get_pipeline_group_match from msal import ConfidentialClientApplication from azure.identity import DefaultAzureCredential +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey +from cryptography.hazmat.backends import default_backend + class MicrosoftXDRPlatform(AbstractPlatform): - def __init__(self, parameters: dict, logger_param: dict, export_mssp: bool=False) -> None: + def __init__( + self, parameters: dict, logger_param: dict, export_mssp: bool = False + ) -> None: super().__init__(name="Microsoft XDR") @@ -25,8 +32,10 @@ def __init__(self, parameters: dict, logger_param: dict, export_mssp: bool=False self._parameters = parameters self._export_mssp = export_mssp - if 'query_period_groups' in self._parameters['rule_parameters']: - self._query_period_groups = self._parameters['rule_parameters']['query_period_groups'] + if "query_period_groups" in self._parameters["rule_parameters"]: + self._query_period_groups = self._parameters["rule_parameters"][ + "query_period_groups" + ] if "query_period" not in self._parameters: raise Exception( @@ -45,13 +54,22 @@ def __init__(self, parameters: dict, logger_param: dict, export_mssp: bool=False self._query_period = self._parameters["query_period"] + if "auth_cert" in self._parameters: + self._auth_cert = self._parameters["auth_cert"] + else: + self._auth_cert = None + if "credential_file" in self._parameters: try: with open(self._parameters["credential_file"], "r") as file: credentials = yaml.safe_load(file) - self._client_id = credentials["client_id"] - self._client_secret = credentials["client_secret"] - self._tenant_id = credentials["tenant_id"] + self._client_id = credentials["client_id"] + self._client_secret = credentials["client_secret"] + self._tenant_id = credentials["tenant_id"] + if "cert_pass" in credentials: + self._cert_pass = credentials["cert_pass"] + else: + self._cert_pass = None except Exception as e: raise Exception(f"Error while reading the credential file {e}") elif "app" in ( @@ -60,6 +78,7 @@ def __init__(self, parameters: dict, logger_param: dict, export_mssp: bool=False self._tenant_id = self._parameters["tenant_id"] self._client_id = self._parameters["client_id"] self._client_secret = self._parameters["client_secret"] + self._cert_pass = self._parameters["cert_pass"] elif "default" in ( self._parameters["search_auth"] or self._parameters["export_auth"] ): @@ -82,7 +101,7 @@ def __init__(self, parameters: dict, logger_param: dict, export_mssp: bool=False else: self._alert_prefix = None - if 'export_list_mssp' in self._parameters: + if "export_list_mssp" in self._parameters: self._export_list_mssp = self._parameters["export_list_mssp"] def get_export_list_mssp(self) -> list: @@ -98,9 +117,13 @@ def run_xdr_search(self, rule_converted, rule_file, tenant_id=None): payload = {"Query": rule_converted, "Timespan": "P1D"} try: if tenant_id: - self.logger.info(f"Searching for rule {rule_file} on tenant {tenant_id}") + self.logger.info( + f"Searching for rule {rule_file} on tenant {tenant_id}" + ) else: - self.logger.info(f"Searching for rule {rule_file} on tenant {self._tenant_id}") + self.logger.info( + f"Searching for rule {rule_file} on tenant {self._tenant_id}" + ) results, status_code = self._post( url="/security/runHuntingQuery", payload=payload, tenant_id=tenant_id @@ -118,8 +141,7 @@ def run_xdr_search(self, rule_converted, rule_file, tenant_id=None): raise def get_rule(self, rule_id, tenant_id=None): - """Retrieve a scheduled alert rule in Microsoft XDR - """ + """Retrieve a scheduled alert rule in Microsoft XDR""" try: params = {"$filter": f"contains(displayName, '{rule_id}')"} rule, status_code = self._get( @@ -194,10 +216,30 @@ def acquire_token(self, tenant_id=None): else: authority = f"https://login.microsoftonline.com/{tenant_id}" # Create a confidential client application + if self._auth_cert: + with open(self._auth_cert, "r") as file: + certificate_data = file.read() + + private_key = serialization.load_pem_private_key( + certificate_data, self._cert_pass, backend=default_backend() + ) + cert = x509.load_pem_x509_certificate( + certificate_data, default_backend() + ) + fingerprint = cert.fingerprint(hashes.SHA1()) + + client_credential = { + "private_key": private_key, + "thumbprint": fingerprint, + "passphrase": self._cert_pass, + } + else: + client_credential = self._client_secret + app = ConfidentialClientApplication( self._client_id, authority=authority, - client_credential=self._client_secret, + client_credential=client_credential, ) # Acquire a token @@ -316,15 +358,24 @@ def create_rule(self, rule_content, rule_converted, rule_file): except Exception as e: self.logger.error(e) - if 'query_period_groups' in self._parameters['rule_parameters']: - query_period_group = get_pipeline_group_match(rule_content, self._query_period_groups) + if "query_period_groups" in self._parameters["rule_parameters"]: + query_period_group = get_pipeline_group_match( + rule_content, self._query_period_groups + ) if query_period_group: - self.logger.debug(f"Applying the query_period value from group {query_period_group}") - alert_rule["schedule"]["period"] = self.process_query_period(self._query_period_groups[query_period_group]['query_period'], rule_file) + self.logger.debug( + f"Applying the query_period value from group {query_period_group}" + ) + alert_rule["schedule"]["period"] = self.process_query_period( + self._query_period_groups[query_period_group]["query_period"], + rule_file, + ) if "custom" in rule_content: if "query_period" in rule_content["custom"]: - alert_rule["schedule"]["period"] = self.process_query_period(rule_content["custom"]["query_period"], rule_file) + alert_rule["schedule"]["period"] = self.process_query_period( + rule_content["custom"]["query_period"], rule_file + ) if "actions" in rule_content["custom"]: responseActions = self.parse_actions( rule_content["custom"]["actions"], rule_file=rule_file @@ -343,8 +394,10 @@ def create_rule(self, rule_content, rule_converted, rule_file): self.logger.info("Exporting to designated customers") for group, info in self._export_list_mssp.items(): - tenant_id = info['tenant_id'] - self.logger.debug(f"Exporting to tenant {tenant_id} from group id {group}") + tenant_id = info["tenant_id"] + self.logger.debug( + f"Exporting to tenant {tenant_id} from group id {group}" + ) try: self.push_detection_rule( @@ -352,7 +405,7 @@ def create_rule(self, rule_content, rule_converted, rule_file): rule_content=rule_content, rule_file=rule_file, rule_converted=rule_converted, - tenant_id=tenant_id + tenant_id=tenant_id, ) except Exception as e: self.logger.error( @@ -366,7 +419,9 @@ def create_rule(self, rule_content, rule_converted, rule_file): ) raise else: - self.logger.error("Export list not found. Please provide the list of designated customers") + self.logger.error( + "Export list not found. Please provide the list of designated customers" + ) raise else: try: @@ -440,8 +495,12 @@ def check_rule_changes(self, existing_rule, new_rule): return False def push_detection_rule( - self, alert_rule=None, rule_content=None, rule_file=None, - rule_converted=None, tenant_id=None + self, + alert_rule=None, + rule_content=None, + rule_file=None, + rule_converted=None, + tenant_id=None, ): existing_rule = self.get_rule(rule_content["id"]) if existing_rule: @@ -450,10 +509,14 @@ def push_detection_rule( return True else: api_url = f"/security/rules/detectionRules/{existing_rule['id']}" - response, status_code = self._patch(url=api_url, payload=alert_rule, tenant_id=tenant_id) + response, status_code = self._patch( + url=api_url, payload=alert_rule, tenant_id=tenant_id + ) else: api_url = "/security/rules/detectionRules" - response, status_code = self._post(url=api_url, payload=alert_rule, tenant_id=tenant_id) + response, status_code = self._post( + url=api_url, payload=alert_rule, tenant_id=tenant_id + ) if status_code == 400: self.logger.error( @@ -659,8 +722,16 @@ def parse_impactedAssets(self, impactedAssets, rule_file=None): raise return impactedAssetsList - def _get(self, url=None, headers=None, params=None, tenant_id=None, - timeout=120, max_retries=3, retry_delay=60): + def _get( + self, + url=None, + headers=None, + params=None, + tenant_id=None, + timeout=120, + max_retries=3, + retry_delay=60, + ): """ Sends a GET request to Microsoft Graph Security API with error handling and retries for specific cases like 429, 500+ errors. @@ -682,9 +753,9 @@ def _get(self, url=None, headers=None, params=None, tenant_id=None, while retry_count < max_retries: try: - response = requests.get(api_url, headers=headers, - params=params, timeout=timeout) - + response = requests.get( + api_url, headers=headers, params=params, timeout=timeout + ) if response.status_code == 429: self.logger.warning( @@ -720,9 +791,17 @@ def _get(self, url=None, headers=None, params=None, tenant_id=None, self.logger.error(f"Failed to get a valid response after {max_retries} retries") return None, 500 - - def _post(self, url=None, payload=None, headers=None, params=None, - tenant_id=None, timeout=120, max_retries=3, retry_delay=60): + def _post( + self, + url=None, + payload=None, + headers=None, + params=None, + tenant_id=None, + timeout=120, + max_retries=3, + retry_delay=60, + ): """ Sends a POST request to Microsoft Graph Security API with error handling and retries for specific cases like 429, 500+ errors. @@ -744,8 +823,9 @@ def _post(self, url=None, payload=None, headers=None, params=None, while retry_count < max_retries: try: - response = requests.post(api_url, headers=headers, - json=payload, timeout=timeout) + response = requests.post( + api_url, headers=headers, json=payload, timeout=timeout + ) if response.status_code == 429: self.logger.warning( @@ -781,8 +861,17 @@ def _post(self, url=None, payload=None, headers=None, params=None, self.logger.error(f"Failed to get a valid response after {max_retries} retries") return None, 500 - def _patch(self, url=None, payload=None, headers=None, params=None, - tenant_id=None, timeout=120, max_retries=3, retry_delay=60): + def _patch( + self, + url=None, + payload=None, + headers=None, + params=None, + tenant_id=None, + timeout=120, + max_retries=3, + retry_delay=60, + ): """ Sends a PATCH request to Microsoft Graph Security API with error handling and retries for specific cases like 429, 500+ errors. @@ -804,8 +893,9 @@ def _patch(self, url=None, payload=None, headers=None, params=None, while retry_count < max_retries: try: - response = requests.patch(api_url, headers=headers, - json=payload, timeout=timeout) + response = requests.patch( + api_url, headers=headers, json=payload, timeout=timeout + ) if response.status_code == 429: self.logger.warning( From d5ff4e60869b136288dd13d7c831b5956e0a5e5e Mon Sep 17 00:00:00 2001 From: Linus Date: Mon, 11 Nov 2024 17:07:16 +0100 Subject: [PATCH 2/6] bugfix --- src/droid/platforms/ms_xdr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/droid/platforms/ms_xdr.py b/src/droid/platforms/ms_xdr.py index 028944b..9a078f1 100644 --- a/src/droid/platforms/ms_xdr.py +++ b/src/droid/platforms/ms_xdr.py @@ -32,8 +32,8 @@ def __init__( self._parameters = parameters self._export_mssp = export_mssp - if "query_period_groups" in self._parameters["rule_parameters"]: - self._query_period_groups = self._parameters["rule_parameters"][ + if "query_period_groups" in self._parameters: + self._query_period_groups = self._parameters[ "query_period_groups" ] From 9ccb19c08adde70f173dfc42eb24cbb868d7b05f Mon Sep 17 00:00:00 2001 From: Linus Date: Mon, 11 Nov 2024 17:34:20 +0100 Subject: [PATCH 3/6] bugfix --- src/droid/platforms/ms_xdr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/droid/platforms/ms_xdr.py b/src/droid/platforms/ms_xdr.py index 9a078f1..f6aeac1 100644 --- a/src/droid/platforms/ms_xdr.py +++ b/src/droid/platforms/ms_xdr.py @@ -357,7 +357,7 @@ def create_rule(self, rule_content, rule_converted, rule_file): except Exception as e: self.logger.error(e) - if "query_period_groups" in self._parameters["rule_parameters"]: + if "query_period_groups" in self._parameters: query_period_group = get_pipeline_group_match( rule_content, self._query_period_groups ) From ece43e7631f8659a173333789566f7d9fc3a590a Mon Sep 17 00:00:00 2001 From: Linus Date: Mon, 11 Nov 2024 18:16:01 +0100 Subject: [PATCH 4/6] converted to binary --- src/droid/platforms/ms_xdr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/droid/platforms/ms_xdr.py b/src/droid/platforms/ms_xdr.py index f6aeac1..0edd738 100644 --- a/src/droid/platforms/ms_xdr.py +++ b/src/droid/platforms/ms_xdr.py @@ -216,11 +216,11 @@ def acquire_token(self, tenant_id=None): authority = f"https://login.microsoftonline.com/{tenant_id}" # Create a confidential client application if self._auth_cert: - with open(self._auth_cert, "r") as file: + with open(self._auth_cert, "rb") as file: certificate_data = file.read() - + cert_pass_bytes = self._cert_pass.encode() if isinstance(self._cert_pass, str) else self._cert_pass private_key = serialization.load_pem_private_key( - certificate_data, self._cert_pass, backend=default_backend() + certificate_data, cert_pass_bytes, backend=default_backend() ) cert = x509.load_pem_x509_certificate( certificate_data, default_backend() @@ -230,7 +230,7 @@ def acquire_token(self, tenant_id=None): client_credential = { "private_key": private_key, "thumbprint": fingerprint, - "passphrase": self._cert_pass, + "passphrase": cert_pass_bytes, } else: client_credential = self._client_secret From 6e446329878dfca9247d563d593d0f09bc16a3f5 Mon Sep 17 00:00:00 2001 From: Linus Date: Mon, 11 Nov 2024 18:24:08 +0100 Subject: [PATCH 5/6] added no password function for key retrieval --- src/droid/platforms/ms_xdr.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/droid/platforms/ms_xdr.py b/src/droid/platforms/ms_xdr.py index 0edd738..b72b5ca 100644 --- a/src/droid/platforms/ms_xdr.py +++ b/src/droid/platforms/ms_xdr.py @@ -33,9 +33,7 @@ def __init__( self._export_mssp = export_mssp if "query_period_groups" in self._parameters: - self._query_period_groups = self._parameters[ - "query_period_groups" - ] + self._query_period_groups = self._parameters["query_period_groups"] if "query_period" not in self._parameters: raise Exception( @@ -218,7 +216,14 @@ def acquire_token(self, tenant_id=None): if self._auth_cert: with open(self._auth_cert, "rb") as file: certificate_data = file.read() - cert_pass_bytes = self._cert_pass.encode() if isinstance(self._cert_pass, str) else self._cert_pass + if self._cert_pass: + cert_pass_bytes = ( + self._cert_pass.encode() + if isinstance(self._cert_pass, str) + else self._cert_pass + ) + else: + cert_pass_bytes = None private_key = serialization.load_pem_private_key( certificate_data, cert_pass_bytes, backend=default_backend() ) @@ -251,7 +256,9 @@ def acquire_token(self, tenant_id=None): self.logger.error( f'Failed to acquire token: {result["error_description"]}' ) - raise Exception(f"Token acquisition failed: {result.get('error', 'Unknown error')}") + raise Exception( + f"Token acquisition failed: {result.get('error', 'Unknown error')}" + ) def process_query_period(self, query_period: str, rule_file: str): """Process the query period time From 3c910ca3afbeb5fca38501a776b6f53a8985ae2f Mon Sep 17 00:00:00 2001 From: Linus Date: Mon, 11 Nov 2024 19:00:41 +0100 Subject: [PATCH 6/6] removed private key extraction --- src/droid/platforms/ms_xdr.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/droid/platforms/ms_xdr.py b/src/droid/platforms/ms_xdr.py index b72b5ca..657f828 100644 --- a/src/droid/platforms/ms_xdr.py +++ b/src/droid/platforms/ms_xdr.py @@ -216,26 +216,17 @@ def acquire_token(self, tenant_id=None): if self._auth_cert: with open(self._auth_cert, "rb") as file: certificate_data = file.read() - if self._cert_pass: - cert_pass_bytes = ( - self._cert_pass.encode() - if isinstance(self._cert_pass, str) - else self._cert_pass - ) - else: - cert_pass_bytes = None - private_key = serialization.load_pem_private_key( - certificate_data, cert_pass_bytes, backend=default_backend() - ) + cert = x509.load_pem_x509_certificate( certificate_data, default_backend() ) fingerprint = cert.fingerprint(hashes.SHA1()) + fingerprint = fingerprint.hex() client_credential = { - "private_key": private_key, + "private_key": certificate_data, "thumbprint": fingerprint, - "passphrase": cert_pass_bytes, + "passphrase": self._cert_pass, } else: client_credential = self._client_secret