Skip to content

Commit

Permalink
feat: supported proxy mode to the same instance of DIAL Core (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Aug 8, 2024
1 parent f49a25c commit 54e9980
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 50 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Copy `.env.example` to `.env` and customize it for your environment:
|---|---|---|
|LOG_LEVEL|INFO|Log level. Use DEBUG for dev purposes and INFO in prod|
|WEB_CONCURRENCY|1|Number of workers for the server|
|DIAL_URL||URL of the local DIAL Core server used for development|
|DIAL_URL||URL of the **local** DIAL Core server used for development|

### Docker

Expand Down
63 changes: 41 additions & 22 deletions aidial_adapter_dial/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ class Config:
arbitrary_types_allowed = True

@classmethod
async def parse(
cls, request: Request, deployment_id: str, endpoint_name: str
) -> "AzureClient":
async def parse(cls, request: Request, endpoint_name: str) -> "AzureClient":

body = await request.json()
headers = request.headers.mutablecopy()
Expand All @@ -80,34 +78,55 @@ async def parse(
message="The 'api-key' request header is missing",
)

remote_dial_api_key = headers.get(UPSTREAM_KEY_HEADER, None)
if not remote_dial_api_key:
raise HTTPException(
status_code=400,
message=f"The {UPSTREAM_KEY_HEADER!r} request header is missing",
)

upstream_endpoint = headers.get(UPSTREAM_ENDPOINT_HEADER, None)
if not upstream_endpoint:
raise HTTPException(
status_code=400,
message=f"The {UPSTREAM_ENDPOINT_HEADER!r} request header is missing",
)

# NOTE: it's not really necessary for the endpoint to point to the same deployment id.
# Here we just follow the convention used in OpenAI adapter.
endpoint_suffix = f"/{deployment_id}/{endpoint_name}"
remote_dial_url = get_hostname(upstream_endpoint)
remote_dial_api_key = headers.get(UPSTREAM_KEY_HEADER, None)

if not remote_dial_api_key:
if remote_dial_url != LOCAL_DIAL_URL:
raise HTTPException(
status_code=400,
message=(
f"Given that {UPSTREAM_KEY_HEADER!r} header is missing, "
f"it's expected that hostname of upstream endpoint ({upstream_endpoint!r}) is "
f"the same as the local DIAL URL ({LOCAL_DIAL_URL!r}) "
),
)

local_dial_api_key = request.headers.get("api-key")
if not local_dial_api_key:
raise HTTPException(
status_code=400,
message="The 'api-key' request header is missing",
)

remote_dial_api_key = local_dial_api_key

endpoint_suffix = f"/{endpoint_name}"
if not upstream_endpoint.endswith(endpoint_suffix):
raise HTTPException(
status_code=400,
message=f"The {UPSTREAM_ENDPOINT_HEADER!r} request header must end with {endpoint_suffix!r}",
)
upstream_endpoint = upstream_endpoint.removesuffix(f"/{endpoint_name}")
upstream_endpoint = upstream_endpoint.removesuffix(endpoint_suffix)

client = AsyncAzureOpenAI(
base_url=upstream_endpoint,
api_key=remote_dial_api_key,
api_version=query_params.get("api-version"),
# NOTE: defaulting missing api-version to an empty string, because
# 1. openai library doesn't allow for a missing api-version
# and a workaround for it would be a recreation of AsyncAzureOpenAI with a check disabled:
# https://gitlab.deltixhub.com/Deltix/openai-apps/dial-interceptor-example/-/blob/62760a4c7a7be740b1c2bc60f14a0a568f31a0bc/aidial_interceptor_example/utils/azure.py#L1-5
# 2. OpenAI adapter treats a missing api-version in the same way as an empty string and that's the only
# place where api-version has any meaning, so the query param modification is safe.
# https://github.com/epam/ai-dial-adapter-openai/blob/b462d1c26ce8f9d569b9c085a849206aad91becf/aidial_adapter_openai/app.py#L93
api_version=query_params.get("api-version") or "",
http_client=get_http_client(),
)

Expand All @@ -117,7 +136,7 @@ async def parse(
api_key=local_dial_api_key,
),
remote_storage=FileStorage(
dial_url=get_hostname(upstream_endpoint),
dial_url=remote_dial_url,
api_key=remote_dial_api_key,
),
)
Expand All @@ -128,11 +147,12 @@ async def parse(
)


@app.post("/embeddings")
@app.post("/openai/deployments/{deployment_id:path}/embeddings")
@dial_exception_decorator
async def embeddings_proxy(request: Request, deployment_id: str):
async def embeddings_proxy(request: Request):
body = await request.json()
az_client = await AzureClient.parse(request, deployment_id, "embeddings")
az_client = await AzureClient.parse(request, "embeddings")

