Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bnewm0609/layer slice #44

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions papermage/magelib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .entity import Entity
from .image import Image
from .indexer import EntityBoxIndexer, EntitySpanIndexer
from .layer import Layer
from .metadata import Metadata
from .span import Span

Expand Down Expand Up @@ -70,6 +71,7 @@
"ImagesFieldName",
"KeywordsFieldName",
"KeywordsFieldName",
"Layer",
"ListsFieldName",
"Metadata",
"MetadataFieldName",
Expand Down
6 changes: 3 additions & 3 deletions papermage/magelib/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ def to_json(self) -> Union[Dict, List]:
def from_json(cls, annotation_json: Union[Dict, List]) -> "Annotation":
pass

def __getattr__(self, field: str) -> List["Annotation"]:
def __getattr__(self, field: str) -> "Layer":
"""This Overloading is convenient syntax since the `entity.layer` operation is intuitive for folks."""
try:
return self.find_by_span(field=field)
except ValueError:
# maybe users just want some attribute of the Annotation object
return self.__getattribute__(field)

def find_by_span(self, field: str) -> List["Annotation"]:
def find_by_span(self, field: str) -> "Layer":
"""This method allows you to access overlapping Annotations
within the Document based on Span"""
if self.doc is None:
Expand All @@ -82,7 +82,7 @@ def find_by_span(self, field: str) -> List["Annotation"]:
else:
raise ValueError(f"Field {field} not found in Document")

def find_by_box(self, field: str) -> List["Annotation"]:
def find_by_box(self, field: str) -> "Layer":
"""This method allows you to access overlapping Annotations
within the Document based on Box"""

Expand Down
19 changes: 10 additions & 9 deletions papermage/magelib/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .span import Span
from .box import Box
from .image import Image
from .layer import Layer
from .metadata import Metadata
from .entity import Entity
from .indexer import EntitySpanIndexer, EntityBoxIndexer
Expand Down Expand Up @@ -67,21 +68,21 @@ def __init__(self, symbols: str, metadata: Optional[Metadata] = None):
def fields(self) -> List[str]:
return list(self.__entity_span_indexers.keys()) + self.SPECIAL_FIELDS

def find(self, query: Union[Span, Box], field_name: str) -> List[Entity]:
def find(self, query: Union[Span, Box], field_name: str) -> Layer:
if isinstance(query, Span):
return self.__entity_span_indexers[field_name].find(query=Entity(spans=[query]))
return Layer(self.__entity_span_indexers[field_name].find(query=Entity(spans=[query])))
elif isinstance(query, Box):
return self.__entity_box_indexers[field_name].find(query=Entity(boxes=[query]))
return Layer(self.__entity_box_indexers[field_name].find(query=Entity(boxes=[query])))
else:
raise TypeError(f"Unsupported query type {type(query)}")

def find_by_span(self, query: Entity, field_name: str) -> List[Entity]:
def find_by_span(self, query: Entity, field_name: str) -> Layer:
# TODO: will rename this to `intersect_by_span`
return self.__entity_span_indexers[field_name].find(query=query)
return Layer(self.__entity_span_indexers[field_name].find(query=query))

def find_by_box(self, query: Entity, field_name: str) -> List[Entity]:
def find_by_box(self, query: Entity, field_name: str) -> Layer:
# TODO: will rename this to `intersect_by_span`
return self.__entity_box_indexers[field_name].find(query=query)
return Layer(self.__entity_box_indexers[field_name].find(query=query))

def check_field_name_availability(self, field_name: str) -> None:
if field_name in self.SPECIAL_FIELDS:
Expand All @@ -91,7 +92,7 @@ def check_field_name_availability(self, field_name: str) -> None:
if field_name in dir(self):
raise AssertionError(f"{field_name} clashes with Document class properties.")

def get_entity(self, field_name: str) -> List[Entity]:
def get_entity(self, field_name: str) -> Layer:
return getattr(self, field_name)

def annotate(self, *predictions: Union[Prediction, Tuple[Prediction, ...]]) -> None:
Expand All @@ -108,7 +109,7 @@ def annotate_entity(self, field_name: str, entities: List[Entity]) -> None:

self.__entity_span_indexers[field_name] = EntitySpanIndexer(entities=entities)
self.__entity_box_indexers[field_name] = EntityBoxIndexer(entities=entities)
setattr(self, field_name, entities)
setattr(self, field_name, Layer(entities))

