Skip to content

Commit

Permalink
make SymSpell accept a EditDistance argument in the constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
mammothb committed Aug 31, 2024
1 parent c789705 commit b20be97
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 39 deletions.
28 changes: 9 additions & 19 deletions symspellpy/symspellpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
max_dictionary_edit_distance: int = 2,
prefix_length: int = 7,
count_threshold: int = 1,
distance_comparer: Optional[EditDistance] = None,
) -> None:
if max_dictionary_edit_distance < 0:
raise ValueError("max_dictionary_edit_distance cannot be negative")
Expand All @@ -95,6 +96,10 @@ def __init__(
)
if count_threshold < 0:
raise ValueError("count_threshold cannot be negative")
if distance_comparer is None:
self.distance_comparer = EditDistance(DistanceAlgorithm.DAMERAU_OSA_FAST)
else:
self.distance_comparer = distance_comparer
self._words: Dict[str, int] = {}
self._below_threshold_words: Dict[str, int] = {}
self._bigrams: Dict[str, int] = {}
Expand All @@ -104,7 +109,7 @@ def __init__(
self._max_dictionary_edit_distance = max_dictionary_edit_distance
self._prefix_length = prefix_length
self._count_threshold = count_threshold
self._distance_algorithm = DistanceAlgorithm.DAMERAU_OSA_FAST
# self._distance_algorithm = DistanceAlgorithm.DAMERAU_OSA_FAST
self._max_length = 0

@property
Expand All @@ -129,19 +134,6 @@ def deletes(self) -> Dict[str, List[str]]:
"""
return self._deletes

@property
def distance_algorithm(self) -> DistanceAlgorithm:
"""The current distance algorithm."""
return self._distance_algorithm

@distance_algorithm.setter
def distance_algorithm(self, value: DistanceAlgorithm) -> None:
if not isinstance(value, DistanceAlgorithm):
raise TypeError(
"can only assign DistanceAlgorithm type values to distance_algorithm"
)
self._distance_algorithm = value

@property
def entry_count(self) -> int:
"""Number of unique correct spelling words."""
Expand Down Expand Up @@ -445,7 +437,6 @@ def early_exit():
candidates.append(phrase[:phrase_prefix_len])
else:
candidates.append(phrase)
distance_comparer = EditDistance(self._distance_algorithm)
while candidate_pointer < len(candidates):
candidate = candidates[candidate_pointer]
candidate_pointer += 1
Expand Down Expand Up @@ -577,7 +568,7 @@ def early_exit():
if suggestion in considered_suggestions:
continue
considered_suggestions.add(suggestion)
distance = distance_comparer.compare(
distance = self.distance_comparer.compare(
phrase, suggestion, max_edit_distance_2
)
if distance < 0:
Expand Down Expand Up @@ -683,7 +674,6 @@ def lookup_compound(
)
suggestions = []
suggestion_parts: List[SuggestItem] = []
distance_comparer = EditDistance(self._distance_algorithm)

# translate every item to its best suggestion, otherwise it remains
# unchanged
Expand Down Expand Up @@ -761,7 +751,7 @@ def lookup_compound(
continue
# select best suggestion for split pair
tmp_term = f"{suggestions_1[0].term} {suggestions_2[0].term}"
tmp_distance = distance_comparer.compare(
tmp_distance = self.distance_comparer.compare(
terms_1[i], tmp_term, max_edit_distance
)
if tmp_distance < 0:
Expand Down Expand Up @@ -858,7 +848,7 @@ def lookup_compound(
joined_term = helpers.case_transfer_similar(phrase, joined_term)
suggestion = SuggestItem(
joined_term,
distance_comparer.compare(phrase, joined_term, 2**31 - 1),
self.distance_comparer.compare(phrase, joined_term, 2**31 - 1),
int(joined_count),
)
return [suggestion]
Expand Down
33 changes: 13 additions & 20 deletions tests/test_symspellpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import pytest

from symspellpy import SymSpell, Verbosity
from symspellpy.editdistance import DistanceAlgorithm
from symspellpy.abstract_distance_comparer import AbstractDistanceComparer
from symspellpy.editdistance import DistanceAlgorithm, EditDistance
from symspellpy.helpers import DictIO

FORTESTS_DIR = Path(__file__).resolve().parent / "fortests"
Expand Down Expand Up @@ -36,6 +37,11 @@ def get_dictionary_stream(request):
yield dict_stream, request.param


class CustomDistanceComparer(AbstractDistanceComparer):
def distance(self, string_1: str, string_2: str, max_distance: int) -> int:
return 0


class TestSymSpellPy:
def test_negative_max_dictionary_edit_distance(self):
with pytest.raises(ValueError) as excinfo:
Expand Down Expand Up @@ -64,26 +70,13 @@ def test_negative_count_threshold(self):
_ = SymSpell(1, 3, -1)
assert "count_threshold cannot be negative" == str(excinfo.value)

@pytest.mark.parametrize(
"algorithm",
[
DistanceAlgorithm.LEVENSHTEIN,
DistanceAlgorithm.DAMERAU_OSA,
DistanceAlgorithm.LEVENSHTEIN_FAST,
DistanceAlgorithm.DAMERAU_OSA_FAST,
],
)
def test_set_distance_algorithm(self, symspell_default, algorithm):
symspell_default.distance_algorithm = algorithm
assert algorithm == symspell_default.distance_algorithm

def test_set_invalid_distance_algorithm(self, symspell_default):
with pytest.raises(TypeError) as excinfo:
symspell_default.distance_algorithm = 1
assert (
"can only assign DistanceAlgorithm type values to distance_algorithm"
== str(excinfo.value)
def test_set_distance_comparer(self):
distance_comparer = EditDistance(
DistanceAlgorithm.USER_PROVIDED, CustomDistanceComparer()
)
sym_spell = SymSpell(distance_comparer=distance_comparer)

assert distance_comparer == sym_spell.distance_comparer

@pytest.mark.parametrize("symspell_short", [None, 0], indirect=True)
def test_create_dictionary_entry_negative_count(self, symspell_short):
Expand Down

0 comments on commit b20be97

Please sign in to comment.