Skip to content

Commit

Permalink
support set memory from ui
Browse files Browse the repository at this point in the history
  • Loading branch information
bdqfork committed Feb 24, 2024
1 parent 5132154 commit 742d88e
Show file tree
Hide file tree
Showing 21 changed files with 560 additions and 16 deletions.
4 changes: 2 additions & 2 deletions libs/superagent/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ ENV PORT="8080"

COPY --from=builder /app/.venv /app/.venv

COPY . ./

# Improve grpc error messages
RUN pip install grpcio-status

COPY . ./

# Enable prisma migrations
RUN prisma generate

Expand Down
7 changes: 6 additions & 1 deletion libs/superagent/app/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from app.models.request import LLMParams
from app.utils.callbacks import CustomAsyncIteratorCallbackHandler
from prisma.enums import AgentType
from prisma.models import Agent
from prisma.models import Agent, MemoryDb

DEFAULT_PROMPT = (
"You are a helpful AI Assistant, answer the users questions to "
Expand All @@ -21,6 +21,7 @@ def __init__(
callbacks: List[CustomAsyncIteratorCallbackHandler] = [],
llm_params: Optional[LLMParams] = {},
agent_config: Agent = None,
memory_config: MemoryDb = None,
):
self.agent_id = agent_id
self.session_id = session_id
Expand All @@ -29,6 +30,7 @@ def __init__(
self.callbacks = callbacks
self.llm_params = llm_params
self.agent_config = agent_config
self.memory_config = memory_config

async def _get_tools(
self,
Expand Down Expand Up @@ -60,6 +62,7 @@ async def get_agent(self):
callbacks=self.callbacks,
llm_params=self.llm_params,
agent_config=self.agent_config,
memory_config=self.memory_config,
)

elif self.agent_config.type == AgentType.LLM:
Expand All @@ -72,6 +75,7 @@ async def get_agent(self):
callbacks=self.callbacks,
llm_params=self.llm_params,
agent_config=self.agent_config,
memory_config=self.memory_config,
)

else:
Expand All @@ -85,6 +89,7 @@ async def get_agent(self):
callbacks=self.callbacks,
llm_params=self.llm_params,
agent_config=self.agent_config,
memory_config=self.memory_config,
)

return await agent.get_agent()
Expand Down
35 changes: 27 additions & 8 deletions libs/superagent/app/agents/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,13 @@
from app.models.tools import DatasourceInput
from app.tools import TOOL_TYPE_MAPPING, create_pydantic_model_from_object, create_tool
from app.tools.datasource import DatasourceTool, StructuredDatasourceTool
from app.utils.helpers import get_first_non_null
from app.utils.llm import LLM_MAPPING
<<<<<<< HEAD
from prisma.models import LLM, Agent, AgentDatasource, AgentTool
=======
from prisma.models import Agent, AgentDatasource, AgentLLM, AgentTool, MemoryDb
>>>>>>> 52ab6b65 (support set memory from db)

DEFAULT_PROMPT = (
"You are a helpful AI Assistant, answer the users questions to "
Expand Down Expand Up @@ -193,33 +198,47 @@ async def _get_prompt(self, agent: Agent) -> str:
content = f"{content}" f"\n\n{datetime.datetime.now().strftime('%Y-%m-%d')}"
return SystemMessage(content=content)

async def _get_memory(self) -> List:
memory_type = config("MEMORY", "motorhead")
if memory_type == "redis":
async def _get_memory(self, memory_db: MemoryDb) -> List:
if memory_db is None:
memory_provider = config("MEMORY")
options = {}
else:
memory_provider = memory_db.provider
options = memory_db.options
if memory_provider == "REDIS" or memory_provider == "redis":
memory = ConversationBufferWindowMemory(
chat_memory=RedisChatMessageHistory(
session_id=(
f"{self.agent_id}-{self.session_id}"
if self.session_id
else f"{self.agent_id}"
),
url=config("REDIS_MEMORY_URL", "redis://localhost:6379/0"),
url=get_first_non_null(
options.get("REDIS_MEMORY_URL"),
config("REDIS_MEMORY_URL", "redis://localhost:6379/0"),
),
key_prefix="superagent:",
),
memory_key="chat_history",
return_messages=True,
output_key="output",
k=config("REDIS_MEMORY_WINDOW", 10),
k=get_first_non_null(
options.get("REDIS_MEMORY_WINDOW"),
config("REDIS_MEMORY_WINDOW", 10),
),
)
else:
elif memory_provider == "MOTORHEAD" or memory_provider == "motorhead":
memory = MotorheadMemory(
session_id=(
f"{self.agent_id}-{self.session_id}"
if self.session_id
else f"{self.agent_id}"
),
memory_key="chat_history",
url=config("MEMORY_API_URL"),
url=get_first_non_null(
options.get("MEMORY_API_URL"),
config("MEMORY_API_URL"),
),
return_messages=True,
output_key="output",
)
Expand All @@ -235,7 +254,7 @@ async def get_agent(self):
agent_tools=self.agent_config.tools,
)
prompt = await self._get_prompt(agent=self.agent_config)
memory = await self._get_memory()
memory = await self._get_memory(memory_db=self.memory_config)

