Skip to content

Commit

Permalink
Replace tool parsing status with a new enumeration to avoid confusion (
Browse files Browse the repository at this point in the history
…#266)

* add tool parsing status enumeration

* update
  • Loading branch information
braisedpork1964 authored Oct 30, 2024
1 parent e8af4cc commit 5dabf3d
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 25 deletions.
5 changes: 2 additions & 3 deletions lagent/agents/aggregator/tool_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

from lagent.agents.aggregator.default_aggregator import DefaultAggregator
from lagent.memory.base_memory import Memory
from lagent.prompts.parsers.tool_parser import MixedToolParser, ToolParser
from lagent.schema import AgentStatusCode
from lagent.prompts.parsers.tool_parser import MixedToolParser, ToolParser, ToolStatusCode


class InternLMToolAggregator(DefaultAggregator):
Expand Down Expand Up @@ -84,7 +83,7 @@ def aggregate(self,
if message.sender == name:
if isinstance(message.formatted, dict):
parsed = message.formatted
if parsed['status'] == AgentStatusCode.SESSION_INVALID_ARG:
if parsed['status'] == ToolStatusCode.PARSING_ERROR:
continue
_message.append(
dict(
Expand Down
12 changes: 6 additions & 6 deletions lagent/agents/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from lagent.hooks import InternLMActionProcessor
from lagent.llms import BaseLLM
from lagent.memory import Memory
from lagent.prompts.parsers import InterpreterParser, MixedToolParser, PluginParser
from lagent.schema import AgentMessage, AgentStatusCode
from lagent.prompts.parsers import InterpreterParser, MixedToolParser, PluginParser, ToolStatusCode
from lagent.schema import AgentMessage
from lagent.utils import create_object

API_PREFIX = (
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(
action_hooks: List = [dict(type=InternLMActionProcessor)],
finish_condition: Callable[
[AgentMessage],
bool] = lambda m: m.formatted['status'] == AgentStatusCode.END,
bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
max_turn: int = 4,
**kwargs,
):
Expand Down Expand Up @@ -165,7 +165,7 @@ def __init__(
action_hooks: List = [dict(type=InternLMActionProcessor)],
finish_condition: Callable[
[AgentMessage],
bool] = lambda m: m.formatted['status'] == AgentStatusCode.END,
bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
max_turn: int = 6,
**kwargs,
):
Expand Down Expand Up @@ -205,7 +205,7 @@ def __init__(
action_hooks: List = [dict(type=InternLMActionProcessor)],
finish_condition: Callable[
[AgentMessage],
bool] = lambda m: m.formatted['status'] == AgentStatusCode.END,
bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
max_turn: int = 4,
**kwargs,
):
Expand Down Expand Up @@ -289,7 +289,7 @@ def __init__(
action_hooks: List = [dict(type=InternLMActionProcessor)],
finish_condition: Callable[
[AgentMessage],
bool] = lambda m: m.formatted['status'] == AgentStatusCode.END,
bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
max_turn: int = 6,
**kwargs,
):
Expand Down
4 changes: 2 additions & 2 deletions lagent/prompts/parsers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .custom_parser import CustomFormatParser
from .json_parser import JSONParser
from .str_parser import StrParser
from .tool_parser import InterpreterParser, MixedToolParser, PluginParser, ToolParser
from .tool_parser import InterpreterParser, MixedToolParser, PluginParser, ToolParser, ToolStatusCode

__all__ = [
'CustomFormatParser', 'JSONParser', 'StrParser', 'ToolParser',
'InterpreterParser', 'PluginParser', 'MixedToolParser'
'InterpreterParser', 'PluginParser', 'MixedToolParser', 'ToolStatusCode'
]
24 changes: 11 additions & 13 deletions lagent/prompts/parsers/tool_parser.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json
from enum import IntEnum

# import re
from typing import Any, Callable, List, Optional

from lagent.prompts.parsers import StrParser
from lagent.schema import AgentStatusCode
from lagent.utils import create_object, load_class_from_string


Expand All @@ -15,6 +15,12 @@ def default_plugin_validate(plugin: str):
return json.loads(plugin)


class ToolStatusCode(IntEnum):
NO_TOOL = 0
VALID_TOOL = 1
PARSING_ERROR = -1


class ToolParser(StrParser):

def __init__(self,
Expand All @@ -34,28 +40,20 @@ def __init__(self,
validate, str) else validate

def parse_response(self, data: str) -> dict:
# match = self.pattern.search(data)
# if not match:
# return dict(
# tool_type=None,
# thought=data,
# action=None,
# status=AgentStatusCode.END)
# thought, action = match.group(1), match.group(2).strip()
if self.format_field['begin'] not in data:
return dict(
tool_type=None,
thought=data,
action=None,
status=AgentStatusCode.END)
status=ToolStatusCode.NO_TOOL)
thought, action, *_ = data.split(self.format_field["begin"])
action = action.split(self.format_field['end'])[0]
status = AgentStatusCode.STREAM_ING
status = ToolStatusCode.VALID_TOOL
if self.validate:
try:
action = self.validate(action)
except Exception:
status = AgentStatusCode.SESSION_INVALID_ARG
status = ToolStatusCode.PARSING_ERROR
return dict(
tool_type=self.tool_type,
thought=thought,
Expand Down Expand Up @@ -131,7 +129,7 @@ def parse_response(self, data: str) -> dict:
tool_type=None,
thought=data,
action=None,
status=AgentStatusCode.END)
status=ToolStatusCode.NO_TOOL)
for name, parser in self.parsers.items():
res = parser.parse_response(data)
if res['tool_type'] == name:
Expand Down
1 change: 0 additions & 1 deletion lagent/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ class AgentStatusCode(IntEnum):


class AgentMessage(BaseModel):

content: Any
sender: str = 'user'
formatted: Optional[Any] = None
Expand Down

0 comments on commit 5dabf3d

Please sign in to comment.