Skip to content

Commit

Permalink
Merge pull request #46 from badgeir/develop
Browse files Browse the repository at this point in the history
Remove use of deprecated pydantic functions, and improve tests
  • Loading branch information
badgeir authored Nov 2, 2024
2 parents cc75151 + b1b07eb commit 5522502
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12", "3.13"]
os: [ubuntu-latest]
name: Python ${{ matrix.python-version }}

Expand Down
18 changes: 11 additions & 7 deletions llmio/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,17 @@ async def execute(
kwargs[_CONTEXT_ARG_NAME] = context

if iscoroutinefunction(self.function):
result = await self.function(**params.dict(), **kwargs)
result = await self.function(**params.model_dump(), **kwargs)
else:
result = self.function(**params.dict(), **kwargs)
result = self.function(**params.model_dump(), **kwargs)

return str(result)

def parse_args(self, args: str) -> pydantic.BaseModel:
"""
Parses the arguments received from the OpenAI API using the Pydantic model.
"""
return self.params.parse_raw(args)
return self.params.model_validate_json(args)

@property
def function_definition(self) -> T.FunctionDefinition:
Expand All @@ -89,12 +89,15 @@ def function_definition(self) -> T.FunctionDefinition:

if self.strict:
schema["additionalProperties"] = False
return T.FunctionDefinition(
definition = T.FunctionDefinition(
name=self.name,
description=self.description,
parameters=schema,
strict=self.strict,
)
if self.strict:
definition["strict"] = True

return definition


_ResponseFormatT = TypeVar("_ResponseFormatT", bound=pydantic.BaseModel)
Expand Down Expand Up @@ -533,7 +536,8 @@ async def speak(
assert self._response_format is not None
response = await self._speak(message, history=history, _context=_context)
parsed_messages = [
self._response_format.parse_raw(message) for message in response.messages
self._response_format.model_validate_json(message)
for message in response.messages
]
return StructuredAgentResponse(
messages=parsed_messages,
Expand All @@ -554,4 +558,4 @@ def response_format(self) -> dict[str, Any]:
}

def _parse_message_inspector_content(self, message: str) -> _ResponseFormatT:
return self._response_format.parse_raw(message)
return self._response_format.model_validate_json(message)
134 changes: 133 additions & 1 deletion tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from typing import Iterable
from unittest.mock import call

import pytest

Expand Down Expand Up @@ -96,8 +97,139 @@ async def inspect_prompt(prompt: list[T.Message]) -> None:
content="The answer is 60",
),
]
with mocked_async_openai_replies(mocks):
with mocked_async_openai_replies(mocks) as mocked:
response = await agent.speak("What is (10 + 20) * 2?")
mocked.assert_has_calls(
[
call(
model="gpt-4o-mini",
messages=[
{
"role": "system",
"content": "You are a calculator.\n\nvalue1 value2",
},
{
"role": "user",
"content": "What is (10 + 20) * 2?",
},
],
tools=agent._tool_definitions,
response_format=None,
),
call(
model="gpt-4o-mini",
messages=[
{
"role": "system",
"content": "You are a calculator.\n\nvalue1 value2",
},
{
"role": "user",
"content": "What is (10 + 20) * 2?",
},
{
"role": "assistant",
"content": "Ok! I'll calculate the answer of (10 + 20) * 2",
"tool_calls": [
{
"id": "add_1",
"type": "function",
"function": {
"name": "add",
"arguments": '{"num1": 10, "num2": 20}',
},
}
],
},
{"role": "tool", "content": "30.0", "tool_call_id": "add_1"},
],
tools=agent._tool_definitions,
response_format=None,
),
call(
model="gpt-4o-mini",
messages=[
{
"role": "system",
"content": "You are a calculator.\n\nvalue1 value2",
},
{
"role": "user",
"content": "What is (10 + 20) * 2?",
},
{
"role": "assistant",
"content": "Ok! I'll calculate the answer of (10 + 20) * 2",
"tool_calls": [
{
"id": "add_1",
"type": "function",
"function": {
"name": "add",
"arguments": '{"num1": 10, "num2": 20}',
},
}
],
},
{"role": "tool", "content": "30.0", "tool_call_id": "add_1"},
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "multiply_1",
"type": "function",
"function": {
"name": "multiply",
"arguments": '{"num1": 30, "num2": 2}',
},
}
],
},
{
"role": "tool",
"content": "60.0",
"tool_call_id": "multiply_1",
},
],
tools=[
{
"function": {
"name": "add",
"description": "",
"parameters": {
"properties": {
"num1": {"type": "number"},
"num2": {"type": "number"},
},
"required": ["num1", "num2"],
"type": "object",
},
},
"type": "function",
},
{
"function": {
"name": "multiply",
"description": "",
"parameters": {
"properties": {
"num1": {"type": "number"},
"num2": {"type": "number"},
},
"required": ["num1", "num2"],
"type": "object",
"additionalProperties": False,
},
"strict": True,
},
"type": "function",
},
],
response_format=None,
),
]
)

assert response.messages == [mocks[0].content, mocks[2].content]
assert response.history == [
Expand Down
4 changes: 1 addition & 3 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ async def var2() -> str:
'parameters': {'properties': {'num1': {'type': 'number'},
'num2': {'type': 'number'}},
'required': ['num1', 'num2'],
'type': 'object'},
'strict': False}
'type': 'object'}}
- multiply
Schema:
Expand Down Expand Up @@ -86,7 +85,6 @@ async def var2() -> str:
],
"type": "object",
},
"strict": False,
},
"type": "function",
},
Expand Down
16 changes: 8 additions & 8 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import contextlib
from typing import Any
from unittest.mock import patch
from typing import Any, Iterator
from unittest.mock import patch, MagicMock

from llmio import types as T, models


@contextlib.contextmanager
def mocked_async_openai_replies(
replies: list[models.ChatCompletionMessage],
):
) -> Iterator[MagicMock]:
with patch(
"llmio.clients.BaseClient.get_chat_completion",
side_effect=[
Expand All @@ -17,14 +17,14 @@ def mocked_async_openai_replies(
)
for reply in replies
],
):
yield replies
) as patched:
yield patched


@contextlib.contextmanager
def mocked_async_openai_lookup(
replies: dict[str, models.ChatCompletionMessage],
):
) -> Iterator[MagicMock]:
def mock_function(
model: str,
messages: list[T.Message],
Expand All @@ -46,5 +46,5 @@ def mock_function(
with patch(
"llmio.clients.BaseClient.get_chat_completion",
side_effect=mock_function,
):
yield replies
) as patched:
yield patched

0 comments on commit 5522502

Please sign in to comment.