diff --git a/holonote/annotate/annotator.py b/holonote/annotate/annotator.py index cc58f5b..192e8fd 100644 --- a/holonote/annotate/annotator.py +++ b/holonote/annotate/annotator.py @@ -6,7 +6,7 @@ import param from bokeh.models.tools import BoxSelectTool, HoverTool from holoviews.element.selection import Selection1DExpr -from .connector import Connector, SQLiteDB +from .connector import Connector, SQLiteDB, AnnotationTable @@ -101,23 +101,30 @@ class AnnotatorInterface(param.Parameterized): connector = param.ClassSelector(class_=Connector, allow_None=False) + annotation_table = param.ClassSelector(class_=AnnotationTable, allow_None=False) + region_types = param.ListSelector(default=['Range'], objects=['Range', 'Point'], doc=""" Enabled region types for the current Annotator.""") connector_class = SQLiteDB - def __init__(self, **params): + def __init__(self, init=True, **params): + if "annotation_table" not in params: + params["annotation_table"] = AnnotationTable() + super().__init__(**params) self._region = {} self._last_region = None + self.annotation_table.register_annotator(self) + self.annotation_table.add_schema_to_conn(self.connector) - self.annotation_table.load(self.connector, fields=self.connector.fields) - - @property - def annotation_table(self): - return self.connector.annotation_table + if init: + self.load() + def load(self): + self.connector._initialize(self.connector.column_schema) + self.annotation_table.load(self.connector, fields=self.connector.fields) @property def df(self): @@ -193,6 +200,8 @@ def _get_range_indices_by_position(self, dim1_pos, dim2_pos=None): if dim2_pos is not None: mask_dim2 = ranges['value'].apply(lambda tpl: tpl[2] <= dim2_pos < tpl[3]) mask = mask & mask_dim2 + if not len(mask): + return [] return list(ranges[mask]['_id']) @@ -203,7 +212,6 @@ def get_indices_by_position(self, dim1_pos, dim2_pos=None): event_matches = [] # TODO: Needs hit testing or similar for point events return range_matches + event_matches - @property def region(self): return self._region @@ -300,6 +308,11 @@ def revert_to_snapshot(self): def snapshot(self): self.annotation_table.snapshot() + def commit(self, return_commits=False): + # self.annotation_table.initialize_table(self.connector) # Only if not in params + commits = self.annotation_table.commits(self.connector) + if return_commits: + return commits class Annotator(AnnotatorInterface): @@ -709,9 +722,6 @@ def revert_to_snapshot(self): super().revert_to_snapshot() self.refresh() - def commit(self): - self.connector.commit() - def set_range(self, startx, endx, starty=None, endy=None): super().set_range(startx, endx, starty, endy) self.show_region() diff --git a/holonote/annotate/connector.py b/holonote/annotate/connector.py index 9144082..3a82f39 100644 --- a/holonote/annotate/connector.py +++ b/holonote/annotate/connector.py @@ -174,10 +174,6 @@ class Connector(param.Parameterized): np.float64: 'REAL', } - def __init__(self, **params): - self.annotation_table = None - super().__init__(**params) - @classmethod def field_value_to_type(cls, value): if isinstance(value, list): @@ -222,31 +218,6 @@ def generate_schema(cls, primary_key, all_region_types, all_kdim_dtypes, field_t return dict(schemas, **cls.schema_from_field_types(field_types)) - def _initialize_annotation_table(self): - if self.annotation_table is None: - self.annotation_table = AnnotationTable() - - def commit(self): - "Applies the commit hook to the connector" - - if self.uninitialized: - self.initialize_table() - - if self.commit_hook is not None: - self.commit_hook() - else: - self.commit_default_schema() - - def commit_default_schema(self): - for commit in self.annotation_table.commits(): - operation = commit['operation'] - kwargs = self.transforms[operation](commit['kwargs']) - getattr(self,self.operation_mapping[operation])(**kwargs) - - for annotator in self.annotation_table._annotators.values(): - annotator.annotation_table.clear_edits() - - def _incompatible_schema_check(self, expected_keys, columns, fields, region_type): msg_prefix = ("Unable to read annotations that were stored with a " "schema inconsistent with the current settings: ") @@ -263,50 +234,6 @@ def _incompatible_schema_check(self, expected_keys, columns, fields, region_type + f'Missing {repr(region_type)} region columns {missing_region_columns}. ' + msg_suffix) - def load_annotation_table(self, annotation_table, fields): - df = self.transforms['load'](self.load_dataframe()) - fields_df = df[fields].copy() - annotation_table.define_fields(fields_df, {ind:ind for ind in fields_df.index}) - all_region_types = [an.region_types for an in annotation_table._annotators.values()] - all_kdim_dtypes = [an.kdim_dtypes for an in annotation_table._annotators.values()] - for region_types, kdim_dtypes in zip(all_region_types, all_kdim_dtypes): - assert all(el in ['Range', 'Point'] for el in region_types) - for region_type in region_types: - if len(kdim_dtypes)==1: - kdim = list(kdim_dtypes.keys())[0] - if region_type == 'Range': - expected_keys = [f'start_{kdim}', f'end_{kdim}'] - self._incompatible_schema_check(expected_keys, list(df.columns), fields, region_type) - annotation_table.define_ranges(kdim, df[f'start_{kdim}'], df[f'end_{kdim}']) - elif region_type == 'Point': - self._incompatible_schema_check([f'point_{kdim}'], list(df.columns), fields, region_type) - annotation_table.define_points(kdim, df[f'point_{kdim}']) - elif len(kdim_dtypes)==2: - kdim1, kdim2 = list(kdim_dtypes.keys()) - if region_type == 'Range': - self._incompatible_schema_check([f'start_{kdim1}', f'end_{kdim1}', - f'start_{kdim2}', f'end_{kdim2}'], - list(df.columns), fields, region_type) - annotation_table.define_ranges([kdim1, kdim2], - df[f'start_{kdim1}'], df[f'end_{kdim1}'], - df[f'start_{kdim2}'], df[f'end_{kdim2}']) - elif region_type == 'Point': - self._incompatible_schema_check([f'point_{kdim1}', f'point_{kdim2}'], - list(df.columns), fields, region_type) - annotation_table.define_points([kdim1, kdim2], df[f'point_{kdim1}'], df[f'point_{kdim2}']) - annotation_table.clear_edits() - - - def add_annotation(self, **fields): - "Primary key specification is optional. Used to works across Annotation instances." - if self.primary_key.field_name not in fields: - index_val = self.primary_key(self, list(self.annotation_table._field_df.index)) - fields[self.primary_key.field_name] = index_val - self.annotation_table.add_annotation('annotator-regions', **fields) - - - for annotator in self.annotation_table._annotators.values(): - annotator.refresh(clear=True) class SQLiteDB(Connector): """ @@ -327,15 +254,17 @@ class SQLiteDB(Connector): 'update':'update_row'} - def __init__(self, column_schema={}, connect=True, **params): + def __init__(self, column_schema=None, connect=True, **params): """ First key in column_schema is assumed to the primary key field if not explicitly specified. """ + if column_schema is None: + column_schema = {} + params['column_schema'] = column_schema self.con, self.cursor = None, None super().__init__(**params) - if connect: self._initialize(column_schema, create_table=False) @@ -344,12 +273,10 @@ def _initialize(self, column_schema, create_table=True): if self.con is None: self.con = sqlite3.connect(self.filename, detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES) - self.cursor = self.con.cursor() + self.cursor = self.con.cursor() # should be context manager if create_table: self.create_table(column_schema=column_schema) - super()._initialize_annotation_table() - @property def uninitialized(self): if self.con is not None: @@ -394,13 +321,6 @@ def create_table(self, column_schema=None): self.cursor.execute(create_table_sql) self.con.commit() - def initialize_table(self): - field_dtypes = {col:str for col in self.fields} # FIXME - generalize - all_region_types = [an.region_types for an in self.annotation_table._annotators.values()] - all_kdim_dtypes = [an.kdim_dtypes for an in self.annotation_table._annotators.values()] - schema = self.generate_schema(self.primary_key, all_region_types, all_kdim_dtypes, field_dtypes) - self.create_table(schema) - def delete_table(self): self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}") self.con.commit() @@ -441,3 +361,7 @@ def update_row(self, **updates): # updates as a dictionary OR remove posarg? query = f"UPDATE {self.table_name} SET " + set_updates + f" WHERE \"{self.primary_key.field_name}\" = ?;" self.cursor.execute(query, list(updates.values()) + [id_val]) self.con.commit() + + def add_schema(self, schema): + # TODO: Check if schema don't overwrite existing columns + self.column_schema = {**self.column_schema, **schema} diff --git a/holonote/annotate/table.py b/holonote/annotate/table.py index 1294a81..dba19e7 100644 --- a/holonote/annotate/table.py +++ b/holonote/annotate/table.py @@ -34,10 +34,13 @@ def __init__(self, **params): self._annotators = weakref.WeakValueDictionary() - def load(self, connector=None, fields_df=None, primary_key_name=None, fields=[]): + def load(self, connector=None, fields_df=None, primary_key_name=None, fields=None): """ Load the AnnotationTable from a connector or a fields DataFrame. """ + if fields is None: + fields = [] + if [connector, primary_key_name] == [None,None]: raise ValueError('Either a connector instance must be supplied or the primary key name supplied') if len(fields) < 1: @@ -47,17 +50,13 @@ def load(self, connector=None, fields_df=None, primary_key_name=None, fields=[]) if fields_df: fields_df = fields_df[fields].copy() # Primary key/index for annotations self._field_df = fields_df - elif connector and not connector.uninitialized: - connector.load_annotation_table(self, fields) + elif connector: + self.load_annotation_table(connector, fields) elif fields_df is None: fields_df = pd.DataFrame(columns=[primary_key_name] + fields) fields_df = fields_df.set_index(primary_key_name) self._field_df = fields_df - # FIXME: Proper solution is to only load relevant columns - self._field_df = self._field_df.drop_duplicates() - self._region_df = self._region_df.drop_duplicates() - self.clear_edits() self._update_index() @@ -143,7 +142,7 @@ def _expand_commit_by_id(self, id_val, fields=None, region_fields=None): def _expand_save_commits(self, ids): return {'field_list':[self._expand_commit_by_id(id_val) for id_val in ids]} - def commits(self): + def _create_commits(self): "Expands out the commit history into commit operations" fields_inds = set(self._field_df.index) region_inds = set(self._region_df['_id'].unique()) @@ -174,6 +173,17 @@ def commits(self): return commits + def commits(self, connector): + commits = self._create_commits() + for commit in commits: + operation = commit['operation'] + kwargs = connector.transforms[operation](commit['kwargs']) + getattr(connector,connector.operation_mapping[operation])(**kwargs) + + for annotator in self._annotators.values(): + annotator.annotation_table.clear_edits() + + return commits def clear_edits(self, edit_type=None): "Clear edit state and index mapping" @@ -299,7 +309,6 @@ def define_points(self, dims, posx, posy=None): "_id":[self._index_mapping[ind] for ind in posx.index]}) self._region_df = pd.concat((self._region_df, additions), ignore_index=True) - def define_ranges(self, dims, startx, endx, starty=None, endy=None): if isinstance(dims, str): dims = (dims,) @@ -345,3 +354,42 @@ def _mask2D(self, kdims): self._region_df["dim1"] == dim1_name, self._region_df["dim2"] == dim2_name ) + def load_annotation_table(self, conn, fields): + df = conn.transforms['load'](conn.load_dataframe()) + fields_df = df[fields].copy() + self.define_fields(fields_df, {ind:ind for ind in fields_df.index}) + all_region_types = [an.region_types for an in self._annotators.values()] + all_kdim_dtypes = [an.kdim_dtypes for an in self._annotators.values()] + for region_types, kdim_dtypes in zip(all_region_types, all_kdim_dtypes): + assert all(el in ['Range', 'Point'] for el in region_types) + for region_type in region_types: + if len(kdim_dtypes)==1: + kdim = list(kdim_dtypes.keys())[0] + if region_type == 'Range': + expected_keys = [f'start_{kdim}', f'end_{kdim}'] + conn._incompatible_schema_check(expected_keys, list(df.columns), fields, region_type) + self.define_ranges(kdim, df[f'start_{kdim}'], df[f'end_{kdim}']) + elif region_type == 'Point': + conn._incompatible_schema_check([f'point_{kdim}'], list(df.columns), fields, region_type) + self.define_points(kdim, df[f'point_{kdim}']) + elif len(kdim_dtypes)==2: + kdim1, kdim2 = list(kdim_dtypes.keys()) + if region_type == 'Range': + conn._incompatible_schema_check([f'start_{kdim1}', f'end_{kdim1}', + f'start_{kdim2}', f'end_{kdim2}'], + list(df.columns), fields, region_type) + self.define_ranges([kdim1, kdim2], + df[f'start_{kdim1}'], df[f'end_{kdim1}'], + df[f'start_{kdim2}'], df[f'end_{kdim2}']) + elif region_type == 'Point': + conn._incompatible_schema_check([f'point_{kdim1}', f'point_{kdim2}'], + list(df.columns), fields, region_type) + self.define_points([kdim1, kdim2], df[f'point_{kdim1}'], df[f'point_{kdim2}']) + self.clear_edits() + + def add_schema_to_conn(self, conn): + field_dtypes = {col: str for col in conn.fields} # FIXME - generalize + all_region_types = [an.region_types for an in self._annotators.values()] + all_kdim_dtypes = [an.kdim_dtypes for an in self._annotators.values()] + schema = conn.generate_schema(conn.primary_key, all_region_types, all_kdim_dtypes, field_dtypes) + conn.add_schema(schema) diff --git a/holonote/tests/conftest.py b/holonote/tests/conftest.py index ee85496..5a23cc4 100644 --- a/holonote/tests/conftest.py +++ b/holonote/tests/conftest.py @@ -67,17 +67,33 @@ def annotator_point2d(conn_sqlite_uuid) -> Annotator: @pytest.fixture() -def multiple_region_annotator(annotator_range1d) -> Annotator: - annotator_range1d.region_types = ["Point", "Range"] - return annotator_range1d +def multiple_region_annotator(conn_sqlite_uuid) -> Annotator: + return Annotator( + {"TIME": np.datetime64}, + fields=["description"], + region_types=["Point", "Range"], + connector=conn_sqlite_uuid, + ) @pytest.fixture() -def multiple_annotators( - conn_sqlite_uuid, annotator_range1d, annotator_range2d -) -> dict[str, Annotator | SQLiteDB]: - annotator_range1d.connector = conn_sqlite_uuid - annotator_range2d.connector = conn_sqlite_uuid +def multiple_annotators(conn_sqlite_uuid) -> dict[str, Annotator | SQLiteDB]: + annotator_range1d = Annotator( + {"TIME": np.datetime64}, + fields=["description"], + region_types=["Range"], + connector=conn_sqlite_uuid, + init=False, + ) + annotator_range2d = Annotator( + {"x": float, "y": float}, + fields=["description"], + region_types=["Range"], + connector=conn_sqlite_uuid, + ) + + annotator_range1d.load() + annotator_range2d.load() output = { "annotation1d": annotator_range1d, "annotation2d": annotator_range2d, diff --git a/holonote/tests/test_annotators_advanced.py b/holonote/tests/test_annotators_advanced.py index eee86e6..40b73d0 100644 --- a/holonote/tests/test_annotators_advanced.py +++ b/holonote/tests/test_annotators_advanced.py @@ -1,8 +1,9 @@ import holoviews as hv import numpy as np import pandas as pd +import pytest -from holonote.annotate import Annotator +from holonote.annotate import Annotator, SQLiteDB def test_multipoint_range_commit_insertion(multiple_region_annotator): @@ -15,7 +16,7 @@ def test_multipoint_range_commit_insertion(multiple_region_annotator): multiple_region_annotator.set_range(start, end) multiple_region_annotator.add_annotation(description=descriptions[1]) - multiple_region_annotator.commit() + multiple_region_annotator.commit(return_commits=True) # FIXME! Index order is inverted? df = pd.DataFrame({'uuid': pd.Series(multiple_region_annotator.df.index[::-1], dtype=object), @@ -44,7 +45,10 @@ def test_infer_kdim_dtype_curve(): def test_multiplot_add_annotation(multiple_annotators): multiple_annotators["annotation1d"].set_range(np.datetime64('2005-02-13'), np.datetime64('2005-02-16')) multiple_annotators["annotation2d"].set_range(-0.25, 0.25, -0.1, 0.1) - multiple_annotators["conn"].add_annotation(description='Multi-plot annotation') + multiple_annotators["annotation1d"].add_annotation(description='Multi-plot annotation') + multiple_annotators["annotation2d"].add_annotation(description='Multi-plot annotation') + multiple_annotators["annotation1d"].commit() + multiple_annotators["annotation2d"].commit() class TestAnnotatorMultipleStringFields: @@ -53,7 +57,7 @@ def test_insertion_values(self, multiple_fields_annotator): start, end = np.datetime64('2022-06-06'), np.datetime64('2022-06-08') multiple_fields_annotator.set_range(start, end) multiple_fields_annotator.add_annotation(field1='A test field', field2='Another test field') - commits = multiple_fields_annotator.annotation_table.commits() + commits = multiple_fields_annotator.commit(return_commits=True) kwargs = commits[0]['kwargs'] assert len(commits)==1, 'Only one insertion commit made' assert 'uuid' in kwargs.keys(), 'Expected uuid primary key in kwargs' @@ -67,7 +71,7 @@ def test_commit_insertion(self, multiple_fields_annotator): field2 = 'Another test field' multiple_fields_annotator.set_range(start, end) multiple_fields_annotator.add_annotation(field1=field1, field2=field2) - multiple_fields_annotator.commit() + multiple_fields_annotator.commit(return_commits=True) df = pd.DataFrame({'uuid': pd.Series(multiple_fields_annotator.df.index[0], dtype=object), 'start_TIME':[start], @@ -87,8 +91,52 @@ def test_commit_update(self, multiple_fields_annotator): multiple_fields_annotator.add_annotation(field1='Field 1.1', field2='Field 1.2') multiple_fields_annotator.set_range(start2, end2) multiple_fields_annotator.add_annotation(field1='Field 2.1', field2='Field 2.2') - multiple_fields_annotator.commit() + multiple_fields_annotator.commit(return_commits=True) multiple_fields_annotator.update_annotation_fields(multiple_fields_annotator.df.index[0], field1='NEW Field 1.1') - multiple_fields_annotator.commit() + multiple_fields_annotator.commit(return_commits=True) sql_df = multiple_fields_annotator.connector.load_dataframe() assert set(sql_df['field1']) == {'NEW Field 1.1', 'Field 2.1'} + + +@pytest.mark.parametrize("method", ["new", "same"]) +def test_reconnect(method, tmp_path): + db_path = str(tmp_path / "test.db") + + if method == "new": + conn1 = SQLiteDB(filename=db_path) + conn2 = SQLiteDB(filename=db_path) + elif method == "same": + conn1 = conn2 = SQLiteDB(filename=db_path) + + # Create annotator with data and commit + a1 = Annotator( + spec={"TIME": np.datetime64}, + fields=["description"], + region_types=["Range"], + connector=conn1, + ) + times = pd.date_range("2022-06-09", "2022-06-13") + for t1, t2 in zip(times[:-1], times[1:]): + a1.set_range(t1, t2) + a1.add_annotation(description='A programmatically defined annotation') + a1.commit(return_commits=True) + + # Save internal dataframes + a1_df = a1.df.copy() + a1_region = a1.annotation_table._region_df.copy() + a1_field = a1.annotation_table._field_df.copy() + + # Add new connector + a2 = Annotator( + spec={"TIME": np.datetime64}, + fields=["description"], + region_types=["Range"], + connector=conn2, + ) + a2_df = a2.df.copy() + a2_region = a2.annotation_table._region_df.copy() + a2_field = a2.annotation_table._field_df.copy() + + pd.testing.assert_frame_equal(a1_df, a2_df) + pd.testing.assert_frame_equal(a1_region, a2_region) + pd.testing.assert_frame_equal(a1_field, a2_field) diff --git a/holonote/tests/test_annotators_basic.py b/holonote/tests/test_annotators_basic.py index 2c91c31..2338fef 100644 --- a/holonote/tests/test_annotators_basic.py +++ b/holonote/tests/test_annotators_basic.py @@ -24,9 +24,9 @@ def test_point_insertion_exception(self, annotator_range1d): def test_insertion_edit_table_columns(self, annotator_range1d): annotator_range1d.set_range(np.datetime64('2022-06-06'), np.datetime64('2022-06-08')) annotator_range1d.add_annotation(description='A test annotation!') - commits = annotator_range1d.annotation_table.commits() + commits = annotator_range1d.commit(return_commits=True) assert len(commits)==1, 'Only one insertion commit made ' - annotator_range1d.commit() + annotator_range1d.commit(return_commits=True) assert commits[0]['operation'] == 'insert' assert set(commits[0]['kwargs'].keys()) == set(annotator_range1d.connector.columns) @@ -34,7 +34,7 @@ def test_range_insertion_values(self, annotator_range1d) -> None: start, end = np.datetime64('2022-06-06'), np.datetime64('2022-06-08') annotator_range1d.set_range(start, end) annotator_range1d.add_annotation(description='A test annotation!') - commits = annotator_range1d.annotation_table.commits() + commits = annotator_range1d.commit(return_commits=True) assert len(commits)==1, 'Only one insertion commit made' kwargs = commits[0]['kwargs'] assert 'uuid' in kwargs.keys(), 'Expected uuid primary key in kwargs' @@ -46,7 +46,7 @@ def test_range_commit_insertion(self, annotator_range1d): description = 'A test annotation!' annotator_range1d.set_range(start, end) annotator_range1d.add_annotation(description=description) - annotator_range1d.commit() + annotator_range1d.commit(return_commits=True) df = pd.DataFrame({'uuid': pd.Series(annotator_range1d.df.index[0], dtype=object), 'start_TIME':[start], @@ -68,12 +68,12 @@ def test_range_addition_deletion_by_uuid(self, annotator_range1d): annotator_range1d.add_annotation(description='Annotation 2', uuid='08286429') annotator_range1d.set_range(start3, end3) annotator_range1d.add_annotation(description='Annotation 3') - annotator_range1d.commit() + annotator_range1d.commit(return_commits=True) sql_df = annotator_range1d.connector.load_dataframe() assert set(sql_df['description']) ==set(['Annotation 1', 'Annotation 2', 'Annotation 3']) deletion_index = sql_df.index[1] annotator_range1d.delete_annotation(deletion_index) - annotator_range1d.commit() + annotator_range1d.commit(return_commits=True) sql_df = annotator_range1d.connector.load_dataframe() assert set(sql_df['description']) == set(['Annotation 1', 'Annotation 3']) @@ -89,7 +89,7 @@ def test_range_define_preserved_index_mismatch(self, annotator_range1d): annotator_range1d.define_ranges(data['start'].iloc[:2], data['end'].iloc[:2]) msg = f"Following annotations have no associated region: {{{annotation_id[2]!r}}}" with pytest.raises(ValueError, match=msg): - annotator_range1d.commit() + annotator_range1d.commit(return_commits=True) def test_range_define_auto_index_mismatch(self, annotator_range1d): starts = [np.datetime64('2022-06-%.2d' % d) for d in range(6,15, 4)] @@ -103,7 +103,7 @@ def test_range_define_auto_index_mismatch(self, annotator_range1d): annotator_range1d.define_ranges(data['start'].iloc[:2], data['end'].iloc[:2]) with pytest.raises(ValueError, match="Following annotations have no associated region:"): - annotator_range1d.commit() + annotator_range1d.commit(return_commits=True) def test_range_define_unassigned_indices(self, annotator_range1d): starts = [np.datetime64('2022-06-%.2d' % d) for d in range(6,15, 4)] @@ -134,9 +134,8 @@ def test_point_insertion_exception(self, annotator_range2d): def test_insertion_edit_table_columns(self, annotator_range2d): annotator_range2d.set_range(-0.25, 0.25, -0.1, 0.1) annotator_range2d.add_annotation(description='A test annotation!') - commits = annotator_range2d.annotation_table.commits() + commits = annotator_range2d.commit(return_commits=True) assert len(commits)==1, 'Only one insertion commit made ' - annotator_range2d.commit() assert commits[0]['operation'] == 'insert' assert set(commits[0]['kwargs'].keys()) == set(annotator_range2d.connector.columns) @@ -144,7 +143,7 @@ def test_range_insertion_values(self, annotator_range2d): startx, endx, starty, endy = -0.25, 0.25, -0.1, 0.1 annotator_range2d.set_range(startx, endx, starty, endy) annotator_range2d.add_annotation(description='A test annotation!') - commits = annotator_range2d.annotation_table.commits() + commits = annotator_range2d.commit(return_commits=True) assert len(commits)==1, 'Only one insertion commit made' kwargs = commits[0]['kwargs'] assert 'uuid' in kwargs.keys(), 'Expected uuid primary key in kwargs' @@ -157,7 +156,7 @@ def test_range_commit_insertion(self, annotator_range2d): description = 'A test annotation!' annotator_range2d.set_range(startx, endx, starty, endy) annotator_range2d.add_annotation(description=description) - annotator_range2d.commit() + annotator_range2d.commit(return_commits=True) df = pd.DataFrame({'uuid': pd.Series(annotator_range2d.df.index[0], dtype=object), 'start_x':[startx], @@ -181,12 +180,12 @@ def test_range_addition_deletion_by_uuid(self, annotator_range2d): annotator_range2d.add_annotation(description='Annotation 2', uuid='08286429') annotator_range2d.set_range(startx3, endx3, starty3, endy3) annotator_range2d.add_annotation(description='Annotation 3') - annotator_range2d.commit() + annotator_range2d.commit(return_commits=True) sql_df = annotator_range2d.connector.load_dataframe() assert set(sql_df['description']) == set(['Annotation 1', 'Annotation 2', 'Annotation 3']) deletion_index = sql_df.index[1] annotator_range2d.delete_annotation(deletion_index) - annotator_range2d.commit() + annotator_range2d.commit(return_commits=True) sql_df = annotator_range2d.connector.load_dataframe() assert set(sql_df['description']) == set(['Annotation 1', 'Annotation 3']) @@ -206,7 +205,7 @@ def test_range_define_preserved_index_mismatch(self, annotator_range2d): msg = f"Following annotations have no associated region: {{{annotation_id[2]!r}}}" with pytest.raises(ValueError, match=msg): - annotator_range2d.commit() + annotator_range2d.commit(return_commits=True) def test_range_define_auto_index_mismatch(self, annotator_range2d): xstarts, xends = [-0.3, -0.2, -0.1], [0.3, 0.2, 0.1] @@ -221,7 +220,7 @@ def test_range_define_auto_index_mismatch(self, annotator_range2d): data['ystart'].iloc[:2], data['yend'].iloc[:2]) msg = "Following annotations have no associated region:" with pytest.raises(ValueError, match=msg): - annotator_range2d.commit() + annotator_range2d.commit(return_commits=True) def test_range_define_unassigned_indices(self, annotator_range2d): xstarts, xends = [-0.3, -0.2, -0.1], [0.3, 0.2, 0.1] @@ -249,9 +248,9 @@ class TestBasicPoint1DAnnotator: def test_insertion_edit_table_columns(self, annotator_point1d): annotator_point1d.set_point(np.datetime64('2022-06-06')) annotator_point1d.add_annotation(description='A test annotation!') - commits = annotator_point1d.annotation_table.commits() + commits = annotator_point1d.commit(return_commits=True) assert len(commits)==1, 'Only one insertion commit made ' - annotator_point1d.commit() + annotator_point1d.commit(return_commits=True) assert commits[0]['operation'] == 'insert' assert set(commits[0]['kwargs'].keys()) == set(annotator_point1d.connector.columns) @@ -265,7 +264,7 @@ def test_point_insertion_values(self, annotator_point1d): timestamp = np.datetime64('2022-06-06') annotator_point1d.set_point(timestamp) annotator_point1d.add_annotation(description='A test annotation!') - commits = annotator_point1d.annotation_table.commits() + commits = annotator_point1d.commit(return_commits=True) assert len(commits)==1, 'Only one insertion commit made' kwargs = commits[0]['kwargs'] assert 'uuid' in kwargs.keys(), 'Expected uuid primary key in kwargs' @@ -277,7 +276,7 @@ def test_point_commit_insertion(self, annotator_point1d): description = 'A test annotation!' annotator_point1d.set_point(timestamp) annotator_point1d.add_annotation(description=description) - annotator_point1d.commit() + annotator_point1d.commit(return_commits=True) df = pd.DataFrame({'uuid': pd.Series(annotator_point1d.df.index[0], dtype=object), 'point_TIME':[timestamp], @@ -298,12 +297,12 @@ def test_point_addition_deletion_by_uuid(self, annotator_point1d): annotator_point1d.add_annotation(description='Annotation 2', uuid='08286429') annotator_point1d.set_point(ts3) annotator_point1d.add_annotation(description='Annotation 3') - annotator_point1d.commit() + annotator_point1d.commit(return_commits=True) sql_df = annotator_point1d.connector.load_dataframe() assert set(sql_df['description']) == set(['Annotation 1', 'Annotation 2', 'Annotation 3']) deletion_index = sql_df.index[1] annotator_point1d.delete_annotation(deletion_index) - annotator_point1d.commit() + annotator_point1d.commit(return_commits=True) sql_df = annotator_point1d.connector.load_dataframe() assert set(sql_df['description']) == set(['Annotation 1', 'Annotation 3']) @@ -318,7 +317,7 @@ def test_point_define_preserved_index_mismatch(self, annotator_point1d): annotator_point1d.define_points(data['timestamps'].iloc[:2]) msg = f"Following annotations have no associated region: {{{annotation_id[2]!r}}}" with pytest.raises(ValueError, match=msg): - annotator_point1d.commit() + annotator_point1d.commit(return_commits=True) def test_point_define_auto_index_mismatch(self, annotator_point1d): timestamps = [np.datetime64('2022-06-%.2d' % d) for d in range(6,15, 4)] @@ -330,7 +329,7 @@ def test_point_define_auto_index_mismatch(self, annotator_point1d): annotator_point1d.define_fields(data[['description']], preserve_index=False) annotator_point1d.define_points(data['timestamps'].iloc[:2]) with pytest.raises(ValueError, match="Following annotations have no associated region:"): - annotator_point1d.commit() + annotator_point1d.commit(return_commits=True) def test_point_define_unassigned_indices(self, annotator_point1d): timestamps = [np.datetime64('2022-06-%.2d' % d) for d in range(6,15, 4)] @@ -354,9 +353,8 @@ class TestBasicPoint2DAnnotator: def test_insertion_edit_table_columns(self, annotator_point2d): annotator_point2d.set_point(-0.25, 0.1) annotator_point2d.add_annotation(description='A test annotation!') - commits = annotator_point2d.annotation_table.commits() + commits = annotator_point2d.commit(return_commits=True) assert len(commits)==1, 'Only one insertion commit made ' - annotator_point2d.commit() assert commits[0]['operation'] == 'insert' assert set(commits[0]['kwargs'].keys()) == set(annotator_point2d.connector.columns) @@ -370,7 +368,7 @@ def test_point_insertion_values(self, annotator_point2d): x,y = 0.5, 0.3 annotator_point2d.set_point(x,y) annotator_point2d.add_annotation(description='A test annotation!') - commits = annotator_point2d.annotation_table.commits() + commits = annotator_point2d.commit(return_commits=True) assert len(commits)==1, 'Only one insertion commit made' kwargs = commits[0]['kwargs'] assert 'uuid' in kwargs.keys(), 'Expected uuid primary key in kwargs' @@ -382,7 +380,7 @@ def test_point_commit_insertion(self, annotator_point2d): description = 'A test annotation!' annotator_point2d.set_point(x,y) annotator_point2d.add_annotation(description=description) - annotator_point2d.commit() + annotator_point2d.commit(return_commits=True) df = pd.DataFrame({'uuid': pd.Series(annotator_point2d.df.index[0], dtype=object), 'point_x':[x], @@ -404,12 +402,12 @@ def test_point_addition_deletion_by_uuid(self, annotator_point2d): annotator_point2d.add_annotation(description='Annotation 2', uuid='08286429') annotator_point2d.set_point(x3, y3) annotator_point2d.add_annotation(description='Annotation 3') - annotator_point2d.commit() + annotator_point2d.commit(return_commits=True) sql_df = annotator_point2d.connector.load_dataframe() assert set(sql_df['description']) == set(['Annotation 1', 'Annotation 2', 'Annotation 3']) deletion_index = sql_df.index[1] annotator_point2d.delete_annotation(deletion_index) - annotator_point2d.commit() + annotator_point2d.commit(return_commits=True) sql_df = annotator_point2d.connector.load_dataframe() assert set(sql_df['description']) == set(['Annotation 1', 'Annotation 3']) @@ -424,7 +422,7 @@ def test_point_define_preserved_index_mismatch(self, annotator_point2d): annotator_point2d.define_points(data['xs'].iloc[:2], data['ys'].iloc[:2]) msg = f"Following annotations have no associated region: {{{annotation_id[2]!r}}}" with pytest.raises(ValueError, match=msg): - annotator_point2d.commit() + annotator_point2d.commit(return_commits=True) def test_point_define_auto_index_mismatch(self, annotator_point2d): xs, ys = [-0.1,-0.2,-0.3], [0.1,0.2,0.3] @@ -437,7 +435,7 @@ def test_point_define_auto_index_mismatch(self, annotator_point2d): annotator_point2d.define_points(data['xs'].iloc[:2], data['ys'].iloc[:2]) msg = "Following annotations have no associated region:" with pytest.raises(ValueError, match=msg): - annotator_point2d.commit() + annotator_point2d.commit(return_commits=True) def test_point_define_unassigned_indices(self, annotator_point2d): xs, ys = [-0.1,-0.2,-0.3], [0.1,0.2,0.3]