Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added schemas and base classes for algos #4

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
from pydantic import BaseModel
import abc
from chatsky_llm_autoconfig.graph import BaseGraph
Expand Down Expand Up @@ -35,7 +36,7 @@ class DialogueGenerator(BaseAlgorithm):
def __init__(self):
super().__init__()

def invoke(self, graph: BaseGraph, start_node: int = 1, end_node: int = 0, topic: str = "") -> Dialogue:
def invoke(self, graph: BaseGraph, start_node: int = 1, end_node: int = 0, topic: str = "") -> List[Dialogue]:
raise NotImplementedError


Expand All @@ -57,19 +58,41 @@ def invoke(self, dialogue: Dialogue, topic: str = "") -> Dialogue:
raise NotImplementedError


class GraphAugmentation(BaseAlgorithm):
"""Graph generator that works only with topics."""

def invoke(self, topic: str, Gaph: BaseGraph) -> BaseGraph:
raise NotImplementedError

async def ainvoke(self, topic: str, Gaph: BaseGraph) -> BaseGraph:
raise NotImplementedError


class TopicGraphGenerator(BaseAlgorithm):
"""Graph generator that works only with topics."""

def invoke(self, topic: str) -> BaseGraph:
raise NotImplementedError

async def ainvoke(self, topic: str) -> BaseGraph:
raise NotImplementedError


class GraphGenerator(BaseAlgorithm):
"""
Base class for generating Graph objects.
"""Graph generator that works only with topics."""

This class is used to create a Graph based on a Dialogue, a specified topic, or an existing Graph.
def invoke(self, dialogue: Dialogue) -> BaseGraph:
raise NotImplementedError

:param dialogue: The Dialogue object used for generating the Graph.
:param graph: An existing Graph object to base the generation on (optional).
:param topic: The topic to guide the Graph generation process (optional).
"""
async def ainvoke(self, dialogue: Dialogue) -> BaseGraph:
raise NotImplementedError

def __init__(self):
super().__init__()

def invoke(self, dialogue: Dialogue = None, graph: BaseGraph = None, topic: str = "") -> BaseGraph:
class GraphExtender(BaseAlgorithm):
"""Graph generator that works only with topics."""

def invoke(self, dialogue: Dialogue, graph: BaseGraph) -> BaseGraph:
raise NotImplementedError

async def ainvoke(self, dialogue: Dialogue, graph: BaseGraph) -> BaseGraph:
raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -1,3 +1,101 @@
from chatsky_llm_autoconfig.algorithms.base import DialogAugmentation
from chatsky_llm_autoconfig.graph import BaseGraph
from chatsky_llm_autoconfig.dialogue import Dialogue
from chatsky_llm_autoconfig.schemas import DialogueMessage
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from typing import List

from langchain_openai import ChatOpenAI
from pydantic import BaseModel
import os

augmentation_prompt = PromptTemplate.from_template(
"""
You are tasked with augmenting a dialogue by adding variations to existing utterances while maintaining the original dialogue flow and intent.

THEME: {topic}

INPUT DIALOGUE:
{dialogue}

INSTRUCTIONS:
1. For each message in the dialogue:
- Keep the same structure (participant, source, target if present)
- Create variation of the 'text' field that:
* Express the same meaning/intent
* Use different wording and phrasing
* Match the given theme
* Sound natural and conversational

2. The output must be a list of dictionaries, where each dictionary has:
- 'text': string
- 'participant': either 'user' or 'assistant'

3. Ensure all utterance variations:
- Are appropriate for the theme
- Maintain consistency in tone and style
- Make sense in the conversation flow

Return ONLY a valid JSON array containing the augmented dialogue messages. Each message should be in this exact format:
For assistant messages: {{"text": "utterance text", "participant": "assistant"}}
For user messages: {{"text": "utterance text", "participant": "user"}}

Example format:
[
{{"text": "How may I assist you today?", "participant": "assistant"}},
{{"text": "I need help with a package", "participant": "user"}},
{{"text": "What kind of package is it?", "participant": "assistant"}}
]
"""
)


class DialogueSequence(BaseModel):
result: List[DialogueMessage]


class DialogAugmentation(BaseModel):
"""Base class for augmenting Dialogues."""

