Skip to content

Commit

Permalink
Merge pull request #91 from small-thinking/add-sync-tool
Browse files Browse the repository at this point in the history
Add sync tool that inherit from dspy.Predict
  • Loading branch information
yxjiang authored Jul 3, 2024
2 parents 0881196 + 8f5bd7d commit 7109adb
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 30 deletions.
62 changes: 34 additions & 28 deletions polymind/core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections.abc import Mapping, Sequence
from typing import Any, Dict, List, Union, get_origin

import dspy
from dotenv import load_dotenv
from dspy import Module, Predict, Retrieve
from pydantic import BaseModel, Field, field_validator
Expand Down Expand Up @@ -95,7 +96,7 @@ def check_type(cls, v: str) -> str:
)


class BaseTool(BaseModel, ABC):
class AbstractTool(BaseModel, ABC):
"""The base class of the tool.
In an agent system, a tool is an object that can be used to perform a task.
For example, search for information from the internet, query a database,
Expand Down Expand Up @@ -138,33 +139,14 @@ def check_descriptions(cls, v: List[str]) -> List[str]:
def get_descriptions(self) -> List[str]:
return self.descriptions

async def __call__(self, input: Message) -> Message:
"""Makes the instance callable, delegating to the execute method.
This allows the instance to be used as a callable object, simplifying the syntax for executing the tool.
Args:
input (Message): The input message to the tool.
Returns:
Message: The output message from the tool.
"""
self._validate_input_message(input)
output_message = await self._execute(input)
self._validate_output_message(output_message)
return output_message

def get_spec(self) -> str:
"""Return the input and output specification of the tool.
Returns:
Tuple[List[Param], List[Param]]: The input and output specification of the tool.
"""
input_json_obj = []
for param in self.input_spec():
input_json_obj.append(param.to_json_obj())
output_json_obj = []
for param in self.output_spec():
output_json_obj.append(param.to_json_obj())
input_json_obj = [param.to_json_obj() for param in self.input_spec()]
output_json_obj = [param.to_json_obj() for param in self.output_spec()]
spec_json_obj = {
"input_message": input_json_obj,
"output_message": output_json_obj,
Expand Down Expand Up @@ -223,10 +205,8 @@ def _validate_input_message(self, input_message: Message) -> None:
for param in input_spec:
if param.name not in input_message.content and param.required:
raise ValueError(f"The input message must contain the field '{param.name}'.")
if param.name in input_message.content and param.required:
# Extract the base type for generics (e.g., List or Dict) or use the type directly
if param.name in input_message.content:
base_type = get_origin(eval(param.type)) if get_origin(eval(param.type)) else eval(param.type)
# Map the typing module types to their concrete types for isinstance checks
type_mapping = {
Sequence: list, # Assuming to treat any sequence as a list
Mapping: dict, # Assuming to treat any mapping as a dict
Expand Down Expand Up @@ -264,8 +244,7 @@ def _validate_output_message(self, output_message: Message) -> None:
for param in output_spec:
if param.name not in output_message.content and param.required:
raise ValueError(f"The output message must contain the field '{param.name}'.")
if param.name in output_message.content and param.required:
# Extract the base type for generics (e.g., List or Dict) or use the type directly
if param.name in output_message.content:
base_type = get_origin(eval(param.type)) if get_origin(eval(param.type)) else eval(param.type)
type_mapping = {
Sequence: list, # Assuming to treat any sequence as a list
Expand All @@ -278,7 +257,14 @@ def _validate_output_message(self, output_message: Message) -> None:
f" but is '{type(output_message.content[param.name])}'."
)

@abstractmethod

class BaseTool(AbstractTool):
async def __call__(self, input: Message) -> Message:
self._validate_input_message(input)
output_message = await self._execute(input)
self._validate_output_message(output_message)
return output_message

async def _execute(self, input: Message) -> Message:
"""Execute the tool and return the result.
The derived class must implement this method to define the behavior of the tool.
Expand All @@ -292,6 +278,26 @@ async def _execute(self, input: Message) -> Message:
pass


class OptimizableBaseTool(AbstractTool, dspy.Predict):
def __call__(self, input: Message) -> Message:
self._validate_input_message(input)
output_message = self.forward(**input.content)
self._validate_output_message(output_message)
return output_message

def forward(self, **kwargs) -> Message:
"""Execute the tool and return the result synchronously.
The derived class must implement this method to define the behavior of the tool.
Args:
**kwargs: The input parameters for the tool.
Returns:
Message: The result of the tool carried in a message.
"""
pass


class ToolManager:
"""Tool manager is able to load the tools from the given folder and initialize them.
All the tools will be indexed in the dict keyed by the tool name.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "polymind"
version = "0.0.51" # Update this version before publishing to PyPI
version = "0.0.52" # Update this version before publishing to PyPI
description = "PolyMind is a customizable collaborative multi-agent framework for collective intelligence and distributed problem solving."
authors = ["TechTao"]
license = "MIT License"
Expand Down
92 changes: 91 additions & 1 deletion tests/polymind/core/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

from polymind.core.codegen import CodeGenerationTool
from polymind.core.message import Message
from polymind.core.tool import BaseTool, Param, ToolManager
from polymind.core.tool import (BaseTool, OptimizableBaseTool, Param,
ToolManager)


class TestParam:
Expand Down Expand Up @@ -447,6 +448,95 @@ def test_to_open_function_format(self, tool_instance, expected_spec):
), "The generated open function format specification should match the expected specification"


class ExampleOptimizableTool(OptimizableBaseTool):
tool_name: str = "example_optimizable_tool"
descriptions: List[str] = ["Performs an example optimization task"]

def input_spec(self) -> List[Param]:
return [
Param(name="input1", type="str", required=True, description="First input parameter", example="example1"),
Param(name="input2", type="int", required=False, description="Second input parameter", example="2"),
]

def output_spec(self) -> List[Param]:
return [
Param(name="output", type="str", required=True, description="Output parameter", example="result"),
]

def forward(self, **kwargs) -> Message:
# Example implementation of forward method
output_content = {"output": "result"}
return Message(content=output_content)


class TestOptimizableBaseTool:
@pytest.mark.parametrize(
"tool_instance, input_message, expected_output",
[
(
ExampleOptimizableTool(),
Message(content={"input1": "example1", "input2": 2}),
Message(content={"output": "result"}),
),
(
ExampleOptimizableTool(),
Message(content={"input1": "example1"}),
Message(content={"output": "result"}),
),
],
)
def test_call(self, tool_instance, input_message, expected_output):
output = tool_instance(input_message)
assert output == expected_output, "The output message should match the expected output"

@pytest.mark.parametrize(
"tool_instance, expected_spec",
[
(
ExampleOptimizableTool(),
{
"type": "function",
"function": {
"name": "example_optimizable_tool",
"description": "Performs an example optimization task",
"parameters": {
"type": "object",
"properties": {
"input1": {
"type": "string",
"example": "example1",
"description": "First input parameter",
},
"input2": {
"type": "integer",
"example": "2",
"description": "Second input parameter",
},
},
"required": ["input1"],
},
"responses": {
"type": "object",
"properties": {
"output": {
"type": "string",
"example": "result",
"description": "Output parameter",
},
},
},
},
},
),
],
)
def test_to_open_function_format(self, tool_instance, expected_spec):
spec = tool_instance.to_open_function_format()
assert (
spec == expected_spec
), "The generated open function format specification should match the expected specification"


class TestToolManager:
@pytest.fixture
def manager(self):
Expand Down

0 comments on commit 7109adb

Please sign in to comment.