Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove multiple annotators in annotation table #31

Merged
merged 5 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 6 additions & 22 deletions holonote/annotate/annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class AnnotatorInterface(param.Parameterized):

connector_class = SQLiteDB

def __init__(self, spec, *, init=True, **params):
def __init__(self, spec, **params):
connector_kws = {'fields': params.get('fields')} if 'fields' in params else {}
connector = params.pop('connector') if 'connector' in params else self.connector_class(**connector_kws)

Expand All @@ -125,12 +125,11 @@ def __init__(self, spec, *, init=True, **params):
connector.fields = self.fields
self._region = {}
self._last_region = None
self.annotation_table = AnnotationTable()
self.annotation_table.register_annotator(self)
self.annotation_table.add_schema_to_conn(self.connector)

if init:
self.load()
self.annotation_table = AnnotationTable()
self.connector._create_column_schema(self.spec, self.fields)
self.connector._initialize(self.connector.column_schema)
self.annotation_table.load(self.connector, fields=self.connector.fields, spec=self.spec)

@classmethod
def normalize_spec(self, input_spec: dict[str, Any]) -> SpecDict:
Expand Down Expand Up @@ -170,24 +169,13 @@ def normalize_spec(self, input_spec: dict[str, Any]) -> SpecDict:

return new_spec

def load(self):
self.connector._initialize(self.connector.column_schema)
self.annotation_table.load(self.connector, fields=self.connector.fields, spec=self.spec)

@property
def df(self):
return self.annotation_table.dataframe

def refresh(self, clear=False):
"Method to update display state of the annotator and optionally clear stale visual state"

def set_annotation_table(self, annotation_table) -> None: # FIXME! Won't work anymore, set_connector??
self._region = {}
self.annotation_table = annotation_table
self.annotation_table.register_annotator(self)
self.annotation_table._update_index()
self.snapshot()

# Selecting annotations

def select_by_index(self, *inds):
Expand Down Expand Up @@ -306,18 +294,14 @@ def _add_annotation(self, **fields):

# Don't do anything if self.region is an empty dict
if self.region and self.region != self._last_region:
if len(self.annotation_table._annotators)>1:
msg = 'Multiple annotation instances attached to the connector: Call add_annotation directly from the associated connector.'
raise AssertionError(msg)
self.annotation_table.add_annotation(self._region, spec=self.spec, **fields)
self._last_region = self._region.copy()

def add_annotation(self, **fields):
self._add_annotation(**fields)

def update_annotation_region(self, index):
self.annotation_table.update_annotation_region(index)

self.annotation_table.update_annotation_region(self._region, index)

def update_annotation_fields(self, index, **fields):
self.annotation_table.update_annotation_fields(index, **fields)
Expand Down
12 changes: 8 additions & 4 deletions holonote/annotate/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import pandas as pd
import param

from holonote.annotate.typing import SpecDict

try:
import sqlalchemy
except ModuleNotFoundError:
Expand Down Expand Up @@ -232,6 +234,12 @@ def _incompatible_schema_check(self, expected_keys, columns, fields, region_type
+ f'Missing {region_type!r} region columns {missing_region_columns}. '
+ msg_suffix)

def _create_column_schema(self, spec: SpecDict, fields: list[str]) -> None:
field_dtypes = {col: str for col in fields} # FIXME - generalize
all_region_types = [{v["region"] for v in spec.values()}]
all_kdim_dtypes = [{k: v["type"] for k, v in spec.items()} ]
schema = self.generate_schema(self.primary_key, all_region_types, all_kdim_dtypes, field_dtypes)
self.column_schema = schema

class SQLiteDB(Connector):
"""
Expand Down Expand Up @@ -360,7 +368,3 @@ def update_row(self, **updates): # updates as a dictionary OR remove posarg?
query = f"UPDATE {self.table_name} SET " + set_updates + f" WHERE \"{self.primary_key.field_name}\" = ?;"
self.cursor.execute(query, [*updates.values(), id_val])
self.con.commit()

def add_schema(self, schema):
# TODO: Check if schema don't overwrite existing columns
self.column_schema = {**self.column_schema, **schema}
32 changes: 3 additions & 29 deletions holonote/annotate/table.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import weakref
from typing import TYPE_CHECKING, Any

import numpy as np
Expand All @@ -22,7 +21,7 @@ class AnnotationTable(param.Parameterized):

index = param.List(default=[])

def __init__(self, **params):
def __init__(self, **params):
"""
Either specify annotation fields with filled field columns
(via connector or dataframe) or declare the expected
Expand All @@ -39,8 +38,6 @@ def __init__(self, **params):
self._update_index()
self._field_df_snapshot, self._region_df_snapshot = None, None

self._annotators = weakref.WeakValueDictionary()

def load(self, connector=None, fields_df=None, primary_key_name=None, fields=None, spec=None):
"""
Load the AnnotationTable from a connector or a fields DataFrame.
Expand Down Expand Up @@ -69,17 +66,7 @@ def load(self, connector=None, fields_df=None, primary_key_name=None, fields=Non
self.clear_edits()
self._update_index()

def register_annotator(self, annotator):
self._annotators[id(annotator)] = annotator


# FIXME: Multiple region updates
def update_annotation_region(self, index):
region = next(iter(self._annotators.values()))._region
if region == {}:
print('No new region selected. Skipping')
return

def update_annotation_region(self, region, index):
value = region['value']
mask = self._region_df[self._region_df._id == index]
assert len(mask)==1, 'TODO: Handle multiple region updates for single index'
Expand Down Expand Up @@ -175,9 +162,7 @@ def commits(self, connector):
kwargs = connector.transforms[operation](commit['kwargs'])
getattr(connector,connector.operation_mapping[operation])(**kwargs)

for annotator in self._annotators.values():
annotator.annotation_table.clear_edits()

self.clear_edits()
return commits

def clear_edits(self, edit_type=None):
Expand Down Expand Up @@ -207,10 +192,6 @@ def add_annotation(self, regions: dict[str, Any], spec: SpecDict, **fields):
self._edits.append({'operation':'insert', 'id':index_value})
self._update_index()

# def refresh_annotators(self, clear=False):
# for annotator in self._annotators.values():
# annotator.refresh(clear=clear)

def _add_annotation_fields(self, index_value, fields=None):

index_name_set = set() if self._field_df.index.name is None else {self._field_df.index.name}
Expand Down Expand Up @@ -433,10 +414,3 @@ def load_annotation_table(self, conn: Connector, fields: list[str], spec: SpecDi

self._update_index()
self.clear_edits()

def add_schema_to_conn(self, conn: Connector) -> None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to connector

field_dtypes = {col: str for col in conn.fields} # FIXME - generalize
all_region_types = [{v["region"] for v in an.spec.values()} for an in self._annotators.values()]
all_kdim_dtypes = [{k: v["type"] for k, v in an.spec.items()} for an in self._annotators.values()]
schema = conn.generate_schema(conn.primary_key, all_region_types, all_kdim_dtypes, field_dtypes)
conn.add_schema(schema)
Loading