Skip to content

Commit

Permalink
support set memory from db
Browse files Browse the repository at this point in the history
  • Loading branch information
bdqfork committed Feb 23, 2024
1 parent b7d9459 commit fac0271
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 8 deletions.
4 changes: 3 additions & 1 deletion libs/superagent/app/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from app.models.request import LLMParams
from app.utils.callbacks import CustomAsyncIteratorCallbackHandler
from prisma.enums import AgentType
from prisma.models import Agent
from prisma.models import Agent, MemoryDb

DEFAULT_PROMPT = (
"You are a helpful AI Assistant, answer the users questions to "
Expand All @@ -21,6 +21,7 @@ def __init__(
callbacks: List[CustomAsyncIteratorCallbackHandler] = [],
llm_params: Optional[LLMParams] = {},
agent_config: Agent = None,
memory_config: MemoryDb = None,
):
self.agent_id = agent_id
self.session_id = session_id
Expand All @@ -29,6 +30,7 @@ def __init__(
self.callbacks = callbacks
self.llm_params = llm_params
self.agent_config = agent_config
self.memory_config = memory_config

async def _get_tools(
self,
Expand Down
30 changes: 24 additions & 6 deletions libs/superagent/app/agents/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from app.models.tools import DatasourceInput
from app.tools import TOOL_TYPE_MAPPING, create_pydantic_model_from_object, create_tool
from app.tools.datasource import DatasourceTool, StructuredDatasourceTool
from app.utils.helpers import get_first_non_null
from app.utils.llm import LLM_MAPPING
from prisma.models import Agent, AgentDatasource, AgentLLM, AgentTool

Expand Down Expand Up @@ -194,32 +195,49 @@ async def _get_prompt(self, agent: Agent) -> str:
return SystemMessage(content=content)

async def _get_memory(self) -> List:
memory_type = config("MEMORY", "motorhead")
if memory_type == "redis":
memory_provider = get_first_non_null(
self.memory_config.provider if self.memory_config else None,
config("MEMORY"),
)
options = (
self.memory_config.options.data
if self.memory_config and self.memory_config.options
else {}
)
if memory_provider == "REDIS" or memory_provider == "redis":
memory = ConversationBufferWindowMemory(
chat_memory=RedisChatMessageHistory(
session_id=(
f"{self.agent_id}-{self.session_id}"
if self.session_id
else f"{self.agent_id}"
),
url=config("REDIS_MEMORY_URL", "redis://localhost:6379/0"),
url=get_first_non_null(
options.get("REDIS_MEMORY_URL"),
config("REDIS_MEMORY_URL", "redis://localhost:6379/0"),
),
key_prefix="superagent:",
),
memory_key="chat_history",
return_messages=True,
output_key="output",
k=config("REDIS_MEMORY_WINDOW", 10),
k=get_first_non_null(
options.get("REDIS_MEMORY_WINDOW"),
config("REDIS_MEMORY_WINDOW", 10),
),
)
else:
elif memory_provider == "MOTORHEAD" or memory_provider == "motorhead":
memory = MotorheadMemory(
session_id=(
f"{self.agent_id}-{self.session_id}"
if self.session_id
else f"{self.agent_id}"
),
memory_key="chat_history",
url=config("MEMORY_API_URL"),
url=get_first_non_null(
options.get("MEMORY_API_URL"),
config("MEMORY_API_URL"),
),
return_messages=True,
output_key="output",
)
Expand Down
5 changes: 5 additions & 0 deletions libs/superagent/app/api/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,10 @@ async def invoke(
if not model and metadata.get("model"):
model = metadata.get("model")

memory_config = await prisma.memorydb.find_unique(
where={"provider": agent_config.memory, "apiUserId": api_user.id},
)

def track_agent_invocation(result):
intermediate_steps_to_obj = [
{
Expand Down Expand Up @@ -565,6 +569,7 @@ async def send_message(
callbacks=monitoring_callbacks,
llm_params=body.llm_params,
agent_config=agent_config,
memory_config=memory_config,
)
agent = await agent_base.get_agent()

Expand Down
104 changes: 104 additions & 0 deletions libs/superagent/app/api/memory_dbs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import json

import segment.analytics as analytics
from decouple import config
from fastapi import APIRouter, Depends

from app.models.request import MemoryDb as MemoryDbRequest
from app.models.response import MemoryDb as MemoryDbResponse
from app.models.response import MemoryDbList as MemoryDbListResponse
from app.utils.api import get_current_api_user, handle_exception
from app.utils.prisma import prisma
from prisma import Json

SEGMENT_WRITE_KEY = config("SEGMENT_WRITE_KEY", None)

router = APIRouter()
analytics.write_key = SEGMENT_WRITE_KEY


@router.post(
"/memory-db",
name="create",
description="Create a new Memory Database",
response_model=MemoryDbResponse,
)
async def create(body: MemoryDbRequest, api_user=Depends(get_current_api_user)):
"""Endpoint for creating a Memory Database"""
if SEGMENT_WRITE_KEY:
analytics.track(api_user.id, "Created Memory Database")

data = await prisma.memorydb.create(
{
**body.dict(),
"apiUserId": api_user.id,
"options": json.dumps(body.options),
}
)
data.options = json.dumps(data.options)
return {"success": True, "data": data}


@router.get(
"/memory-dbs",
name="list",
description="List all Memory Databases",
response_model=MemoryDbListResponse,
)
async def list(api_user=Depends(get_current_api_user)):
"""Endpoint for listing all Memory Databases"""
try:
data = await prisma.memorydb.find_many(
where={"apiUserId": api_user.id}, order={"createdAt": "desc"}
)
# Convert options to string
for item in data:
item.options = json.dumps(item.options)
return {"success": True, "data": data}
except Exception as e:
handle_exception(e)


@router.get(
"/memory-dbs/{memory_db_id}",
name="get",
description="Get a single Memory Database",
response_model=MemoryDbResponse,
)
async def get(memory_db_id: str, api_user=Depends(get_current_api_user)):
"""Endpoint for getting a single Memory Database"""
try:
data = await prisma.memorydb.find_first(
where={"id": memory_db_id, "apiUserId": api_user.id}
)
data.options = json.dumps(data.options)
return {"success": True, "data": data}
except Exception as e:
handle_exception(e)


@router.patch(
"/memory-dbs/{memory_db_id}",
name="update",
description="Patch a Memory Database",
response_model=MemoryDbResponse,
)
async def update(
memory_db_id: str, body: MemoryDbRequest, api_user=Depends(get_current_api_user)
):
"""Endpoint for patching a Memory Database"""
try:
if SEGMENT_WRITE_KEY:
analytics.track(api_user.id, "Updated Vector Database")
data = await prisma.memorydb.update(
where={"id": memory_db_id},
data={
**body.dict(exclude_unset=True),
"apiUserId": api_user.id,
"options": Json(body.options),
},
)
data.options = json.dumps(data.options)
return {"success": True, "data": data}
except Exception as e:
handle_exception(e)
8 changes: 7 additions & 1 deletion libs/superagent/app/models/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from openai.types.beta.assistant_create_params import Tool as OpenAiAssistantTool
from pydantic import BaseModel

from prisma.enums import AgentType, LLMProvider, VectorDbProvider
from prisma.enums import AgentType, LLMProvider, MemoryDbProvider, VectorDbProvider


class ApiUser(BaseModel):
Expand Down Expand Up @@ -40,6 +40,7 @@ class AgentUpdate(BaseModel):
initialMessage: Optional[str]
prompt: Optional[str]
llmModel: Optional[str]
memory: Optional[str]
description: Optional[str]
avatar: Optional[str]
type: Optional[str]
Expand Down Expand Up @@ -132,3 +133,8 @@ class WorkflowInvoke(BaseModel):
class VectorDb(BaseModel):
provider: VectorDbProvider
options: Dict


class MemoryDb(BaseModel):
provider: MemoryDbProvider
options: Dict
13 changes: 13 additions & 0 deletions libs/superagent/app/models/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from prisma.models import (
Datasource as DatasourceModel,
)
from prisma.models import (
MemoryDb as MemoryDbModel,
)
from prisma.models import (
Tool as ToolModel,
)
Expand Down Expand Up @@ -141,3 +144,13 @@ class VectorDb(BaseModel):
class VectorDbList(BaseModel):
success: bool
data: Optional[List[VectorDbModel]]


class MemoryDb(BaseModel):
success: bool
data: Optional[MemoryDbModel]


class MemoryDbList(BaseModel):
success: bool
data: Optional[List[MemoryDbModel]]
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
-- CreateEnum
CREATE TYPE "MemoryDbProvider" AS ENUM ('MOTORHEAD', 'REDIS');

-- AlterTable
ALTER TABLE "Agent" ADD COLUMN "memory" "MemoryDbProvider" DEFAULT 'MOTORHEAD';

-- CreateTable
CREATE TABLE "MemoryDb" (
"id" TEXT NOT NULL,
"provider" "MemoryDbProvider" NOT NULL DEFAULT 'MOTORHEAD',
"options" JSONB NOT NULL,
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updatedAt" TIMESTAMP(3) NOT NULL,
"apiUserId" TEXT NOT NULL,

CONSTRAINT "MemoryDb_pkey" PRIMARY KEY ("id")
);

-- AddForeignKey
ALTER TABLE "MemoryDb" ADD CONSTRAINT "MemoryDb_apiUserId_fkey" FOREIGN KEY ("apiUserId") REFERENCES "ApiUser"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
17 changes: 17 additions & 0 deletions libs/superagent/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ enum VectorDbProvider {
SUPABASE
}

enum MemoryDbProvider {
MOTORHEAD
REDIS
}

model ApiUser {
id String @id @default(uuid())
token String?
Expand All @@ -106,6 +111,7 @@ model ApiUser {
workflows Workflow[]
vectorDb VectorDb[]
workflowConfigs WorkflowConfig[]
MemoryDb MemoryDb[]
}

model Agent {
Expand All @@ -120,6 +126,7 @@ model Agent {
updatedAt DateTime @updatedAt
llms AgentLLM[]
llmModel LLMModel? @default(GPT_3_5_TURBO_16K_0613)
memory MemoryDbProvider? @default(MOTORHEAD)
prompt String?
apiUserId String
apiUser ApiUser @relation(fields: [apiUserId], references: [id])
Expand Down Expand Up @@ -253,3 +260,13 @@ model VectorDb {
apiUserId String
apiUser ApiUser @relation(fields: [apiUserId], references: [id])
}

model MemoryDb {
id String @id @default(uuid())
provider MemoryDbProvider @default(MOTORHEAD)
options Json
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
apiUserId String
apiUser ApiUser @relation(fields: [apiUserId], references: [id])
}

0 comments on commit fac0271

Please sign in to comment.