Skip to content

Commit

Permalink
[Feat] Add Sequential and AsyncSequential agents (#270)
Browse files Browse the repository at this point in the history
* add sequential agents

* display agent hierarchy

* update

* simplify arguments
  • Loading branch information
braisedpork1964 authored Nov 12, 2024
1 parent 5726f95 commit e72713a
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 5 deletions.
4 changes: 2 additions & 2 deletions lagent/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .agent import Agent, AgentDict, AgentList, AsyncAgent
from .agent import Agent, AgentDict, AgentList, AsyncAgent, AsyncSequential, Sequential
from .react import AsyncReAct, ReAct
from .stream import AgentForInternLM, AsyncAgentForInternLM, AsyncMathCoder, MathCoder

__all__ = [
'Agent', 'AgentDict', 'AgentList', 'AsyncAgent', 'AgentForInternLM',
'AsyncAgentForInternLM', 'MathCoder', 'AsyncMathCoder', 'ReAct',
'AsyncReAct'
'AsyncReAct', 'Sequential', 'AsyncSequential'
]
96 changes: 93 additions & 3 deletions lagent/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from collections import OrderedDict, UserDict, UserList, abc
from functools import wraps
from itertools import chain, repeat
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union

from lagent.agents.aggregator import DefaultAggregator
Expand Down Expand Up @@ -169,7 +170,22 @@ def reset(self, session_id=0):
self.memory.reset(session_id=session_id)

def __repr__(self):
return f"{self.__class__.__name__}(name='{self.name}', description='{self.description or ''}')"

def _rcsv_repr(agent, n_indent=1):
res = agent.__class__.__name__ + (f"(name='{agent.name}')"
if agent.name else '')
modules = [
f"{n_indent * ' '}({name}): {_rcsv_repr(agent, n_indent + 1)}"
for name, agent in getattr(agent, '_agents', {}).items()
]
if modules:
res += '(\n' + '\n'.join(
modules) + f'\n{(n_indent - 1) * " "})'
elif not res.endswith(')'):
res += '()'
return res

return _rcsv_repr(self)


class AsyncAgent(Agent):
Expand Down Expand Up @@ -225,6 +241,78 @@ async def forward(self,
return llm_response


class Sequential(Agent):
"""Sequential is an agent container that forwards messages to each agent
in the order they are added."""

def __init__(self, *agents: Union[Agent, AsyncAgent, Iterable], **kwargs):
super().__init__(**kwargs)
self._agents = OrderedDict()
if not agents:
raise ValueError('At least one agent should be provided')
if isinstance(agents[0],
Iterable) and not isinstance(agents[0], Agent):
if not agents[0]:
raise ValueError('At least one agent should be provided')
agents = agents[0]
for key, agent in enumerate(agents):
if isinstance(agents, Mapping):
key, agent = agent, agents[agent]
elif isinstance(agent, tuple):
key, agent = agent
self.add_agent(key, agent)

def add_agent(self, name: str, agent: Union[Agent, AsyncAgent]):
assert isinstance(
agent, (Agent, AsyncAgent
)), f'{type(agent)} is not an Agent or AsyncAgent subclass'
self._agents[str(name)] = agent

def forward(self,
*message: AgentMessage,
session_id=0,
exit_at: Optional[int] = None,
**kwargs) -> AgentMessage:
assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
if exit_at is None:
exit_at = len(self) - 1
iterator = chain.from_iterable(repeat(self._agents.values()))
for _ in range(exit_at + 1):
agent = next(iterator)
if isinstance(message, AgentMessage):
message = (message, )
message = agent(*message, session_id=session_id, **kwargs)
return message

def __getitem__(self, key):
if isinstance(key, int) and key < 0:
assert key >= -len(self), 'index out of range'
key = len(self) + key
return self._agents[str(key)]

def __len__(self):
return len(self._agents)


class AsyncSequential(Sequential, AsyncAgent):

async def forward(self,
*message: AgentMessage,
session_id=0,
exit_at: Optional[int] = None,
**kwargs) -> AgentMessage:
assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
if exit_at is None:
exit_at = len(self) - 1
iterator = chain.from_iterable(repeat(self._agents.values()))
for _ in range(exit_at + 1):
agent = next(iterator)
if isinstance(message, AgentMessage):
message = (message, )
message = await agent(*message, session_id=session_id, **kwargs)
return message


class AgentContainerMixin:

def __init_subclass__(cls):
Expand Down Expand Up @@ -276,18 +364,20 @@ def _backup(d):
setattr(cls, method, wrap_api(getattr(cls, method)))


class AgentList(UserList, Agent, AgentContainerMixin):
class AgentList(Agent, UserList, AgentContainerMixin):

def __init__(self,
agents: Optional[Iterable[Union[Agent, AsyncAgent]]] = None):
Agent.__init__(self, memory=None)
UserList.__init__(self, agents)
self.name = None


class AgentDict(UserDict, Agent, AgentContainerMixin):
class AgentDict(Agent, UserDict, AgentContainerMixin):

def __init__(self,
agents: Optional[Mapping[str, Union[Agent,
AsyncAgent]]] = None):
Agent.__init__(self, memory=None)
UserDict.__init__(self, agents)
self.name = None

0 comments on commit e72713a

Please sign in to comment.