def remove_entity(self, field_name: str):
for entity in getattr(self, field_name):
Expand Down
24 changes: 24 additions & 0 deletions papermage/magelib/layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from collections import UserList
from typing import List, Union, Any

from .entity import Entity

class Layer(UserList):
"""Wraps a list of entities"""
def __init__(self, entities: List[Entity] = None):
if entities is None:
super().__init__()
else:
super().__init__(entities)

def __getitem__(self, index: Union[int, slice]) -> Union[Entity, "Layer"]:
if isinstance(index, int):
return self.data[index]
else:
return Layer(self.data[index])

def __getattr__(self, field: str) -> "Layer":
return Layer([
getattr(entity, field) for entity in self.data
])

82 changes: 41 additions & 41 deletions tests/test_magelib/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,43 +111,43 @@ def test_cross_referencing(self):
doc.annotate_entity(field_name="chunks", entities=chunks)

# find by span is the default overload of Entity.__attr__
self.assertListEqual(doc.chunks[0].tokens, tokens[0:3])
self.assertListEqual(doc.chunks[1].tokens, tokens[3:5])
self.assertListEqual(doc.chunks[2].tokens, [tokens[5]])
self.assertSequenceEqual(doc.chunks[0].tokens, tokens[0:3])
self.assertSequenceEqual(doc.chunks[1].tokens, tokens[3:5])
self.assertSequenceEqual(doc.chunks[2].tokens, [tokens[5]])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assertListEqual enforces that both arguments are list, while assertSequenceEqual does not. Because the first argument is a Layer, we use assertSequenceEqual.

# backwards
self.assertListEqual(doc.tokens[0].chunks, [chunks[0]])
self.assertListEqual(doc.tokens[1].chunks, [chunks[0]])
self.assertListEqual(doc.tokens[2].chunks, [chunks[0]])
self.assertListEqual(doc.tokens[3].chunks, [chunks[1]])
self.assertListEqual(doc.tokens[4].chunks, [chunks[1]])
self.assertListEqual(doc.tokens[5].chunks, [chunks[2]])
self.assertSequenceEqual(doc.tokens[0].chunks, [chunks[0]])
self.assertSequenceEqual(doc.tokens[1].chunks, [chunks[0]])
self.assertSequenceEqual(doc.tokens[2].chunks, [chunks[0]])
self.assertSequenceEqual(doc.tokens[3].chunks, [chunks[1]])
self.assertSequenceEqual(doc.tokens[4].chunks, [chunks[1]])
self.assertSequenceEqual(doc.tokens[5].chunks, [chunks[2]])

# find by span works fine
self.assertListEqual(doc.chunks[0].tokens, doc.find_by_span(query=doc.chunks[0], field_name="tokens"))
self.assertListEqual(doc.chunks[1].tokens, doc.find_by_span(query=doc.chunks[1], field_name="tokens"))
self.assertListEqual(doc.chunks[2].tokens, doc.find_by_span(query=doc.chunks[2], field_name="tokens"))
self.assertSequenceEqual(doc.chunks[0].tokens, doc.find_by_span(query=doc.chunks[0], field_name="tokens"))
self.assertSequenceEqual(doc.chunks[1].tokens, doc.find_by_span(query=doc.chunks[1], field_name="tokens"))
self.assertSequenceEqual(doc.chunks[2].tokens, doc.find_by_span(query=doc.chunks[2], field_name="tokens"))

