From f24f5fbd687ffe56dbd80e69450a23a0b34210a3 Mon Sep 17 00:00:00 2001 From: Peter Leupi Date: Fri, 23 Aug 2024 01:49:36 +0200 Subject: [PATCH] Refactor tests --- tests/test_gather.py | 55 ++++++++++++++++----------------- tests/test_structured_output.py | 49 ++++++++++++++--------------- 2 files changed, 49 insertions(+), 55 deletions(-) diff --git a/tests/test_gather.py b/tests/test_gather.py index 15bf0a4..1e37923 100644 --- a/tests/test_gather.py +++ b/tests/test_gather.py @@ -24,11 +24,11 @@ async def test_gather_basic() -> None: model="gpt-4o-mini", ) - on_message_called_with = [] + on_message_called_with = set() @agent.on_message async def on_message(message: str) -> None: - on_message_called_with.append(message) + on_message_called_with.add(message) pass with utils.mocked_async_openai_lookup( @@ -46,7 +46,7 @@ async def on_message(message: str) -> None: ] assert response.messages == [f"A{i}"] - assert sorted(on_message_called_with) == sorted([f"A{i}" for i in range(100)]) + assert on_message_called_with == {f"A{i}" for i in range(100)} async def test_gather_tools() -> None: @@ -62,26 +62,29 @@ async def test_gather_tools() -> None: class User: id: int - add_called_with = [] - mul_called_with = [] + def __hash__(self) -> int: + return hash(self.id) + + add_called_with = set() + mul_called_with = set() @agent.tool() def add(num1: int, num2: int, _context: User) -> str: - add_called_with.append((num1, num2, User(id=_context.id))) + add_called_with.add((num1, num2, User(id=_context.id))) return f"add: {num1 + num2}" @agent.tool() async def multiply(num1: int, num2: int) -> str: - mul_called_with.append((num1, num2)) + mul_called_with.add((num1, num2)) return f"mul: {num1 * num2}" - on_message_async_called_with = [] + on_message_async_called_with = set() inspect_prompt_async_called_with = [] inspect_output_async_called_with = [] @agent.on_message async def on_message(message: str, _context: User) -> None: - on_message_async_called_with.append((message, User(id=_context.id))) + on_message_async_called_with.add((message, User(id=_context.id))) @agent.inspect_prompt async def inspect_prompt_async(prompt: list[Message], _context: User) -> None: @@ -91,13 +94,13 @@ async def inspect_prompt_async(prompt: list[Message], _context: User) -> None: async def inspect_output_async(message: Message, _context: User) -> None: inspect_output_async_called_with.append((message, User(id=_context.id))) - on_message_sync_called_with = [] + on_message_sync_called_with = set() inspect_prompt_sync_called_with = [] inspect_output_sync_called_with = [] @agent.on_message def on_message_sync(message: str, _context: User) -> None: - on_message_sync_called_with.append((message, User(id=_context.id))) + on_message_sync_called_with.add((message, User(id=_context.id))) @agent.inspect_prompt def inspect_prompt_sync(prompt: list[Message], _context: User) -> None: @@ -188,26 +191,20 @@ def inspect_output_sync(message: Message, _context: User) -> None: f"Answer: {i + i} and {i * i}", ] - assert sorted(add_called_with) == sorted( - [(i, i, User(id=i)) for i in range(batch_size)] - ) - assert sorted(on_message_async_called_with) == sorted( - [ - (f"Calculating {i} + {i} and {i} * {i}...", User(id=i)) - for i in range(batch_size) - ] - + [(f"Answer: {i + i} and {i * i}", User(id=i)) for i in range(batch_size)] - ) - assert sorted(mul_called_with) == sorted([(i, i) for i in range(batch_size)]) + assert add_called_with == {(i, i, User(id=i)) for i in range(batch_size)} + assert on_message_async_called_with == { + (f"Calculating {i} + {i} and {i} * {i}...", User(id=i)) + for i in range(batch_size) + } | {(f"Answer: {i + i} and {i * i}", User(id=i)) for i in range(batch_size)} + + assert mul_called_with == {(i, i) for i in range(batch_size)} assert len(inspect_prompt_async_called_with) == batch_size * 2 assert len(inspect_output_async_called_with) == batch_size * 2 - assert sorted(on_message_sync_called_with) == sorted( - [ - (f"Calculating {i} + {i} and {i} * {i}...", User(id=i)) - for i in range(batch_size) - ] - + [(f"Answer: {i + i} and {i * i}", User(id=i)) for i in range(batch_size)] - ) + assert on_message_sync_called_with == { + (f"Calculating {i} + {i} and {i} * {i}...", User(id=i)) + for i in range(batch_size) + } | {(f"Answer: {i + i} and {i * i}", User(id=i)) for i in range(batch_size)} + assert len(inspect_prompt_sync_called_with) == batch_size * 2 assert len(inspect_output_sync_called_with) == batch_size * 2 diff --git a/tests/test_structured_output.py b/tests/test_structured_output.py index f0292e4..dc2a33b 100644 --- a/tests/test_structured_output.py +++ b/tests/test_structured_output.py @@ -36,26 +36,29 @@ class OutputFormat(BaseModel): class User: id: int - add_called_with = [] - mul_called_with = [] + def __hash__(self): + return hash(self.id) + + add_called_with = set() + mul_called_with = set() @agent.tool() def add(num1: int, num2: int, _context: User) -> str: - add_called_with.append((num1, num2, User(id=_context.id))) + add_called_with.add((num1, num2, User(id=_context.id))) return f"add: {num1 + num2}" @agent.tool() async def multiply(num1: int, num2: int) -> str: - mul_called_with.append((num1, num2)) + mul_called_with.add((num1, num2)) return f"mul: {num1 * num2}" - on_message_async_called_with = [] + on_message_async_called_with = set() inspect_prompt_async_called_with = [] inspect_output_async_called_with = [] @agent.on_message async def on_message(message: OutputFormat, _context: User) -> None: - on_message_async_called_with.append((message.message, User(id=_context.id))) + on_message_async_called_with.add((message.message, User(id=_context.id))) @agent.inspect_prompt async def inspect_prompt_async(prompt: list[Message], _context: User) -> None: @@ -65,13 +68,13 @@ async def inspect_prompt_async(prompt: list[Message], _context: User) -> None: async def inspect_output_async(message: Message, _context: User) -> None: inspect_output_async_called_with.append((message, User(id=_context.id))) - on_message_sync_called_with = [] + on_message_sync_called_with = set() inspect_prompt_sync_called_with = [] inspect_output_sync_called_with = [] @agent.on_message def on_message_sync(message: OutputFormat, _context: User) -> None: - on_message_sync_called_with.append((message.message, User(id=_context.id))) + on_message_sync_called_with.add((message.message, User(id=_context.id))) @agent.inspect_prompt def inspect_prompt_sync(prompt: list[Message], _context: User) -> None: @@ -186,26 +189,20 @@ def inspect_output_sync(message: Message, _context: User) -> None: OutputFormat(message=f"Answer: {i + i} and {i * i}", i=i), ] - assert sorted(add_called_with) == sorted( - [(i, i, User(id=i)) for i in range(batch_size)] - ) - assert sorted(on_message_async_called_with) == sorted( - [ - (f"Calculating {i} + {i} and {i} * {i}...", User(id=i)) - for i in range(batch_size) - ] - + [(f"Answer: {i + i} and {i * i}", User(id=i)) for i in range(batch_size)] - ) - assert sorted(mul_called_with) == sorted([(i, i) for i in range(batch_size)]) + assert add_called_with == {(i, i, User(id=i)) for i in range(batch_size)} + assert on_message_async_called_with == { + (f"Calculating {i} + {i} and {i} * {i}...", User(id=i)) + for i in range(batch_size) + } | {(f"Answer: {i + i} and {i * i}", User(id=i)) for i in range(batch_size)} + + assert mul_called_with == {(i, i) for i in range(batch_size)} assert len(inspect_prompt_async_called_with) == batch_size * 2 assert len(inspect_output_async_called_with) == batch_size * 2 - assert sorted(on_message_sync_called_with) == sorted( - [ - (f"Calculating {i} + {i} and {i} * {i}...", User(id=i)) - for i in range(batch_size) - ] - + [(f"Answer: {i + i} and {i * i}", User(id=i)) for i in range(batch_size)] - ) + assert on_message_sync_called_with == { + (f"Calculating {i} + {i} and {i} * {i}...", User(id=i)) + for i in range(batch_size) + } | {(f"Answer: {i + i} and {i * i}", User(id=i)) for i in range(batch_size)} + assert len(inspect_prompt_sync_called_with) == batch_size * 2 assert len(inspect_output_sync_called_with) == batch_size * 2