Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding cert auth for xdr #22

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 4 additions & 19 deletions src/droid/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def droid_platform_config(args, config_path):
if "export_auth" in config and config["export_auth"] not in auth_methods:
raise ValueError(f"Invalid export_auth: {config['export_auth']}")

if config["search_auth"] == "app" and not "credential_file" in config:
if (config["search_auth"] == "app" and not "credential_file" in config) or (config["export_auth"] == "app" and args.export and not "credential_file" in config):

if environ.get("DROID_AZURE_TENANT_ID"):
tenant_id = environ.get("DROID_AZURE_TENANT_ID")
Expand All @@ -219,25 +219,10 @@ def droid_platform_config(args, config_path):
else:
raise Exception("Please use: export DROID_AZURE_CLIENT_SECRET=<client_secret>")

elif config["export_auth"] == "app" and args.export and not "credential_file" in config:

if environ.get("DROID_AZURE_TENANT_ID"):
tenant_id = environ.get("DROID_AZURE_TENANT_ID")
config["tenant_id"] = tenant_id
if environ.get("DROID_AZURE_CERT_PASS"):
config["cert_pass"] = environ.get("DROID_AZURE_CERT_PASS")
else:
raise Exception("Please use: export DROID_AZURE_TENANT_ID=<tenant_id>")

if environ.get("DROID_AZURE_CLIENT_ID"):
client_id = environ.get("DROID_AZURE_CLIENT_ID")
config["client_id"] = client_id
else:
raise Exception("Please use: export DROID_AZURE_CLIENT_ID=<client_id>")

if environ.get("DROID_AZURE_CLIENT_SECRET"):
client_secret = environ.get("DROID_AZURE_CLIENT_SECRET")
config["client_secret"] = client_secret
else:
raise Exception("Please use: export DROID_AZURE_CLIENT_SECRET=<client_secret>")
config["cert_pass"] = None

return config

Expand Down
168 changes: 128 additions & 40 deletions src/droid/platforms/ms_xdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,17 @@
from droid.platforms.common import get_pipeline_group_match
from msal import ConfidentialClientApplication
from azure.identity import DefaultAzureCredential
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.hazmat.backends import default_backend


class MicrosoftXDRPlatform(AbstractPlatform):

def __init__(self, parameters: dict, logger_param: dict, export_mssp: bool=False) -> None:
def __init__(
self, parameters: dict, logger_param: dict, export_mssp: bool = False
) -> None:

super().__init__(name="Microsoft XDR")

