-
Notifications
You must be signed in to change notification settings - Fork 0
/
masking_functions.py
66 lines (58 loc) · 2.31 KB
/
masking_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import numpy as np
arr = np.asarray
def mutate_at_random(vec, conservation_score, masking_index, mutprob = torch.tensor(.05), **kwargs):
"""
Mask indices uniformly at random
"""
n = len(vec)
mutvec = vec.detach().clone()
# Add a uniform mutation probability
weightvec = torch.ones(n)*mutprob
assert (weightvec.mean().isclose(torch.tensor(mutprob, dtype=torch.float)))
to_mutate = torch.bernoulli(weightvec)
to_mutate = to_mutate.bool()
mutvec[to_mutate, :] = 0
mutvec[to_mutate, masking_index] = 1
return mutvec, weightvec
def mutate_with_conservation_score(vec, conservation_scores, masking_index, mutprob = .05, beta = 1, **kwargs):
"""
Mask indices weighted by conservation score
"""
n = len(vec)
mutvec = vec.detach().clone()
beta = torch.Tensor([beta])
# Conservation scores range from 0 <-> 5, normalize to get to 0 <-> 1-mutprob
conservation_scores = (1-mutprob)*arr(conservation_scores)/5
# Add a uniform mutation probability
uniform_weightvec = torch.ones(n, dtype=torch.long)*mutprob # Uniform baseline weight
# Add conservation probability and uniform mutation probability
unnormalized_weightvec = ((uniform_weightvec + beta*conservation_scores)/(beta+mutprob)).float() # Increase weight by conservation scores
# Renormalize sum of probabilities to have same overall mutprob
weightvec = (unnormalized_weightvec/unnormalized_weightvec.mean())*mutprob
assert (weightvec.mean().isclose(torch.tensor([mutprob], dtype=torch.float)))
assert torch.max(weightvec) < 1, print(torch.max(weightvec))
# Indices to mask
to_mutate = torch.bernoulli(weightvec)
to_mutate = to_mutate.bool()
mutvec[to_mutate, :] = 0
mutvec[to_mutate, masking_index] = 1
return mutvec, weightvec
import random
def get_crop(mutvec, seqlength):
"""
Crop protein to be a length of seqlength
"""
# Length of protein
n = mutvec.shape[0]
# If protein > seqlength get a random position
if n > seqlength:
lastend = n - seqlength
start = random.randint(0, lastend)
end = start + seqlength
cropvec = mutvec[start:end, ]
# Else select entire protein and pad the rest
else:
cropvec = mutvec
start = 0; end = n
return cropvec, n, start, end