Skip to content

Commit

Permalink
support set memory from ui
Browse files Browse the repository at this point in the history
  • Loading branch information
bdqfork committed Feb 24, 2024
1 parent 5132154 commit 88ff32e
Show file tree
Hide file tree
Showing 23 changed files with 575 additions and 20 deletions.
4 changes: 2 additions & 2 deletions libs/superagent/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ ENV PORT="8080"

COPY --from=builder /app/.venv /app/.venv

COPY . ./

# Improve grpc error messages
RUN pip install grpcio-status

COPY . ./

# Enable prisma migrations
RUN prisma generate

Expand Down
7 changes: 6 additions & 1 deletion libs/superagent/app/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from app.models.request import LLMParams
from app.utils.callbacks import CustomAsyncIteratorCallbackHandler
from prisma.enums import AgentType
from prisma.models import Agent
from prisma.models import Agent, MemoryDb

DEFAULT_PROMPT = (
"You are a helpful AI Assistant, answer the users questions to "
Expand All @@ -21,6 +21,7 @@ def __init__(
callbacks: List[CustomAsyncIteratorCallbackHandler] = [],
llm_params: Optional[LLMParams] = {},
agent_config: Agent = None,
memory_config: MemoryDb = None,
):
self.agent_id = agent_id
self.session_id = session_id
Expand All @@ -29,6 +30,7 @@ def __init__(
self.callbacks = callbacks
self.llm_params = llm_params
self.agent_config = agent_config
self.memory_config = memory_config

async def _get_tools(
self,
Expand Down Expand Up @@ -60,6 +62,7 @@ async def get_agent(self):
callbacks=self.callbacks,
llm_params=self.llm_params,
agent_config=self.agent_config,
memory_config=self.memory_config,
)

elif self.agent_config.type == AgentType.LLM:
Expand All @@ -72,6 +75,7 @@ async def get_agent(self):
callbacks=self.callbacks,
llm_params=self.llm_params,
agent_config=self.agent_config,
memory_config=self.memory_config,
)

else:
Expand All @@ -85,6 +89,7 @@ async def get_agent(self):
callbacks=self.callbacks,
llm_params=self.llm_params,
agent_config=self.agent_config,
memory_config=self.memory_config,
)

return await agent.get_agent()
Expand Down
37 changes: 28 additions & 9 deletions libs/superagent/app/agents/langchain.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import json
import logging
import re
from typing import Any, List

Expand All @@ -23,8 +24,11 @@
from app.models.tools import DatasourceInput
from app.tools import TOOL_TYPE_MAPPING, create_pydantic_model_from_object, create_tool
from app.tools.datasource import DatasourceTool, StructuredDatasourceTool
from app.utils.helpers import get_first_non_null
from app.utils.llm import LLM_MAPPING
from prisma.models import LLM, Agent, AgentDatasource, AgentTool
from prisma.models import LLM, Agent, AgentDatasource, AgentTool, MemoryDb

logger = logging.getLogger(__name__)

DEFAULT_PROMPT = (
"You are a helpful AI Assistant, answer the users questions to "
Expand Down Expand Up @@ -193,33 +197,48 @@ async def _get_prompt(self, agent: Agent) -> str:
content = f"{content}" f"\n\n{datetime.datetime.now().strftime('%Y-%m-%d')}"
return SystemMessage(content=content)

async def _get_memory(self) -> List:
memory_type = config("MEMORY", "motorhead")
if memory_type == "redis":
async def _get_memory(self, memory_db: MemoryDb) -> List:
logger.debug(f"Use memory config: {memory_db}")
if memory_db is None:
memory_provider = config("MEMORY")
options = {}
else:
memory_provider = memory_db.provider
options = memory_db.options
if memory_provider == "REDIS" or memory_provider == "redis":
memory = ConversationBufferWindowMemory(
chat_memory=RedisChatMessageHistory(
session_id=(
f"{self.agent_id}-{self.session_id}"
if self.session_id
else f"{self.agent_id}"
),
url=config("REDIS_MEMORY_URL", "redis://localhost:6379/0"),
url=get_first_non_null(
options.get("REDIS_MEMORY_URL"),
config("REDIS_MEMORY_URL", "redis://localhost:6379/0"),
),
key_prefix="superagent:",
),
memory_key="chat_history",
return_messages=True,
output_key="output",
k=config("REDIS_MEMORY_WINDOW", 10),
k=get_first_non_null(
options.get("REDIS_MEMORY_WINDOW"),
config("REDIS_MEMORY_WINDOW", 10),
),
)
else:
elif memory_provider == "MOTORHEAD" or memory_provider == "motorhead":
memory = MotorheadMemory(
session_id=(
f"{self.agent_id}-{self.session_id}"
if self.session_id
else f"{self.agent_id}"
),
memory_key="chat_history",
url=config("MEMORY_API_URL"),
url=get_first_non_null(
options.get("MEMORY_API_URL"),
config("MEMORY_API_URL"),
),
return_messages=True,
output_key="output",
)
Expand All @@ -235,7 +254,7 @@ async def get_agent(self):
agent_tools=self.agent_config.tools,
)
prompt = await self._get_prompt(agent=self.agent_config)
memory = await self._get_memory()
memory = await self._get_memory(memory_db=self.memory_config)