# backwards
self.assertListEqual(doc.tokens[0].chunks, doc.find_by_span(query=doc.tokens[0], field_name="chunks"))
self.assertListEqual(doc.tokens[1].chunks, doc.find_by_span(query=doc.tokens[1], field_name="chunks"))
self.assertListEqual(doc.tokens[2].chunks, doc.find_by_span(query=doc.tokens[2], field_name="chunks"))
self.assertListEqual(doc.tokens[3].chunks, doc.find_by_span(query=doc.tokens[3], field_name="chunks"))
self.assertListEqual(doc.tokens[4].chunks, doc.find_by_span(query=doc.tokens[4], field_name="chunks"))
self.assertListEqual(doc.tokens[5].chunks, doc.find_by_span(query=doc.tokens[5], field_name="chunks"))
self.assertSequenceEqual(doc.tokens[0].chunks, doc.find_by_span(query=doc.tokens[0], field_name="chunks"))
self.assertSequenceEqual(doc.tokens[1].chunks, doc.find_by_span(query=doc.tokens[1], field_name="chunks"))
self.assertSequenceEqual(doc.tokens[2].chunks, doc.find_by_span(query=doc.tokens[2], field_name="chunks"))
self.assertSequenceEqual(doc.tokens[3].chunks, doc.find_by_span(query=doc.tokens[3], field_name="chunks"))
self.assertSequenceEqual(doc.tokens[4].chunks, doc.find_by_span(query=doc.tokens[4], field_name="chunks"))
self.assertSequenceEqual(doc.tokens[5].chunks, doc.find_by_span(query=doc.tokens[5], field_name="chunks"))

# find by box
self.assertListEqual(doc.find_by_box(query=doc.chunks[0], field_name="tokens"), doc.tokens[0:3])
self.assertListEqual(doc.find_by_box(query=doc.chunks[1], field_name="tokens"), doc.tokens[3:6])
self.assertListEqual(doc.find_by_box(query=doc.chunks[2], field_name="tokens"), [])
self.assertSequenceEqual(doc.find_by_box(query=doc.chunks[0], field_name="tokens"), doc.tokens[0:3])
self.assertSequenceEqual(doc.find_by_box(query=doc.chunks[1], field_name="tokens"), doc.tokens[3:6])
self.assertSequenceEqual(doc.find_by_box(query=doc.chunks[2], field_name="tokens"), [])

# backwards
self.assertListEqual(doc.find_by_box(query=doc.tokens[0], field_name="chunks"), [chunks[0]])
self.assertListEqual(doc.find_by_box(query=doc.tokens[1], field_name="chunks"), [chunks[0]])
self.assertListEqual(doc.find_by_box(query=doc.tokens[2], field_name="chunks"), [chunks[0]])
self.assertListEqual(doc.find_by_box(query=doc.tokens[3], field_name="chunks"), [chunks[1]])
self.assertListEqual(doc.find_by_box(query=doc.tokens[4], field_name="chunks"), [chunks[1]])
self.assertListEqual(doc.find_by_box(query=doc.tokens[5], field_name="chunks"), [chunks[1]])
self.assertSequenceEqual(doc.find_by_box(query=doc.tokens[0], field_name="chunks"), [chunks[0]])
self.assertSequenceEqual(doc.find_by_box(query=doc.tokens[1], field_name="chunks"), [chunks[0]])
self.assertSequenceEqual(doc.find_by_box(query=doc.tokens[2], field_name="chunks"), [chunks[0]])
self.assertSequenceEqual(doc.find_by_box(query=doc.tokens[3], field_name="chunks"), [chunks[1]])
self.assertSequenceEqual(doc.find_by_box(query=doc.tokens[4], field_name="chunks"), [chunks[1]])
self.assertSequenceEqual(doc.find_by_box(query=doc.tokens[5], field_name="chunks"), [chunks[1]])

