Skip to content

Commit

Permalink
Refactor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
badgeir committed Aug 22, 2024
1 parent ae8fb96 commit f24f5fb
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 55 deletions.
55 changes: 26 additions & 29 deletions tests/test_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ async def test_gather_basic() -> None:
model="gpt-4o-mini",
)

on_message_called_with = []
on_message_called_with = set()

@agent.on_message
async def on_message(message: str) -> None:
on_message_called_with.append(message)
on_message_called_with.add(message)
pass

with utils.mocked_async_openai_lookup(
Expand All @@ -46,7 +46,7 @@ async def on_message(message: str) -> None:
]
assert response.messages == [f"A{i}"]

assert sorted(on_message_called_with) == sorted([f"A{i}" for i in range(100)])
assert on_message_called_with == {f"A{i}" for i in range(100)}


async def test_gather_tools() -> None:
Expand All @@ -62,26 +62,29 @@ async def test_gather_tools() -> None:
class User:
id: int

add_called_with = []
mul_called_with = []
def __hash__(self) -> int:
return hash(self.id)

add_called_with = set()
mul_called_with = set()

@agent.tool()
def add(num1: int, num2: int, _context: User) -> str:
add_called_with.append((num1, num2, User(id=_context.id)))
add_called_with.add((num1, num2, User(id=_context.id)))
return f"add: {num1 + num2}"

@agent.tool()
async def multiply(num1: int, num2: int) -> str:
mul_called_with.append((num1, num2))
mul_called_with.add((num1, num2))
return f"mul: {num1 * num2}"

on_message_async_called_with = []
on_message_async_called_with = set()
inspect_prompt_async_called_with = []
inspect_output_async_called_with = []

@agent.on_message
async def on_message(message: str, _context: User) -> None:
on_message_async_called_with.append((message, User(id=_context.id)))
on_message_async_called_with.add((message, User(id=_context.id)))

@agent.inspect_prompt
async def inspect_prompt_async(prompt: list[Message], _context: User) -> None:
Expand All @@ -91,13 +94,13 @@ async def inspect_prompt_async(prompt: list[Message], _context: User) -> None:
async def inspect_output_async(message: Message, _context: User) -> None:
inspect_output_async_called_with.append((message, User(id=_context.id)))

on_message_sync_called_with = []
on_message_sync_called_with = set()
inspect_prompt_sync_called_with = []
inspect_output_sync_called_with = []

@agent.on_message
def on_message_sync(message: str, _context: User) -> None:
on_message_sync_called_with.append((message, User(id=_context.id)))
on_message_sync_called_with.add((message, User(id=_context.id)))

@agent.inspect_prompt
def inspect_prompt_sync(prompt: list[Message], _context: User) -> None:
Expand Down Expand Up @@ -188,26 +191,20 @@ def inspect_output_sync(message: Message, _context: User) -> None:
f"Answer: {i + i} and {i * i}",
]

assert sorted(add_called_with) == sorted(
[(i, i, User(id=i)) for i in range(batch_size)]
)
assert sorted(on_message_async_called_with) == sorted(
[
(f"Calculating {i} + {i} and {i} * {i}...", User(id=i))
for i in range(batch_size)
]
+ [(f"Answer: {i + i} and {i * i}", User(id=i)) for i in range(batch_size)]
)
assert sorted(mul_called_with) == sorted([(i, i) for i in range(batch_size)])
assert add_called_with == {(i, i, User(id=i)) for i in range(batch_size)}
assert on_message_async_called_with == {
(f"Calculating {i} + {i} and {i} * {i}...", User(id=i))
for i in range(batch_size)
} | {(f"Answer: {i + i} and {i * i}", User(id=i)) for i in range(batch_size)}

assert mul_called_with == {(i, i) for i in range(batch_size)}
assert len(inspect_prompt_async_called_with) == batch_size * 2
assert len(inspect_output_async_called_with) == batch_size * 2

assert sorted(on_message_sync_called_with) == sorted(
[
(f"Calculating {i} + {i} and {i} * {i}...", User(id=i))
for i in range(batch_size)
]
+ [(f"Answer: {i + i} and {i * i}", User(id=i)) for i in range(batch_size)]
)
assert on_message_sync_called_with == {
(f"Calculating {i} + {i} and {i} * {i}...", User(id=i))
for i in range(batch_size)
} | {(f"Answer: {i + i} and {i * i}", User(id=i)) for i in range(batch_size)}

