diff --git a/polymind/core/tool.py b/polymind/core/tool.py index 1fd1e3d..e3d67cf 100644 --- a/polymind/core/tool.py +++ b/polymind/core/tool.py @@ -9,6 +9,7 @@ from collections.abc import Mapping, Sequence from typing import Any, Dict, List, Union, get_origin +import dspy from dotenv import load_dotenv from dspy import Module, Predict, Retrieve from pydantic import BaseModel, Field, field_validator @@ -95,7 +96,7 @@ def check_type(cls, v: str) -> str: ) -class BaseTool(BaseModel, ABC): +class AbstractTool(BaseModel, ABC): """The base class of the tool. In an agent system, a tool is an object that can be used to perform a task. For example, search for information from the internet, query a database, @@ -138,33 +139,14 @@ def check_descriptions(cls, v: List[str]) -> List[str]: def get_descriptions(self) -> List[str]: return self.descriptions - async def __call__(self, input: Message) -> Message: - """Makes the instance callable, delegating to the execute method. - This allows the instance to be used as a callable object, simplifying the syntax for executing the tool. - - Args: - input (Message): The input message to the tool. - - Returns: - Message: The output message from the tool. - """ - self._validate_input_message(input) - output_message = await self._execute(input) - self._validate_output_message(output_message) - return output_message - def get_spec(self) -> str: """Return the input and output specification of the tool. Returns: Tuple[List[Param], List[Param]]: The input and output specification of the tool. """ - input_json_obj = [] - for param in self.input_spec(): - input_json_obj.append(param.to_json_obj()) - output_json_obj = [] - for param in self.output_spec(): - output_json_obj.append(param.to_json_obj()) + input_json_obj = [param.to_json_obj() for param in self.input_spec()] + output_json_obj = [param.to_json_obj() for param in self.output_spec()] spec_json_obj = { "input_message": input_json_obj, "output_message": output_json_obj, @@ -223,10 +205,8 @@ def _validate_input_message(self, input_message: Message) -> None: for param in input_spec: if param.name not in input_message.content and param.required: raise ValueError(f"The input message must contain the field '{param.name}'.") - if param.name in input_message.content and param.required: - # Extract the base type for generics (e.g., List or Dict) or use the type directly + if param.name in input_message.content: base_type = get_origin(eval(param.type)) if get_origin(eval(param.type)) else eval(param.type) - # Map the typing module types to their concrete types for isinstance checks type_mapping = { Sequence: list, # Assuming to treat any sequence as a list Mapping: dict, # Assuming to treat any mapping as a dict @@ -264,8 +244,7 @@ def _validate_output_message(self, output_message: Message) -> None: for param in output_spec: if param.name not in output_message.content and param.required: raise ValueError(f"The output message must contain the field '{param.name}'.") - if param.name in output_message.content and param.required: - # Extract the base type for generics (e.g., List or Dict) or use the type directly + if param.name in output_message.content: base_type = get_origin(eval(param.type)) if get_origin(eval(param.type)) else eval(param.type) type_mapping = { Sequence: list, # Assuming to treat any sequence as a list @@ -278,7 +257,14 @@ def _validate_output_message(self, output_message: Message) -> None: f" but is '{type(output_message.content[param.name])}'." ) - @abstractmethod + +class BaseTool(AbstractTool): + async def __call__(self, input: Message) -> Message: + self._validate_input_message(input) + output_message = await self._execute(input) + self._validate_output_message(output_message) + return output_message + async def _execute(self, input: Message) -> Message: """Execute the tool and return the result. The derived class must implement this method to define the behavior of the tool. @@ -292,6 +278,26 @@ async def _execute(self, input: Message) -> Message: pass +class OptimizableBaseTool(AbstractTool, dspy.Predict): + def __call__(self, input: Message) -> Message: + self._validate_input_message(input) + output_message = self.forward(**input.content) + self._validate_output_message(output_message) + return output_message + + def forward(self, **kwargs) -> Message: + """Execute the tool and return the result synchronously. + The derived class must implement this method to define the behavior of the tool. + + Args: + **kwargs: The input parameters for the tool. + + Returns: + Message: The result of the tool carried in a message. + """ + pass + + class ToolManager: """Tool manager is able to load the tools from the given folder and initialize them. All the tools will be indexed in the dict keyed by the tool name. diff --git a/pyproject.toml b/pyproject.toml index 40db143..adc96cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "polymind" -version = "0.0.51" # Update this version before publishing to PyPI +version = "0.0.52" # Update this version before publishing to PyPI description = "PolyMind is a customizable collaborative multi-agent framework for collective intelligence and distributed problem solving." authors = ["TechTao"] license = "MIT License" diff --git a/tests/polymind/core/test_tool.py b/tests/polymind/core/test_tool.py index 4b5a93d..5cf5f0a 100644 --- a/tests/polymind/core/test_tool.py +++ b/tests/polymind/core/test_tool.py @@ -13,7 +13,8 @@ from polymind.core.codegen import CodeGenerationTool from polymind.core.message import Message -from polymind.core.tool import BaseTool, Param, ToolManager +from polymind.core.tool import (BaseTool, OptimizableBaseTool, Param, + ToolManager) class TestParam: @@ -447,6 +448,95 @@ def test_to_open_function_format(self, tool_instance, expected_spec): ), "The generated open function format specification should match the expected specification" +class ExampleOptimizableTool(OptimizableBaseTool): + tool_name: str = "example_optimizable_tool" + descriptions: List[str] = ["Performs an example optimization task"] + + def input_spec(self) -> List[Param]: + return [ + Param(name="input1", type="str", required=True, description="First input parameter", example="example1"), + Param(name="input2", type="int", required=False, description="Second input parameter", example="2"), + ] + + def output_spec(self) -> List[Param]: + return [ + Param(name="output", type="str", required=True, description="Output parameter", example="result"), + ] + + def forward(self, **kwargs) -> Message: + # Example implementation of forward method + output_content = {"output": "result"} + return Message(content=output_content) + + +class TestOptimizableBaseTool: + @pytest.mark.parametrize( + "tool_instance, input_message, expected_output", + [ + ( + ExampleOptimizableTool(), + Message(content={"input1": "example1", "input2": 2}), + Message(content={"output": "result"}), + ), + ( + ExampleOptimizableTool(), + Message(content={"input1": "example1"}), + Message(content={"output": "result"}), + ), + ], + ) + def test_call(self, tool_instance, input_message, expected_output): + output = tool_instance(input_message) + assert output == expected_output, "The output message should match the expected output" + + @pytest.mark.parametrize( + "tool_instance, expected_spec", + [ + ( + ExampleOptimizableTool(), + { + "type": "function", + "function": { + "name": "example_optimizable_tool", + "description": "Performs an example optimization task", + "parameters": { + "type": "object", + "properties": { + "input1": { + "type": "string", + "example": "example1", + "description": "First input parameter", + }, + "input2": { + "type": "integer", + "example": "2", + "description": "Second input parameter", + }, + }, + "required": ["input1"], + }, + "responses": { + "type": "object", + "properties": { + "output": { + "type": "string", + "example": "result", + "description": "Output parameter", + }, + }, + }, + }, + }, + ), + ], + ) + def test_to_open_function_format(self, tool_instance, expected_spec): + spec = tool_instance.to_open_function_format() + assert ( + spec == expected_spec + ), "The generated open function format specification should match the expected specification" + + class TestToolManager: @pytest.fixture def manager(self):