Skip to content

Commit

Permalink
fix: protect model repository creation with a lock to avoid a race co…
Browse files Browse the repository at this point in the history
…ndition (#5095)

Signed-off-by: Frost Ming <[email protected]>
  • Loading branch information
frostming authored Nov 22, 2024
1 parent dafbb98 commit 6455e19
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions src/bentoml/_internal/cloud/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings
from concurrent.futures import ThreadPoolExecutor
from tempfile import NamedTemporaryFile
from threading import Lock

import attrs
import fs
Expand Down Expand Up @@ -47,6 +48,7 @@
class ModelAPI:
_client: RestApiClient = attrs.field(repr=False)
spinner: Spinner = attrs.field(repr=False, factory=Spinner)
_lock: Lock = attrs.field(repr=False, init=False, factory=Lock)

def push(
self,
Expand Down Expand Up @@ -88,17 +90,22 @@ def _do_push_model(
if version is None:
raise BentoMLException(f'Model "{model}" version cannot be None')

with self.spinner.spin(text=f'Fetching model repository "{name}"'):
model_repository = rest_client.v1.get_model_repository(
model_repository_name=name
)
if not model_repository:
with self.spinner.spin(
text=f'Model repository "{name}" not found, creating now..'
):
model_repository = rest_client.v1.create_model_repository(
req=CreateModelRepositorySchema(name=name, description="")
with self._lock:
# Models might be pushed by multiple threads at the same time
# when they are under the same model repository, race condition
# might happen when creating the model repository. So we need to
# protect it with a lock.
with self.spinner.spin(text=f'Fetching model repository "{name}"'):
model_repository = rest_client.v1.get_model_repository(
model_repository_name=name
)
if not model_repository:
with self.spinner.spin(
text=f'Model repository "{name}" not found, creating now..'
):
model_repository = rest_client.v1.create_model_repository(
req=CreateModelRepositorySchema(name=name, description="")
)
with self.spinner.spin(
text=f'Try fetching model "{model}" from remote model store..'
):
Expand Down

0 comments on commit 6455e19

Please sign in to comment.