Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add nebullvm as backend #697

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions .github/workflows/cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ jobs:
pip install --no-cache-dir "server/[onnx]"
pip install --no-cache-dir "server/[transformers]"
pip install --no-cache-dir "server/[search]"
pip install --no-cache-dir "server/[nebullvm]"
- name: Test
id: test
run: |
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ jobs:
pip install --no-cache-dir "server/[onnx]"
pip install --no-cache-dir "server/[transformers]"
pip install --no-cache-dir "server/[search]"
pip install --no-cache-dir "server/[nebullvm]"
- name: Test
id: test
run: |
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/ambv/black
rev: 22.3.0
rev: 22.10.0
hooks:
- id: black
types: [python]
Expand Down
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@

- [[```8600286c```](https://github.com/jina-ai/clip-as-service/commit/8600286cf53755c127cf258af918b6bdf3e86691)] __-__ update readme (*Han Xiao*)
- [[```5e1dd607```](https://github.com/jina-ai/clip-as-service/commit/5e1dd607e47a94265f48cbb2a70406c5057b86fa)] __-__ __version__: the next version will be 0.2.4 (*Jina Dev Bot*)

<a name=release-note-0-3-1></a>
## Release Note (`0.3.1`)

Expand Down Expand Up @@ -1449,4 +1449,4 @@
- [[```d520ebb8```](https://github.com/jina-ai/clip-as-service/commit/d520ebb835e2814f7696148a0dcabbbf8bdadc76)] __-__ remove unused md (*numb3r3*)
- [[```2c3c61f9```](https://github.com/jina-ai/clip-as-service/commit/2c3c61f9d6f5a351f235dbad45879f0c7c4fd986)] __-__ __version__: the next version will be 0.7.1 (*Jina Dev Bot*)
- [[```53636cea```](https://github.com/jina-ai/clip-as-service/commit/53636cea63bf8063bcfd744aae4577df8e0eab2e)] __-__ bump version to 0.7.0 (*numb3r3*)

11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

CLIP-as-service is a low-latency high-scalability service for embedding images and text. It can be easily integrated as a microservice into neural search solutions.

⚡ **Fast**: Serve CLIP models with TensorRT, ONNX runtime and PyTorch w/o JIT with 800QPS<sup>[*]</sup>. Non-blocking duplex streaming on requests and responses, designed for large data and long-running tasks.
⚡ **Fast**: Serve CLIP models with Nebullvm, TensorRT, ONNX runtime and PyTorch w/o JIT with 800QPS<sup>[*]</sup>. Non-blocking duplex streaming on requests and responses, designed for large data and long-running tasks.

🫐 **Elastic**: Horizontally scale up and down multiple CLIP models on single GPU, with automatic load balancing.

Expand Down Expand Up @@ -225,13 +225,15 @@ gives:

CLIP-as-service consists of two Python packages `clip-server` and `clip-client` that can be installed _independently_. Both require Python 3.7+.


### Install server

<table>
<tr>
<td> Pytorch Runtime ⚡ </td>
<td> ONNX Runtime ⚡⚡</td>
<td> TensorRT Runtime ⚡⚡⚡ </td>
<td> Nebullvm ⚡⚡⚡⚡</td>
</tr>
<tr>
<td>
Expand All @@ -254,6 +256,13 @@ pip install "clip-server[onnx]"
pip install nvidia-pyindex
pip install "clip-server[tensorrt]"
```
</td>
<td>

```bash
pip install "clip-server[nebullvm]"
```

</td>
</tr>
</table>
Expand Down
35 changes: 34 additions & 1 deletion docs/user-guides/server.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ pip install "clip_server[onnx]"
python -m clip_server onnx-flow.yml
```

### Start a Nebullvm-backed server

ecently we added the support for another AI-accelerator backend: nebullvm.
It can be used in a similar way to `onnxruntime`, running:
```bash
pip install "clip_server[nebullvm]"

python -m clip_server nebullvm-flow.yml
```

### Start a TensorRT-backed server

Expand Down Expand Up @@ -137,7 +146,7 @@ cat my.yml | python -m clip_server -i
This can be very useful when using `clip_server` in a Docker container.

And to answer your doubt, `clip_server` has three built-in YAML configs as a part of the package resources. When you do `python -m clip_server` it loads the Pytorch config, and when you do `python -m clip_server onnx-flow.yml` it loads the ONNX config.
In the same way, when you do `python -m clip_server tensorrt-flow.yml` it loads the TensorRT config.
In the same way, when you do `python -m clip_server tensorrt-flow.yml` and `python -m clip_server nebullvm-flow.yml` it loads the TensorRT and Nebullvm config respectively.

Let's look at these three built-in YAML configs:

Expand Down Expand Up @@ -175,6 +184,22 @@ executors:
```
````

````{tab} nebullvm-flow.yml

```yaml
jtype: Flow
version: '1'
with:
port: 51000
executors:
- name: clip_n
uses:
jtype: CLIPEncoder
metas:
py_modules:
- executors/clip_nebullvm.py
```
````

````{tab} tensorrt-flow.yml

Expand Down Expand Up @@ -297,6 +322,14 @@ There are also runtime-specific parameters listed below:

````

For nebullvm backend, you just need to set name and mini_batch size

| Parameter | Description |
|-----------|--------------------------------------------------------------------------------------------------------------------------------|
| `name` | Model weights, default is `ViT-B/32`. Support all OpenAI released pretrained models. | |
| `minibatch_size` | The size of a minibatch for CPU preprocessing and GPU encoding, default 64. Reduce the size of it if you encounter OOM on GPU. |


For example, to turn on JIT and force PyTorch running on CPU, one can do:

```{code-block} yaml
Expand Down
91 changes: 91 additions & 0 deletions server/clip_server/executors/clip_nebullvm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os
import warnings
from functools import partial
from multiprocessing.pool import ThreadPool
from typing import Optional, Dict

import torch
from clip_server.executors.helper import (
split_img_txt_da,
preproc_image,
preproc_text,
set_rank,
)
from clip_server.model import clip
from clip_server.model.clip_nebullvm import CLIPNebullvmModel, EnvRunner
from jina import Executor, requests, DocumentArray


class CLIPEncoder(Executor):
def __init__(
self,
name: str = 'ViT-B/32',
device: Optional[str] = None,
num_worker_preprocess: int = 4,
minibatch_size: int = 64,
numb3r3 marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
):
super().__init__(**kwargs)

self._preprocess_tensor = clip._transform_ndarray(clip.MODEL_SIZE[name])
self._pool = ThreadPool(processes=num_worker_preprocess)

self._minibatch_size = minibatch_size
if not device:
self._device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self._device = device
if not self._device.startswith('cuda') and (
'NEBULLVM_THREADS_PER_MODEL' not in os.environ
and hasattr(self.runtime_args, 'replicas')
):
replicas = getattr(self.runtime_args, 'replicas', 1)
num_threads = max(1, torch.get_num_threads() // replicas)
if num_threads < 2:
warnings.warn(
f'Too many replicas ({replicas}) vs too few threads {num_threads} may result in '
f'sub-optimal performance.'
)
else:
num_threads = None
self._model = CLIPNebullvmModel(name, clip.MODEL_SIZE[name])
with EnvRunner(self._device, num_threads):
self._model.optimize_models(batch_size=minibatch_size)

@requests(on='/rank')
async def rank(self, docs: 'DocumentArray', parameters: Dict, **kwargs):
await self.encode(docs['@r,m'])
set_rank(docs)

@requests
async def encode(self, docs: 'DocumentArray', **kwargs):
_img_da = DocumentArray()
_txt_da = DocumentArray()
for d in docs:
split_img_txt_da(d, _img_da, _txt_da)

# for image
if _img_da:
for minibatch in _img_da.map_batch(
partial(
preproc_image, preprocess_fn=self._preprocess_tensor, return_np=True
),
batch_size=self._minibatch_size,
pool=self._pool,
):
minibatch.embeddings = self._model.encode_image(minibatch.tensors)

# for text
if _txt_da:
for minibatch, _texts in _txt_da.map_batch(
partial(preproc_text, return_np=True),
batch_size=self._minibatch_size,
pool=self._pool,
):
minibatch.embeddings = self._model.encode_text(minibatch.tensors)
minibatch.texts = _texts

# drop tensors
docs.tensors = None

return docs
154 changes: 154 additions & 0 deletions server/clip_server/model/clip_nebullvm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import os

import numpy as np
import torch.cuda

from clip_server.model.pretrained_models import (
download_model,
_OPENCLIP_MODELS,
_MULTILINGUALCLIP_MODELS,
)
from clip_server.model.clip_model import BaseCLIPModel
from clip_server.model.clip_onnx import _MODELS, _S3_BUCKET, _S3_BUCKET_V2


class CLIPNebullvmModel(BaseCLIPModel):
def __init__(self, name: str, model_path: str = None):
super().__init__(name)
if name in _MODELS:
if not model_path:
cache_dir = os.path.expanduser(
f'~/.cache/clip/{name.replace("/", "-").replace("::", "-")}'
)
textual_model_name, textual_model_md5 = _MODELS[name][0]
self._textual_path = download_model(
url=_S3_BUCKET_V2 + textual_model_name,
target_folder=cache_dir,
md5sum=textual_model_md5,
with_resume=True,
)
visual_model_name, visual_model_md5 = _MODELS[name][1]
self._visual_path = download_model(
url=_S3_BUCKET_V2 + visual_model_name,
target_folder=cache_dir,
md5sum=visual_model_md5,
with_resume=True,
)
else:
if os.path.isdir(model_path):
self._textual_path = os.path.join(model_path,
'textual.onnx')
self._visual_path = os.path.join(model_path, 'visual.onnx')
if not os.path.isfile(
self._textual_path) or not os.path.isfile(
self._visual_path
):
raise RuntimeError(
f'The given model path {model_path} does not contain `textual.onnx` and `visual.onnx`'
)
else:
raise RuntimeError(
f'The given model path {model_path} should be a folder containing both '
f'`textual.onnx` and `visual.onnx`.'
)
else:
raise RuntimeError(
'CLIP model {} not found or not supports ONNX backend; below is a list of all available models:\n{}'.format(
name,
''.join(
['\t- {}\n'.format(i) for i in list(_MODELS.keys())]),
)
)

def optimize_models(
self,
**kwargs,
):
from nebullvm.api.functions import optimize_model

general_kwargs = {}
general_kwargs.update(kwargs)

dynamic_info = {
"inputs": [
{0: 'batch', 1: 'num_channels', 2: 'pixel_size', 3: 'pixel_size'}
],
"outputs": [{0: 'batch'}],
}

self._visual_model = optimize_model(
self._visual_path,
input_data=[
(
(
np.random.randn(1, 3, self.pixel_size, self.pixel_size).astype(
np.float32
),
),
0,
)
],
dynamic_info=dynamic_info,
**general_kwargs,
)

dynamic_info = {
"inputs": [
{0: 'batch', 1: 'num_tokens'},
],
"outputs": [
{0: 'batch'},
],
}

self._textual_model = optimize_model(
self._textual_path,
input_data=[((np.random.randint(0, 100, (1, 77)),), 0)],
dynamic_info=dynamic_info,
**general_kwargs,
)

@staticmethod
def get_model_name(name: str):
if name in _OPENCLIP_MODELS:
from clip_server.model.openclip_model import OpenCLIPModel

return OpenCLIPModel.get_model_name(name)
elif name in _MULTILINGUALCLIP_MODELS:
from clip_server.model.mclip_model import MultilingualCLIPModel

return MultilingualCLIPModel.get_model_name(name)

return name

def encode_image(self, onnx_image):
(visual_output,) = self._visual_model(onnx_image)
return visual_output

def encode_text(self, onnx_text):
(textual_output,) = self._textual_model(onnx_text)
return textual_output


class EnvRunner:
def __init__(self, device: str, num_threads: int = None):
self.device = device
self.cuda_str = None
self.rm_cuda_flag = False
self.num_threads = num_threads

def __enter__(self):
if self.device == "cpu" and torch.cuda.is_available():
self.cuda_str = os.environ.get("CUDA_VISIBLE_DEVICES")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
self.rm_cuda_flag = self.cuda_str is None
if self.num_threads is not None:
os.environ["NEBULLVM_THREADS_PER_MODEL"] = f"{self.num_threads}"

def __exit__(self, exc_type, exc_val, exc_tb):
if self.cuda_str is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = self.cuda_str
elif self.rm_cuda_flag:
os.environ.pop("CUDA_VISIBLE_DEVICES")
if self.num_threads is not None:
os.environ.pop("NEBULLVM_THREADS_PER_MODEL")
Loading