From c22a7951b196d01a3217066783e5b2323a0538af Mon Sep 17 00:00:00 2001 From: alisalim17 Date: Mon, 29 Apr 2024 19:01:43 +0400 Subject: [PATCH 01/10] refactor: remove unused superagent agent type --- libs/superagent/app/agents/superagent.py | 31 ------------- libs/superagent/app/memory/base.py | 58 ------------------------ 2 files changed, 89 deletions(-) delete mode 100644 libs/superagent/app/agents/superagent.py diff --git a/libs/superagent/app/agents/superagent.py b/libs/superagent/app/agents/superagent.py deleted file mode 100644 index 496e51ae4..000000000 --- a/libs/superagent/app/agents/superagent.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Any, List, Tuple - -from decouple import config - -from app.agents.base import AgentBase -from app.memory.base import Memory -from prisma.models import Agent, AgentDatasource, AgentTool - - -class SuperagentAgent(AgentBase): - async def _get_tools( - self, agent_datasources: List[AgentDatasource], agent_tools: List[AgentTool] - ) -> List: - tools = [] - return tools - - async def _get_memory(self) -> Tuple[str, List[Any]]: - memory = Memory( - session_id=f"{self.agent_id}-{self.session_id}" - if self.session_id - else f"{self.agent_id}", - url=config("MEMORY_API_URL"), - ) - return memory - - async def get_agent(self, config: Agent) -> Any: - # memory = await self._get_memory() - # tools = await self._get_tools( - # agent_datasources=config.datasources, agent_tools=config.tools - # ) - pass diff --git a/libs/superagent/app/memory/base.py b/libs/superagent/app/memory/base.py index e61ce5496..e69de29bb 100644 --- a/libs/superagent/app/memory/base.py +++ b/libs/superagent/app/memory/base.py @@ -1,58 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple - -import requests -from decouple import config - -MANAGED_URL = config("MEMORY_API_URL") - - -class Memory: - """Assistant memory""" - - def __init__( - self, - session_id: str, - url: str = MANAGED_URL, - timeout: int = 3000, - context: Optional[str] = None, - ): - self.url = url - self.timeout = timeout - self.session_id = session_id - self.context = context - self.chat_memory = [] - - def __get_headers(self) -> Dict[str, str]: - headers = { - "Content-Type": "application/json", - } - return headers - - async def init(self) -> Tuple[str, List[Any]]: - res = requests.get( - f"{self.url}/sessions/{self.session_id}/memory", - timeout=self.timeout, - headers=self.__get_headers(), - ) - res_data = res.json() - res_data = res_data.get("data", res_data) - messages = res_data.get("messages", []) - context = res_data.get("context", "NONE") - return (context, list(reversed(messages))) - - def save_context(self, input: str, output: str) -> None: - requests.post( - f"{self.url}/sessions/{self.session_id}/memory", - timeout=self.timeout, - json={ - "messages": [ - {"role": "Human", "content": f"{input}"}, - {"role": "AI", "content": f"{output}"}, - ] - }, - headers=self.__get_headers(), - ) - - def delete_session(self) -> None: - """Delete a session""" - requests.delete(f"{self.url}/sessions/{self.session_id}/memory") From cd398ede5dc65f4a9bc6d1fb9b58ee256aec146f Mon Sep 17 00:00:00 2001 From: alisalim17 Date: Tue, 30 Apr 2024 10:47:56 +0400 Subject: [PATCH 02/10] bump redis version to 5.0.4 --- libs/superagent/poetry.lock | 10 +++++----- libs/superagent/pyproject.toml | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/libs/superagent/poetry.lock b/libs/superagent/poetry.lock index 9857cdffe..ebd321186 100644 --- a/libs/superagent/poetry.lock +++ b/libs/superagent/poetry.lock @@ -4084,17 +4084,17 @@ setuptools = ">=41.0" [[package]] name = "redis" -version = "5.0.1" +version = "5.0.4" description = "Python client for Redis database and key-value store" optional = false python-versions = ">=3.7" files = [ - {file = "redis-5.0.1-py3-none-any.whl", hash = "sha256:ed4802971884ae19d640775ba3b03aa2e7bd5e8fb8dfaed2decce4d0fc48391f"}, - {file = "redis-5.0.1.tar.gz", hash = "sha256:0dab495cd5753069d3bc650a0dde8a8f9edde16fc5691b689a566eda58100d0f"}, + {file = "redis-5.0.4-py3-none-any.whl", hash = "sha256:7adc2835c7a9b5033b7ad8f8918d09b7344188228809c98df07af226d39dec91"}, + {file = "redis-5.0.4.tar.gz", hash = "sha256:ec31f2ed9675cc54c21ba854cfe0462e6faf1d83c8ce5944709db8a4700b9c61"}, ] [package.dependencies] -async-timeout = {version = ">=4.0.2", markers = "python_full_version <= \"3.11.2\""} +async-timeout = {version = ">=4.0.3", markers = "python_full_version < \"3.11.3\""} [package.extras] hiredis = ["hiredis (>=1.0.0)"] @@ -6025,4 +6025,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8.1, <3.12" -content-hash = "9049d2eda40cf7a7809de8eeac32efac0753e9006d50b8ef98aca4ef75f0e703" +content-hash = "c1d87a3c0e460cc17f2f2dec3510ff5fa058e50eca86e5c678799decb359fa7f" diff --git a/libs/superagent/pyproject.toml b/libs/superagent/pyproject.toml index 7ea34bbb4..aa9452451 100644 --- a/libs/superagent/pyproject.toml +++ b/libs/superagent/pyproject.toml @@ -61,7 +61,7 @@ langchain-openai = "^0.0.5" python-docx = "^1.1.0" prisma = "^0.12.0" stripe = "^8.2.0" -redis = "^5.0.1" +redis = "5.0.4" langsmith = "^0.1.9" langfuse = "2.21.3" tavily-python = "^0.3.1" From d91d50fa78845017425463797423fcac2250f927 Mon Sep 17 00:00:00 2001 From: alisalim17 Date: Tue, 30 Apr 2024 21:26:15 +0400 Subject: [PATCH 03/10] feat: add buffer memory --- libs/superagent/app/memory/base.py | 29 ++++++++++ libs/superagent/app/memory/buffer_memory.py | 53 +++++++++++++++++++ .../app/memory/memory_stores/base.py | 22 ++++++++ .../app/memory/memory_stores/redis.py | 31 +++++++++++ 4 files changed, 135 insertions(+) create mode 100644 libs/superagent/app/memory/buffer_memory.py create mode 100644 libs/superagent/app/memory/memory_stores/base.py create mode 100644 libs/superagent/app/memory/memory_stores/redis.py diff --git a/libs/superagent/app/memory/base.py b/libs/superagent/app/memory/base.py index e69de29bb..e7f56d443 100644 --- a/libs/superagent/app/memory/base.py +++ b/libs/superagent/app/memory/base.py @@ -0,0 +1,29 @@ +from abc import ABC, abstractmethod +from typing import List + +from app.memory.memory_stores.base import BaseMemoryStore +from app.memory.message import BaseMessage + + +class BaseMemory(ABC): + memory_store: BaseMemoryStore + + @abstractmethod + def add_message(self, message: BaseMessage) -> None: + ... + + @abstractmethod + async def aadd_message(self, message: BaseMessage) -> None: + ... + + @abstractmethod + def get_messages(self) -> List[BaseMessage]: + """ + List all the messages stored in the memory. + Messages are returned in the descending order of their creation. + """ + ... + + @abstractmethod + def clear(self) -> None: + ... diff --git a/libs/superagent/app/memory/buffer_memory.py b/libs/superagent/app/memory/buffer_memory.py new file mode 100644 index 000000000..96c2467fa --- /dev/null +++ b/libs/superagent/app/memory/buffer_memory.py @@ -0,0 +1,53 @@ +from typing import Optional + +from litellm import model_cost + +from app.memory.base import BaseMemory +from app.memory.memory_stores.base import BaseMemoryStore +from app.memory.message import BaseMessage + +DEFAULT_TOKEN_LIMIT_RATIO = 0.75 +DEFAULT_TOKEN_LIMIT = 3000 + + +class BufferMemory(BaseMemory): + def __init__( + self, + memory_store: BaseMemoryStore, + tokenizer_fn: callable, + model: str, + max_tokens: Optional[int] = None, + ): + self.memory_store = memory_store + self.tokenizer_fn = tokenizer_fn + self.model = model + context_window = model_cost.get(self.model, {}).get("max_input_tokens") + self.context_window = max_tokens or context_window * DEFAULT_TOKEN_LIMIT_RATIO + + def add_message(self, message: BaseMessage) -> None: + self.memory_store.add_message(message) + + async def aadd_message(self, message: BaseMessage) -> None: + await self.memory_store.aadd_message(message) + + def get_messages( + self, + inital_token_usage: int = 0, + ) -> list[BaseMessage]: + messages = self.memory_store.get_messages() + + index = 0 + token_usage = inital_token_usage + while index < len(messages): + message = messages[index] + curr_token_usage = self.tokenizer_fn(text=message.content) + if token_usage + curr_token_usage > self.context_window: + break + + token_usage += curr_token_usage + index += 1 + + return messages[:index] + + def clear(self) -> None: + self.memory_store.clear() diff --git a/libs/superagent/app/memory/memory_stores/base.py b/libs/superagent/app/memory/memory_stores/base.py new file mode 100644 index 000000000..f7fe11987 --- /dev/null +++ b/libs/superagent/app/memory/memory_stores/base.py @@ -0,0 +1,22 @@ +from abc import ABC, abstractmethod +from typing import List + +from app.memory.message import BaseMessage + + +class BaseMemoryStore(ABC): + @abstractmethod + def get_messages(self) -> List[BaseMessage]: + ... # noqa + + @abstractmethod + def add_message(self, value: BaseMessage): + ... # noqa + + @abstractmethod + async def aadd_message(self, value: BaseMessage): + ... # noqa + + @abstractmethod + def clear(self): + ... diff --git a/libs/superagent/app/memory/memory_stores/redis.py b/libs/superagent/app/memory/memory_stores/redis.py new file mode 100644 index 000000000..5ba86cacf --- /dev/null +++ b/libs/superagent/app/memory/memory_stores/redis.py @@ -0,0 +1,31 @@ +from asyncio import get_event_loop + +from redis import Redis + +from app.memory.memory_stores.base import BaseMemoryStore +from app.memory.message import BaseMessage + + +class RedisMemoryStore(BaseMemoryStore): + key_prefix: str = "message_history:" + + def __init__(self, uri: str, session_id: str): + self.redis = Redis.from_url(uri) + self.session_id = session_id + + @property + def key(self): + return self.key_prefix + self.session_id + + def add_message(self, message: BaseMessage): + self.redis.lpush(self.key, message.json()) + + async def aadd_message(self, message: BaseMessage): + loop = get_event_loop() + await loop.run_in_executor(None, self.add_message, message) + + def get_messages(self) -> list[BaseMessage]: + return [BaseMessage.parse_raw(m) for m in self.redis.lrange(self.key, 0, -1)] + + def clear(self): + self.redis.delete(self.key) From ec9c30a0aafa4b0d911fb1f3f2b6edfe749ae194 Mon Sep 17 00:00:00 2001 From: alisalim17 Date: Tue, 30 Apr 2024 21:27:30 +0400 Subject: [PATCH 04/10] add message type --- libs/superagent/app/memory/message.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 libs/superagent/app/memory/message.py diff --git a/libs/superagent/app/memory/message.py b/libs/superagent/app/memory/message.py new file mode 100644 index 000000000..45490aca8 --- /dev/null +++ b/libs/superagent/app/memory/message.py @@ -0,0 +1,15 @@ +from enum import Enum + +from pydantic import BaseModel + + +class MessageType(str, Enum): + HUMAN = "human" + AI = "ai" + TOOL_CALL = "tool_call" + TOOL_RESULT = "tool_result" + + +class BaseMessage(BaseModel): + type: MessageType + content: str From a863272e4edea0767284618e479ee4e0d3f1ad47 Mon Sep 17 00:00:00 2001 From: alisalim17 Date: Tue, 30 Apr 2024 21:27:46 +0400 Subject: [PATCH 05/10] remove unused import --- libs/superagent/app/tools/google_search.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/libs/superagent/app/tools/google_search.py b/libs/superagent/app/tools/google_search.py index 7dfcf4a7d..9770b2c2e 100644 --- a/libs/superagent/app/tools/google_search.py +++ b/libs/superagent/app/tools/google_search.py @@ -1,11 +1,9 @@ -import aiohttp -import requests import json -from decouple import config +import aiohttp +import requests from langchain_community.tools import BaseTool - url = "https://google.serper.dev/search" From a7527d78b06aadb41422704060ed0377691c39e9 Mon Sep 17 00:00:00 2001 From: alisalim17 Date: Tue, 30 Apr 2024 22:12:39 +0400 Subject: [PATCH 06/10] add memory to llm agent --- libs/superagent/app/agents/base.py | 6 + libs/superagent/app/agents/langchain.py | 8 +- libs/superagent/app/agents/llm.py | 213 +++++++++++++++++++----- 3 files changed, 184 insertions(+), 43 deletions(-) diff --git a/libs/superagent/app/agents/base.py b/libs/superagent/app/agents/base.py index 305a9cb8a..b9e614d6e 100644 --- a/libs/superagent/app/agents/base.py +++ b/libs/superagent/app/agents/base.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from functools import cached_property from typing import Any, List, Optional from langchain.agents import AgentExecutor @@ -78,6 +79,11 @@ def prompt(self) -> Any: def tools(self) -> Any: ... + # TODO: Set a proper return type when we remove Langchain agent type + @cached_property + async def memory(self) -> Any: + ... + @abstractmethod def get_agent(self) -> AgentExecutor: ... diff --git a/libs/superagent/app/agents/langchain.py b/libs/superagent/app/agents/langchain.py index c5fc39e95..63209af6a 100644 --- a/libs/superagent/app/agents/langchain.py +++ b/libs/superagent/app/agents/langchain.py @@ -1,4 +1,5 @@ import datetime +from functools import cached_property from decouple import config from langchain.agents import AgentType, initialize_agent @@ -62,9 +63,8 @@ def _get_llm(self): max_tokens=llm_data.params.max_tokens, ) - async def _get_memory( - self, - ) -> None | MotorheadMemory | ConversationBufferWindowMemory: + @cached_property + async def memory(self) -> None | MotorheadMemory | ConversationBufferWindowMemory: # if memory is already set, in the main agent base class, return it if not self.session_id: raise ValueError("Session ID is required to initialize memory") @@ -95,7 +95,7 @@ async def _get_memory( async def get_agent(self): llm = self._get_llm() - memory = await self._get_memory() + memory = await self.memory tools = self.tools prompt = self.prompt diff --git a/libs/superagent/app/agents/llm.py b/libs/superagent/app/agents/llm.py index 476a3a8f6..fbb8f089a 100644 --- a/libs/superagent/app/agents/llm.py +++ b/libs/superagent/app/agents/llm.py @@ -1,6 +1,9 @@ +import asyncio import datetime import json import logging +from dataclasses import dataclass +from functools import cached_property, partial from typing import Any from decouple import config @@ -12,9 +15,14 @@ get_llm_provider, get_supported_openai_params, stream_chunk_builder, + token_counter, ) from app.agents.base import AgentBase +from app.memory.base import BaseMessage +from app.memory.buffer_memory import BufferMemory +from app.memory.memory_stores.redis import RedisMemoryStore +from app.memory.message import MessageType from app.tools import get_tools from app.utils.callbacks import CustomAsyncIteratorCallbackHandler from app.utils.prisma import prisma @@ -27,30 +35,19 @@ logger = logging.getLogger(__name__) +@dataclass +class ToolCallResponse: + action_log: AgentActionMessageLog + result: Any + return_direct: bool = False + success: bool = True + + async def call_tool( agent_data: Agent, session_id: str, function: Any -) -> tuple[AgentActionMessageLog, Any]: +) -> ToolCallResponse: name = function.get("name") - try: - args = json.loads(function.get("arguments")) - except Exception as e: - logger.error(f"Error parsing function arguments for {name}: {e}") - raise e - - tools = get_tools( - agent_data=agent_data, - session_id=session_id, - ) - tool_to_call = None - for tool in tools: - if tool.name == name: - tool_to_call = tool - break - if not tool_to_call: - raise Exception(f"Function {name} not found in tools") - - logging.info(f"Calling tool {name} with arguments {args}") - + args = function.get("arguments") action_log = AgentActionMessageLog( tool=name, tool_input=args, @@ -67,15 +64,60 @@ async def call_tool( ) ], ) - tool_res = None + try: - tool_res = await tool_to_call._arun(**args) - logging.info(f"Tool {name} returned {tool_res}") + args = json.loads(args) except Exception as e: - tool_res = f"Error calling {tool_to_call.name} tool with arguments {args}: {e}" - logging.error(f"Error calling tool {name}: {e}") + msg = f"Error parsing function arguments for {name}: {e}" + logger.error(msg) + return ToolCallResponse( + action_log=action_log, + result=msg, + return_direct=False, + success=False, + ) + + tools = get_tools( + agent_data=agent_data, + session_id=session_id, + ) + + logging.info(f"Calling tool {name} with arguments {args}") + + tool_to_call = None + for tool in tools: + if tool.name == name: + tool_to_call = tool + break + + if not tool_to_call: + msg = f"Function {name} not found in tools, avaliable tool names: {', '.join([tool.name for tool in tools])}" + logger.error(msg) + return ToolCallResponse( + action_log=action_log, + result=msg, + return_direct=False, + success=False, + ) - return (action_log, tool_res, tool_to_call.return_direct) + try: + result = await tool_to_call._arun(**args) + logging.info(f"Tool {name} returned {result}") + return ToolCallResponse( + action_log=action_log, + result=result, + return_direct=tool_to_call.return_direct, + success=True, + ) + except Exception as e: + msg = f"Error calling {tool_to_call.name} tool with arguments {args}: {e}" + logger.error(msg) + return ToolCallResponse( + action_log=action_log, + result=msg, + return_direct=False, + success=False, + ) class LLMAgent(AgentBase): @@ -107,6 +149,22 @@ def tools(self): for tool in tools ] + @cached_property + def memory(self) -> BufferMemory: + redisMemoryStore = RedisMemoryStore( + uri=config("REDIS_MEMORY_URL", "redis://localhost:6379/0"), + session_id=self.session_id, + ) + tokenizer_fn = partial(token_counter, model=self.llm_data.model) + + bufferMemory = BufferMemory( + memory_store=redisMemoryStore, + model=self.llm_data.model, + tokenizer_fn=tokenizer_fn, + ) + + return bufferMemory + @property def prompt(self): base_prompt = self.agent_data.prompt or DEFAULT_PROMPT @@ -123,6 +181,15 @@ def prompt(self): else: prompt = base_prompt + messages = self.memory.get_messages( + inital_token_usage=len(prompt), + ) + if len(messages) > 0: + prompt += "\n\n Previous messages: \n" + for message in messages: + prompt += ( + f"""{message.type.value.capitalize()}: {message.content}\n\n""" + ) return prompt @property @@ -178,22 +245,23 @@ async def _execute_tool_calls(self, tool_calls: list[dict], **kwargs): session_id=self.session_id, function=tool_call.get("function"), ) - (action_log, tool_res, return_direct) = intermediate_step - self.intermediate_steps.append((action_log, tool_res)) + self.intermediate_steps.append( + (intermediate_step.action_log, intermediate_step.result) + ) new_message = { "role": "tool", "name": tool_call.get("function").get("name"), - "content": tool_res, + "content": intermediate_step.result, } if tool_call.get("id"): new_message["tool_call_id"] = tool_call.get("id") messages.append(new_message) - if return_direct: + if intermediate_step.return_direct: if self.enable_streaming: - await self._stream_by_lines(tool_res) + await self._stream_by_lines(intermediate_step.result) self.streaming_callback.done.set() - return tool_res + return intermediate_step.result self.messages = messages kwargs["messages"] = self.messages @@ -278,7 +346,7 @@ async def _process_completion_response(self, res): if content: output += content if self._stream_directly: - await self.streaming_callback.on_llm_new_token(content) + await self._stream_by_lines(content) return (tool_calls, new_messages, output) @@ -306,6 +374,27 @@ async def _acompletion(self, **kwargs) -> Any: tool_calls, new_messages, output = result + if output: + await self.memory.aadd_message( + message=BaseMessage( + type=MessageType.AI, + content=output, + ) + ) + + if tool_calls: + await asyncio.gather( + *[ + self.memory.aadd_message( + message=BaseMessage( + type=MessageType.TOOL_CALL, + content=json.dumps(tool_call), + ) + ) + for tool_call in tool_calls + ] + ) + self.messages = new_messages if tool_calls: @@ -334,6 +423,13 @@ async def ainvoke(self, input, *_, **kwargs): }, ] + await self.memory.aadd_message( + message=BaseMessage( + type=MessageType.HUMAN, + content=self.input, + ) + ) + if self.enable_streaming: self._set_streaming_callback(kwargs.get("config", {}).get("callbacks", [])) @@ -396,6 +492,13 @@ async def ainvoke(self, input, *_, **kwargs): self.input = input tool_results = [] + await self.memory.aadd_message( + message=BaseMessage( + type=MessageType.HUMAN, + content=self.input, + ) + ) + if self.enable_streaming: self._set_streaming_callback(kwargs.get("config", {}).get("callbacks", [])) @@ -423,22 +526,47 @@ async def ainvoke(self, input, *_, **kwargs): ) tool_calls = res.choices[0].message.get("tool_calls", []) + await asyncio.gather( + *[ + self.memory.aadd_message( + message=BaseMessage( + type=MessageType.TOOL_CALL, + content=json.dumps(tool_call), + ) + ) + for tool_call in tool_calls + ] + ) + for tool_call in tool_calls: - (action_log, tool_res, return_direct) = await call_tool( + intermediate_step = await call_tool( agent_data=self.agent_data, session_id=self.session_id, function=tool_call.function.dict(), ) - tool_results.append((action_log, tool_res)) - if return_direct: + + # TODO: handle the failure in tool call case + # if not intermediate_step.success: + # self.memory.add_message( + # message=BaseMessage( + # type=MessageType.TOOL_RESULT, + # content=intermediate_step.result, + # ) + # ) + + tool_results.append( + (intermediate_step.action_log, intermediate_step.result) + ) + + if intermediate_step.return_direct: if self.enable_streaming: - await self._stream_by_lines(tool_res) + await self._stream_by_lines(intermediate_step.result) self.streaming_callback.done.set() return { "intermediate_steps": tool_results, "input": self.input, - "output": tool_res, + "output": intermediate_step.result, } if len(tool_results) > 0: @@ -473,6 +601,13 @@ async def ainvoke(self, input, *_, **kwargs): else: output = res.choices[0].message.content + await self.memory.aadd_message( + message=BaseMessage( + type=MessageType.AI, + content=output, + ) + ) + return { "intermediate_steps": tool_results, "input": self.input, From 6cefeb7f3ae792592e171593866c7cd75393e973 Mon Sep 17 00:00:00 2001 From: alisalim17 Date: Tue, 30 Apr 2024 22:33:43 +0400 Subject: [PATCH 07/10] fix openai tool calling object's serialization --- libs/superagent/app/agents/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/superagent/app/agents/llm.py b/libs/superagent/app/agents/llm.py index fbb8f089a..5f0b19f15 100644 --- a/libs/superagent/app/agents/llm.py +++ b/libs/superagent/app/agents/llm.py @@ -531,7 +531,7 @@ async def ainvoke(self, input, *_, **kwargs): self.memory.aadd_message( message=BaseMessage( type=MessageType.TOOL_CALL, - content=json.dumps(tool_call), + content=tool_call.json(), ) ) for tool_call in tool_calls From 19f1e9a9fc0959344d3285e94aa727d61a06104a Mon Sep 17 00:00:00 2001 From: alisalim17 Date: Wed, 1 May 2024 18:37:44 +0400 Subject: [PATCH 08/10] use xml syntax in the prompt for chat history --- libs/superagent/app/agents/llm.py | 221 ++++++++++++++++-------------- libs/superagent/poetry.lock | 2 +- 2 files changed, 118 insertions(+), 105 deletions(-) diff --git a/libs/superagent/app/agents/llm.py b/libs/superagent/app/agents/llm.py index 96f0b4d86..4296c0cc9 100644 --- a/libs/superagent/app/agents/llm.py +++ b/libs/superagent/app/agents/llm.py @@ -98,7 +98,7 @@ async def call_tool( break if not tool_to_call: - msg = f"Function {function.name} not found in tools, avaliable tool names: {', '.join([tool.name for tool in tools])}" + msg = f"Tool {function.name} not found in tools, avaliable tool names: {', '.join([tool.name for tool in tools])}" logger.error(msg) return ToolCallResponse( action_log=action_log, @@ -193,11 +193,11 @@ def prompt(self): inital_token_usage=len(prompt), ) if len(messages) > 0: - prompt += "\n\n Previous messages: \n" + prompt += "\n\n Here's the previous conversation: \n" for message in messages: - prompt += ( - f"""{message.type.value.capitalize()}: {message.content}\n\n""" - ) + prompt += f"""<{message.type.value}> {message.content} \n""" + prompt += " \n" + return prompt @property @@ -267,7 +267,7 @@ async def _execute_tools( ) new_message = { "role": "tool", - "name": tool_call.get("function").get("name"), + "name": tool_call.function.name, "content": tool_call_res.result, } if tool_call.id: @@ -402,7 +402,7 @@ async def _acompletion(self, depth: int = 0, **kwargs) -> Any: self.memory.aadd_message( message=BaseMessage( type=MessageType.TOOL_CALL, - content=json.dumps(tool_call), + content=tool_call.json(), ) ) for tool_call in tool_calls @@ -512,115 +512,128 @@ def messages(self): ] async def ainvoke(self, input, *_, **kwargs): - self.input = input - tool_results = [] + output = "" + try: + self.input = input + tool_results = [] - await self.memory.aadd_message( - message=BaseMessage( - type=MessageType.HUMAN, - content=self.input, - ) - ) + if self.enable_streaming: + self._set_streaming_callback( + kwargs.get("config", {}).get("callbacks", []) + ) - if self.enable_streaming: - self._set_streaming_callback(kwargs.get("config", {}).get("callbacks", [])) + if len(self.tools) > 0: + openai_llm = await prisma.llm.find_first( + where={ + "provider": LLMProvider.OPENAI.value, + "apiUserId": self.agent_data.apiUserId, + } + ) + if openai_llm: + openai_api_key = openai_llm.apiKey + else: + openai_api_key = config("OPENAI_API_KEY") + logger.warn( + "OpenAI API Key not found in database, using environment variable" + ) - if len(self.tools) > 0: - openai_llm = await prisma.llm.find_first( - where={ - "provider": LLMProvider.OPENAI.value, - "apiUserId": self.agent_data.apiUserId, - } - ) - if openai_llm: - openai_api_key = openai_llm.apiKey - else: - openai_api_key = config("OPENAI_API_KEY") - logger.warn( - "OpenAI API Key not found in database, using environment variable" + res = await acompletion( + api_key=openai_api_key, + model="gpt-3.5-turbo-0125", + messages=self.messages_function_calling, + tools=self.tools, + stream=False, ) - res = await acompletion( - api_key=openai_api_key, - model="gpt-3.5-turbo-0125", - messages=self.messages_function_calling, - tools=self.tools, - stream=False, - ) + tool_calls = [] + if ( + hasattr(res.choices[0].message, "tool_calls") + and res.choices[0].message.tool_calls + ): + tool_calls = res.choices[0].message.tool_calls + + for tool_call in tool_calls: + tool_call_res = await call_tool( + agent_data=self.agent_data, + session_id=self.session_id, + function=tool_call.function, + ) - tool_calls = res.choices[0].message.get("tool_calls", []) - for tool_call in tool_calls: - tool_call_res = await call_tool( - agent_data=self.agent_data, - session_id=self.session_id, - function=tool_call.function, - ) + # TODO: handle the failure in tool call case + # if not intermediate_step.success: + # self.memory.add_message( + # message=BaseMessage( + # type=MessageType.TOOL_RESULT, + # content=intermediate_step.result, + # ) + # ) + + tool_results.append( + (tool_call_res.action_log, tool_call_res.result) + ) - # TODO: handle the failure in tool call case - # if not intermediate_step.success: - # self.memory.add_message( - # message=BaseMessage( - # type=MessageType.TOOL_RESULT, - # content=intermediate_step.result, - # ) - # ) - - tool_results.append((tool_call_res.action_log, tool_call_res.result)) - - if tool_call_res.return_direct: - if self.enable_streaming: - await self._stream_text_by_lines(tool_call_res.result) - self.streaming_callback.done.set() - - return { - "intermediate_steps": tool_results, - "input": self.input, - "output": tool_call_res.result, - } + if tool_call_res.return_direct: + if self.enable_streaming: + await self._stream_text_by_lines(tool_call_res.result) + self.streaming_callback.done.set() + + output = tool_call_res.result + + return { + "intermediate_steps": tool_results, + "input": self.input, + "output": output, + } + + if len(tool_results) > 0: + INPUT_TEMPLATE = "{input}\n Context: {context}\n" + self.input = INPUT_TEMPLATE.format( + input=self.input, + context="\n\n".join( + [tool_response for (_, tool_response) in tool_results] + ), + ) - if len(tool_results) > 0: - INPUT_TEMPLATE = "{input}\n Context: {context}\n" - self.input = INPUT_TEMPLATE.format( - input=self.input, - context="\n\n".join( - [tool_response for (_, tool_response) in tool_results] - ), + params = self.llm_data.params.dict(exclude_unset=True) + second_res = await acompletion( + api_key=self.llm_data.llm.apiKey, + model=self.llm_data.model, + messages=self.messages, + stream=self.enable_streaming, + **params, ) - params = self.llm_data.params.dict(exclude_unset=True) - second_res = await acompletion( - api_key=self.llm_data.llm.apiKey, - model=self.llm_data.model, - messages=self.messages, - stream=self.enable_streaming, - **params, - ) - - output = "" - if self.enable_streaming: - await self.streaming_callback.on_llm_start() - second_res = cast(CustomStreamWrapper, second_res) + if self.enable_streaming: + await self.streaming_callback.on_llm_start() + second_res = cast(CustomStreamWrapper, second_res) - async for chunk in second_res: - token = chunk.choices[0].delta.content - if token: - output += token - await self.streaming_callback.on_llm_new_token(token) + async for chunk in second_res: + token = chunk.choices[0].delta.content + if token: + output += token + await self.streaming_callback.on_llm_new_token(token) - self.streaming_callback.done.set() - else: - second_res = cast(ModelResponse, second_res) - output = second_res.choices[0].message.content + self.streaming_callback.done.set() + else: + second_res = cast(ModelResponse, second_res) + output = second_res.choices[0].message.content - await self.memory.aadd_message( - message=BaseMessage( - type=MessageType.AI, - content=output, + return { + "intermediate_steps": tool_results, + "input": self.input, + "output": output, + } + finally: + await self.memory.aadd_message( + message=BaseMessage( + type=MessageType.HUMAN, + content=self.input, + ) ) - ) - return { - "intermediate_steps": tool_results, - "input": self.input, - "output": output, - } + await self.memory.aadd_message( + message=BaseMessage( + type=MessageType.AI, + content=output, + ) + ) diff --git a/libs/superagent/poetry.lock b/libs/superagent/poetry.lock index 0b4f3d228..433a94c84 100644 --- a/libs/superagent/poetry.lock +++ b/libs/superagent/poetry.lock @@ -6025,4 +6025,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8.1, <3.12" -content-hash = "c1d87a3c0e460cc17f2f2dec3510ff5fa058e50eca86e5c678799decb359fa7f" +content-hash = "36479548d83988d37ef0ac239f97aa9be5f61c4fddbd240e3eda3f0dbcc37358" From abbbf23709bf708a946a0c403e5faafa70059a3d Mon Sep 17 00:00:00 2001 From: alisalim17 Date: Thu, 2 May 2024 10:26:02 +0400 Subject: [PATCH 09/10] rename call_tool to execute_tool --- libs/superagent/app/agents/llm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/superagent/app/agents/llm.py b/libs/superagent/app/agents/llm.py index 4296c0cc9..c2ad06c31 100644 --- a/libs/superagent/app/agents/llm.py +++ b/libs/superagent/app/agents/llm.py @@ -52,7 +52,7 @@ class ToolCallResponse: success: bool = True -async def call_tool( +async def execute_tool( agent_data: Agent, session_id: str, function: Function ) -> ToolCallResponse: action_log = AgentActionMessageLog( @@ -257,7 +257,7 @@ async def _execute_tools( ): messages: list = kwargs.get("messages") for tool_call in tool_calls: - tool_call_res = await call_tool( + tool_call_res = await execute_tool( agent_data=self.agent_data, session_id=self.session_id, function=tool_call.function, @@ -553,7 +553,7 @@ async def ainvoke(self, input, *_, **kwargs): tool_calls = res.choices[0].message.tool_calls for tool_call in tool_calls: - tool_call_res = await call_tool( + tool_call_res = await execute_tool( agent_data=self.agent_data, session_id=self.session_id, function=tool_call.function, From 45aa969dff7ea5943a1eeec05c1df85f449f4781 Mon Sep 17 00:00:00 2001 From: alisalim17 Date: Thu, 2 May 2024 10:53:29 +0400 Subject: [PATCH 10/10] list messages in the ascending order of their creation --- libs/superagent/app/agents/llm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/libs/superagent/app/agents/llm.py b/libs/superagent/app/agents/llm.py index c2ad06c31..63856b348 100644 --- a/libs/superagent/app/agents/llm.py +++ b/libs/superagent/app/agents/llm.py @@ -193,6 +193,7 @@ def prompt(self): inital_token_usage=len(prompt), ) if len(messages) > 0: + messages.reverse() prompt += "\n\n Here's the previous conversation: \n" for message in messages: prompt += f"""<{message.type.value}> {message.content} \n"""