Skip to content

Commit

Permalink
Merge pull request #76 from small-thinking/add-codegen
Browse files Browse the repository at this point in the history
Add code gen tool
  • Loading branch information
yxjiang authored Apr 27, 2024
2 parents 2223f12 + a5d02d2 commit 19f94ac
Show file tree
Hide file tree
Showing 3 changed files with 370 additions and 3 deletions.
234 changes: 233 additions & 1 deletion polymind/core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import os
import re
import subprocess
import sys
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
Expand Down Expand Up @@ -396,6 +397,55 @@ def _set_client(self):
"""Set the client for the language model."""
pass

def input_spec(self):
return [
Param(
name="input",
type="str",
required=True,
description="The prompt for the chat.",
example="hello, how are you?",
),
Param(
name="system_prompt",
type="str",
required=False,
example="You are a helpful AI assistant.",
description="The system prompt for the chat.",
),
Param(
name="max_tokens",
type="int",
required=False,
example="1500",
description="The maximum number of tokens for the chat.",
),
Param(
name="temperature",
type="float",
required=False,
example="0.7",
description="The temperature for the chat.",
),
Param(
name="top_p",
type="float",
required=False,
example="0.1",
description="The top p for the chat.",
),
]

def output_spec(self) -> List[Param]:
return [
Param(
name="output",
type="str",
required=True,
description="The response from the chat.",
),
]

