Skip to content

Commit

Permalink
LLM offsets logic consolidate w/ checks and test case fix (pytorch#1422)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1422

Consolidate offsets logic with extra checks to one function. May be used to later group data in gradient LLM attribution. Test case fixed as a result of checks.

Reviewed By: cyrjano

Differential Revision: D65010820

fbshipit-source-id: a88cde9decf1c850dcd16dc2c5aacf5c4e8cd4f2
  • Loading branch information
craymichael authored and facebook-github-bot committed Oct 29, 2024
1 parent 638b920 commit 492ae0e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 16 deletions.
61 changes: 45 additions & 16 deletions captum/attr/_core/llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,42 @@ def _clean_up_pretty_token(token: str) -> str:
return token.replace("\n", "\\n").strip()


def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List[str]:
def _encode_with_offsets(
txt: str,
tokenizer: TokenizerLike,
add_special_tokens: bool = True,
**kwargs: Any,
) -> Tuple[List[int], List[Tuple[int, int]]]:
enc = tokenizer(
txt,
return_offsets_mapping=True,
add_special_tokens=add_special_tokens,
**kwargs,
)
input_ids = cast(List[int], enc["input_ids"])
offset_mapping = cast(List[Tuple[int, int]], enc["offset_mapping"])
assert len(input_ids) == len(offset_mapping), (
f"{len(input_ids)} != {len(offset_mapping)}: {txt} -> "
f"{input_ids}, {offset_mapping}"
)
# For the case where offsets are not set properly (the end and start are
# equal for all tokens - fall back on the start of the next span in the
# offset mapping)
offset_mapping_corrected = []
for i, (start, end) in enumerate(offset_mapping):
if start == end:
if (i + 1) < len(offset_mapping):
end = offset_mapping[i + 1][0]
else:
end = len(txt)
offset_mapping_corrected.append((start, end))
return input_ids, offset_mapping_corrected


def _convert_ids_to_pretty_tokens(
ids: Tensor,
tokenizer: TokenizerLike,
) -> List[str]:
"""
Convert ids to tokens without ugly unicode characters (e.g., Ġ). See:
https://github.com/huggingface/transformers/issues/4786 and
Expand All @@ -241,32 +276,26 @@ def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List
> used spaces in its process
"""
txt = tokenizer.decode(ids)
input_ids: Optional[List[int]] = None
# Don't add special tokens (they're either already there, or we don't want them)
enc = tokenizer(txt, return_offsets_mapping=True, add_special_tokens=False)
input_ids = cast(List[int], enc["input_ids"])
offset_mapping = cast(List[Tuple[int, int]], enc["offset_mapping"])
input_ids, offset_mapping = _encode_with_offsets(
txt, tokenizer, add_special_tokens=False
)

pretty_tokens = []
end_prev = -1
idx = 0
for i, (input_id, offset) in enumerate(zip(input_ids, offset_mapping)):
for i, offset in enumerate(offset_mapping):
start, end = offset
if start == end:
# For the case where offsets are not set properly (the end and start are
# equal for all tokens - fall back on the start of the next span in the
# offset mapping)
if (i + 1) < len(input_ids):
end = offset_mapping[i + 1][0]
else:
end = len(txt)
if input_id != ids[idx]:
if input_ids[i] != ids[idx]:
# When the re-encoded string doesn't match the original encoding we skip
# this token and hope for the best, falling back on a naive method. This
# can happen when a tokenizer might add a token that corresponds to
# a space only when add_special_tokens=False.
warnings.warn(
f"(i={i}) input_id {input_id} != ids[idx] {ids[idx]} (corresponding "
f"to text: {repr(txt[start:end])}). Skipping this token.",
f"(i={i}, idx={idx}) input_ids[i] {input_ids[i]} != ids[idx] "
f"{ids[idx]} (corresponding to text: {repr(txt[start:end])}). "
"Skipping this token.",
stacklevel=2,
)
continue
Expand Down
2 changes: 2 additions & 0 deletions tests/attr/test_llm_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def __call__(

if return_offsets_mapping:
offset_mapping = []
if add_special_tokens:
offset_mapping.append((0, 0))
idx = 0
for token in text.split(" "):
offset_mapping.append((idx - (0 if idx == 0 else 1), idx + len(token)))
Expand Down

0 comments on commit 492ae0e

Please sign in to comment.