Skip to content

Commit

Permalink
Merge pull request #71 from databricks-industry-solutions/feature/mod…
Browse files Browse the repository at this point in the history
…el_serving_endpoint

Custom model creation and hosting to Serving Endpoint + optimizations
  • Loading branch information
erinaldidb authored Nov 14, 2024
2 parents 291e0e1 + 079b8f0 commit e614f69
Show file tree
Hide file tree
Showing 67 changed files with 5,878 additions and 141 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dist/
downloads/
eggs/
.eggs/
lib/
./lib/
lib64/
parts/
sdist/
Expand Down Expand Up @@ -135,3 +135,4 @@ dmypy.json

.databricks
dbx/pixels/resources/ohif/app-config-custom.js
dbx/pixels/resources/lakehouse_app/app.yaml
33 changes: 28 additions & 5 deletions 05-MONAILabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@
# COMMAND ----------

# DBTITLE 1,Install MONAILabel_Pixels and databricks-sdk
# MAGIC %pip install git+https://github.com/erinaldidb/MONAILabel_Pixels.git databricks-sdk --upgrade

# COMMAND ----------

dbutils.library.restartPython()
# MAGIC %pip install git+https://github.com/erinaldidb/MONAILabel_Pixels databricks-sdk --upgrade -q
# MAGIC dbutils.library.restartPython()

# COMMAND ----------

Expand All @@ -37,12 +34,38 @@

# COMMAND ----------

# MAGIC %md
# MAGIC ###Download the radiology app in the cluster
# MAGIC
# MAGIC The following command will download the radiology app from the MONAILabel github and saves it in the cluster

# COMMAND ----------

# DBTITLE 1,Downloading Radiology Apps with MonaiLabel
# MAGIC %sh
# MAGIC monailabel apps --download --name radiology --output /local_disk0/monai/apps/

# COMMAND ----------

# MAGIC %md
# MAGIC ### Starting the MONAILabel server
# MAGIC
# MAGIC The next command will start the monailabel server with the radiology app downloaded before. It will use the pre-trained autosegmentation model

# COMMAND ----------

# DBTITLE 1,Monailabel Radiology Segmentation
# MAGIC %sh
# MAGIC monailabel start_server --app /local_disk0/monai/apps/radiology --studies $DATABRICKS_HOST --conf models segmentation --table $DATABRICKS_PIXELS_TABLE

# COMMAND ----------

# MAGIC %md
# MAGIC ### Create a segmentation model with user defined labels and start the MONAILabel server
# MAGIC
# MAGIC The next command will copy the segmentation model to the radiology app directory and start the MONAILabel server with the specified configuration, including custom labels and without using the pre-trained model.

# COMMAND ----------

# MAGIC %sh
# MAGIC monailabel start_server --app /local_disk0/monai/apps/radiology --studies $DATABRICKS_HOST --conf models segmentation --conf labels '{"lung_left":1,"lung_right":2}' --conf use_pretrained_model false --table $DATABRICKS_PIXELS_TABLE
7 changes: 5 additions & 2 deletions 06-OHIF-Viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,11 @@ async def _reverse_proxy_statements(request: Request):
url = httpx.URL(path=request.url.path.replace("/sqlwarehouse/",""))

#Replace SQL Warehouse parameter
body = await request.json()
body['warehouse_id'] = os.environ['DATABRICKS_WAREHOUSE_ID']
if request.method == "POST":
body = await request.json()
body['warehouse_id'] = os.environ['DATABRICKS_WAREHOUSE_ID']
else:
body = {}