Expand All @@ -25,8 +32,8 @@ def __init__(self, parameters: dict, logger_param: dict, export_mssp: bool=False
self._parameters = parameters
self._export_mssp = export_mssp

if 'query_period_groups' in self._parameters['rule_parameters']:
self._query_period_groups = self._parameters['rule_parameters']['query_period_groups']
if "query_period_groups" in self._parameters:
self._query_period_groups = self._parameters["query_period_groups"]

if "query_period" not in self._parameters:
raise Exception(
Expand All @@ -45,13 +52,22 @@ def __init__(self, parameters: dict, logger_param: dict, export_mssp: bool=False

self._query_period = self._parameters["query_period"]

if "auth_cert" in self._parameters:
self._auth_cert = self._parameters["auth_cert"]
else:
self._auth_cert = None

if "credential_file" in self._parameters:
try:
with open(self._parameters["credential_file"], "r") as file:
credentials = yaml.safe_load(file)
self._client_id = credentials["client_id"]
self._client_secret = credentials["client_secret"]
self._tenant_id = credentials["tenant_id"]
self._client_id = credentials["client_id"]
self._client_secret = credentials["client_secret"]
self._tenant_id = credentials["tenant_id"]
if "cert_pass" in credentials:
self._cert_pass = credentials["cert_pass"]
else:
self._cert_pass = None
except Exception as e:
raise Exception(f"Error while reading the credential file {e}")
elif "app" in (
Expand All @@ -60,6 +76,7 @@ def __init__(self, parameters: dict, logger_param: dict, export_mssp: bool=False
self._tenant_id = self._parameters["tenant_id"]
self._client_id = self._parameters["client_id"]
self._client_secret = self._parameters["client_secret"]
self._cert_pass = self._parameters["cert_pass"]
elif "default" in (
self._parameters["search_auth"] or self._parameters["export_auth"]
):
Expand All @@ -82,7 +99,7 @@ def __init__(self, parameters: dict, logger_param: dict, export_mssp: bool=False
else:
self._alert_prefix = None

if 'export_list_mssp' in self._parameters:
if "export_list_mssp" in self._parameters:
self._export_list_mssp = self._parameters["export_list_mssp"]

def get_export_list_mssp(self) -> list:
Expand All @@ -97,9 +114,13 @@ def run_xdr_search(self, rule_converted, rule_file, tenant_id=None):
payload = {"Query": rule_converted, "Timespan": "P1D"}
try:
if tenant_id:
self.logger.info(f"Searching for rule {rule_file} on tenant {tenant_id}")
self.logger.info(
f"Searching for rule {rule_file} on tenant {tenant_id}"
)
else:
self.logger.info(f"Searching for rule {rule_file} on tenant {self._tenant_id}")
self.logger.info(
f"Searching for rule {rule_file} on tenant {self._tenant_id}"
)

results, status_code = self._post(
url="/security/runHuntingQuery", payload=payload, tenant_id=tenant_id
Expand All @@ -117,8 +138,7 @@ def run_xdr_search(self, rule_converted, rule_file, tenant_id=None):
raise

def get_rule(self, rule_id, tenant_id=None):
"""Retrieve a scheduled alert rule in Microsoft XDR
"""
"""Retrieve a scheduled alert rule in Microsoft XDR"""
try:
params = {"$filter": f"contains(displayName, '{rule_id}')"}
rule, status_code = self._get(
Expand Down Expand Up @@ -193,10 +213,28 @@ def acquire_token(self, tenant_id=None):
else:
authority = f"https://login.microsoftonline.com/{tenant_id}"
# Create a confidential client application
if self._auth_cert:
with open(self._auth_cert, "rb") as file:
certificate_data = file.read()

cert = x509.load_pem_x509_certificate(
certificate_data, default_backend()
)
fingerprint = cert.fingerprint(hashes.SHA1())
fingerprint = fingerprint.hex()

client_credential = {
"private_key": certificate_data,
"thumbprint": fingerprint,
"passphrase": self._cert_pass,
}
else:
client_credential = self._client_secret

app = ConfidentialClientApplication(
self._client_id,
authority=authority,
client_credential=self._client_secret,
client_credential=client_credential,
)

# Acquire a token
Expand All @@ -209,7 +247,9 @@ def acquire_token(self, tenant_id=None):
self.logger.error(
f'Failed to acquire token: {result["error_description"]}'
)
raise Exception(f"Token acquisition failed: {result.get('error', 'Unknown error')}")
raise Exception(
f"Token acquisition failed: {result.get('error', 'Unknown error')}"
)

def process_query_period(self, query_period: str, rule_file: str):
"""Process the query period time
Expand Down Expand Up @@ -315,15 +355,24 @@ def create_rule(self, rule_content, rule_converted, rule_file):
except Exception as e:
self.logger.error(e)

if 'query_period_groups' in self._parameters['rule_parameters']:
query_period_group = get_pipeline_group_match(rule_content, self._query_period_groups)
if "query_period_groups" in self._parameters:
query_period_group = get_pipeline_group_match(
rule_content, self._query_period_groups
)
if query_period_group:
self.logger.debug(f"Applying the query_period value from group {query_period_group}")
alert_rule["schedule"]["period"] = self.process_query_period(self._query_period_groups[query_period_group]['query_period'], rule_file)
self.logger.debug(
f"Applying the query_period value from group {query_period_group}"
)
alert_rule["schedule"]["period"] = self.process_query_period(
self._query_period_groups[query_period_group]["query_period"],
rule_file,
)

if "custom" in rule_content:
if "query_period" in rule_content["custom"]:
alert_rule["schedule"]["period"] = self.process_query_period(rule_content["custom"]["query_period"], rule_file)
alert_rule["schedule"]["period"] = self.process_query_period(
rule_content["custom"]["query_period"], rule_file
)
if "actions" in rule_content["custom"]:
responseActions = self.parse_actions(
rule_content["custom"]["actions"], rule_file=rule_file
Expand All @@ -343,16 +392,18 @@ def create_rule(self, rule_content, rule_converted, rule_file):
self.logger.info("Exporting to designated customers")
for group, info in self._export_list_mssp.items():

tenant_id = info['tenant_id']
self.logger.debug(f"Exporting to tenant {tenant_id} from group id {group}")
tenant_id = info["tenant_id"]
self.logger.debug(
f"Exporting to tenant {tenant_id} from group id {group}"
)

try:
self.push_detection_rule(
alert_rule=alert_rule,
rule_content=rule_content,
rule_file=rule_file,
rule_converted=rule_converted,
tenant_id=tenant_id
tenant_id=tenant_id,
)
except Exception as e:
self.logger.error(
Expand All @@ -368,7 +419,9 @@ def create_rule(self, rule_content, rule_converted, rule_file):
if error:
raise
else:
self.logger.error("Export list not found. Please provide the list of designated customers")
self.logger.error(
"Export list not found. Please provide the list of designated customers"
)
raise
else:
try:
Expand Down Expand Up @@ -442,8 +495,12 @@ def check_rule_changes(self, existing_rule, new_rule):
return False

def push_detection_rule(
self, alert_rule=None, rule_content=None, rule_file=None,
rule_converted=None, tenant_id=None
self,
alert_rule=None,
rule_content=None,
rule_file=None,
rule_converted=None,
tenant_id=None,
):
existing_rule = self.get_rule(rule_content["id"], tenant_id=tenant_id)
if existing_rule:
Expand All @@ -452,10 +509,14 @@ def push_detection_rule(
return True
else:
api_url = f"/security/rules/detectionRules/{existing_rule['id']}"
response, status_code = self._patch(url=api_url, payload=alert_rule, tenant_id=tenant_id)
response, status_code = self._patch(
url=api_url, payload=alert_rule, tenant_id=tenant_id
)
else:
api_url = "/security/rules/detectionRules"
response, status_code = self._post(url=api_url, payload=alert_rule, tenant_id=tenant_id)
response, status_code = self._post(
url=api_url, payload=alert_rule, tenant_id=tenant_id
)

if status_code == 400:
self.logger.error(
Expand Down Expand Up @@ -661,8 +722,16 @@ def parse_impactedAssets(self, impactedAssets, rule_file=None):
raise
return impactedAssetsList

def _get(self, url=None, headers=None, params=None, tenant_id=None,
timeout=120, max_retries=3, retry_delay=60):
def _get(
self,
url=None,
headers=None,
params=None,
tenant_id=None,
timeout=120,
max_retries=3,
retry_delay=60,
):
"""
Sends a GET request to Microsoft Graph Security API with error handling
and retries for specific cases like 429, 500+ errors.
Expand All @@ -684,9 +753,9 @@ def _get(self, url=None, headers=None, params=None, tenant_id=None,

while retry_count < max_retries:
try:
response = requests.get(api_url, headers=headers,
params=params, timeout=timeout)

response = requests.get(
api_url, headers=headers, params=params, timeout=timeout
)

if response.status_code == 429:
self.logger.warning(
Expand Down Expand Up @@ -722,9 +791,17 @@ def _get(self, url=None, headers=None, params=None, tenant_id=None,
self.logger.error(f"Failed to get a valid response after {max_retries} retries")
return None, 500


def _post(self, url=None, payload=None, headers=None, params=None,
tenant_id=None, timeout=120, max_retries=3, retry_delay=60):
def _post(
self,
url=None,
payload=None,
headers=None,
params=None,
tenant_id=None,
timeout=120,
max_retries=3,
retry_delay=60,
):
"""
Sends a POST request to Microsoft Graph Security API with error handling
and retries for specific cases like 429, 500+ errors.
Expand All @@ -746,8 +823,9 @@ def _post(self, url=None, payload=None, headers=None, params=None,

while retry_count < max_retries:
try:
response = requests.post(api_url, headers=headers,
json=payload, timeout=timeout)
response = requests.post(
api_url, headers=headers, json=payload, timeout=timeout
)

if response.status_code == 429:
self.logger.warning(
Expand Down Expand Up @@ -783,8 +861,17 @@ def _post(self, url=None, payload=None, headers=None, params=None,
self.logger.error(f"Failed to get a valid response after {max_retries} retries")
return None, 500

def _patch(self, url=None, payload=None, headers=None, params=None,
tenant_id=None, timeout=120, max_retries=3, retry_delay=60):
def _patch(
self,
url=None,
payload=None,
headers=None,
params=None,
tenant_id=None,
timeout=120,
max_retries=3,
retry_delay=60,
):
"""
Sends a PATCH request to Microsoft Graph Security API with error handling
and retries for specific cases like 429, 500+ errors.
Expand All @@ -806,8 +893,9 @@ def _patch(self, url=None, payload=None, headers=None, params=None,

while retry_count < max_retries:
try:
response = requests.patch(api_url, headers=headers,
json=payload, timeout=timeout)
response = requests.patch(
api_url, headers=headers, json=payload, timeout=timeout
)

if response.status_code == 429:
self.logger.warning(
Expand Down