-
Notifications
You must be signed in to change notification settings - Fork 0
/
stratifications.py
103 lines (87 loc) · 3.76 KB
/
stratifications.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import numpy as np
from itertools import product
# noinspection PyUnresolvedReferences
from pgenlib import PgenReader
from utils import pairwise
def stratify_by_tertiles(array):
# noinspection PyUnresolvedReferences
assert np.issubdtype(array.dtype, np.floating)
# Return mask of tertiles
tertiles = 0, 1/3, 2/3, 1
tertile_names = 'low', 'medium', 'high'
tertile_masks = {}
for tertile_name, (low, high) in zip(
tertile_names, pairwise(array.quantile(tertiles))):
if tertile_name == 'high':
mask = (array >= low) & (array <= high)
else:
mask = (array >= low) & (array < high)
tertile_masks[tertile_name] = mask
return tertile_masks
def stratify_binary(array):
binary_masks = {'yes': array == 1, 'no': array == 0}
assert (binary_masks['yes'] | binary_masks['no'] | np.isnan(array)).all()
return binary_masks
def stratify_BMI(exposure):
assert exposure.name == 'Body mass index (BMI)'
BMI_range_masks = {
'normal': (exposure < 25),
'overweight': ((exposure >= 25) & (exposure < 30)),
'obese': exposure >= 30
}
return BMI_range_masks
def stratify_BMI_fine_grained(exposure):
assert exposure.name == 'Body mass index (BMI)'
BMI_range_masks = {
'normal': (exposure < 25),
'overweight_low': ((exposure >= 25) & (exposure < 27.5)),
'overweight_high': ((exposure >= 27.5) & (exposure < 30)),
'obese_low': ((exposure >= 30) & (exposure < 35)),
'obese_high': exposure >= 35,
}
return BMI_range_masks
def get_exposure_stratifications(exposure):
assert exposure.name == 'Body mass index (BMI)'
if hasattr(exposure, 'fine_grained_BMI'):
exposure_masks = stratify_BMI_fine_grained(exposure)
else:
exposure_masks = stratify_BMI(exposure)
return exposure_masks
def get_PRS_stratifications(PRS):
return stratify_by_tertiles(PRS)
def get_family_history_stratifications(family_history):
return stratify_binary(family_history)
def get_stratifications(exposure=None, PRS=None, family_history=None):
# Stratify by exposure, PRS, and/or family history.
# Returns a dictionary of {stratification name: mask of individuals}
# 1. Get "marginal" strats for exposure, PRS and family history individually
marginal_stratifications = {}
if exposure is not None:
marginal_stratifications['exposure'] = \
get_exposure_stratifications(exposure)
if PRS is not None:
marginal_stratifications['PRS'] = get_PRS_stratifications(PRS)
if family_history is not None:
marginal_stratifications['family_history'] = \
get_family_history_stratifications(family_history)
# 2. Full strats are all possible combinations of marginal strats,
# anded together
stratifications = {}
for stratification_names in product(*marginal_stratifications.values()):
stratification_name = '_'.join(
f'{stratification_name}_{array_name}'
for array_name, stratification_name in zip(
marginal_stratifications, stratification_names))
masks = [marginal_stratifications[array_name][stratification_name]
for array_name, stratification_name in zip(
marginal_stratifications, stratification_names)]
stratification_mask = np.all(masks, axis=0)
stratifications[stratification_name] = stratification_mask
# 3. Print stats
individuals_per_stratification = {
stratification_name: stratification.sum()
for stratification_name, stratification in stratifications.items()}
print(f'Individuals in each stratification: '
f'{individuals_per_stratification}')
assert min(individuals_per_stratification.values()) > 50
return stratifications