rp_req = client.build_request(request.method, url,
headers={
Expand Down
31 changes: 21 additions & 10 deletions 07-OHIF-Lakehouse-App.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
init_env()

app_name = "pixels-ohif-viewer"
serving_endpoint = "pixels-monai-uc"

w = WorkspaceClient()

# COMMAND ----------

Expand All @@ -52,38 +55,42 @@

# COMMAND ----------

from databricks.sdk.service.apps import AppResource, AppResourceSqlWarehouse, AppResourceSqlWarehouseSqlWarehousePermission
from databricks.sdk.service.apps import AppResource, AppResourceSqlWarehouse, AppResourceSqlWarehouseSqlWarehousePermission, AppResourceServingEndpoint, AppResourceServingEndpointServingEndpointPermission

from pathlib import Path
import dbx.pixels.resources

w = WorkspaceClient()

path = Path(dbx.pixels.__file__).parent
lha_path = (f"{path}/resources/lakehouse_app")

with open(f"{lha_path}/app-config.yaml", "r") as config_input:
with open(f"{lha_path}/app.yaml", "w") as config_custom:
config_custom.write(
config_input.read()
.replace("{PIXELS_TABLE}",os.environ["DATABRICKS_PIXELS_TABLE"])
.replace("{PIXELS_TABLE}", table)
)

sql_resource = AppResource(
name="sql_warehouse",
sql_warehouse=AppResourceSqlWarehouse(
id=os.environ["DATABRICKS_WAREHOUSE_ID"],
id=sql_warehouse_id,
permission=AppResourceSqlWarehouseSqlWarehousePermission.CAN_USE
)
)

serving_endpoint = AppResource(
name="serving_endpoint",
serving_endpoint=AppResourceServingEndpoint(
name=serving_endpoint,
permission=AppResourceServingEndpointServingEndpointPermission.CAN_QUERY
)
)

print(f"Creating Lakehouse App with name {app_name}, this step will require few minutes to complete")

app_created = w.apps.create_and_wait(name=app_name, resources=[sql_resource])
app_created = w.apps.create_and_wait(name=app_name, resources=[sql_resource, serving_endpoint])
app_deploy = w.apps.deploy_and_wait(app_name=app_name, source_code_path=lha_path)

service_principal_id = app_deploy.deployment_artifacts.source_code_path.split("/")[3]

print(app_deploy.status.message)
print(app_created.url)

Expand All @@ -100,6 +107,10 @@

from databricks.sdk.service import catalog

app_instance = w.apps.get(app_name)
last_deployment = w.apps.get_deployment(app_name, app_instance.active_deployment.deployment_id)
service_principal_id = last_deployment.deployment_artifacts.source_code_path.split("/")[3]

#Grant USE CATALOG permissions on CATALOG
w.grants.update(full_name=table.split(".")[0],
securable_type=catalog.SecurableType.CATALOG,
Expand Down Expand Up @@ -127,7 +138,7 @@
securable_type=catalog.SecurableType.TABLE,
changes=[
catalog.PermissionsChange(
add=[catalog.Privilege.SELECT],
add=[catalog.Privilege.ALL_PRIVILEGES],
principal=service_principal_id
)
]
Expand All @@ -138,7 +149,7 @@
securable_type=catalog.SecurableType.VOLUME,
changes=[
catalog.PermissionsChange(
add=[catalog.Privilege.READ_VOLUME],
add=[catalog.Privilege.ALL_PRIVILEGES],
principal=service_principal_id
)
]
Expand Down
10 changes: 10 additions & 0 deletions config/proxy_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ def init_widgets(show_volume=False):

# COMMAND ----------

def init_model_serving_widgets():
dbutils.widgets.text("model_uc_name", "main.pixels_solacc.monai_pixels_model", label="3.0 Model name stored in UC")
model_uc_name = dbutils.widgets.get("model_uc_name")
dbutils.widgets.text("serving_endpoint_name", "pixels-monai-uc", label="4.0 Serving Endpoint name")
serving_endpoint_name = dbutils.widgets.get("serving_endpoint_name")

return model_uc_name, serving_endpoint_name

# COMMAND ----------

def init_env():
sql_warehouse_id = dbutils.widgets.get("sqlWarehouseID")
table = dbutils.widgets.get("table")
Expand Down
30 changes: 21 additions & 9 deletions dbx/pixels/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pyspark.sql.streaming.query import StreamingQuery

from dbx.pixels.logging import LoggerProvider
from dbx.pixels.utils import unzip_pandas_udf, identify_type_udf
from dbx.pixels.utils import identify_type_udf, unzip_pandas_udf

# dfZipWithIndex helper function

Expand Down Expand Up @@ -76,6 +76,23 @@ def __init__(
"spark.databricks.delta.optimizeWrite.enabled": False,
}

def _init_tables(self):
import os
import os.path
from pathlib import Path

import dbx.pixels

path = Path(dbx.pixels.__file__).parent
sql_base_path = f"{path}/resources/sql"

files = os.listdir(sql_base_path)
for file_name in files:
file_path = os.path.join(sql_base_path, file_name)
with open(file_path, "r") as file:
sql_command = file.read().replace("{UC_TABLE}", self._table)
self._spark.sql(sql_command)

def __repr__(self):
return f'Catalog(spark, table="{self._table}")'

Expand Down Expand Up @@ -138,13 +155,15 @@ def catalog(
Returns:
DataFrame: A DataFrame of the cataloged data, with metadata and optionally extracted contents from zip files.
"""

assert self._spark is not None
assert self._spark.version is not None

self._anon = self._is_anon(path)
self._spark

self._init_tables()

# Used only for streaming
self._queryName = f"pixels_{path}_{self._table}"
self._isStreaming = streaming
Expand Down Expand Up @@ -180,13 +199,6 @@ def catalog(
if extractZip:
logger.info("Started unzip process")

self._spark.sql(
f"""
CREATE TABLE IF NOT EXISTS {self._table}_unzip
TBLPROPERTIES ('delta.targetFileSize' = '1mb')
"""
)

unzip_stream = (
df.withColumn(
"path", f.explode(unzip_pandas_udf("path", f.lit(extractZipBasePath)))
Expand Down
9 changes: 7 additions & 2 deletions dbx/pixels/dicom/dicom_meta_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@ class DicomMetaExtractor(Transformer):
"""

# Day extractor inherit of property of Transformer
def __init__(self, catalog, inputCol="local_path", outputCol="meta", basePath="dbfs:/"):
def __init__(
self, catalog, inputCol="local_path", outputCol="meta", basePath="dbfs:/", deep=True
):
self.inputCol = inputCol # the name of your columns
self.outputCol = outputCol # the name of your output column
self.basePath = basePath
self.catalog = catalog
self.deep = (
deep # If deep = True analyze also pixels_array data, may impact performance if enabled
)

def check_input_type(self, schema):
field = schema[self.inputCol]
Expand Down Expand Up @@ -43,5 +48,5 @@ def _transform(self, df):
"""
self.check_input_type(df.schema)
return df.withColumn("is_anon", lit(self.catalog.is_anon())).withColumn(
self.outputCol, dicom_meta_udf(col(self.inputCol), lit("True"), col("is_anon"))
self.outputCol, dicom_meta_udf(col(self.inputCol), lit(self.deep), col("is_anon"))
)
7 changes: 5 additions & 2 deletions dbx/pixels/dicom/dicom_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import json
import os

from pyspark.sql.functions import udf
from pydicom import Dataset
from pyspark.sql.functions import udf


def cloud_open(path: str, anon: bool = False):
try:
Expand All @@ -22,6 +23,7 @@ def cloud_open(path: str, anon: bool = False):
except Exception as e:
raise Exception(f"path: {path} is_anon: {anon} exception: {e} exception.args: {e.args}")


def check_pixel_data(ds: Dataset) -> Dataset | None:
"""Check if pixel data exists before attempting to access it.
pydicom.Dataset.pixel_array will throw an exception if the
Expand All @@ -35,6 +37,7 @@ def check_pixel_data(ds: Dataset) -> Dataset | None:
return None
return a


@udf
def dicom_meta_udf(path: str, deep: bool = True, anon: bool = False) -> dict:
"""Extract metadata from header of dicom image file
Expand Down Expand Up @@ -79,4 +82,4 @@ def dicom_meta_udf(path: str, deep: bool = True, anon: bool = False) -> dict:
except_str = str(
{"udf": "dicom_meta_udf", "error": str(err), "args": str(err.args), "path": path}
)
return except_str
return except_str
3 changes: 2 additions & 1 deletion dbx/pixels/logging.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
import sys


class LoggerProvider:
"""
This class provides a logger instance for logging messages.
"""

def __new__(self):
"""
This method is a constructor that creates a new instance of the LoggerProvider class.
Expand Down
Empty file.
39 changes: 39 additions & 0 deletions dbx/pixels/modelserving/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Iterator

import pandas as pd
from pyspark.ml.pipeline import Transformer
from pyspark.sql.functions import col, pandas_udf

from dbx.pixels.modelserving.serving_endpoint_client import MONAILabelClient


class MONAILabelTransformer(Transformer):
"""
Transformer class to generate autosegmentations of DICOM files using MONAILabel serving endpoint.
"""

def __init__(self, endpoint_name="pixels-monai-uc", inputCol="meta"):
self.inputCol = inputCol
self.endpoint_name = endpoint_name

def _transform(self, df):
@pandas_udf("result string, error string")
def autosegm_monai_udf(iterator: Iterator[pd.Series]) -> Iterator[pd.DataFrame]:
client = MONAILabelClient(self.endpoint_name)

for s in iterator:
results, errors = [], []
for series_uid in s:
result, error = client.predict(series_uid)
results.append(result)
errors.append(error)

yield pd.DataFrame({"result": results, "error": errors})

return (
df.selectExpr(f"{self.inputCol}:['0020000E'].Value[0] as series_uid")
.filter("contains(meta:['00080008'], 'AXIAL')")
.distinct()
.withColumn("segmentation_result", autosegm_monai_udf(col("series_uid")))
.selectExpr("series_uid", "segmentation_result.*")
)
30 changes: 30 additions & 0 deletions dbx/pixels/modelserving/serving_endpoint_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import json
import os

from mlflow.deployments import get_deploy_client


class MONAILabelClient:
def __init__(self, endpoint_name, max_retries=3, request_timeout_sec=300):
os.environ["MLFLOW_HTTP_REQUEST_MAX_RETRIES"] = str(max_retries)
os.environ["MLFLOW_HTTP_REQUEST_TIMEOUT"] = str(request_timeout_sec)
os.environ["MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT"] = str(request_timeout_sec)

self.client = get_deploy_client("databricks")
self.endpoint = endpoint_name
self.max_retries = max_retries

def predict(self, series_uid, iteration=0, prev_error=None):
if iteration > self.max_retries:
return ("", str(prev_error))

try:
response = self.client.predict(
endpoint=self.endpoint,
inputs={"inputs": {"series_uid": series_uid}},
)
return (json.loads(response.predictions)["file_path"], "")
except Exception as e:
if "torch.OutOfMemoryError: CUDA out of memory" in str(e):
return ("", str(e))
return self.predict(series_uid, iteration + 1, prev_error=str(e))
Loading

0 comments on commit e614f69

Please sign in to comment.