diff --git a/src/droid/__main__.py b/src/droid/__main__.py index 6049766..8db50b0 100644 --- a/src/droid/__main__.py +++ b/src/droid/__main__.py @@ -199,7 +199,7 @@ def droid_platform_config(args, config_path): if "export_auth" in config and config["export_auth"] not in auth_methods: raise ValueError(f"Invalid export_auth: {config['export_auth']}") - if config["search_auth"] == "app" and not "credential_file" in config: + if (config["search_auth"] == "app" and not "credential_file" in config) or (config["export_auth"] == "app" and args.export and not "credential_file" in config): if environ.get("DROID_AZURE_TENANT_ID"): tenant_id = environ.get("DROID_AZURE_TENANT_ID") @@ -219,25 +219,10 @@ def droid_platform_config(args, config_path): else: raise Exception("Please use: export DROID_AZURE_CLIENT_SECRET=") - elif config["export_auth"] == "app" and args.export and not "credential_file" in config: - - if environ.get("DROID_AZURE_TENANT_ID"): - tenant_id = environ.get("DROID_AZURE_TENANT_ID") - config["tenant_id"] = tenant_id + if environ.get("DROID_AZURE_CERT_PASS"): + config["cert_pass"] = environ.get("DROID_AZURE_CERT_PASS") else: - raise Exception("Please use: export DROID_AZURE_TENANT_ID=") - - if environ.get("DROID_AZURE_CLIENT_ID"): - client_id = environ.get("DROID_AZURE_CLIENT_ID") - config["client_id"] = client_id - else: - raise Exception("Please use: export DROID_AZURE_CLIENT_ID=") - - if environ.get("DROID_AZURE_CLIENT_SECRET"): - client_secret = environ.get("DROID_AZURE_CLIENT_SECRET") - config["client_secret"] = client_secret - else: - raise Exception("Please use: export DROID_AZURE_CLIENT_SECRET=") + config["cert_pass"] = None return config diff --git a/src/droid/platforms/ms_xdr.py b/src/droid/platforms/ms_xdr.py index 57b1993..657f828 100644 --- a/src/droid/platforms/ms_xdr.py +++ b/src/droid/platforms/ms_xdr.py @@ -13,10 +13,17 @@ from droid.platforms.common import get_pipeline_group_match from msal import ConfidentialClientApplication from azure.identity import DefaultAzureCredential +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey +from cryptography.hazmat.backends import default_backend + class MicrosoftXDRPlatform(AbstractPlatform): - def __init__(self, parameters: dict, logger_param: dict, export_mssp: bool=False) -> None: + def __init__( + self, parameters: dict, logger_param: dict, export_mssp: bool = False + ) -> None: super().__init__(name="Microsoft XDR") @@ -25,8 +32,8 @@ def __init__(self, parameters: dict, logger_param: dict, export_mssp: bool=False self._parameters = parameters self._export_mssp = export_mssp - if 'query_period_groups' in self._parameters['rule_parameters']: - self._query_period_groups = self._parameters['rule_parameters']['query_period_groups'] + if "query_period_groups" in self._parameters: + self._query_period_groups = self._parameters["query_period_groups"] if "query_period" not in self._parameters: raise Exception( @@ -45,13 +52,22 @@ def __init__(self, parameters: dict, logger_param: dict, export_mssp: bool=False self._query_period = self._parameters["query_period"] + if "auth_cert" in self._parameters: + self._auth_cert = self._parameters["auth_cert"] + else: + self._auth_cert = None + if "credential_file" in self._parameters: try: with open(self._parameters["credential_file"], "r") as file: credentials = yaml.safe_load(file) - self._client_id = credentials["client_id"] - self._client_secret = credentials["client_secret"] - self._tenant_id = credentials["tenant_id"] + self._client_id = credentials["client_id"] + self._client_secret = credentials["client_secret"] + self._tenant_id = credentials["tenant_id"] + if "cert_pass" in credentials: + self._cert_pass = credentials["cert_pass"] + else: + self._cert_pass = None except Exception as e: raise Exception(f"Error while reading the credential file {e}") elif "app" in ( @@ -60,6 +76,7 @@ def __init__(self, parameters: dict, logger_param: dict, export_mssp: bool=False self._tenant_id = self._parameters["tenant_id"] self._client_id = self._parameters["client_id"] self._client_secret = self._parameters["client_secret"] + self._cert_pass = self._parameters["cert_pass"] elif "default" in ( self._parameters["search_auth"] or self._parameters["export_auth"] ): @@ -82,7 +99,7 @@ def __init__(self, parameters: dict, logger_param: dict, export_mssp: bool=False else: self._alert_prefix = None - if 'export_list_mssp' in self._parameters: + if "export_list_mssp" in self._parameters: self._export_list_mssp = self._parameters["export_list_mssp"] def get_export_list_mssp(self) -> list: @@ -97,9 +114,13 @@ def run_xdr_search(self, rule_converted, rule_file, tenant_id=None): payload = {"Query": rule_converted, "Timespan": "P1D"} try: if tenant_id: - self.logger.info(f"Searching for rule {rule_file} on tenant {tenant_id}") + self.logger.info( + f"Searching for rule {rule_file} on tenant {tenant_id}" + ) else: - self.logger.info(f"Searching for rule {rule_file} on tenant {self._tenant_id}") + self.logger.info( + f"Searching for rule {rule_file} on tenant {self._tenant_id}" + ) results, status_code = self._post( url="/security/runHuntingQuery", payload=payload, tenant_id=tenant_id @@ -117,8 +138,7 @@ def run_xdr_search(self, rule_converted, rule_file, tenant_id=None): raise def get_rule(self, rule_id, tenant_id=None): - """Retrieve a scheduled alert rule in Microsoft XDR - """ + """Retrieve a scheduled alert rule in Microsoft XDR""" try: params = {"$filter": f"contains(displayName, '{rule_id}')"} rule, status_code = self._get( @@ -193,10 +213,28 @@ def acquire_token(self, tenant_id=None): else: authority = f"https://login.microsoftonline.com/{tenant_id}" # Create a confidential client application + if self._auth_cert: + with open(self._auth_cert, "rb") as file: + certificate_data = file.read() + + cert = x509.load_pem_x509_certificate( + certificate_data, default_backend() + ) + fingerprint = cert.fingerprint(hashes.SHA1()) + fingerprint = fingerprint.hex() + + client_credential = { + "private_key": certificate_data, + "thumbprint": fingerprint, + "passphrase": self._cert_pass, + } + else: + client_credential = self._client_secret + app = ConfidentialClientApplication( self._client_id, authority=authority, - client_credential=self._client_secret, + client_credential=client_credential, ) # Acquire a token @@ -209,7 +247,9 @@ def acquire_token(self, tenant_id=None): self.logger.error( f'Failed to acquire token: {result["error_description"]}' ) - raise Exception(f"Token acquisition failed: {result.get('error', 'Unknown error')}") + raise Exception( + f"Token acquisition failed: {result.get('error', 'Unknown error')}" + ) def process_query_period(self, query_period: str, rule_file: str): """Process the query period time @@ -315,15 +355,24 @@ def create_rule(self, rule_content, rule_converted, rule_file): except Exception as e: self.logger.error(e) - if 'query_period_groups' in self._parameters['rule_parameters']: - query_period_group = get_pipeline_group_match(rule_content, self._query_period_groups) + if "query_period_groups" in self._parameters: + query_period_group = get_pipeline_group_match( + rule_content, self._query_period_groups + ) if query_period_group: - self.logger.debug(f"Applying the query_period value from group {query_period_group}") - alert_rule["schedule"]["period"] = self.process_query_period(self._query_period_groups[query_period_group]['query_period'], rule_file) + self.logger.debug( + f"Applying the query_period value from group {query_period_group}" + ) + alert_rule["schedule"]["period"] = self.process_query_period( + self._query_period_groups[query_period_group]["query_period"], + rule_file, + ) if "custom" in rule_content: if "query_period" in rule_content["custom"]: - alert_rule["schedule"]["period"] = self.process_query_period(rule_content["custom"]["query_period"], rule_file) + alert_rule["schedule"]["period"] = self.process_query_period( + rule_content["custom"]["query_period"], rule_file + ) if "actions" in rule_content["custom"]: responseActions = self.parse_actions( rule_content["custom"]["actions"], rule_file=rule_file @@ -343,8 +392,10 @@ def create_rule(self, rule_content, rule_converted, rule_file): self.logger.info("Exporting to designated customers") for group, info in self._export_list_mssp.items(): - tenant_id = info['tenant_id'] - self.logger.debug(f"Exporting to tenant {tenant_id} from group id {group}") + tenant_id = info["tenant_id"] + self.logger.debug( + f"Exporting to tenant {tenant_id} from group id {group}" + ) try: self.push_detection_rule( @@ -352,7 +403,7 @@ def create_rule(self, rule_content, rule_converted, rule_file): rule_content=rule_content, rule_file=rule_file, rule_converted=rule_converted, - tenant_id=tenant_id + tenant_id=tenant_id, ) except Exception as e: self.logger.error( @@ -368,7 +419,9 @@ def create_rule(self, rule_content, rule_converted, rule_file): if error: raise else: - self.logger.error("Export list not found. Please provide the list of designated customers") + self.logger.error( + "Export list not found. Please provide the list of designated customers" + ) raise else: try: @@ -442,8 +495,12 @@ def check_rule_changes(self, existing_rule, new_rule): return False def push_detection_rule( - self, alert_rule=None, rule_content=None, rule_file=None, - rule_converted=None, tenant_id=None + self, + alert_rule=None, + rule_content=None, + rule_file=None, + rule_converted=None, + tenant_id=None, ): existing_rule = self.get_rule(rule_content["id"], tenant_id=tenant_id) if existing_rule: @@ -452,10 +509,14 @@ def push_detection_rule( return True else: api_url = f"/security/rules/detectionRules/{existing_rule['id']}" - response, status_code = self._patch(url=api_url, payload=alert_rule, tenant_id=tenant_id) + response, status_code = self._patch( + url=api_url, payload=alert_rule, tenant_id=tenant_id + ) else: api_url = "/security/rules/detectionRules" - response, status_code = self._post(url=api_url, payload=alert_rule, tenant_id=tenant_id) + response, status_code = self._post( + url=api_url, payload=alert_rule, tenant_id=tenant_id + ) if status_code == 400: self.logger.error( @@ -661,8 +722,16 @@ def parse_impactedAssets(self, impactedAssets, rule_file=None): raise return impactedAssetsList - def _get(self, url=None, headers=None, params=None, tenant_id=None, - timeout=120, max_retries=3, retry_delay=60): + def _get( + self, + url=None, + headers=None, + params=None, + tenant_id=None, + timeout=120, + max_retries=3, + retry_delay=60, + ): """ Sends a GET request to Microsoft Graph Security API with error handling and retries for specific cases like 429, 500+ errors. @@ -684,9 +753,9 @@ def _get(self, url=None, headers=None, params=None, tenant_id=None, while retry_count < max_retries: try: - response = requests.get(api_url, headers=headers, - params=params, timeout=timeout) - + response = requests.get( + api_url, headers=headers, params=params, timeout=timeout + ) if response.status_code == 429: self.logger.warning( @@ -722,9 +791,17 @@ def _get(self, url=None, headers=None, params=None, tenant_id=None, self.logger.error(f"Failed to get a valid response after {max_retries} retries") return None, 500 - - def _post(self, url=None, payload=None, headers=None, params=None, - tenant_id=None, timeout=120, max_retries=3, retry_delay=60): + def _post( + self, + url=None, + payload=None, + headers=None, + params=None, + tenant_id=None, + timeout=120, + max_retries=3, + retry_delay=60, + ): """ Sends a POST request to Microsoft Graph Security API with error handling and retries for specific cases like 429, 500+ errors. @@ -746,8 +823,9 @@ def _post(self, url=None, payload=None, headers=None, params=None, while retry_count < max_retries: try: - response = requests.post(api_url, headers=headers, - json=payload, timeout=timeout) + response = requests.post( + api_url, headers=headers, json=payload, timeout=timeout + ) if response.status_code == 429: self.logger.warning( @@ -783,8 +861,17 @@ def _post(self, url=None, payload=None, headers=None, params=None, self.logger.error(f"Failed to get a valid response after {max_retries} retries") return None, 500 - def _patch(self, url=None, payload=None, headers=None, params=None, - tenant_id=None, timeout=120, max_retries=3, retry_delay=60): + def _patch( + self, + url=None, + payload=None, + headers=None, + params=None, + tenant_id=None, + timeout=120, + max_retries=3, + retry_delay=60, + ): """ Sends a PATCH request to Microsoft Graph Security API with error handling and retries for specific cases like 429, 500+ errors. @@ -806,8 +893,9 @@ def _patch(self, url=None, payload=None, headers=None, params=None, while retry_count < max_retries: try: - response = requests.patch(api_url, headers=headers, - json=payload, timeout=timeout) + response = requests.patch( + api_url, headers=headers, json=payload, timeout=timeout + ) if response.status_code == 429: self.logger.warning(