if len(tools) > 0:
agent = initialize_agent(
Expand Down
5 changes: 5 additions & 0 deletions libs/superagent/app/api/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,10 @@ async def invoke(
if not model and metadata.get("model"):
model = metadata.get("model")

memory_config = await prisma.memorydb.find_first(
where={"provider": agent_config.memory, "apiUserId": api_user.id},
)

def track_agent_invocation(result):
intermediate_steps_to_obj = [
{
Expand Down Expand Up @@ -571,6 +575,7 @@ async def send_message(
callbacks=monitoring_callbacks,
llm_params=body.llm_params,
agent_config=agent_config,
memory_config=memory_config,
)
agent = await agent_base.get_agent()

Expand Down
104 changes: 104 additions & 0 deletions libs/superagent/app/api/memory_dbs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import json

import segment.analytics as analytics
from decouple import config
from fastapi import APIRouter, Depends

from app.models.request import MemoryDb as MemoryDbRequest
from app.models.response import MemoryDb as MemoryDbResponse
from app.models.response import MemoryDbList as MemoryDbListResponse
from app.utils.api import get_current_api_user, handle_exception
from app.utils.prisma import prisma
from prisma import Json

SEGMENT_WRITE_KEY = config("SEGMENT_WRITE_KEY", None)

router = APIRouter()
analytics.write_key = SEGMENT_WRITE_KEY


@router.post(
"/memory-db",
name="create",
description="Create a new Memory Database",
response_model=MemoryDbResponse,
)
async def create(body: MemoryDbRequest, api_user=Depends(get_current_api_user)):
"""Endpoint for creating a Memory Database"""
if SEGMENT_WRITE_KEY:
analytics.track(api_user.id, "Created Memory Database")

data = await prisma.memorydb.create(
{
**body.dict(),
"apiUserId": api_user.id,
"options": json.dumps(body.options),
}
)
data.options = json.dumps(data.options)
return {"success": True, "data": data}


@router.get(
"/memory-dbs",
name="list",
description="List all Memory Databases",
response_model=MemoryDbListResponse,
)
async def list(api_user=Depends(get_current_api_user)):
"""Endpoint for listing all Memory Databases"""
try:
data = await prisma.memorydb.find_many(
where={"apiUserId": api_user.id}, order={"createdAt": "desc"}
)
# Convert options to string
for item in data:
item.options = json.dumps(item.options)
return {"success": True, "data": data}
except Exception as e:
handle_exception(e)


@router.get(
"/memory-dbs/{memory_db_id}",
name="get",
description="Get a single Memory Database",
response_model=MemoryDbResponse,
)
async def get(memory_db_id: str, api_user=Depends(get_current_api_user)):
"""Endpoint for getting a single Memory Database"""
try:
data = await prisma.memorydb.find_first(
where={"id": memory_db_id, "apiUserId": api_user.id}
)
data.options = json.dumps(data.options)
return {"success": True, "data": data}
except Exception as e:
handle_exception(e)


@router.patch(
"/memory-dbs/{memory_db_id}",
name="update",
description="Patch a Memory Database",
response_model=MemoryDbResponse,
)
async def update(
memory_db_id: str, body: MemoryDbRequest, api_user=Depends(get_current_api_user)
):
"""Endpoint for patching a Memory Database"""
try:
if SEGMENT_WRITE_KEY:
analytics.track(api_user.id, "Updated Memory Database")
data = await prisma.memorydb.update(
where={"id": memory_db_id},
data={
**body.dict(exclude_unset=True),
"apiUserId": api_user.id,
"options": Json(body.options),
},
)
data.options = json.dumps(data.options)
return {"success": True, "data": data}
except Exception as e:
handle_exception(e)
8 changes: 7 additions & 1 deletion libs/superagent/app/models/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from openai.types.beta.assistant_create_params import Tool as OpenAiAssistantTool
from pydantic import BaseModel

from prisma.enums import AgentType, LLMProvider, VectorDbProvider
from prisma.enums import AgentType, LLMProvider, MemoryDbProvider, VectorDbProvider


class ApiUser(BaseModel):
Expand Down Expand Up @@ -40,6 +40,7 @@ class AgentUpdate(BaseModel):
initialMessage: Optional[str]
prompt: Optional[str]
llmModel: Optional[str]
memory: Optional[str]
description: Optional[str]
avatar: Optional[str]
type: Optional[str]
Expand Down Expand Up @@ -132,3 +133,8 @@ class WorkflowInvoke(BaseModel):
class VectorDb(BaseModel):
provider: VectorDbProvider
options: Dict


class MemoryDb(BaseModel):
provider: MemoryDbProvider
options: Dict
13 changes: 13 additions & 0 deletions libs/superagent/app/models/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from prisma.models import (
Datasource as DatasourceModel,
)
from prisma.models import (
MemoryDb as MemoryDbModel,
)
from prisma.models import (
Tool as ToolModel,
)
Expand Down Expand Up @@ -141,3 +144,13 @@ class VectorDb(BaseModel):
class VectorDbList(BaseModel):
success: bool
data: Optional[List[VectorDbModel]]


class MemoryDb(BaseModel):
success: bool
data: Optional[MemoryDbModel]


class MemoryDbList(BaseModel):
success: bool
data: Optional[List[MemoryDbModel]]
2 changes: 2 additions & 0 deletions libs/superagent/app/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
api_user,
datasources,
llms,
memory_dbs,
tools,
vector_dbs,
workflows,
Expand All @@ -24,3 +25,4 @@
workflow_configs.router, tags=["Workflow Config"], prefix=api_prefix
)
router.include_router(vector_dbs.router, tags=["Vector Database"], prefix=api_prefix)
router.include_router(memory_dbs.router, tags=["Memory Database"], prefix=api_prefix)
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
-- CreateEnum
CREATE TYPE "MemoryDbProvider" AS ENUM ('MOTORHEAD', 'REDIS');

-- AlterTable
ALTER TABLE "Agent" ADD COLUMN "memory" "MemoryDbProvider" DEFAULT 'MOTORHEAD';

-- CreateTable
CREATE TABLE "MemoryDb" (
"id" TEXT NOT NULL,
"provider" "MemoryDbProvider" NOT NULL DEFAULT 'MOTORHEAD',
"options" JSONB NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"apiUserId" TEXT NOT NULL,

CONSTRAINT "MemoryDb_pkey" PRIMARY KEY ("id")
);

-- AddForeignKey
ALTER TABLE "MemoryDb" ADD CONSTRAINT "MemoryDb_apiUserId_fkey" FOREIGN KEY ("apiUserId") REFERENCES "ApiUser"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
17 changes: 17 additions & 0 deletions libs/superagent/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ enum VectorDbProvider {
SUPABASE
}

enum MemoryDbProvider {
MOTORHEAD
REDIS
}

model ApiUser {
id String @id @default(uuid())
token String?
Expand All @@ -106,6 +111,7 @@ model ApiUser {
workflows Workflow[]
vectorDb VectorDb[]
workflowConfigs WorkflowConfig[]
MemoryDb MemoryDb[]
}

model Agent {
Expand All @@ -120,6 +126,7 @@ model Agent {
updatedAt DateTime @updatedAt
llms AgentLLM[]
llmModel LLMModel? @default(GPT_3_5_TURBO_16K_0613)
memory MemoryDbProvider? @default(MOTORHEAD)
prompt String?
apiUserId String
apiUser ApiUser @relation(fields: [apiUserId], references: [id])
Expand Down Expand Up @@ -253,3 +260,13 @@ model VectorDb {
apiUserId String
apiUser ApiUser @relation(fields: [apiUserId], references: [id])
}

model MemoryDb {
id String @id @default(uuid())
provider MemoryDbProvider @default(MOTORHEAD)
options Json
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
apiUserId String
apiUser ApiUser @relation(fields: [apiUserId], references: [id])
}
Loading

0 comments on commit 742d88e

Please sign in to comment.