@abstractmethod
async def _invoke(self, input: Message) -> Message:
"""Invoke the language model with the input message and return the response message.
Expand All @@ -421,7 +471,7 @@ async def _execute(self, input: Message) -> Message:

# Validate the input message.
prompt = input.get("input", "")
system_prompt = input.get("system_prompt", self.system_prompt)
system_prompt = input.get("system_prompt", "")
if not prompt:
raise ValueError("Prompt in the field 'input' cannot be empty.")
input.content.update(
Expand Down Expand Up @@ -602,3 +652,185 @@ async def _execute(self, input: Message) -> Message:
if self.enable_ranking: # Rank the retrieved results based on the query.
response_message = await self._refine(input=input, response=response_message)
return response_message


class CodeGenerationTool(BaseTool):
"""A tool that can generate code based on user requirements and execute it."""

tool_name: str = Field(default="code_generation_tool", description="The name of the tool.")
llm_tool: LLMTool = Field(default=None, description="The language model tool to generate the code.")
max_attempts: int = Field(default=3, description="The maximum number of attempts to generate the code.")
descriptions: List[str] = Field(
default=[
"This tool can generate code based on user requirements and then execute it to fulfill the requirement.",
"The tool will generate the code to solve the problem based on the requirement.",
"After generating the code, the tool can execute it and provide the output.",
"This tool can use libraries like matplotlib, pandas, yfinance, and numpy to solve problems.",
],
description="The descriptions of the tool.",
)

codegen_prompt_template: str = """
You are a programmer that can generate code based on the requirement to solve the problem.
Please generate the code in python and put it in the code block below.
Note you would need to save the result in a Dict[str, Any] variable named 'output'.
An example:
Requirement: Write a function draw a pie chart based on the input data.
Code:
```python
import matplotlib.pyplot as plt
data = [10, 20, 30, 40] # Data in user input
plt.pie(data)
# Save the plot to a file
filepath = "pie_chart.png"
plt.savefig(filepath)
output = {{"filepath": filepath}}
```
The below is the actual user requirement:
------
{user_requirement}
------
The previous error if any:
------
{previous_error}
------
"""

output_extract_template: str = """
You are the checker and extract to check the output (as Dict[str, Any]) that is
intentded to solve the problem according to the requirement.
The output is generated by the code.
Please check carefully whether the result in the output really solve the problem. And always use double quotes.
If the output is correct, please extract it as str and put it into a json blob:
{{
"status": "success",
"output": {output}
}}
If the output is not correct, please return a json blob with the error message:
{{
"status": "error",
"reason": "The user asks for the stock price of APPL, but the output has no relevant information.",
}}
The user requirement:
------
{requirement}
------
The output that is a string representation of Dict[str, Any]:
------
{output}
------
"""

def __init__(self, llm_tool: LLMTool, **kwargs):
super().__init__(**kwargs)
self.llm_tool = llm_tool
self._logger = Logger(__file__)

def input_spec(self) -> List[Param]:
return [
Param(
name="input",
type="str",
required=True,
description="A natural language description of the problem or requirement.",
example="Write a function that takes two numbers as input and returns their sum.",
),
]

def output_spec(self) -> List[Param]:
return [
Param(
name="code",
type="str",
required=True,
description="The generated code to solve the problem.",
),
Param(
name="output",
type="str",
required=True,
description="The output of running the generated code.",
),
]

async def _execute(self, input: Message) -> Message:
previous_errors = []
requirement = input.content["input"]
attempts = 0
while attempts < self.max_attempts:
code = await self._code_gen(requirement=requirement, previous_errors=previous_errors)
try:
output_dict: Dict[str, Any] = await self._code_run(code)
output = await self._output_parse(requirement=requirement, output=output_dict)
return Message(content={"code": code, "output": output})
except ValueError as e:
self._logger.warning(f"Failed to parse output:\n{output_dict}\nError: {e}. Retrying...")
previous_errors.append(str(e))
attempts += 1
except Exception as e:
self._logger.warning(f"Failed to execute code: {e}. Retrying...")
previous_errors.append(str(e))
attempts += 1
raise ValueError(f"Failed to generate code after {self.max_attempts} attempts.")

async def _code_gen(self, requirement: str, previous_errors: List[str]) -> str:
previous_error = "\n".join(previous_errors)
prompt = self.codegen_prompt_template.format(user_requirement=requirement, previous_error=previous_error)
input_message = Message(content={"input": prompt})
response_message = await self.llm_tool(input=input_message)
generated_text = response_message.content.get("output", "")
code = ""
code_block = re.search(r"```python(.*?)```", generated_text, re.DOTALL)
if code_block:
code = code_block.group(1).strip()
return code
raise ValueError(f"Failed to generate code: {generated_text}")

def _extract_required_packages(self, code: str) -> List[str]:
# Regex to capture both simple imports, aliased imports, and from-imports
pattern = r"\bimport\s+([\w]+)|\bfrom\s+([\w]+)\b.*?import"
matches = re.findall(pattern, code)
# Extract non-empty matches and ensure only the package name is included
packages = {match[0] or match[1] for match in matches}
return list(packages)

async def _code_run(self, code: str) -> Dict[str, Any]:
# Ensure all required packages are installed before executing the code
packages = self._extract_required_packages(code)
# Install the required packages if they are not installed
for package in packages:
try:
__import__(package)
except ImportError:
subprocess.check_call([sys.executable, "-m", "pip", "install", package])

local = {"output": {}}
exec(code, globals(), local)
output = local.get("output", {})
return output

async def _output_parse(self, requirement: str, output: Dict[str, Any]) -> str:
"""Use LLM to parse the output based on the requirement.
Args:
requirement (str): The user requirement.
output (Dict[str, Any]): The output from the code execution.
Returns:
str: The parsed output. It should be a string representation of Dict[str, Any].
"""
prompt = self.output_extract_template.format(requirement=requirement, output=json.dumps(output, indent=4))
input_message = Message(content={"input": prompt})
response_message = await self.llm_tool(input=input_message)
parsed_output_json_str = response_message.content.get("output", "")
parsed_output_json = json.loads(parsed_output_json_str)
if parsed_output_json["status"] != "success":
raise ValueError(f"Generated output is incorrect: {parsed_output_json['reason']}")
json_str = json.dumps(parsed_output_json["output"], indent=4)
return json_str
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "polymind"
version = "0.0.41" # Update this version before publishing to PyPI
version = "0.0.42" # Update this version before publishing to PyPI
description = "PolyMind is a customizable collaborative multi-agent framework for collective intelligence and distributed problem solving."
authors = ["TechTao"]
license = "MIT License"
Expand Down
Loading

0 comments on commit 19f94ac

Please sign in to comment.