Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

literal workaround #777

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
)
from .translate import param_to_pydantic
from .utils import (
clean_sql, describe_data, gather_table_sources, get_data, get_pipeline,
get_schema, report_error, retry_llm_output,
clean_sql, create_aliases, describe_data, gather_table_sources, get_data,
get_pipeline, get_schema, report_error, retry_llm_output,
)
from .views import AnalysisOutput, LumenOutput, SQLOutput

Expand Down Expand Up @@ -465,6 +465,7 @@ async def _select_relevant_table(self, messages: list[Message]) -> tuple[str, Ba
available_sources = self._memory["available_sources"]
tables_to_source, tables_schema_str = await gather_table_sources(available_sources)
tables = tuple(tables_to_source)

if messages and messages[-1]["content"].startswith("Show the table: '"):
# Handle the case where explicitly requested a table
table = messages[-1]["content"].replace("Show the table: '", "")[:-1]
Expand All @@ -479,15 +480,16 @@ async def _select_relevant_table(self, messages: list[Message]) -> tuple[str, Ba
elif len(tables) > FUZZY_TABLE_LENGTH:
tables = await self._get_closest_tables(messages, tables)
system_prompt = self._render_prompt("select_table", tables_schema_str=tables_schema_str)
table_model = make_table_model(tables)
tables_aliases = create_aliases(tables)
table_model = make_table_model(tuple(tables_aliases))
result = await self.llm.invoke(
messages,
system=system_prompt,
response_model=table_model,
allow_partial=False,
max_retries=3,
)
table = result.relevant_table
table = tables_aliases[result.relevant_table]
step.stream(f"{result.chain_of_thought}\n\nSelected table: {table}")
else:
table = tables[0]
Expand Down
15 changes: 9 additions & 6 deletions lumen/ai/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .logs import ChatLogs
from .memory import _Memory, memory
from .models import Validity, make_agent_model, make_plan_models
from .utils import get_schema, retry_llm_output
from .utils import create_aliases, get_schema, retry_llm_output

if TYPE_CHECKING:
from panel.chat.step import ChatStep
Expand Down Expand Up @@ -577,7 +577,8 @@ async def _make_plan(
reason_model: type[BaseModel],
plan_model: type[BaseModel],
step: ChatStep,
schemas: dict[str, dict] | None = None
schemas: dict[str, dict] | None = None,
tables_aliases: dict[str, str] | None = None
) -> BaseModel:
user_msg = messages[-1]
info = ''
Expand All @@ -601,13 +602,14 @@ async def _make_plan(
messages=messages,
system=system,
response_model=reason_model,
allow_partial=False,
max_retries=3,
)
if reasoning.chain_of_thought: # do not replace with empty string
step.stream(reasoning.chain_of_thought, replace=True)
requested = [
t for t in getattr(reasoning, 'tables', [])
if t and t not in provided
tables_aliases[t] for t in getattr(reasoning, 'tables', [])
if t and tables_aliases[t] not in provided
]
new_msg = dict(
role=user_msg['role'],
Expand Down Expand Up @@ -662,7 +664,8 @@ async def _compute_execution_graph(self, messages: list[Message], agents: dict[s
for table in src.get_tables():
tables[table] = src

reason_model, plan_model = make_plan_models(agent_names, list(tables))
tables_aliases = create_aliases(tables)
reason_model, plan_model = make_plan_models(agent_names, tuple(tables_aliases))
planned = False
unmet_dependencies = set()
schemas = {}
Expand All @@ -671,7 +674,7 @@ async def _compute_execution_graph(self, messages: list[Message], agents: dict[s
while not planned:
try:
plan = await self._make_plan(
messages, agents, tables, unmet_dependencies, reason_model, plan_model, istep, schemas
messages, agents, tables, unmet_dependencies, reason_model, plan_model, istep, schemas, tables_aliases
)
except Exception as e:
if self.interface.callback_exception not in ('raise', 'verbose'):
Expand Down
2 changes: 1 addition & 1 deletion lumen/ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def make_plan_models(agent_names: list[str], tables: list[str]):
extras['tables'] = (
list[Literal[tuple(tables)]],
FieldInfo(
description="A list of tables to load into memory before coming up with a plan. NOTE: Simple queries asking to list the tables/datasets do not require loading the tables. Table names MUST match verbatim including the quotations, apostrophes, periods, or lack thereof."
description="A list of tables to load into memory before coming up with a plan. NOTE: Simple queries asking to list the tables/datasets do not require loading the tables. Table names MUST match verbatim."
)
)
reasoning = create_model(
Expand Down
10 changes: 10 additions & 0 deletions lumen/ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import inspect
import math
import re
import time

from functools import wraps
Expand Down Expand Up @@ -318,3 +319,12 @@ async def gather_table_sources(available_sources: list[Source]) -> tuple[dict[st
else:
tables_schema_str += f"### {table}\n"
return tables_to_source, tables_schema_str


def create_aliases(names: list[str]) -> dict[str, str]:
"""
Replaces non-alphanumeric characters with underscore `_`.
"""
return {
re.sub(r'[^a-zA-Z0-9_]', '_', name): name for name in names
}