diff --git a/packages/syft/src/syft/client/datasite_client.py b/packages/syft/src/syft/client/datasite_client.py index 7553344ad5a..d72011f2c18 100644 --- a/packages/syft/src/syft/client/datasite_client.py +++ b/packages/syft/src/syft/client/datasite_client.py @@ -24,6 +24,7 @@ from ..service.dataset.dataset import Contributor from ..service.dataset.dataset import CreateAsset from ..service.dataset.dataset import CreateDataset +from ..service.dataset.dataset import Dataset from ..service.dataset.dataset import _check_asset_must_contain_mock from ..service.migration.object_migration_state import MigrationData from ..service.request.request import Request @@ -36,6 +37,7 @@ from ..service.user.user import UserView from ..types.blob_storage import BlobFile from ..types.errors import SyftException +from ..types.twin_object import TwinObject from ..types.uid import UID from ..util.misc_objs import HTMLObject from ..util.util import get_mb_size @@ -98,27 +100,24 @@ class DatasiteClient(SyftClient): def __repr__(self) -> str: return f"" - def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess: - # relative - from ..types.twin_object import TwinObject - + def upload_dataset( + self, dataset: CreateDataset, force_replace: bool = False + ) -> SyftSuccess | SyftError: if self.users is None: raise SyftException(public_message=f"can't get user service for {self}") user = self.users.get_current_user() + if user.role not in [ServiceRole.DATA_OWNER, ServiceRole.ADMIN]: return SyftError(message="You don't have permission to upload datasets.") + dataset = add_default_uploader(user, dataset) for i in range(len(dataset.asset_list)): asset = dataset.asset_list[i] dataset.asset_list[i] = add_default_uploader(user, asset) - # dataset._check_asset_must_contain_mock() - dataset_size: float = 0.0 - # TODO: Refactor so that object can also be passed to generate warnings - self.api.connection = cast(ServerConnection, self.api.connection) metadata = self.api.connection.get_server_metadata(self.api.signing_key) @@ -134,10 +133,27 @@ def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess: ) prompt_warning_message(message=message, confirm=True) - with tqdm( - total=len(dataset.asset_list), colour="green", desc="Uploading" - ) as pbar: - for asset in dataset.asset_list: + # check if the a dataset with the same name already exists + search_res = self.api.services.dataset.search(dataset.name) + dataset_exists: bool = len(search_res) > 0 + + if not dataset_exists: + return self._upload_new_dataset(dataset) + + existed_dataset: Dataset = search_res[0] + if not force_replace: + return SyftError( + message=f"Dataset with name the '{dataset.name}' already exists. " + "Please use `upload_dataset(dataset, force_replace=True)` to overwrite." + ) + return self._replace_dataset(existed_dataset, dataset) + + def _upload_assets(self, assets: list[CreateAsset]) -> float | SyftError: + total_assets_size: float = 0.0 + + with tqdm(total=len(assets), colour="green", desc="Uploading") as pbar: + for asset in assets: + # create and save a twin object representing the asset to the blob store try: contains_empty: bool = asset.contains_empty() twin = TwinObject( @@ -163,16 +179,55 @@ def upload_dataset(self, dataset: CreateDataset) -> SyftSuccess: asset.action_id = twin.id asset.server_uid = self.id - dataset_size += get_mb_size(asset.data) + + total_assets_size += get_mb_size(asset.data) # Update the progress bar and set the dynamic description pbar.set_description(f"Uploading: {asset.name}") pbar.update(1) - dataset.mb_size = dataset_size + return total_assets_size + + def _upload_new_dataset(self, dataset: CreateDataset) -> SyftSuccess | SyftError: + # upload the assets + total_assets_size: float | SyftError = self._upload_assets(dataset.asset_list) + if isinstance(total_assets_size, SyftError): + return total_assets_size + + # check if the types of the assets are valid + dataset.mb_size = total_assets_size _check_asset_must_contain_mock(dataset.asset_list) - dataset.check() - return self.api.services.dataset.add(dataset=dataset) + valid = dataset.check() + if isinstance(valid, SyftError): + return valid + + # add the dataset object to the dataset store + try: + return self.api.services.dataset.add(dataset=dataset) + except Exception as e: + return SyftError(message=f"Failed to upload dataset. {e}") + + def _replace_dataset( + self, existed_dataset: Dataset, dataset: CreateDataset + ) -> SyftSuccess | SyftError: + # TODO: is there a way to check if the assets already exist and have not changed, + # since if uploading the assets will have different UIDs + total_assets_size: float | SyftError = self._upload_assets(dataset.asset_list) + if isinstance(total_assets_size, SyftError): + return total_assets_size + + # check if the types of the assets are valid + dataset.mb_size = total_assets_size + valid = dataset.check() + _check_asset_must_contain_mock(dataset.asset_list) + if isinstance(valid, SyftError): + return valid + try: + return self.api.services.dataset.replace( + existed_dataset_uid=existed_dataset.id, dataset=dataset + ) + except Exception as e: + return SyftError(message=f"Failed to replace dataset. {e}") def forgot_password(self, email: str) -> SyftSuccess | SyftError: return self.connection.forgot_password(email=email) diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index 0d295b81982..25edccdf52a 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -382,7 +382,7 @@ def launch( display( SyftInfo( message=f"You have launched a development server at http://{host}:{server_handle.port}." - + "It is intended only for local use." + + " It is intended only for local use." ) ) return server_handle diff --git a/packages/syft/src/syft/protocol/data_protocol.py b/packages/syft/src/syft/protocol/data_protocol.py index 1ea9d1ae203..ab56f064551 100644 --- a/packages/syft/src/syft/protocol/data_protocol.py +++ b/packages/syft/src/syft/protocol/data_protocol.py @@ -77,12 +77,27 @@ def handle_union_type_klass_name(type_klass_name: str) -> str: return type_klass_name +def get_klass_or_canonical_name(arg: Any) -> str: + """Get the class name or canonical name of the object. + + If the object is a subclass of SyftBaseObject, then use canonical name + to identify the object. + + """ + + return ( + arg.__canonical_name__ # If SyftBaseObject subclass, ignore class name + if hasattr(arg, "__canonical_name__") + else getattr(arg, "__name__", str(arg)) + ) + + def handle_annotation_repr_(annotation: type) -> str: """Handle typing representation.""" origin = typing.get_origin(annotation) args = typing.get_args(annotation) if origin and args: - args_repr = ", ".join(getattr(arg, "__name__", str(arg)) for arg in args) + args_repr = ", ".join(get_klass_or_canonical_name(arg) for arg in args) origin_repr = getattr(origin, "__name__", str(origin)) # Handle typing.Union and types.UnionType diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 5f9f6a8fab1..93f5df0006c 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -1,5 +1,16 @@ { "1": { "release_name": "0.9.1.json" + }, + "dev": { + "object_versions": { + "CreateAsset": { + "2": { + "version": 2, + "hash": "1637c1e35c8cb65c9d667ad91c824b2cc5cf5b281e93770915a68c4926e6a567", + "action": "add" + } + } + } } } diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index 6c4a91e06c2..663c1ebd3a5 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -98,7 +98,6 @@ class Asset(SyftObject): shape: tuple | None = None created_at: DateTime = DateTime.now() uploader: Contributor | None = None - # _kwarg_name and _dataset_name are set by the UserCode.assets _kwarg_name: str | None = None _dataset_name: str | None = None @@ -314,7 +313,7 @@ def check_mock(data: Any, mock: Any) -> bool: @serializable() -class CreateAsset(SyftObject): +class CreateAssetV1(SyftObject): # version __canonical_name__ = "CreateAsset" __version__ = SYFT_OBJECT_VERSION_1 @@ -336,6 +335,30 @@ class CreateAsset(SyftObject): __repr_attrs__ = ["name"] model_config = ConfigDict(validate_assignment=True, extra="forbid") + +@serializable() +class CreateAsset(SyftObject): + # version + __canonical_name__ = "CreateAsset" + __version__ = SYFT_OBJECT_VERSION_2 + + id: UID | None = None # type:ignore[assignment] + name: str + description: MarkdownDescription | None = None + contributors: set[Contributor] = set() + data_subjects: list[DataSubjectCreate] = [] + server_uid: UID | None = None + action_id: UID | None = None + data: Any | None = None + mock: Any | None = None + shape: tuple | None = None + mock_is_real: bool = False + created_at: DateTime | None = None + uploader: Contributor | None = None + + __repr_attrs__ = ["name"] + model_config = ConfigDict(validate_assignment=True, extra="forbid") + def __init__(self, description: str | None = None, **data: Any) -> None: if isinstance(description, str): description = MarkdownDescription(text=description) @@ -514,9 +537,7 @@ def _repr_html_(self) -> Any: """ else: description_info_message = "" - if self.to_be_deleted: - return "This dataset has been marked for deletion. The underlying data may be not available." - return f""" + repr_html = f"""

