Skip to content

Commit

Permalink
Invalidate only cached entries related to scores. (#763)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Jun 15, 2021
1 parent 327b084 commit dd3da2a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
22 changes: 20 additions & 2 deletions k2/python/k2/fsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Union

import os
import re
import shutil
import torch

Expand Down Expand Up @@ -215,16 +216,33 @@ def __init__(
# the FSA is valid.
_ = self.properties

def _invalidate_cache_(self):
def _invalidate_cache_(self, scores_only: bool = True) -> None:
'''Intended for internal use only so its
name begins with an underline.
Also, it changes `self` in-place.
Currently, it is used only when the `scores` field
are re-assigned.
Args:
scores_only:
It True, it invalidates only cached entries related
to scores. If False, the whole cache is invalidated.
'''
self.__dict__['_cache'] = dict()
if scores_only is False:
self.__dict__['_cache'] = dict()
else:
pattern = re.compile('score')
to_remove = []

for key in self.__dict__['_cache']:
if pattern.search(key):
to_remove.append(key)

for key in to_remove:
del self.__dict__['_cache'][key]

def to_str(self, openfst: bool = False) -> str:
extra_labels = []
Expand Down
18 changes: 18 additions & 0 deletions k2/python/tests/fsa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,24 @@ def test_convert_attr_to_ragged(self):
expected = k2.RaggedInt('[ [1] [] [-1] ]')
assert str(fsa.tensor_attr2) == str(expected)

def test_invalidate_cache(self):
s = '''
0 1 1 0.1
1 2 -1 0.2
2
'''
fsa = k2.Fsa.from_str(s)
fsa = k2.create_fsa_vec([fsa])
fsa.get_tot_scores(True, True)

assert 'forward_scores_double_log' in fsa._cache
assert 'state_batches' in fsa._cache

fsa.scores *= 2

assert 'forward_scores_double_log' not in fsa._cache
assert 'state_batches' in fsa._cache


if __name__ == '__main__':
unittest.main()

0 comments on commit dd3da2a

Please sign in to comment.