def test_cross_referencing_with_missing_entity_fields(self):
"""What happens when annotate a Doc with entiites missing spans or boxes?
Expand All @@ -169,18 +169,18 @@ def test_cross_referencing_with_missing_entity_fields(self):
]
doc.annotate_entity(field_name="tokens", entities=tokens)
doc.annotate_entity(field_name="chunks", entities=chunks)
self.assertListEqual(doc.find_by_box(query=doc.chunks[0], field_name="tokens"), [])
self.assertListEqual(doc.find_by_box(query=doc.chunks[1], field_name="tokens"), [])
self.assertListEqual(doc.find_by_box(query=doc.chunks[2], field_name="tokens"), [])
self.assertListEqual(doc.find_by_span(query=doc.chunks[0], field_name="tokens"), [])
self.assertListEqual(doc.find_by_span(query=doc.chunks[1], field_name="tokens"), [])
self.assertListEqual(doc.find_by_span(query=doc.chunks[2], field_name="tokens"), [])
self.assertListEqual(doc.find_by_box(query=doc.tokens[0], field_name="chunks"), [])
self.assertListEqual(doc.find_by_box(query=doc.tokens[1], field_name="chunks"), [])
self.assertListEqual(doc.find_by_box(query=doc.tokens[2], field_name="chunks"), [])
self.assertListEqual(doc.find_by_span(query=doc.tokens[0], field_name="chunks"), [])
self.assertListEqual(doc.find_by_span(query=doc.tokens[1], field_name="chunks"), [])
self.assertListEqual(doc.find_by_span(query=doc.tokens[2], field_name="chunks"), [])
self.assertSequenceEqual(doc.find_by_box(query=doc.chunks[0], field_name="tokens"), [])
self.assertSequenceEqual(doc.find_by_box(query=doc.chunks[1], field_name="tokens"), [])
self.assertSequenceEqual(doc.find_by_box(query=doc.chunks[2], field_name="tokens"), [])
self.assertSequenceEqual(doc.find_by_span(query=doc.chunks[0], field_name="tokens"), [])
self.assertSequenceEqual(doc.find_by_span(query=doc.chunks[1], field_name="tokens"), [])
self.assertSequenceEqual(doc.find_by_span(query=doc.chunks[2], field_name="tokens"), [])
self.assertSequenceEqual(doc.find_by_box(query=doc.tokens[0], field_name="chunks"), [])
self.assertSequenceEqual(doc.find_by_box(query=doc.tokens[1], field_name="chunks"), [])
self.assertSequenceEqual(doc.find_by_box(query=doc.tokens[2], field_name="chunks"), [])
self.assertSequenceEqual(doc.find_by_span(query=doc.tokens[0], field_name="chunks"), [])
self.assertSequenceEqual(doc.find_by_span(query=doc.tokens[1], field_name="chunks"), [])
self.assertSequenceEqual(doc.find_by_span(query=doc.tokens[2], field_name="chunks"), [])

def test_query(self):
doc = Document("This is a test document!")
Expand All @@ -201,12 +201,12 @@ def test_query(self):
doc.annotate_entity(field_name="chunks", entities=chunks)

# test query by span
self.assertListEqual(
self.assertSequenceEqual(
doc.find_by_span(query=doc.chunks[0], field_name="tokens"),
doc.find(query=doc.chunks[0].spans[0], field_name="tokens"),
)
# test query by box
self.assertListEqual(
self.assertSequenceEqual(
doc.find_by_box(query=doc.chunks[0], field_name="tokens"),
doc.find(query=doc.chunks[0].boxes[0], field_name="tokens"),
)
Expand Down
29 changes: 29 additions & 0 deletions tests/test_magelib/test_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import unittest

from papermage import Document, Entity, Layer

class TestLayer(unittest.TestCase):
def test_layer_slice(self):
doc = Document("This is a test document!")
tokens = [
Entity.from_json({"spans": [[0, 4]], "boxes": [[0, 0, 0.5, 0.5, 0]]}),
Entity.from_json({"spans": [[5, 7]], "boxes": [[1, 1, 0.5, 0.5, 0]]}),
Entity.from_json({"spans": [[8, 9]], "boxes": [[2, 2, 0.5, 0.5, 0]]}),
Entity.from_json({"spans": [[10, 14]], "boxes": [[3, 3, 0.5, 0.5, 0]]}),
Entity.from_json({"spans": [[15, 23]], "boxes": [[4, 4, 0.5, 0.5, 0]]}),
Entity.from_json({"spans": [[23, 24]], "boxes": [[5, 5, 0.5, 0.5, 0]]}),
]
chunks = [
Entity.from_json({"spans": [[0, 9]], "boxes": [[0, 0, 2.01, 2.01, 0]]}),
Entity.from_json({"spans": [[12, 23]], "boxes": [[3.0, 3.0, 4.0, 4.0, 0]]}),
Entity.from_json({"spans": [[23, 24]], "boxes": [[0, 0, 10.0, 10.0, 1]]}),
]
doc.annotate_entity(field_name="tokens", entities=tokens)
doc.annotate_entity(field_name="chunks", entities=chunks)

assert isinstance(doc.tokens, Layer)
assert isinstance(doc.chunks[1:3], Layer)

self.assertSequenceEqual(doc.chunks[1:3], chunks[1:3])
self.assertSequenceEqual(doc.chunks[1:3].text, ['st document', '!'])
assert isinstance(doc.chunks[:3].tokens, Layer)
Loading