From 6bf13f3c797f33a471f2f0c9ea1997230fa673e3 Mon Sep 17 00:00:00 2001 From: Ben Newman Date: Wed, 16 Aug 2023 15:47:54 -0700 Subject: [PATCH 1/2] Add first stab at Layer implementation --- papermage/magelib/__init__.py | 2 ++ papermage/magelib/annotation.py | 6 +++--- papermage/magelib/document.py | 19 ++++++++++--------- papermage/magelib/layer.py | 24 ++++++++++++++++++++++++ 4 files changed, 39 insertions(+), 12 deletions(-) create mode 100644 papermage/magelib/layer.py diff --git a/papermage/magelib/__init__.py b/papermage/magelib/__init__.py index 365ef84..86c473a 100644 --- a/papermage/magelib/__init__.py +++ b/papermage/magelib/__init__.py @@ -43,6 +43,7 @@ from .entity import Entity from .image import Image from .indexer import EntityBoxIndexer, EntitySpanIndexer +from .layer import Layer from .metadata import Metadata from .span import Span @@ -70,6 +71,7 @@ "ImagesFieldName", "KeywordsFieldName", "KeywordsFieldName", + "Layer", "ListsFieldName", "Metadata", "MetadataFieldName", diff --git a/papermage/magelib/annotation.py b/papermage/magelib/annotation.py index e9bd682..d5a62ae 100644 --- a/papermage/magelib/annotation.py +++ b/papermage/magelib/annotation.py @@ -63,7 +63,7 @@ def to_json(self) -> Union[Dict, List]: def from_json(cls, annotation_json: Union[Dict, List]) -> "Annotation": pass - def __getattr__(self, field: str) -> List["Annotation"]: + def __getattr__(self, field: str) -> "Layer": """This Overloading is convenient syntax since the `entity.layer` operation is intuitive for folks.""" try: return self.find_by_span(field=field) @@ -71,7 +71,7 @@ def __getattr__(self, field: str) -> List["Annotation"]: # maybe users just want some attribute of the Annotation object return self.__getattribute__(field) - def find_by_span(self, field: str) -> List["Annotation"]: + def find_by_span(self, field: str) -> "Layer": """This method allows you to access overlapping Annotations within the Document based on Span""" if self.doc is None: @@ -82,7 +82,7 @@ def find_by_span(self, field: str) -> List["Annotation"]: else: raise ValueError(f"Field {field} not found in Document") - def find_by_box(self, field: str) -> List["Annotation"]: + def find_by_box(self, field: str) -> "Layer": """This method allows you to access overlapping Annotations within the Document based on Box""" diff --git a/papermage/magelib/document.py b/papermage/magelib/document.py index 1f8ddef..d7c83f2 100644 --- a/papermage/magelib/document.py +++ b/papermage/magelib/document.py @@ -11,6 +11,7 @@ from .span import Span from .box import Box from .image import Image +from .layer import Layer from .metadata import Metadata from .entity import Entity from .indexer import EntitySpanIndexer, EntityBoxIndexer @@ -67,21 +68,21 @@ def __init__(self, symbols: str, metadata: Optional[Metadata] = None): def fields(self) -> List[str]: return list(self.__entity_span_indexers.keys()) + self.SPECIAL_FIELDS - def find(self, query: Union[Span, Box], field_name: str) -> List[Entity]: + def find(self, query: Union[Span, Box], field_name: str) -> Layer: if isinstance(query, Span): - return self.__entity_span_indexers[field_name].find(query=Entity(spans=[query])) + return Layer(self.__entity_span_indexers[field_name].find(query=Entity(spans=[query]))) elif isinstance(query, Box): - return self.__entity_box_indexers[field_name].find(query=Entity(boxes=[query])) + return Layer(self.__entity_box_indexers[field_name].find(query=Entity(boxes=[query]))) else: raise TypeError(f"Unsupported query type {type(query)}") - def find_by_span(self, query: Entity, field_name: str) -> List[Entity]: + def find_by_span(self, query: Entity, field_name: str) -> Layer: # TODO: will rename this to `intersect_by_span` - return self.__entity_span_indexers[field_name].find(query=query) + return Layer(self.__entity_span_indexers[field_name].find(query=query)) - def find_by_box(self, query: Entity, field_name: str) -> List[Entity]: + def find_by_box(self, query: Entity, field_name: str) -> Layer: # TODO: will rename this to `intersect_by_span` - return self.__entity_box_indexers[field_name].find(query=query) + return Layer(self.__entity_box_indexers[field_name].find(query=query)) def check_field_name_availability(self, field_name: str) -> None: if field_name in self.SPECIAL_FIELDS: @@ -91,7 +92,7 @@ def check_field_name_availability(self, field_name: str) -> None: if field_name in dir(self): raise AssertionError(f"{field_name} clashes with Document class properties.") - def get_entity(self, field_name: str) -> List[Entity]: + def get_entity(self, field_name: str) -> Layer: return getattr(self, field_name) def annotate(self, *predictions: Union[Prediction, Tuple[Prediction, ...]]) -> None: @@ -108,7 +109,7 @@ def annotate_entity(self, field_name: str, entities: List[Entity]) -> None: self.__entity_span_indexers[field_name] = EntitySpanIndexer(entities=entities) self.__entity_box_indexers[field_name] = EntityBoxIndexer(entities=entities) - setattr(self, field_name, entities) + setattr(self, field_name, Layer(entities)) def remove_entity(self, field_name: str): for entity in getattr(self, field_name): diff --git a/papermage/magelib/layer.py b/papermage/magelib/layer.py new file mode 100644 index 0000000..60e6982 --- /dev/null +++ b/papermage/magelib/layer.py @@ -0,0 +1,24 @@ +from collections import UserList +from typing import List, Union, Any + +from .entity import Entity + +class Layer(UserList): + """Wraps a list of entities""" + def __init__(self, entities: List[Entity] = None): + if entities is None: + super().__init__() + else: + super().__init__(entities) + + def __getitem__(self, index: Union[int, slice]) -> Union[Entity, "Layer"]: + if isinstance(index, int): + return self.data[index] + else: + return Layer(self.data[index]) + + def __getattr__(self, field: str) -> "Layer": + return Layer([ + getattr(entity, field) for entity in self.data + ]) + From 2d9bcb16a0db5fe719fb968abae3487e65586b21 Mon Sep 17 00:00:00 2001 From: Ben Newman Date: Wed, 16 Aug 2023 15:49:10 -0700 Subject: [PATCH 2/2] Change tests to not enforce list type --- tests/test_magelib/test_document.py | 82 ++++++++++++++--------------- tests/test_magelib/test_layer.py | 29 ++++++++++ 2 files changed, 70 insertions(+), 41 deletions(-) create mode 100644 tests/test_magelib/test_layer.py diff --git a/tests/test_magelib/test_document.py b/tests/test_magelib/test_document.py index 17399ac..a2b6dc9 100644 --- a/tests/test_magelib/test_document.py +++ b/tests/test_magelib/test_document.py @@ -111,43 +111,43 @@ def test_cross_referencing(self): doc.annotate_entity(field_name="chunks", entities=chunks) # find by span is the default overload of Entity.__attr__ - self.assertListEqual(doc.chunks[0].tokens, tokens[0:3]) - self.assertListEqual(doc.chunks[1].tokens, tokens[3:5]) - self.assertListEqual(doc.chunks[2].tokens, [tokens[5]]) + self.assertSequenceEqual(doc.chunks[0].tokens, tokens[0:3]) + self.assertSequenceEqual(doc.chunks[1].tokens, tokens[3:5]) + self.assertSequenceEqual(doc.chunks[2].tokens, [tokens[5]]) # backwards - self.assertListEqual(doc.tokens[0].chunks, [chunks[0]]) - self.assertListEqual(doc.tokens[1].chunks, [chunks[0]]) - self.assertListEqual(doc.tokens[2].chunks, [chunks[0]]) - self.assertListEqual(doc.tokens[3].chunks, [chunks[1]]) - self.assertListEqual(doc.tokens[4].chunks, [chunks[1]]) - self.assertListEqual(doc.tokens[5].chunks, [chunks[2]]) + self.assertSequenceEqual(doc.tokens[0].chunks, [chunks[0]]) + self.assertSequenceEqual(doc.tokens[1].chunks, [chunks[0]]) + self.assertSequenceEqual(doc.tokens[2].chunks, [chunks[0]]) + self.assertSequenceEqual(doc.tokens[3].chunks, [chunks[1]]) + self.assertSequenceEqual(doc.tokens[4].chunks, [chunks[1]]) + self.assertSequenceEqual(doc.tokens[5].chunks, [chunks[2]]) # find by span works fine - self.assertListEqual(doc.chunks[0].tokens, doc.find_by_span(query=doc.chunks[0], field_name="tokens")) - self.assertListEqual(doc.chunks[1].tokens, doc.find_by_span(query=doc.chunks[1], field_name="tokens")) - self.assertListEqual(doc.chunks[2].tokens, doc.find_by_span(query=doc.chunks[2], field_name="tokens")) + self.assertSequenceEqual(doc.chunks[0].tokens, doc.find_by_span(query=doc.chunks[0], field_name="tokens")) + self.assertSequenceEqual(doc.chunks[1].tokens, doc.find_by_span(query=doc.chunks[1], field_name="tokens")) + self.assertSequenceEqual(doc.chunks[2].tokens, doc.find_by_span(query=doc.chunks[2], field_name="tokens")) # backwards - self.assertListEqual(doc.tokens[0].chunks, doc.find_by_span(query=doc.tokens[0], field_name="chunks")) - self.assertListEqual(doc.tokens[1].chunks, doc.find_by_span(query=doc.tokens[1], field_name="chunks")) - self.assertListEqual(doc.tokens[2].chunks, doc.find_by_span(query=doc.tokens[2], field_name="chunks")) - self.assertListEqual(doc.tokens[3].chunks, doc.find_by_span(query=doc.tokens[3], field_name="chunks")) - self.assertListEqual(doc.tokens[4].chunks, doc.find_by_span(query=doc.tokens[4], field_name="chunks")) - self.assertListEqual(doc.tokens[5].chunks, doc.find_by_span(query=doc.tokens[5], field_name="chunks")) + self.assertSequenceEqual(doc.tokens[0].chunks, doc.find_by_span(query=doc.tokens[0], field_name="chunks")) + self.assertSequenceEqual(doc.tokens[1].chunks, doc.find_by_span(query=doc.tokens[1], field_name="chunks")) + self.assertSequenceEqual(doc.tokens[2].chunks, doc.find_by_span(query=doc.tokens[2], field_name="chunks")) + self.assertSequenceEqual(doc.tokens[3].chunks, doc.find_by_span(query=doc.tokens[3], field_name="chunks")) + self.assertSequenceEqual(doc.tokens[4].chunks, doc.find_by_span(query=doc.tokens[4], field_name="chunks")) + self.assertSequenceEqual(doc.tokens[5].chunks, doc.find_by_span(query=doc.tokens[5], field_name="chunks")) # find by box - self.assertListEqual(doc.find_by_box(query=doc.chunks[0], field_name="tokens"), doc.tokens[0:3]) - self.assertListEqual(doc.find_by_box(query=doc.chunks[1], field_name="tokens"), doc.tokens[3:6]) - self.assertListEqual(doc.find_by_box(query=doc.chunks[2], field_name="tokens"), []) + self.assertSequenceEqual(doc.find_by_box(query=doc.chunks[0], field_name="tokens"), doc.tokens[0:3]) + self.assertSequenceEqual(doc.find_by_box(query=doc.chunks[1], field_name="tokens"), doc.tokens[3:6]) + self.assertSequenceEqual(doc.find_by_box(query=doc.chunks[2], field_name="tokens"), []) # backwards - self.assertListEqual(doc.find_by_box(query=doc.tokens[0], field_name="chunks"), [chunks[0]]) - self.assertListEqual(doc.find_by_box(query=doc.tokens[1], field_name="chunks"), [chunks[0]]) - self.assertListEqual(doc.find_by_box(query=doc.tokens[2], field_name="chunks"), [chunks[0]]) - self.assertListEqual(doc.find_by_box(query=doc.tokens[3], field_name="chunks"), [chunks[1]]) - self.assertListEqual(doc.find_by_box(query=doc.tokens[4], field_name="chunks"), [chunks[1]]) - self.assertListEqual(doc.find_by_box(query=doc.tokens[5], field_name="chunks"), [chunks[1]]) + self.assertSequenceEqual(doc.find_by_box(query=doc.tokens[0], field_name="chunks"), [chunks[0]]) + self.assertSequenceEqual(doc.find_by_box(query=doc.tokens[1], field_name="chunks"), [chunks[0]]) + self.assertSequenceEqual(doc.find_by_box(query=doc.tokens[2], field_name="chunks"), [chunks[0]]) + self.assertSequenceEqual(doc.find_by_box(query=doc.tokens[3], field_name="chunks"), [chunks[1]]) + self.assertSequenceEqual(doc.find_by_box(query=doc.tokens[4], field_name="chunks"), [chunks[1]]) + self.assertSequenceEqual(doc.find_by_box(query=doc.tokens[5], field_name="chunks"), [chunks[1]]) def test_cross_referencing_with_missing_entity_fields(self): """What happens when annotate a Doc with entiites missing spans or boxes? @@ -169,18 +169,18 @@ def test_cross_referencing_with_missing_entity_fields(self): ] doc.annotate_entity(field_name="tokens", entities=tokens) doc.annotate_entity(field_name="chunks", entities=chunks) - self.assertListEqual(doc.find_by_box(query=doc.chunks[0], field_name="tokens"), []) - self.assertListEqual(doc.find_by_box(query=doc.chunks[1], field_name="tokens"), []) - self.assertListEqual(doc.find_by_box(query=doc.chunks[2], field_name="tokens"), []) - self.assertListEqual(doc.find_by_span(query=doc.chunks[0], field_name="tokens"), []) - self.assertListEqual(doc.find_by_span(query=doc.chunks[1], field_name="tokens"), []) - self.assertListEqual(doc.find_by_span(query=doc.chunks[2], field_name="tokens"), []) - self.assertListEqual(doc.find_by_box(query=doc.tokens[0], field_name="chunks"), []) - self.assertListEqual(doc.find_by_box(query=doc.tokens[1], field_name="chunks"), []) - self.assertListEqual(doc.find_by_box(query=doc.tokens[2], field_name="chunks"), []) - self.assertListEqual(doc.find_by_span(query=doc.tokens[0], field_name="chunks"), []) - self.assertListEqual(doc.find_by_span(query=doc.tokens[1], field_name="chunks"), []) - self.assertListEqual(doc.find_by_span(query=doc.tokens[2], field_name="chunks"), []) + self.assertSequenceEqual(doc.find_by_box(query=doc.chunks[0], field_name="tokens"), []) + self.assertSequenceEqual(doc.find_by_box(query=doc.chunks[1], field_name="tokens"), []) + self.assertSequenceEqual(doc.find_by_box(query=doc.chunks[2], field_name="tokens"), []) + self.assertSequenceEqual(doc.find_by_span(query=doc.chunks[0], field_name="tokens"), []) + self.assertSequenceEqual(doc.find_by_span(query=doc.chunks[1], field_name="tokens"), []) + self.assertSequenceEqual(doc.find_by_span(query=doc.chunks[2], field_name="tokens"), []) + self.assertSequenceEqual(doc.find_by_box(query=doc.tokens[0], field_name="chunks"), []) + self.assertSequenceEqual(doc.find_by_box(query=doc.tokens[1], field_name="chunks"), []) + self.assertSequenceEqual(doc.find_by_box(query=doc.tokens[2], field_name="chunks"), []) + self.assertSequenceEqual(doc.find_by_span(query=doc.tokens[0], field_name="chunks"), []) + self.assertSequenceEqual(doc.find_by_span(query=doc.tokens[1], field_name="chunks"), []) + self.assertSequenceEqual(doc.find_by_span(query=doc.tokens[2], field_name="chunks"), []) def test_query(self): doc = Document("This is a test document!") @@ -201,12 +201,12 @@ def test_query(self): doc.annotate_entity(field_name="chunks", entities=chunks) # test query by span - self.assertListEqual( + self.assertSequenceEqual( doc.find_by_span(query=doc.chunks[0], field_name="tokens"), doc.find(query=doc.chunks[0].spans[0], field_name="tokens"), ) # test query by box - self.assertListEqual( + self.assertSequenceEqual( doc.find_by_box(query=doc.chunks[0], field_name="tokens"), doc.find(query=doc.chunks[0].boxes[0], field_name="tokens"), ) diff --git a/tests/test_magelib/test_layer.py b/tests/test_magelib/test_layer.py new file mode 100644 index 0000000..820281e --- /dev/null +++ b/tests/test_magelib/test_layer.py @@ -0,0 +1,29 @@ +import unittest + +from papermage import Document, Entity, Layer + +class TestLayer(unittest.TestCase): + def test_layer_slice(self): + doc = Document("This is a test document!") + tokens = [ + Entity.from_json({"spans": [[0, 4]], "boxes": [[0, 0, 0.5, 0.5, 0]]}), + Entity.from_json({"spans": [[5, 7]], "boxes": [[1, 1, 0.5, 0.5, 0]]}), + Entity.from_json({"spans": [[8, 9]], "boxes": [[2, 2, 0.5, 0.5, 0]]}), + Entity.from_json({"spans": [[10, 14]], "boxes": [[3, 3, 0.5, 0.5, 0]]}), + Entity.from_json({"spans": [[15, 23]], "boxes": [[4, 4, 0.5, 0.5, 0]]}), + Entity.from_json({"spans": [[23, 24]], "boxes": [[5, 5, 0.5, 0.5, 0]]}), + ] + chunks = [ + Entity.from_json({"spans": [[0, 9]], "boxes": [[0, 0, 2.01, 2.01, 0]]}), + Entity.from_json({"spans": [[12, 23]], "boxes": [[3.0, 3.0, 4.0, 4.0, 0]]}), + Entity.from_json({"spans": [[23, 24]], "boxes": [[0, 0, 10.0, 10.0, 1]]}), + ] + doc.annotate_entity(field_name="tokens", entities=tokens) + doc.annotate_entity(field_name="chunks", entities=chunks) + + assert isinstance(doc.tokens, Layer) + assert isinstance(doc.chunks[1:3], Layer) + + self.assertSequenceEqual(doc.chunks[1:3], chunks[1:3]) + self.assertSequenceEqual(doc.chunks[1:3].text, ['st document', '!']) + assert isinstance(doc.chunks[:3].tokens, Layer) \ No newline at end of file