Skip to content

Commit

Permalink
allow for overriding all dropouts, as well as convenience method for …
Browse files Browse the repository at this point in the history
…dynamically setting target crop length
  • Loading branch information
lucidrains committed Dec 30, 2021
1 parent 2bf8213 commit 18614f7
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
6 changes: 5 additions & 1 deletion enformer_pytorch/enformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ def __init__(
heads = 8,
output_heads = dict(human = 5313, mouse= 1643),
target_length = TARGET_LENGTH,
dropout_rate = 0.4,
num_alphabet = 4,
attn_dim_key = 64,
dropout_rate = 0.4,
attn_dropout = 0.05,
pos_dropout = 0.01
):
Expand Down Expand Up @@ -359,6 +359,10 @@ def __init__(
nn.Softplus()
), output_heads))

def set_target_length(self, target_length):
crop_module = self._trunk[-2]
crop_module.target_length = target_length

@property
def trunk(self):
return self._trunk
Expand Down
7 changes: 3 additions & 4 deletions enformer_pytorch/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ def remove_nones(d):
def load_pretrained_model(
slug,
force = False,
target_length = None,
dropout_rate = None,
model = None
model = None,
**kwargs
):
if slug not in CONFIG:
print(f'model {slug} not found among available choices: [{", ".join(CONFIG.keys())}]')
Expand All @@ -58,7 +57,7 @@ def load_pretrained_model(

# load

override_params = remove_nones({'target_length': target_length, 'dropout_rate': dropout_rate})
override_params = remove_nones(kwargs)
params = {**config['params'], **override_params}

if not exists(model):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'enformer-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.1.19',
version = '0.1.20',
license='MIT',
description = 'Enformer - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 18614f7

Please sign in to comment.