{self.name}

Summary

@@ -529,9 +550,19 @@ def _repr_html_(self) -> Any: {self.url}

Contributors: To see full details call dataset.contributors.

-

Assets

- {self.assets._repr_html_()} - """ + """ + if self.to_be_deleted: + repr_html += ( + "

" + "This dataset has been marked for deletion. The underlying data may be not available" + "

" + ) + else: + repr_html += f""" +

Assets

+ {self.assets._repr_html_()} + """ + return repr_html def action_ids(self) -> list[UID]: return [asset.action_id for asset in self.asset_list if asset.action_id] @@ -874,3 +905,10 @@ class DatasetUpdate(PartialSyftObject): name: str to_be_deleted: bool + asset_list: list[Asset] + contributors: set[Contributor] + citation: str + url: str + description: MarkdownDescription + uploader: Contributor + summary: str diff --git a/packages/syft/src/syft/service/dataset/dataset_service.py b/packages/syft/src/syft/service/dataset/dataset_service.py index 43cbfacb117..b010f18c6cd 100644 --- a/packages/syft/src/syft/service/dataset/dataset_service.py +++ b/packages/syft/src/syft/service/dataset/dataset_service.py @@ -115,15 +115,16 @@ def get_all( context: AuthedServiceContext, page_size: int | None = 0, page_index: int | None = 0, + include_deleted: bool = False, ) -> DatasetPageView | DictTuple[str, Dataset]: """Get a Dataset""" - datasets = self.stash.get_all(context.credentials).unwrap() + datasets = self.stash.get_all( + context.credentials, include_deleted=include_deleted + ).unwrap() for dataset in datasets: if context.server is not None: dataset.server_uid = context.server.id - if dataset.to_be_deleted: - datasets.remove(dataset) return _paginate_dataset_collection( datasets=datasets, page_size=page_size, page_index=page_index @@ -141,9 +142,7 @@ def search( results = self.get_all(context) filtered_results = [ - dataset - for dataset_name, dataset in results.items() - if name in dataset_name and not dataset.to_be_deleted + dataset for dataset_name, dataset in results.items() if name in dataset_name ] return _paginate_dataset_collection( @@ -242,6 +241,28 @@ def delete( return_msg.append(f"Dataset with id '{uid}' successfully deleted.") return SyftSuccess(message="\n".join(return_msg)) + @service_method( + path="dataset.replace", + name="replace", + roles=DATA_OWNER_ROLE_LEVEL, + unwrap_on_success=False, + ) + def replace( + self, + context: AuthedServiceContext, + existed_dataset_uid: UID, + dataset: CreateDataset, + ) -> SyftSuccess: + dataset = dataset.to(Dataset, context=context) + dataset.id = existed_dataset_uid + self.stash.update( + credentials=context.credentials, dataset_update=dataset + ).unwrap() + # TODO: should we delete the existed dataset's asssets after force replace? + return SyftSuccess( + message=f"Dataset with id '{existed_dataset_uid}' successfully replaced." + ) + TYPE_TO_SERVICE[Dataset] = DatasetService SERVICE_TO_TYPES[DatasetService].update({Dataset}) diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py index 19fc33c5906..69541a1a752 100644 --- a/packages/syft/src/syft/service/dataset/dataset_stash.py +++ b/packages/syft/src/syft/service/dataset/dataset_stash.py @@ -32,6 +32,23 @@ def get_by_name(self, credentials: SyftVerifyKey, name: str) -> Dataset: qks = QueryKeys(qks=[NamePartitionKey.with_obj(name)]) return self.query_one(credentials=credentials, qks=qks).unwrap() + @as_result(StashException) + def update( + self, + credentials: SyftVerifyKey, + dataset_update: DatasetUpdate | Dataset, + has_permission: bool = False, + ) -> Dataset: + return ( + super() + .update( + credentials=credentials, + obj=dataset_update, + has_permission=has_permission, + ) + .unwrap() + ) + @as_result(StashException) def search_action_ids(self, credentials: SyftVerifyKey, uid: UID) -> list[Dataset]: qks = QueryKeys(qks=[ActionIDsPartitionKey.with_obj(uid)]) @@ -43,22 +60,13 @@ def get_all( credentials: SyftVerifyKey, order_by: PartitionKey | None = None, has_permission: bool = False, + include_deleted: bool = False, ) -> list: result = super().get_all(credentials, order_by, has_permission).unwrap() - filtered_datasets = [dataset for dataset in result if not dataset.to_be_deleted] + if not include_deleted: + filtered_datasets = [ + dataset for dataset in result if not dataset.to_be_deleted + ] + else: + filtered_datasets = result return filtered_datasets - - # FIX: This shouldn't be the update method, it just marks the dataset for deletion - @as_result(StashException) - def update( - self, - credentials: SyftVerifyKey, - obj: DatasetUpdate, - has_permission: bool = False, - ) -> Dataset: - _obj = self.check_type(obj, DatasetUpdate).unwrap() - # FIX: This method needs a revamp - qk = self.partition.store_query_key(obj) - return self.partition.update( - credentials=credentials, qk=qk, obj=_obj, has_permission=has_permission - ).unwrap() diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index cc97802a08b..80f59b8653e 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -875,7 +875,6 @@ def update( current = self.find_one(credentials, id=obj.id).unwrap() obj.apply(to=current) obj = current - obj = self.check_type(obj, self.object_type).unwrap() qk = self.partition.store_query_key(obj) return self.partition.update( diff --git a/packages/syft/src/syft/store/kv_document_store.py b/packages/syft/src/syft/store/kv_document_store.py index 77bfbbb3297..a379e3b8963 100644 --- a/packages/syft/src/syft/store/kv_document_store.py +++ b/packages/syft/src/syft/store/kv_document_store.py @@ -378,6 +378,7 @@ def _remove_keys( for qk in uqks: pk_key, pk_value = qk.key, qk.value ck_col = self.unique_keys[pk_key] + # ck_col.pop(pk_value, None) ck_col.pop(store_key.value, None) self.unique_keys[pk_key] = ck_col @@ -480,7 +481,6 @@ def _update( ) store_query_key = self.settings.store_key.with_obj(_original_obj) - # remove old keys self._remove_keys( store_key=store_query_key, diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index 20dedae88c6..1f88b814db0 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -136,8 +136,8 @@ class SyftBaseObject(pydantic.BaseModel, SyftHashableObject): model_config = ConfigDict(arbitrary_types_allowed=True) # the name which doesn't change even when there are multiple classes - __canonical_name__: str - __version__: int # data is always versioned + __canonical_name__: str = "SyftBaseObject" + __version__: int = SYFT_OBJECT_VERSION_1 # data is always versioned syft_server_location: UID | None = Field(default=None, exclude=True) syft_client_verify_key: SyftVerifyKey | None = Field(default=None, exclude=True) diff --git a/packages/syft/tests/syft/request/request_service_test.py b/packages/syft/tests/syft/request/request_service_test.py new file mode 100644 index 00000000000..0e6f06c7ab0 --- /dev/null +++ b/packages/syft/tests/syft/request/request_service_test.py @@ -0,0 +1,52 @@ +# third party +from faker import Faker + +# syft absolute +import syft +from syft.client.client import SyftClient +from syft.server.worker import Worker +from syft.service.action.action_object import ActionObject +from syft.service.request.request import Request +from syft.service.response import SyftError +from syft.service.response import SyftSuccess + + +def test_set_tags_delete_requests(faker: Faker, worker: Worker, ds_client: SyftClient): + """ " + Scneario: DS client submits a code request. Root client sets some wrong tags, then + delete the request. DS client then submit the request again, root client then set + the correct tags. + """ + root_client: SyftClient = worker.root_client + dummy_data = [1, 2, 3] + data = ActionObject.from_obj(dummy_data) + action_obj = data.send(root_client) + + @syft.syft_function( + input_policy=syft.ExactMatch(data=action_obj), + output_policy=syft.SingleExecutionExactOutput(), + ) + def simple_function(data): + return sum(data) + + result = ds_client.code.request_code_execution(simple_function) + assert not isinstance(result, SyftError) + + request = root_client.requests.get_all()[0] + set_tag_res = root_client.api.services.request.set_tags(request, ["tag1", "tag2"]) + assert isinstance(set_tag_res, Request) + assert set_tag_res.tags == ["tag1", "tag2"] + + del_res = root_client.api.services.request.delete_by_uid(request.id) + assert isinstance(del_res, SyftSuccess) + assert len(root_client.api.services.request.get_all()) == 0 + assert root_client.api.services.request.get_by_uid(request.id) is None + + result = ds_client.code.request_code_execution(simple_function) + assert not isinstance(result, SyftError) + request = root_client.requests.get_all()[0] + set_tag_res = root_client.api.services.request.set_tags( + request, ["computing", "sum"] + ) + assert isinstance(set_tag_res, Request) + assert set_tag_res.tags == ["computing", "sum"] diff --git a/packages/syft/tests/syft/service/dataset/dataset_service_test.py b/packages/syft/tests/syft/service/dataset/dataset_service_test.py index 4d73e35fa6f..3ab78af6e1a 100644 --- a/packages/syft/tests/syft/service/dataset/dataset_service_test.py +++ b/packages/syft/tests/syft/service/dataset/dataset_service_test.py @@ -19,6 +19,7 @@ from syft.service.dataset.dataset import CreateAsset as Asset from syft.service.dataset.dataset import CreateDataset as Dataset from syft.service.dataset.dataset import _ASSET_WITH_NONE_MOCK_ERROR_MESSAGE +from syft.service.response import SyftError from syft.service.response import SyftSuccess from syft.types.errors import SyftException @@ -220,7 +221,6 @@ def test_datasite_client_cannot_upload_dataset_with_non_mock(worker: Worker) -> dataset.asset_list[0].mock = None root_datasite_client = worker.root_client - with pytest.raises(ValueError) as excinfo: root_datasite_client.upload_dataset(dataset) @@ -310,7 +310,7 @@ def test_upload_dataset_with_assets_of_different_data_types( ) -def test_delete_small_datasets(worker: Worker, small_dataset: Dataset) -> None: +def test_upload_delete_small_dataset(worker: Worker, small_dataset: Dataset) -> None: root_client = worker.root_client assert not can_upload_to_blob_storage(small_dataset, root_client.metadata).unwrap() upload_res = root_client.upload_dataset(small_dataset) @@ -320,6 +320,7 @@ def test_delete_small_datasets(worker: Worker, small_dataset: Dataset) -> None: asset = dataset.asset_list[0] assert isinstance(asset.data, np.ndarray) assert isinstance(asset.mock, np.ndarray) + assert len(root_client.api.services.blob_storage.get_all()) == 0 # delete the dataset without deleting its assets del_res = root_client.api.services.dataset.delete( @@ -329,6 +330,7 @@ def test_delete_small_datasets(worker: Worker, small_dataset: Dataset) -> None: assert isinstance(asset.data, np.ndarray) assert isinstance(asset.mock, np.ndarray) assert len(root_client.api.services.dataset.get_all()) == 0 + # we can still get back the deleted dataset by uid deleted_dataset = root_client.api.services.dataset.get_by_id(uid=dataset.id) assert deleted_dataset.name == f"_deleted_{dataset.name}_{dataset.id}" @@ -345,7 +347,7 @@ def test_delete_small_datasets(worker: Worker, small_dataset: Dataset) -> None: assert len(root_client.api.services.dataset.get_all()) == 0 -def test_delete_big_datasets(worker: Worker, big_dataset: Dataset) -> None: +def test_upload_delete_big_dataset(worker: Worker, big_dataset: Dataset) -> None: root_client = worker.root_client assert can_upload_to_blob_storage(big_dataset, root_client.metadata).unwrap() upload_res = root_client.upload_dataset(big_dataset) @@ -384,3 +386,132 @@ def test_delete_big_datasets(worker: Worker, big_dataset: Dataset) -> None: print(asset.mock) assert len(root_client.api.services.blob_storage.get_all()) == 0 assert len(root_client.api.services.dataset.get_all()) == 0 + + +def test_reupload_dataset(worker: Worker, small_dataset: Dataset) -> None: + root_client = worker.root_client + + # upload a dataset + upload_res = root_client.upload_dataset(small_dataset) + assert isinstance(upload_res, SyftSuccess) + dataset = root_client.api.services.dataset.get_all()[0] + + # delete the dataset + del_res = root_client.api.services.dataset.delete(dataset.id) + assert isinstance(del_res, SyftSuccess) + assert len(root_client.api.services.dataset.get_all()) == 0 + assert len(root_client.api.services.dataset.get_all(include_deleted=True)) == 1 + search_res = root_client.api.services.dataset.search(small_dataset.name) + assert len(search_res) == 0 + # reupload a dataset with the same name should be successful + reupload_res = root_client.upload_dataset(small_dataset) + assert isinstance(reupload_res, SyftSuccess) + assert len(root_client.api.services.dataset.get_all()) == 1 + assert len(root_client.api.services.dataset.get_all(include_deleted=True)) == 2 + search_res = root_client.api.services.dataset.search(small_dataset.name) + assert len(search_res) == 1 + assert all(small_dataset.assets[0].data == search_res[0].assets[0].data) + assert all(small_dataset.assets[0].mock == search_res[0].assets[0].mock) + + +def test_upload_dataset_with_force_replace_small_dataset( + worker: Worker, small_dataset: Dataset +) -> None: + root_client = worker.root_client + + # upload a dataset + upload_res = root_client.upload_dataset(small_dataset) + assert isinstance(upload_res, SyftSuccess) + first_uploaded_dataset = root_client.api.services.dataset.get_all()[0] + + # upload again without the `force_replace` flag should fail + reupload_res = root_client.upload_dataset(small_dataset) + assert isinstance(reupload_res, SyftError) + + # change something about the dataset, then upload it again with `force_replace` + dataset = Dataset( + name=small_dataset.name, + asset_list=[ + sy.Asset( + name="small_dataset", + data=np.array([3, 2, 1]), + mock=np.array([2, 2, 2]), + ) + ], + description="This is my numpy data", + url="https://mydataset.com", + summary="contain some super secret data", + ) + force_replace_upload_res = root_client.upload_dataset(dataset, force_replace=True) + assert isinstance(force_replace_upload_res, SyftSuccess) + assert len(root_client.api.services.dataset.get_all()) == 1 + + updated_dataset = root_client.api.services.dataset.get_all()[0] + assert updated_dataset.id == first_uploaded_dataset.id + assert updated_dataset.name == small_dataset.name + assert updated_dataset.description.text == dataset.description.text + assert updated_dataset.summary == dataset.summary + assert updated_dataset.url == dataset.url + assert all(updated_dataset.assets[0].data == dataset.assets[0].data) + assert all(updated_dataset.assets[0].mock == dataset.assets[0].mock) + + +def test_upload_dataset_with_force_replace_big_dataset( + worker: Worker, big_dataset: Dataset +) -> None: + root_client = worker.root_client + assert can_upload_to_blob_storage(big_dataset, root_client.metadata) + + # upload a dataset + upload_res = root_client.upload_dataset(big_dataset) + assert isinstance(upload_res, SyftSuccess) + first_uploaded_dataset = root_client.api.services.dataset.get_all()[0] + + # change about the dataset metadata and also its data and mock, but keep its name, + # then upload it again with `force_replace=True` + updated_mock = big_dataset.assets[0].mock * 2 + updated_data = big_dataset.assets[0].data + 1 + + dataset = Dataset( + name=big_dataset.name, + asset_list=[ + sy.Asset( + name="big_dataset", + data=updated_data, + mock=updated_mock, + ) + ], + description="This is my numpy data", + url="https://mydataset.com", + summary="contain some super secret data", + ) + force_replace_upload_res = root_client.upload_dataset(dataset, force_replace=True) + assert isinstance(force_replace_upload_res, SyftSuccess) + # TODO: Old data were not removed from the blob storage after force replace. What to do? + assert len(root_client.api.services.blob_storage.get_all()) == 4 + assert len(root_client.api.services.dataset.get_all()) == 1 + + updated_dataset = root_client.api.services.dataset.get_all()[0] + assert updated_dataset.id == first_uploaded_dataset.id + assert updated_dataset.name == big_dataset.name + assert updated_dataset.description.text == dataset.description.text + assert updated_dataset.summary == dataset.summary + assert updated_dataset.url == dataset.url + assert all(updated_dataset.assets[0].data == dataset.assets[0].data) + assert all(updated_dataset.assets[0].mock == dataset.assets[0].mock) + + mock_obj: ActionObject = root_client.api.services.action.get( + updated_dataset.assets[0].action_id, TwinMode.MOCK + ) + retrieved_mock = root_client.api.services.blob_storage.read( + mock_obj.syft_blob_storage_entry_id + ).read() + assert np.sum(retrieved_mock - updated_mock) == 0 + + data_obj: ActionObject = root_client.api.services.action.get( + updated_dataset.assets[0].action_id, TwinMode.PRIVATE + ) + retrieved_data = root_client.api.services.blob_storage.read( + data_obj.syft_blob_storage_entry_id + ).read() + assert np.sum(retrieved_data - updated_data) == 0