This repository has been archived by the owner on Apr 1, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 5
/
stats_getter.py
144 lines (122 loc) · 5.09 KB
/
stats_getter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# tfidf_to_heatmaps.py - computes the tf idf score for each molecule and returns the heatmap representation.
# author: Matthew Wampler-Doty
# author: Andrea Cadeddu
#
import sys, json
from collections import defaultdict
from math import log
from rdkit import Chem
def label_mol(mol):
for a in mol.GetAtoms():
a.SetProp('molAtomMapNumber', str(a.GetIdx() + 1))
def weight_bonds(words, idfs, mol):
# Compute the word matches for bonds in the molecule
from rdkit.Chem import rdMolDescriptors
bond_words = defaultdict(set)
# Keep track of counts for use in TF-IDF later
doc_word_counts = defaultdict(float)
# avg_weight = defaultdict({'fragno':0,'avg':0.0,'median':0.0,'stddev':0.0})
import numpy as np
z = np.array([])
for i, w in enumerate(words):
mol_matches = mol.GetSubstructMatches(w, uniquify=False)
if mol_matches:
doc_word_counts[i] += len(mol_matches)
z=np.append(z, rdMolDescriptors.CalcExactMolWt(w, onlyHeavy=True))
for m in mol_matches:
cmap = dict(enumerate(m))
for b in w.GetBonds():
start = b.GetBeginAtomIdx()
end = b.GetEndAtomIdx()
bond_words[frozenset([cmap[start], cmap[end]])].add(i)
#FIXME:this
print "{bondsno}\t{fragno}\t{avg}\t{median}\t{stddev}\t{mw}".format(**{'bondsno':len(mol.GetBonds()),
'fragno':len(doc_word_counts.keys()),
'avg':np.mean(z),
'median':np.median(z),
'stddev':np.std(z),
'mw':rdMolDescriptors.CalcExactMolWt(mol),
'mw_heavyonly':rdMolDescriptors.CalcExactMolWt(mol, onlyHeavy=True)})
# Compute the maximal words
words_to_use = doc_word_counts.keys()
# Compute the TF-IDF scores for each word
maxtf = float(max(doc_word_counts[t] for t in words_to_use))
score = defaultdict(float, ((t, doc_word_counts[t] / maxtf * idfs[t])
for t in words_to_use))
# Get the combined TF-IDF scores for each bond
bond_weights = dict((k, sum(score[t] for t in v))
for k, v in bond_words.items())
# Return the bond values
return bond_weights
def get_idfs(counts, docs):
"""Computes the IDF scores given a table of counts"""
return [log(float(docs) / float(c)) for c in counts]
def get_color_pairs(mol, bw, maxes, mins):
cp = []
palette = {
mins[0]: "FF0000",
mins[1]: "C76939",
mins[2]: "FFDA45",
maxes[0]: "00FF00",
maxes[1]: "00AA00",
maxes[2]: "005500"
}
for x in mol.GetBonds():
start = x.GetBeginAtomIdx()
end = x.GetEndAtomIdx()
score = bw.get(frozenset([start, end]), 0.0)
cp.append(
{'a': start + 1, 'b': end + 1, 'color': palette.get(score, "000000")})
return cp
if __name__ == "__main__":
# arguments:
# 1) fragment dictionary
# 2) number of documents
# 3) input file (list of smiles)
# 4) output file.json
smart_stats_fn = sys.argv[1]
numdocs = float(sys.argv[2])
# Get all the mols with counts from smarts file
words = []
counts = []
smarts = []
with open(smart_stats_fn) as smart_stats:
for line in smart_stats:
c, s = line.strip().split('\t')
words.append(Chem.MolFromSmarts(s, mergeHs=True))
counts.append(float(c))
smarts.append(s)
# Compute the idf scores
idfs = get_idfs(counts, numdocs)
data = []
with open(sys.argv[3], 'r') as inp:
i = 0
for line in inp.readlines():
i += 1
try:
mol = Chem.MolFromSmiles(line)
bweights = weight_bonds(words, idfs, mol)
# needs at least 6 diff kind of bonds!!
allweights = sorted(list(set(bweights.values())))
maxweights = allweights[-3:] # top3
minweights = allweights[:3] # bottom 3
# Label the atoms
label_mol(mol)
smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
data.append({
"format": {
'width': 500,
'height': 500,
'fname': str(i) + ".svg"
},
"data": {
'smiles': smiles,
'pairs': get_color_pairs(mol, bweights, maxweights, minweights),
}
})
except Exception as e:
pass
# #print "line", i, line, "failed."
with open(sys.argv[4], 'w') as outfile:
json.dump({'entries': data}, outfile, sort_keys=True,
indent=4, separators=(',', ': '))