assert len(inspect_prompt_sync_called_with) == batch_size * 2
assert len(inspect_output_sync_called_with) == batch_size * 2
49 changes: 23 additions & 26 deletions tests/test_structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,29 @@ class OutputFormat(BaseModel):
class User:
id: int

add_called_with = []
mul_called_with = []
def __hash__(self):
return hash(self.id)

add_called_with = set()
mul_called_with = set()

@agent.tool()
def add(num1: int, num2: int, _context: User) -> str:
add_called_with.append((num1, num2, User(id=_context.id)))
add_called_with.add((num1, num2, User(id=_context.id)))
return f"add: {num1 + num2}"

@agent.tool()
async def multiply(num1: int, num2: int) -> str:
mul_called_with.append((num1, num2))
mul_called_with.add((num1, num2))
return f"mul: {num1 * num2}"

on_message_async_called_with = []
on_message_async_called_with = set()
inspect_prompt_async_called_with = []
inspect_output_async_called_with = []

@agent.on_message
async def on_message(message: OutputFormat, _context: User) -> None:
on_message_async_called_with.append((message.message, User(id=_context.id)))
on_message_async_called_with.add((message.message, User(id=_context.id)))

@agent.inspect_prompt
async def inspect_prompt_async(prompt: list[Message], _context: User) -> None:
Expand All @@ -65,13 +68,13 @@ async def inspect_prompt_async(prompt: list[Message], _context: User) -> None:
async def inspect_output_async(message: Message, _context: User) -> None:
inspect_output_async_called_with.append((message, User(id=_context.id)))

on_message_sync_called_with = []
on_message_sync_called_with = set()
inspect_prompt_sync_called_with = []
inspect_output_sync_called_with = []

@agent.on_message
def on_message_sync(message: OutputFormat, _context: User) -> None:
on_message_sync_called_with.append((message.message, User(id=_context.id)))
on_message_sync_called_with.add((message.message, User(id=_context.id)))

@agent.inspect_prompt
def inspect_prompt_sync(prompt: list[Message], _context: User) -> None:
Expand Down Expand Up @@ -186,26 +189,20 @@ def inspect_output_sync(message: Message, _context: User) -> None:
OutputFormat(message=f"Answer: {i + i} and {i * i}", i=i),
]

assert sorted(add_called_with) == sorted(
[(i, i, User(id=i)) for i in range(batch_size)]
)
assert sorted(on_message_async_called_with) == sorted(
[
(f"Calculating {i} + {i} and {i} * {i}...", User(id=i))
for i in range(batch_size)
]
+ [(f"Answer: {i + i} and {i * i}", User(id=i)) for i in range(batch_size)]
)
assert sorted(mul_called_with) == sorted([(i, i) for i in range(batch_size)])
assert add_called_with == {(i, i, User(id=i)) for i in range(batch_size)}
assert on_message_async_called_with == {
(f"Calculating {i} + {i} and {i} * {i}...", User(id=i))
for i in range(batch_size)
} | {(f"Answer: {i + i} and {i * i}", User(id=i)) for i in range(batch_size)}

assert mul_called_with == {(i, i) for i in range(batch_size)}
assert len(inspect_prompt_async_called_with) == batch_size * 2
assert len(inspect_output_async_called_with) == batch_size * 2

assert sorted(on_message_sync_called_with) == sorted(
[
(f"Calculating {i} + {i} and {i} * {i}...", User(id=i))
for i in range(batch_size)
]
+ [(f"Answer: {i + i} and {i * i}", User(id=i)) for i in range(batch_size)]
)
assert on_message_sync_called_with == {
(f"Calculating {i} + {i} and {i} * {i}...", User(id=i))
for i in range(batch_size)
} | {(f"Answer: {i + i} and {i * i}", User(id=i)) for i in range(batch_size)}

assert len(inspect_prompt_sync_called_with) == batch_size * 2
assert len(inspect_output_sync_called_with) == batch_size * 2

0 comments on commit f24f5fb

Please sign in to comment.