diff --git a/src/bentoml/_internal/cloud/model.py b/src/bentoml/_internal/cloud/model.py index f2f5a145e99..6a02f1f30e5 100644 --- a/src/bentoml/_internal/cloud/model.py +++ b/src/bentoml/_internal/cloud/model.py @@ -6,6 +6,7 @@ import warnings from concurrent.futures import ThreadPoolExecutor from tempfile import NamedTemporaryFile +from threading import Lock import attrs import fs @@ -47,6 +48,7 @@ class ModelAPI: _client: RestApiClient = attrs.field(repr=False) spinner: Spinner = attrs.field(repr=False, factory=Spinner) + _lock: Lock = attrs.field(repr=False, init=False, factory=Lock) def push( self, @@ -88,17 +90,22 @@ def _do_push_model( if version is None: raise BentoMLException(f'Model "{model}" version cannot be None') - with self.spinner.spin(text=f'Fetching model repository "{name}"'): - model_repository = rest_client.v1.get_model_repository( - model_repository_name=name - ) - if not model_repository: - with self.spinner.spin( - text=f'Model repository "{name}" not found, creating now..' - ): - model_repository = rest_client.v1.create_model_repository( - req=CreateModelRepositorySchema(name=name, description="") + with self._lock: + # Models might be pushed by multiple threads at the same time + # when they are under the same model repository, race condition + # might happen when creating the model repository. So we need to + # protect it with a lock. + with self.spinner.spin(text=f'Fetching model repository "{name}"'): + model_repository = rest_client.v1.get_model_repository( + model_repository_name=name ) + if not model_repository: + with self.spinner.spin( + text=f'Model repository "{name}" not found, creating now..' + ): + model_repository = rest_client.v1.create_model_repository( + req=CreateModelRepositorySchema(name=name, description="") + ) with self.spinner.spin( text=f'Try fetching model "{model}" from remote model store..' ):