diff --git a/symspellpy/symspellpy.py b/symspellpy/symspellpy.py index 4b41b87..b46a342 100644 --- a/symspellpy/symspellpy.py +++ b/symspellpy/symspellpy.py @@ -84,6 +84,7 @@ def __init__( max_dictionary_edit_distance: int = 2, prefix_length: int = 7, count_threshold: int = 1, + distance_comparer: Optional[EditDistance] = None, ) -> None: if max_dictionary_edit_distance < 0: raise ValueError("max_dictionary_edit_distance cannot be negative") @@ -95,6 +96,10 @@ def __init__( ) if count_threshold < 0: raise ValueError("count_threshold cannot be negative") + if distance_comparer is None: + self.distance_comparer = EditDistance(DistanceAlgorithm.DAMERAU_OSA_FAST) + else: + self.distance_comparer = distance_comparer self._words: Dict[str, int] = {} self._below_threshold_words: Dict[str, int] = {} self._bigrams: Dict[str, int] = {} @@ -104,7 +109,7 @@ def __init__( self._max_dictionary_edit_distance = max_dictionary_edit_distance self._prefix_length = prefix_length self._count_threshold = count_threshold - self._distance_algorithm = DistanceAlgorithm.DAMERAU_OSA_FAST + # self._distance_algorithm = DistanceAlgorithm.DAMERAU_OSA_FAST self._max_length = 0 @property @@ -129,19 +134,6 @@ def deletes(self) -> Dict[str, List[str]]: """ return self._deletes - @property - def distance_algorithm(self) -> DistanceAlgorithm: - """The current distance algorithm.""" - return self._distance_algorithm - - @distance_algorithm.setter - def distance_algorithm(self, value: DistanceAlgorithm) -> None: - if not isinstance(value, DistanceAlgorithm): - raise TypeError( - "can only assign DistanceAlgorithm type values to distance_algorithm" - ) - self._distance_algorithm = value - @property def entry_count(self) -> int: """Number of unique correct spelling words.""" @@ -445,7 +437,6 @@ def early_exit(): candidates.append(phrase[:phrase_prefix_len]) else: candidates.append(phrase) - distance_comparer = EditDistance(self._distance_algorithm) while candidate_pointer < len(candidates): candidate = candidates[candidate_pointer] candidate_pointer += 1 @@ -577,7 +568,7 @@ def early_exit(): if suggestion in considered_suggestions: continue considered_suggestions.add(suggestion) - distance = distance_comparer.compare( + distance = self.distance_comparer.compare( phrase, suggestion, max_edit_distance_2 ) if distance < 0: @@ -683,7 +674,6 @@ def lookup_compound( ) suggestions = [] suggestion_parts: List[SuggestItem] = [] - distance_comparer = EditDistance(self._distance_algorithm) # translate every item to its best suggestion, otherwise it remains # unchanged @@ -761,7 +751,7 @@ def lookup_compound( continue # select best suggestion for split pair tmp_term = f"{suggestions_1[0].term} {suggestions_2[0].term}" - tmp_distance = distance_comparer.compare( + tmp_distance = self.distance_comparer.compare( terms_1[i], tmp_term, max_edit_distance ) if tmp_distance < 0: @@ -858,7 +848,7 @@ def lookup_compound( joined_term = helpers.case_transfer_similar(phrase, joined_term) suggestion = SuggestItem( joined_term, - distance_comparer.compare(phrase, joined_term, 2**31 - 1), + self.distance_comparer.compare(phrase, joined_term, 2**31 - 1), int(joined_count), ) return [suggestion] diff --git a/tests/test_symspellpy.py b/tests/test_symspellpy.py index 08cf2f4..ab48be6 100644 --- a/tests/test_symspellpy.py +++ b/tests/test_symspellpy.py @@ -4,7 +4,8 @@ import pytest from symspellpy import SymSpell, Verbosity -from symspellpy.editdistance import DistanceAlgorithm +from symspellpy.abstract_distance_comparer import AbstractDistanceComparer +from symspellpy.editdistance import DistanceAlgorithm, EditDistance from symspellpy.helpers import DictIO FORTESTS_DIR = Path(__file__).resolve().parent / "fortests" @@ -36,6 +37,11 @@ def get_dictionary_stream(request): yield dict_stream, request.param +class CustomDistanceComparer(AbstractDistanceComparer): + def distance(self, string_1: str, string_2: str, max_distance: int) -> int: + return 0 + + class TestSymSpellPy: def test_negative_max_dictionary_edit_distance(self): with pytest.raises(ValueError) as excinfo: @@ -64,26 +70,13 @@ def test_negative_count_threshold(self): _ = SymSpell(1, 3, -1) assert "count_threshold cannot be negative" == str(excinfo.value) - @pytest.mark.parametrize( - "algorithm", - [ - DistanceAlgorithm.LEVENSHTEIN, - DistanceAlgorithm.DAMERAU_OSA, - DistanceAlgorithm.LEVENSHTEIN_FAST, - DistanceAlgorithm.DAMERAU_OSA_FAST, - ], - ) - def test_set_distance_algorithm(self, symspell_default, algorithm): - symspell_default.distance_algorithm = algorithm - assert algorithm == symspell_default.distance_algorithm - - def test_set_invalid_distance_algorithm(self, symspell_default): - with pytest.raises(TypeError) as excinfo: - symspell_default.distance_algorithm = 1 - assert ( - "can only assign DistanceAlgorithm type values to distance_algorithm" - == str(excinfo.value) + def test_set_distance_comparer(self): + distance_comparer = EditDistance( + DistanceAlgorithm.USER_PROVIDED, CustomDistanceComparer() ) + sym_spell = SymSpell(distance_comparer=distance_comparer) + + assert distance_comparer == sym_spell.distance_comparer @pytest.mark.parametrize("symspell_short", [None, 0], indirect=True) def test_create_dictionary_entry_negative_count(self, symspell_short):