response: CreateEmbeddingResponse = await call_with_extra_body(
az_client.client.embeddings.create, body
Expand All @@ -141,13 +161,12 @@ async def embeddings_proxy(request: Request, deployment_id: str):
return response.to_dict()


@app.post("/chat/completions")
@app.post("/openai/deployments/{deployment_id:path}/chat/completions")
@dial_exception_decorator
async def chat_completions_proxy(request: Request, deployment_id: str):
async def chat_completions_proxy(request: Request):

az_client = await AzureClient.parse(
request, deployment_id, "chat/completions"
)
az_client = await AzureClient.parse(request, "chat/completions")

transformer = az_client.attachment_transformer

Expand Down
79 changes: 61 additions & 18 deletions aidial_adapter_dial/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
class AttachmentTransformer(BaseModel):
local_storage: FileStorage
local_user_bucket: str
local_app_data: str
local_appdata: str

remote_storage: FileStorage
remote_user_bucket: str

proxy_mode: bool

@classmethod
async def create(
cls, remote_storage: FileStorage, local_storage: FileStorage
Expand All @@ -34,22 +36,42 @@ async def create(
local_user_bucket = AppData.parse(local_appdata).user_bucket

remote = await remote_storage.get_bucket(session)
remote_appdata = remote.get("appdata")
if remote_appdata is None:
remote_user_bucket = remote["bucket"]
else:
remote_user_bucket = AppData.parse(remote_appdata).user_bucket

proxy_mode = (
remote_storage.dial_url == local_storage.dial_url
and remote_storage.api_key == local_storage.api_key
)

log.debug(f"proxy_mode: {proxy_mode}")

return cls(
remote_storage=remote_storage,
remote_user_bucket=remote["bucket"],
remote_user_bucket=remote_user_bucket,
local_storage=local_storage,
local_user_bucket=local_user_bucket,
local_app_data=local_appdata,
local_appdata=local_appdata,
proxy_mode=proxy_mode,
)

def get_remote_url(self, local_url: str) -> str:
"""
user/app files:
< files/LOCAL_USER_BUCKET/PATH
> files/REMOTE_USER_BUCKET/LOCAL_USER_BUCKET/PATH
if proxy_mode:
< files/LOCAL_USER_BUCKET/PATH
> files/REMOTE_USER_BUCKET/PATH
else:
< files/LOCAL_USER_BUCKET/PATH
> files/REMOTE_USER_BUCKET/LOCAL_USER_BUCKET/PATH
"""

if self.proxy_mode:
return local_url

if not local_url.startswith(f"files/{self.local_user_bucket}/"):
raise ValueError(f"Unexpected local URL: {local_url!r}")

Expand All @@ -58,8 +80,12 @@ def get_remote_url(self, local_url: str) -> str:
def get_local_url(self, remote_url: str) -> str:
"""
user/app files uploaded from local to remote earlier (reverse of get_remote_url):
< files/REMOTE_USER_BUCKET/LOCAL_USER_BUCKET/PATH
> files/LOCAL_USER_BUCKET/PATH
if proxy_mode:
< files/REMOTE_USER_BUCKET/PATH
> files/LOCAL_USER_BUCKET/PATH
else:
< files/REMOTE_USER_BUCKET/LOCAL_USER_BUCKET/PATH
> files/LOCAL_USER_BUCKET/PATH
created by remote (user):
< files/REMOTE_USER_BUCKET/appdata/REMOTE_APP_NAME/PATH
Expand All @@ -82,19 +108,26 @@ def get_local_url(self, remote_url: str) -> str:
f"files/{self.remote_user_bucket}/"
)

if remote_path.startswith(f"{self.local_user_bucket}/"):
path = remote_path.removeprefix(f"{self.local_user_bucket}/")
return f"files/{self.local_user_bucket}/{path}"
else:
if remote_path.startswith("appdata/"):
regex = r"appdata/([^/]+)/(.+)"
match = re.match(regex, remote_path)
if match is None:
raise ValueError(
f"The remote file ({remote_url!r}) is expected to be uploaded to a remote appdata path"
)
raise ValueError(f"Invalid remote appdata path: {remote_url!r}")
_remote_app_name, path = match.groups()
return f"files/{self.local_appdata}/{path}"

if not self.proxy_mode:
if remote_path.startswith(f"{self.local_user_bucket}/"):
path = remote_path.removeprefix(f"{self.local_user_bucket}/")
return f"files/{self.local_user_bucket}/{path}"

return f"files/{self.local_app_data}/{path}"
raise ValueError(
f"The remote file ({remote_url!r}) is expected to be uploaded either "
"to remote appdata path or "
"to a local user bucket subpath of remote user bucket."
)
else:
return remote_url

async def modify_request_attachment(self, attachment: dict) -> None:
if (ref_url := attachment.get("reference_url")) and (
Expand Down Expand Up @@ -188,9 +221,15 @@ async def download_and_upload_file(
dest_url: str,
content_type: str | None,
):
async with aiohttp.ClientSession() as session:
content = await src_storage.download(src_url, session)
await dest_storage.upload(dest_url, content_type, content, session)
log.debug(f"downloading from {src_url!r} and uploading to {dest_url!r}")

if src_url != dest_url:
if _is_directory(src_url):
raise ValueError("Directories aren't yet supported")

async with aiohttp.ClientSession() as session:
content = await src_storage.download(src_url, session)
await dest_storage.upload(dest_url, content_type, content, session)


async def modify_message(
Expand All @@ -205,3 +244,7 @@ async def modify_message(
return
for attachment in attachments:
await modify_attachment(attachment)


def _is_directory(url: str) -> bool:
return url[-1] == "/"
9 changes: 7 additions & 2 deletions aidial_adapter_dial/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,15 @@ def create_error(
def to_dial_exception(e: Exception) -> HTTPException | FastAPIException:
if isinstance(e, APIStatusError):
r = e.response
headers = r.headers

if "Content-Length" in headers:
del headers["Content-Length"]

return FastAPIException(
detail=r.json(),
detail=r.text,
status_code=r.status_code,
headers=dict(r.headers),
headers=dict(headers),
)

if isinstance(e, APITimeoutError):
Expand Down
26 changes: 22 additions & 4 deletions aidial_adapter_dial/utils/sse_stream.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import json
import logging
from typing import Any, AsyncIterator, Mapping

from aidial_adapter_dial.utils.exceptions import create_error
from aidial_adapter_dial.utils.exceptions import (
create_error,
to_dial_exception,
to_starlette_exception,
)

DATA_PREFIX = "data: "
OPENAI_END_MARKER = "[DONE]"
Expand Down Expand Up @@ -54,9 +59,22 @@ async def parse_openai_sse_stream(
yield chunk


log = logging.getLogger(__name__)


async def to_openai_sse_stream(
stream: AsyncIterator[dict],
) -> AsyncIterator[str]:
async for chunk in stream:
yield format_chunk(chunk)
yield END_CHUNK
try:
async for chunk in stream:
yield format_chunk(chunk)
yield END_CHUNK
except Exception as e:
log.exception(
f"caught exception while streaming: {type(e).__module__}.{type(e).__name__}"
)

dial_exception = to_dial_exception(e)
starlette_exception = to_starlette_exception(dial_exception)

yield format_chunk(starlette_exception.detail)
4 changes: 4 additions & 0 deletions aidial_adapter_dial/utils/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ async def upload(
content: bytes,
session: aiohttp.ClientSession,
) -> FileMetadata:
log.debug(f"uploading file {url!r}")

if self.to_dial_url(url) is None:
raise ValueError(f"URL isn't DIAL url: {url!r}")
url = self.to_abs_url(url)
Expand Down Expand Up @@ -158,6 +160,8 @@ def to_abs_url(self, link: str) -> str:
return ret

async def download(self, url: str, session: aiohttp.ClientSession) -> bytes:
log.debug(f"downloading file {url!r}")

if self.to_dial_url(url) is None:
raise ValueError(f"URL isn't DIAL url: {url!r}")
url = self.to_abs_url(url)
Expand Down
6 changes: 3 additions & 3 deletions docker-compose/local/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ services:
chat:
ports:
- "3000:3000"
image: epam/ai-dial-chat:0.10.0
image: epam/ai-dial-chat:0.14.2
depends_on:
- themes
- core
Expand All @@ -22,7 +22,7 @@ services:
THEMES_CONFIG_HOST: "http://themes:8080"
DIAL_API_HOST: "http://core:8080"
DIAL_API_KEY: "dial_api_key"
ENABLED_FEATURES: "conversations-section,prompts-section,top-settings,top-clear-conversation,top-chat-info,top-chat-model-settings,empty-chat-settings,header,footer,request-api-key,report-an-issue,likes,input-files,attachments-manager"
ENABLED_FEATURES: "conversations-section,prompts-section,top-settings,top-clear-conversation,top-chat-info,top-chat-model-settings,empty-chat-settings,header,footer,request-api-key,report-an-issue,likes,conversations-sharing,prompts-sharing,input-files,attachments-manager,conversations-publishing,prompts-publishing,custom-logo,input-links"

redis:
image: redis:7.2.4-alpine3.19
Expand All @@ -42,7 +42,7 @@ services:
user: ${UID:-root}
ports:
- "8080:8080"
image: epam/ai-dial-core:0.9.0
image: epam/ai-dial-core:0.14.0
environment:
'AIDIAL_SETTINGS': '/opt/settings/settings.json'
'JAVA_OPTS': '-Dgflog.config=/opt/settings/gflog.xml'
Expand Down

0 comments on commit 54e9980

Please sign in to comment.