Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow using kwargs in CLI #787

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions lumen/ai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ class OpenAI(Llm):

api_key = param.String(doc="The OpenAI API key.")

base_url = param.String(doc="The OpenAI base.")
provider_endpoint = param.String(doc="The OpenAI base.")

mode = param.Selector(default=Mode.TOOLS)

Expand All @@ -303,8 +303,8 @@ def get_client(self, model_key: MODEL_TYPE, response_model: BaseModel | None = N

model_kwargs = self._get_model_kwargs(model_key)
model = model_kwargs.pop("model")
if self.base_url:
model_kwargs["base_url"] = self.base_url
if self.provider_endpoint:
model_kwargs["base_url"] = self.provider_endpoint
if self.api_key:
model_kwargs["api_key"] = self.api_key
if self.organization:
Expand Down Expand Up @@ -339,7 +339,7 @@ class AzureOpenAI(Llm):

api_version = param.String(doc="The Azure AI Studio API version.")

azure_endpoint = param.String(doc="The Azure AI Studio endpoint.")
provider_endpoint = param.String(doc="The Azure AI Studio endpoint.")

mode = param.Selector(default=Mode.TOOLS)

Expand All @@ -358,8 +358,8 @@ def get_client(self, model_key: str, response_model: BaseModel | None = None, **
model_kwargs["api_version"] = self.api_version
if self.api_key:
model_kwargs["api_key"] = self.api_key
if self.azure_endpoint:
model_kwargs["azure_endpoint"] = self.azure_endpoint
if self.endpoint:
model_kwargs["azure_endpoint"] = self.provider_endpoint
llm = openai.AsyncAzureOpenAI(**model_kwargs)

if self.interceptor:
Expand Down Expand Up @@ -461,7 +461,7 @@ class AzureMistralAI(MistralAI):

api_key = param.String(default=os.getenv("AZURE_API_KEY"), doc="The Azure API key")

azure_endpoint = param.String(default=os.getenv("AZURE_ENDPOINT"), doc="The Azure endpoint to invoke.")
provider_endpoint = param.String(default=os.getenv("AZURE_ENDPOINT"), doc="The Azure endpoint to invoke.")

model_kwargs = param.Dict(default={
"default": {"model": "azureai"},
Expand All @@ -476,7 +476,7 @@ async def llm_chat_non_stream_async(*args, **kwargs):

model_kwargs = self._get_model_kwargs(model_key)
model_kwargs["api_key"] = self.api_key
model_kwargs["azure_endpoint"] = self.azure_endpoint
model_kwargs["azure_endpoint"] = self.provider_endpoint
model = model_kwargs.pop("model")
llm = MistralAzure(**model_kwargs)

Expand Down
6 changes: 3 additions & 3 deletions lumen/ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
from panel.chat.step import ChatStep


def render_template(template_path: Path, prompt_overrides: dict, **context):
def render_template(template_path: Path, prompt_overrides: dict, relative_to: Path = PROMPTS_DIR, **context):
try:
template_path = template_path.relative_to(PROMPTS_DIR).as_posix()
template_path = template_path.relative_to(relative_to).as_posix()
except ValueError:
pass
fs_loader = FileSystemLoader(PROMPTS_DIR)
fs_loader = FileSystemLoader(relative_to)

if prompt_overrides:
# Dynamically create block definitions based on dictionary keys with proper escaping
Expand Down
245 changes: 199 additions & 46 deletions lumen/command/ai.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,238 @@
from __future__ import annotations

import argparse
import inspect
import json
import os
import sys

import bokeh.command.util # type: ignore
from textwrap import dedent

from bokeh.application.handlers.code import CodeHandler # type: ignore
from bokeh.command.util import die # type: ignore
import bokeh.command.subcommands.serve

from bokeh.application.handlers.code import CodeHandler
from bokeh.command.util import die
from panel.command import Serve, transform_cmds
from panel.io.application import Application

SOURCE_CODE = """
import lumen.ai as lmai
lmai.ExplorerUI([{tables}]).servable()"""

VALID_EXTENSIONS = ['.parq', '.parquet', '.csv', '.json']
from lumen.ai.config import THIS_DIR

from ..ai import agents as lumen_agents # Aliased here
from ..ai.utils import render_template

VALID_EXTENSIONS = [".parq", ".parquet", ".csv", ".json"]
CMD_DIR = THIS_DIR / ".." / "command"


class LLMConfig:
"""Configuration handler for LLM providers"""

PROVIDER_ENV_VARS = {
"openai": "OPENAI_API_KEY",
"anthropic": "ANTHROPIC_API_KEY",
"mistral": "MISTRAL_API_KEY",
"azure-mistral": "AZURE_API_KEY",
"azure-openai": "AZURE_API_KEY",
}

@classmethod
def detect_provider(cls) -> str | None:
"""Detect available LLM provider based on environment variables"""
for provider, env_var in cls.PROVIDER_ENV_VARS.items():
if env_var and os.environ.get(env_var):
return provider
return None

@classmethod
def get_api_key(cls, provider: str) -> str | None:
"""Get API key for specified provider"""
env_var = cls.PROVIDER_ENV_VARS.get(provider)
if env_var:
return os.environ.get(env_var)
return None


class LumenAIServe(Serve):
"""Extended Serve command that handles both Panel/Bokeh and Lumen AI arguments"""

def __init__(self, parser: argparse.ArgumentParser) -> None:
super().__init__(parser=parser)
self.add_lumen_arguments(parser)

def add_lumen_arguments(self, parser: argparse.ArgumentParser) -> None:
"""Add Lumen AI specific arguments to the parser"""
group = parser.add_argument_group("Lumen AI Configuration")
group.add_argument(
"--provider",
choices=["openai", "azure-openai", "anthropic", "mistral", "azure-mistral", "llama"],
help="LLM provider (auto-detected from environment variables if not specified)",
)
group.add_argument("--api-key", help="API key for the LLM provider")
group.add_argument("--temperature", type=float, help="Temperature for the LLM")
group.add_argument(
"--provider-endpoint", help="Custom endpoint for the LLM provider"
)
group.add_argument("--agents", nargs="+", help="Additional agents to include")
group.add_argument(
"--model-kwargs",
type=str,
help="JSON string of model keyword arguments for the LLM. Example: --model-kwargs '{\"repo\": \"abcdef\"}'",
)

def invoke(self, args: argparse.Namespace) -> bool:
"""Override invoke to handle both sets of arguments"""
provider = args.provider
api_key = args.api_key
temperature = args.temperature
provider_endpoint = args.provider_endpoint
agents = args.agents

if not provider:
provider = LLMConfig.detect_provider()

if not api_key and provider:
api_key = LLMConfig.get_api_key(provider)

model_kwargs = None
if args.model_kwargs:
try:
model_kwargs = json.loads(args.model_kwargs)
except json.JSONDecodeError as e:
die(f"Invalid JSON format for --model-kwargs: {e}\n"
f"Ensure the argument is properly escaped. Example: --model-kwargs '{{\"key\": \"value\"}}'")

agent_classes = [
(name, cls) for name, cls in inspect.getmembers(lumen_agents, inspect.isclass)
if issubclass(cls, lumen_agents.Agent) and cls is not lumen_agents.Agent
]
agent_class_names = {name.lower(): name for name, cls in agent_classes}

if agents:
# Adjust agent names to match the class names, case-insensitively
agents = [
agent_class_names.get(agent.lower()) or
agent_class_names.get(f"{agent.lower()}agent") or
agent
for agent in agents
]

def build_single_handler_applications(
paths: list[str], argvs: dict[str, list[str]] | None = None
) -> dict[str, Application]:
handler = AIHandler(
paths,
provider=provider,
api_key=api_key,
temperature=temperature,
provider_endpoint=provider_endpoint,
agents=agents,
model_kwargs=model_kwargs,
)
if handler.failed:
raise RuntimeError(
f"Error loading {paths}:\n\n{handler.error}\n{handler.error_detail}"
)
return {"/lumen_ai": Application(handler)}

bokeh.command.subcommands.serve.build_single_handler_applications = (
build_single_handler_applications
)

return super().invoke(args)


class AIHandler(CodeHandler):
''' Modify Bokeh documents by using Lumen AI on a dataset.

'''

def __init__(self, tables: list[str], **kwargs) -> None:
tables = list({table for table in tables if any(table.endswith(ext) for ext in VALID_EXTENSIONS)})
source = SOURCE_CODE.format(tables=','.join([repr(t) for t in tables]))
super().__init__(filename='lumen_ai.py', source=source, **kwargs)


def build_single_handler_applications(paths: list[str], argvs: dict[str, list[str]] | None = None) -> dict[str, Application]:
''' Custom to allow for standalone `lumen-ai` command to launch without data'''
handler = AIHandler(paths)
if handler.failed:
raise RuntimeError(f"Error loading {paths}:\n\n{handler.error}\n{handler.error_detail} ")
return {'/lumen_ai': Application(handler)}

bokeh.command.subcommands.serve.build_single_handler_applications = build_single_handler_applications
"""Handler for Lumen AI applications"""

def __init__(
self,
tables: list[str],
provider: str | None = None,
api_key: str | None = None,
temperature: float | None = None,
provider_endpoint: str | None = None,
agents: list[str] | None = None,
model_kwargs: dict | None = None,
**kwargs,
) -> None:
source = self._build_source_code(
tables=tables,
provider=provider,
api_key=api_key,
temperature=temperature,
provider_endpoint=provider_endpoint,
agents=agents,
model_kwargs=model_kwargs,
)
super().__init__(filename="lumen_ai.py", source=source, **kwargs)

def _build_source_code(self, tables: list[str], **config) -> str:
"""Build source code with configuration"""
tables = [
table
for table in tables
if any(table.endswith(ext) for ext in VALID_EXTENSIONS)
]

context = {
"tables": [repr(t) for t in tables],
"provider": config.get("provider"),
"api_key": config.get("api_key"),
"provider_endpoint": config.get("provider_endpoint"),
"agents": config.get("agents"),
"temperature": config.get("temperature"),
"model_kwargs": config.get('model_kwargs') or {},
}
context = {k: v for k, v in context.items() if v is not None}

source = render_template(
CMD_DIR / "app.py.jinja2", {}, relative_to=CMD_DIR, **context
).replace("\n\n", "\n").strip()
return source


def main(args=None):
parser = argparse.ArgumentParser(
prog="lumen-ai",
description="""
Lumen AI - Launch Lumen AI applications easily.\n\n To start the application without any
data, simply run 'lumen-ai' with no additional arguments. You can upload data through
the chat interface afterwards.
""",
epilog="See '<command> --help' to read about a specific subcommand."
description=dedent(
"""
Lumen AI - Launch Lumen AI applications with customizable LLM configuration.
To start the application without any data, simply run 'lumen-ai' with no additional arguments.
"""
),
epilog="See '<command> --help' to read about a specific subcommand.",
)

parser.add_argument(
'-v', '--version', action='version', version='Lumen AI 1.0.0'
)
parser.add_argument("-v", "--version", action="version", version="Lumen AI 1.0.0")

subs = parser.add_subparsers(help="Sub-commands", dest="command")

subs = parser.add_subparsers(help="Sub-commands")
serve_parser = subs.add_parser(
Serve.name, help="""
Run a bokeh server to serve the Lumen AI application.
This command should be followed by dataset paths or directories
to add to the chat memory, which can be a .parq, .parquet, .csv,
or .json file. run `lumen-ai serve --help` for more options)
""")
serve_command = Serve(parser=serve_parser)
Serve.name, help="Run a bokeh server to serve the Lumen AI application."
)
serve_command = LumenAIServe(parser=serve_parser)
serve_parser.set_defaults(invoke=serve_command.invoke)

if len(sys.argv) > 1 and sys.argv[1] in ('--help', '-h'):
if len(sys.argv) > 1 and sys.argv[1] in ("--help", "-h"):
args = parser.parse_args(sys.argv[1:])
args.invoke(args)
sys.exit()

if len(sys.argv) == 1:
# If no command is specified, start the server with an empty application
sys.argv.extend(['serve', 'no_data'])
sys.argv.extend(["serve", "no_data"])

sys.argv = transform_cmds(sys.argv)
args = parser.parse_args(sys.argv[1:])

if not hasattr(args, "invoke"):
parser.print_help()
sys.exit(1)

try:
ret = args.invoke(args)
except Exception as e:
import traceback

traceback.print_exc()
die("ERROR: " + str(e))

Expand Down
Loading