Skip to content

Commit

Permalink
Duplicate behavior (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
hoxbro authored Aug 28, 2023
1 parent 0f6b3e3 commit 90169aa
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 152 deletions.
32 changes: 21 additions & 11 deletions holonote/annotate/annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import param
from bokeh.models.tools import BoxSelectTool, HoverTool
from holoviews.element.selection import Selection1DExpr
from .connector import Connector, SQLiteDB
from .connector import Connector, SQLiteDB, AnnotationTable



Expand Down Expand Up @@ -101,23 +101,30 @@ class AnnotatorInterface(param.Parameterized):

connector = param.ClassSelector(class_=Connector, allow_None=False)

annotation_table = param.ClassSelector(class_=AnnotationTable, allow_None=False)

region_types = param.ListSelector(default=['Range'], objects=['Range', 'Point'], doc="""
Enabled region types for the current Annotator.""")

connector_class = SQLiteDB

def __init__(self, **params):
def __init__(self, init=True, **params):
if "annotation_table" not in params:
params["annotation_table"] = AnnotationTable()

super().__init__(**params)
self._region = {}
self._last_region = None

self.annotation_table.register_annotator(self)
self.annotation_table.add_schema_to_conn(self.connector)

self.annotation_table.load(self.connector, fields=self.connector.fields)

@property
def annotation_table(self):
return self.connector.annotation_table
if init:
self.load()

def load(self):
self.connector._initialize(self.connector.column_schema)
self.annotation_table.load(self.connector, fields=self.connector.fields)

@property
def df(self):
Expand Down Expand Up @@ -193,6 +200,8 @@ def _get_range_indices_by_position(self, dim1_pos, dim2_pos=None):
if dim2_pos is not None:
mask_dim2 = ranges['value'].apply(lambda tpl: tpl[2] <= dim2_pos < tpl[3])
mask = mask & mask_dim2
if not len(mask):
return []
return list(ranges[mask]['_id'])


Expand All @@ -203,7 +212,6 @@ def get_indices_by_position(self, dim1_pos, dim2_pos=None):
event_matches = [] # TODO: Needs hit testing or similar for point events
return range_matches + event_matches


@property
def region(self):
return self._region
Expand Down Expand Up @@ -300,6 +308,11 @@ def revert_to_snapshot(self):
def snapshot(self):
self.annotation_table.snapshot()

def commit(self, return_commits=False):
# self.annotation_table.initialize_table(self.connector) # Only if not in params
commits = self.annotation_table.commits(self.connector)
if return_commits:
return commits


class Annotator(AnnotatorInterface):
Expand Down Expand Up @@ -709,9 +722,6 @@ def revert_to_snapshot(self):
super().revert_to_snapshot()
self.refresh()

def commit(self):
self.connector.commit()

def set_range(self, startx, endx, starty=None, endy=None):
super().set_range(startx, endx, starty, endy)
self.show_region()
Expand Down
94 changes: 9 additions & 85 deletions holonote/annotate/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,6 @@ class Connector(param.Parameterized):
np.float64: 'REAL',
}

def __init__(self, **params):
self.annotation_table = None
super().__init__(**params)