def __init__(self, **data):
super().__init__(**data)
self.parser = JsonOutputParser(pydantic_object=DialogueSequence)
self.model = ChatOpenAI(model="gpt-4o-mini", api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL"), temperature=0.7)
self.chain = augmentation_prompt | self.model | self.parser

def invoke(self, *, dialogue: Dialogue, topic: str = "") -> Dialogue:
"""
Augment the input dialogue with variations.

Args:
dialogue: The input Dialogue object to augment
topic: Optional topic to guide the augmentation

Returns:
Dialogue: Augmented dialogue object
"""
# Convert dialogue to string format for prompt
# Предполагая, что у Dialogue есть str представление
dialogue_str = str(dialogue)

# Get augmented messages
result = self.chain.invoke({"topic": topic, "dialogue": dialogue_str})

# Create new Dialogue object with augmented messages
return Dialogue(messages=result, topic=topic)

async def ainvoke(self, *, dialogue: Dialogue, topic: str = "") -> Dialogue:
"""
Async version of dialogue augmentation.

Args:
dialogue: The input Dialogue object to augment
topic: Optional topic to guide the augmentation

Returns:
Dialogue: Augmented dialogue object
"""
dialogue_str = str(dialogue)

result = await self.chain.ainvoke({"topic": topic, "dialogue": dialogue_str})

return Dialogue(messages=result, topic=topic)
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def invoke(self, graph: BaseGraph, start_node: int = 1, end_node: int = -1, topi
# Check if the last node has edges and add the last edge utterances
edges = list(nx_graph.edges(current_node, data=True))
if edges:
last_edge_data = edges[-1][2] # Get the last edge's data
# Get the last edge's data
last_edge_data = edges[-1][2]
last_edge_utterance = (
random.choice(last_edge_data["utterances"])
if isinstance(last_edge_data["utterances"], list)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from chatsky_llm_autoconfig.algorithms.base import TopicGraphGenerator
from chatsky_llm_autoconfig.autometrics.registry import AlgorithmRegistry
from chatsky_llm_autoconfig.schemas import DialogueGraph
from langchain_openai import ChatOpenAI

from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser

from chatsky_llm_autoconfig.graph import BaseGraph, Graph
import os

from pydantic import SecretStr

cycle_graph_generation_prompt = PromptTemplate.from_template(
martynov-dm marked this conversation as resolved.
Show resolved Hide resolved
"""
Create a cyclic dialogue graph where the conversation MUST return to an existing node.

**CRITICAL: Response Specificity**
Responses must acknowledge and build upon what the user has already specified:

INCORRECT flow:
- User: "I'd like to order a coffee"
- Staff: "What would you like to order?" (TOO GENERAL - ignores that they specified coffee)

CORRECT flow:
- User: "I'd like to order a coffee"
- Staff: "What kind of coffee would you like?" (GOOD - acknowledges they want coffee)

Example of a CORRECT cyclic graph for a coffee shop:
"edges": [
{{ "source": 1, "target": 2, "utterances": ["Hi, I'd like to order a coffee"] }},
{{ "source": 2, "target": 3, "utterances": ["A large latte please"] }},
{{ "source": 3, "target": 4, "utterances": ["Yes, that's correct"] }},
{{ "source": 4, "target": 5, "utterances": ["Here's my payment"] }},
{{ "source": 5, "target": 2, "utterances": ["I'd like to order another coffee"] }}
],
"nodes": [
{{ "id": 1, "label": "welcome", "is_start": true, "utterances": ["Welcome! How can I help you today?"] }},
{{ "id": 2, "label": "ask_coffee_type", "is_start": false, "utterances": ["What kind of coffee would you like?"] }},
{{ "id": 3, "label": "confirm", "is_start": false, "utterances": ["That's a large latte. Is this correct?"] }},
{{ "id": 4, "label": "payment", "is_start": false, "utterances": ["Great! That'll be $5. Please proceed with payment."] }},
{{ "id": 5, "label": "completed", "is_start": false, "utterances": ["Thank you! Would you like another coffee?"] }}
]

**Rules:**
1) Responses must acknowledge what the user has already specified
2) The final node MUST connect back to an existing node
3) Each node must have clear purpose
4) Return ONLY the JSON without commentary
5) Graph must be cyclic - no dead ends
6) All edges must connect to existing nodes
7) The cycle point should make logical sense

**Your task is to create a cyclic dialogue graph about the following topic:** {topic}.
"""
)


@AlgorithmRegistry.register(input_type=str, output_type=Graph)
class CycleGraphGenerator(TopicGraphGenerator):
"""Generator specifically for topic-based cyclic graphs"""

def __init__(self):
super().__init__()

def invoke(self, topic: str) -> BaseGraph:
"""
Generate a cyclic dialogue graph based on the topic input.

:param input_data: TopicInput containing the topic
:return: Generated Graph object with cyclic structure
"""

prompt_template = cycle_graph_generation_prompt
parser = JsonOutputParser(pydantic_object=DialogueGraph)
model = ChatOpenAI(model="gpt-4o", api_key=SecretStr(os.getenv("OPENAI_API_KEY") or ""), base_url=os.getenv("OPENAI_BASE_URL"), temperature=0)
chain = prompt_template | model | parser

generated_graph = chain.invoke({"topic": topic})

return Graph(generated_graph)

async def ainvoke(self, *args, **kwargs):
pass
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from chatsky_llm_autoconfig.autometrics.registry import AlgorithmRegistry
import chatsky_llm_autoconfig.algorithms.dialogue_generation
import chatsky_llm_autoconfig.algorithms.dialogue_augmentation
import chatsky_llm_autoconfig.algorithms.graph_generation

import json
from chatsky_llm_autoconfig.graph import Graph, BaseGraph
Expand Down
Loading