if len(tools) > 0:
agent = initialize_agent(
Expand Down
5 changes: 5 additions & 0 deletions libs/superagent/app/api/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,10 @@ async def invoke(
if not model and metadata.get("model"):
model = metadata.get("model")

memory_config = await prisma.memorydb.find_first(
where={"provider": agent_config.memory, "apiUserId": api_user.id},
)

def track_agent_invocation(result):
intermediate_steps_to_obj = [
{
Expand Down Expand Up @@ -571,6 +575,7 @@ async def send_message(
callbacks=monitoring_callbacks,
llm_params=body.llm_params,
agent_config=agent_config,
memory_config=memory_config,
)
agent = await agent_base.get_agent()

Expand Down
104 changes: 104 additions & 0 deletions libs/superagent/app/api/memory_dbs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import json

import segment.analytics as analytics
from decouple import config
from fastapi import APIRouter, Depends

from app.models.request import MemoryDb as MemoryDbRequest
from app.models.response import MemoryDb as MemoryDbResponse
from app.models.response import MemoryDbList as MemoryDbListResponse
from app.utils.api import get_current_api_user, handle_exception
from app.utils.prisma import prisma
from prisma import Json

SEGMENT_WRITE_KEY = config("SEGMENT_WRITE_KEY", None)

router = APIRouter()
analytics.write_key = SEGMENT_WRITE_KEY


@router.post(
"/memory-db",
name="create",
description="Create a new Memory Database",
response_model=MemoryDbResponse,
)
async def create(body: MemoryDbRequest, api_user=Depends(get_current_api_user)):
"""Endpoint for creating a Memory Database"""
if SEGMENT_WRITE_KEY:
analytics.track(api_user.id, "Created Memory Database")

data = await prisma.memorydb.create(
{
**body.dict(),
"apiUserId": api_user.id,
"options": json.dumps(body.options),
}
)
data.options = json.dumps(data.options)
return {"success": True, "data": data}


@router.get(
"/memory-dbs",
name="list",
description="List all Memory Databases",
response_model=MemoryDbListResponse,
)
async def list(api_user=Depends(get_current_api_user)):
"""Endpoint for listing all Memory Databases"""
try:
data = await prisma.memorydb.find_many(
where={"apiUserId": api_user.id}, order={"createdAt": "desc"}
)
# Convert options to string
for item in data:
item.options = json.dumps(item.options)
return {"success": True, "data": data}
except Exception as e:
handle_exception(e)


@router.get(
"/memory-dbs/{memory_db_id}",
name="get",
description="Get a single Memory Database",
response_model=MemoryDbResponse,
)
async def get(memory_db_id: str, api_user=Depends(get_current_api_user)):
"""Endpoint for getting a single Memory Database"""
try:
data = await prisma.memorydb.find_first(
where={"id": memory_db_id, "apiUserId": api_user.id}
)
data.options = json.dumps(data.options)
return {"success": True, "data": data}
except Exception as e:
handle_exception(e)


@router.patch(
"/memory-dbs/{memory_db_id}",
name="update",
description="Patch a Memory Database",
response_model=MemoryDbResponse,
)
async def update(
memory_db_id: str, body: MemoryDbRequest, api_user=Depends(get_current_api_user)
):
"""Endpoint for patching a Memory Database"""
try:
if SEGMENT_WRITE_KEY:
analytics.track(api_user.id, "Updated Memory Database")
data = await prisma.memorydb.update(
where={"id": memory_db_id},
data={
**body.dict(exclude_unset=True),
"apiUserId": api_user.id,
"options": Json(body.options),
},
)
data.options = json.dumps(data.options)
return {"success": True, "data": data}
except Exception as e:
handle_exception(e)
3 changes: 2 additions & 1 deletion libs/superagent/app/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import time

import colorlog
Expand Down Expand Up @@ -26,7 +27,7 @@
console_handler.setFormatter(formatter)

logging.basicConfig(
level=logging.INFO,
level=os.environ.get("LOG_LEVEL", "INFO"),
format="%(levelname)s: %(message)s",
handlers=[console_handler],
force=True,
Expand Down
8 changes: 7 additions & 1 deletion libs/superagent/app/models/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from openai.types.beta.assistant_create_params import Tool as OpenAiAssistantTool
from pydantic import BaseModel

from prisma.enums import AgentType, LLMProvider, VectorDbProvider
from prisma.enums import AgentType, LLMProvider, MemoryDbProvider, VectorDbProvider


class ApiUser(BaseModel):
Expand Down Expand Up @@ -40,6 +40,7 @@ class AgentUpdate(BaseModel):
initialMessage: Optional[str]
prompt: Optional[str]
llmModel: Optional[str]
memory: Optional[str]
description: Optional[str]
avatar: Optional[str]
type: Optional[str]
Expand Down Expand Up @@ -132,3 +133,8 @@ class WorkflowInvoke(BaseModel):
class VectorDb(BaseModel):
provider: VectorDbProvider
options: Dict


class MemoryDb(BaseModel):
provider: MemoryDbProvider
options: Dict
13 changes: 13 additions & 0 deletions libs/superagent/app/models/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from prisma.models import (
Datasource as DatasourceModel,
)
from prisma.models import (
MemoryDb as MemoryDbModel,
)
from prisma.models import (
Tool as ToolModel,
)
Expand Down Expand Up @@ -141,3 +144,13 @@ class VectorDb(BaseModel):
class VectorDbList(BaseModel):
success: bool
data: Optional[List[VectorDbModel]]


class MemoryDb(BaseModel):
success: bool
data: Optional[MemoryDbModel]


class MemoryDbList(BaseModel):
success: bool
data: Optional[List[MemoryDbModel]]
2 changes: 2 additions & 0 deletions libs/superagent/app/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
api_user,
datasources,
llms,
memory_dbs,
tools,
vector_dbs,
workflows,
Expand All @@ -24,3 +25,4 @@
workflow_configs.router, tags=["Workflow Config"], prefix=api_prefix
)
router.include_router(vector_dbs.router, tags=["Vector Database"], prefix=api_prefix)
router.include_router(memory_dbs.router, tags=["Memory Database"], prefix=api_prefix)
Loading

0 comments on commit 88ff32e

Please sign in to comment.