@classmethod
def field_value_to_type(cls, value):
if isinstance(value, list):
Expand Down Expand Up @@ -222,31 +218,6 @@ def generate_schema(cls, primary_key, all_region_types, all_kdim_dtypes, field_t
return dict(schemas, **cls.schema_from_field_types(field_types))


def _initialize_annotation_table(self):
if self.annotation_table is None:
self.annotation_table = AnnotationTable()

def commit(self):
"Applies the commit hook to the connector"

if self.uninitialized:
self.initialize_table()

if self.commit_hook is not None:
self.commit_hook()
else:
self.commit_default_schema()

def commit_default_schema(self):
for commit in self.annotation_table.commits():
operation = commit['operation']
kwargs = self.transforms[operation](commit['kwargs'])
getattr(self,self.operation_mapping[operation])(**kwargs)

for annotator in self.annotation_table._annotators.values():
annotator.annotation_table.clear_edits()


def _incompatible_schema_check(self, expected_keys, columns, fields, region_type):
msg_prefix = ("Unable to read annotations that were stored with a "
"schema inconsistent with the current settings: ")
Expand All @@ -263,50 +234,6 @@ def _incompatible_schema_check(self, expected_keys, columns, fields, region_type
+ f'Missing {repr(region_type)} region columns {missing_region_columns}. '
+ msg_suffix)

def load_annotation_table(self, annotation_table, fields):
df = self.transforms['load'](self.load_dataframe())
fields_df = df[fields].copy()
annotation_table.define_fields(fields_df, {ind:ind for ind in fields_df.index})
all_region_types = [an.region_types for an in annotation_table._annotators.values()]
all_kdim_dtypes = [an.kdim_dtypes for an in annotation_table._annotators.values()]
for region_types, kdim_dtypes in zip(all_region_types, all_kdim_dtypes):
assert all(el in ['Range', 'Point'] for el in region_types)
for region_type in region_types:
if len(kdim_dtypes)==1:
kdim = list(kdim_dtypes.keys())[0]
if region_type == 'Range':
expected_keys = [f'start_{kdim}', f'end_{kdim}']
self._incompatible_schema_check(expected_keys, list(df.columns), fields, region_type)
annotation_table.define_ranges(kdim, df[f'start_{kdim}'], df[f'end_{kdim}'])
elif region_type == 'Point':
self._incompatible_schema_check([f'point_{kdim}'], list(df.columns), fields, region_type)
annotation_table.define_points(kdim, df[f'point_{kdim}'])
elif len(kdim_dtypes)==2:
kdim1, kdim2 = list(kdim_dtypes.keys())
if region_type == 'Range':
self._incompatible_schema_check([f'start_{kdim1}', f'end_{kdim1}',
f'start_{kdim2}', f'end_{kdim2}'],
list(df.columns), fields, region_type)
annotation_table.define_ranges([kdim1, kdim2],
df[f'start_{kdim1}'], df[f'end_{kdim1}'],
df[f'start_{kdim2}'], df[f'end_{kdim2}'])
elif region_type == 'Point':
self._incompatible_schema_check([f'point_{kdim1}', f'point_{kdim2}'],
list(df.columns), fields, region_type)
annotation_table.define_points([kdim1, kdim2], df[f'point_{kdim1}'], df[f'point_{kdim2}'])
annotation_table.clear_edits()


def add_annotation(self, **fields):
"Primary key specification is optional. Used to works across Annotation instances."
if self.primary_key.field_name not in fields:
index_val = self.primary_key(self, list(self.annotation_table._field_df.index))
fields[self.primary_key.field_name] = index_val
self.annotation_table.add_annotation('annotator-regions', **fields)


for annotator in self.annotation_table._annotators.values():
annotator.refresh(clear=True)

class SQLiteDB(Connector):
"""
Expand All @@ -327,15 +254,17 @@ class SQLiteDB(Connector):
'update':'update_row'}


def __init__(self, column_schema={}, connect=True, **params):
def __init__(self, column_schema=None, connect=True, **params):
"""
First key in column_schema is assumed to the primary key field if not explicitly specified.
"""
if column_schema is None:
column_schema = {}

params['column_schema'] = column_schema
self.con, self.cursor = None, None
super().__init__(**params)


if connect:
self._initialize(column_schema, create_table=False)

Expand All @@ -344,12 +273,10 @@ def _initialize(self, column_schema, create_table=True):
if self.con is None:
self.con = sqlite3.connect(self.filename, detect_types=sqlite3.PARSE_DECLTYPES |
sqlite3.PARSE_COLNAMES)
self.cursor = self.con.cursor()
self.cursor = self.con.cursor() # should be context manager
if create_table:
self.create_table(column_schema=column_schema)

super()._initialize_annotation_table()

@property
def uninitialized(self):
if self.con is not None:
Expand Down Expand Up @@ -394,13 +321,6 @@ def create_table(self, column_schema=None):
self.cursor.execute(create_table_sql)
self.con.commit()

def initialize_table(self):
field_dtypes = {col:str for col in self.fields} # FIXME - generalize
all_region_types = [an.region_types for an in self.annotation_table._annotators.values()]
all_kdim_dtypes = [an.kdim_dtypes for an in self.annotation_table._annotators.values()]
schema = self.generate_schema(self.primary_key, all_region_types, all_kdim_dtypes, field_dtypes)
self.create_table(schema)

def delete_table(self):
self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
self.con.commit()
Expand Down Expand Up @@ -441,3 +361,7 @@ def update_row(self, **updates): # updates as a dictionary OR remove posarg?
query = f"UPDATE {self.table_name} SET " + set_updates + f" WHERE \"{self.primary_key.field_name}\" = ?;"
self.cursor.execute(query, list(updates.values()) + [id_val])
self.con.commit()

def add_schema(self, schema):
# TODO: Check if schema don't overwrite existing columns
self.column_schema = {**self.column_schema, **schema}
66 changes: 57 additions & 9 deletions holonote/annotate/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@ def __init__(self, **params):

self._annotators = weakref.WeakValueDictionary()

def load(self, connector=None, fields_df=None, primary_key_name=None, fields=[]):
def load(self, connector=None, fields_df=None, primary_key_name=None, fields=None):
"""
Load the AnnotationTable from a connector or a fields DataFrame.
"""
if fields is None:
fields = []

if [connector, primary_key_name] == [None,None]:
raise ValueError('Either a connector instance must be supplied or the primary key name supplied')
if len(fields) < 1:
Expand All @@ -47,17 +50,13 @@ def load(self, connector=None, fields_df=None, primary_key_name=None, fields=[])
if fields_df:
fields_df = fields_df[fields].copy() # Primary key/index for annotations
self._field_df = fields_df
elif connector and not connector.uninitialized:
connector.load_annotation_table(self, fields)
elif connector:
self.load_annotation_table(connector, fields)
elif fields_df is None:
fields_df = pd.DataFrame(columns=[primary_key_name] + fields)
fields_df = fields_df.set_index(primary_key_name)
self._field_df = fields_df

# FIXME: Proper solution is to only load relevant columns
self._field_df = self._field_df.drop_duplicates()
self._region_df = self._region_df.drop_duplicates()

self.clear_edits()
self._update_index()

Expand Down Expand Up @@ -143,7 +142,7 @@ def _expand_commit_by_id(self, id_val, fields=None, region_fields=None):
def _expand_save_commits(self, ids):
return {'field_list':[self._expand_commit_by_id(id_val) for id_val in ids]}

def commits(self):
def _create_commits(self):
"Expands out the commit history into commit operations"
fields_inds = set(self._field_df.index)
region_inds = set(self._region_df['_id'].unique())
Expand Down Expand Up @@ -174,6 +173,17 @@ def commits(self):

return commits

def commits(self, connector):
commits = self._create_commits()
for commit in commits:
operation = commit['operation']
kwargs = connector.transforms[operation](commit['kwargs'])
getattr(connector,connector.operation_mapping[operation])(**kwargs)

for annotator in self._annotators.values():
annotator.annotation_table.clear_edits()

return commits

def clear_edits(self, edit_type=None):
"Clear edit state and index mapping"
Expand Down Expand Up @@ -299,7 +309,6 @@ def define_points(self, dims, posx, posy=None):
"_id":[self._index_mapping[ind] for ind in posx.index]})
self._region_df = pd.concat((self._region_df, additions), ignore_index=True)


def define_ranges(self, dims, startx, endx, starty=None, endy=None):
if isinstance(dims, str):
dims = (dims,)
Expand Down Expand Up @@ -345,3 +354,42 @@ def _mask2D(self, kdims):
self._region_df["dim1"] == dim1_name, self._region_df["dim2"] == dim2_name
)

def load_annotation_table(self, conn, fields):
df = conn.transforms['load'](conn.load_dataframe())
fields_df = df[fields].copy()
self.define_fields(fields_df, {ind:ind for ind in fields_df.index})
all_region_types = [an.region_types for an in self._annotators.values()]
all_kdim_dtypes = [an.kdim_dtypes for an in self._annotators.values()]
for region_types, kdim_dtypes in zip(all_region_types, all_kdim_dtypes):
assert all(el in ['Range', 'Point'] for el in region_types)
for region_type in region_types:
if len(kdim_dtypes)==1:
kdim = list(kdim_dtypes.keys())[0]
if region_type == 'Range':
expected_keys = [f'start_{kdim}', f'end_{kdim}']
conn._incompatible_schema_check(expected_keys, list(df.columns), fields, region_type)
self.define_ranges(kdim, df[f'start_{kdim}'], df[f'end_{kdim}'])
elif region_type == 'Point':
conn._incompatible_schema_check([f'point_{kdim}'], list(df.columns), fields, region_type)
self.define_points(kdim, df[f'point_{kdim}'])
elif len(kdim_dtypes)==2:
kdim1, kdim2 = list(kdim_dtypes.keys())
if region_type == 'Range':
conn._incompatible_schema_check([f'start_{kdim1}', f'end_{kdim1}',
f'start_{kdim2}', f'end_{kdim2}'],
list(df.columns), fields, region_type)
self.define_ranges([kdim1, kdim2],
df[f'start_{kdim1}'], df[f'end_{kdim1}'],
df[f'start_{kdim2}'], df[f'end_{kdim2}'])
elif region_type == 'Point':
conn._incompatible_schema_check([f'point_{kdim1}', f'point_{kdim2}'],
list(df.columns), fields, region_type)
self.define_points([kdim1, kdim2], df[f'point_{kdim1}'], df[f'point_{kdim2}'])
self.clear_edits()

def add_schema_to_conn(self, conn):
field_dtypes = {col: str for col in conn.fields} # FIXME - generalize
all_region_types = [an.region_types for an in self._annotators.values()]
all_kdim_dtypes = [an.kdim_dtypes for an in self._annotators.values()]
schema = conn.generate_schema(conn.primary_key, all_region_types, all_kdim_dtypes, field_dtypes)
conn.add_schema(schema)
32 changes: 24 additions & 8 deletions holonote/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,33 @@ def annotator_point2d(conn_sqlite_uuid) -> Annotator:


@pytest.fixture()
def multiple_region_annotator(annotator_range1d) -> Annotator:
annotator_range1d.region_types = ["Point", "Range"]
return annotator_range1d
def multiple_region_annotator(conn_sqlite_uuid) -> Annotator:
return Annotator(
{"TIME": np.datetime64},
fields=["description"],
region_types=["Point", "Range"],
connector=conn_sqlite_uuid,
)


@pytest.fixture()
def multiple_annotators(
conn_sqlite_uuid, annotator_range1d, annotator_range2d
) -> dict[str, Annotator | SQLiteDB]:
annotator_range1d.connector = conn_sqlite_uuid
annotator_range2d.connector = conn_sqlite_uuid
def multiple_annotators(conn_sqlite_uuid) -> dict[str, Annotator | SQLiteDB]:
annotator_range1d = Annotator(
{"TIME": np.datetime64},
fields=["description"],
region_types=["Range"],
connector=conn_sqlite_uuid,
init=False,
)
annotator_range2d = Annotator(
{"x": float, "y": float},
fields=["description"],
region_types=["Range"],
connector=conn_sqlite_uuid,
)

annotator_range1d.load()
annotator_range2d.load()
output = {
"annotation1d": annotator_range1d,
"annotation2d": annotator_range2d,
Expand Down
Loading

0 comments on commit 90169aa

Please sign in to comment.