Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

working on generation #5

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,6 @@ def invoke(self, graph: BaseGraph, start_node: int = 1, end_node: int = 0, topic
raise NotImplementedError


class DialogAugmentation(BaseAlgorithm):
"""
Base class for augmenting Dialogues.

This class takes a Dialogue as input and returns an augmented Dialogue as output.
It is designed for data augmentation or other manipulations of Dialogues.

:param dialogue: The Dialogue object to be augmented.
:param topic: The topic to guide the augmentation process (optional).
"""

def __init__(self):
super().__init__()

def invoke(self, dialogue: Dialogue, topic: str = "") -> Dialogue:
raise NotImplementedError


class GraphAugmentation(BaseAlgorithm):
"""Graph generator that works only with topics."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,29 @@
from chatsky_llm_autoconfig.schemas import DialogueMessage
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from typing import List

from typing import List, Any
from langchain_openai import ChatOpenAI
from pydantic import BaseModel
from pydantic import BaseModel, SecretStr
import os
from chatsky_llm_autoconfig.algorithms.base import BaseAlgorithm
from chatsky_llm_autoconfig.autometrics.registry import AlgorithmRegistry
from langchain_core.runnables import RunnableSerializable


class DialogueSequence(BaseModel):
result: List[DialogueMessage]


@AlgorithmRegistry.register(input_type=Dialogue, output_type=Dialogue)
class DialogAugmentation(BaseAlgorithm):
"""Base class for augmenting Dialogues."""

augmentation_prompt: PromptTemplate = ""

augmentation_prompt = PromptTemplate.from_template(
"""
def __init__(self):
super().__init__()
self.augmentation_prompt = PromptTemplate.from_template(
"""
You are tasked with augmenting a dialogue by adding variations to existing utterances while maintaining the original dialogue flow and intent.

THEME: {topic}
Expand Down Expand Up @@ -46,21 +61,7 @@
{{"text": "What kind of package is it?", "participant": "assistant"}}
]
"""
)


class DialogueSequence(BaseModel):
result: List[DialogueMessage]


class DialogAugmentation(BaseModel):
"""Base class for augmenting Dialogues."""

def __init__(self, **data):
super().__init__(**data)
self.parser = JsonOutputParser(pydantic_object=DialogueSequence)
self.model = ChatOpenAI(model="gpt-4o-mini", api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL"), temperature=0.7)
self.chain = augmentation_prompt | self.model | self.parser
)

def invoke(self, *, dialogue: Dialogue, topic: str = "") -> Dialogue:
"""
Expand All @@ -73,14 +74,15 @@ def invoke(self, *, dialogue: Dialogue, topic: str = "") -> Dialogue:
Returns:
Dialogue: Augmented dialogue object
"""
# Convert dialogue to string format for prompt
# Предполагая, что у Dialogue есть str представление
dialogue_str = str(dialogue)
parser: JsonOutputParser = JsonOutputParser(pydantic_object=DialogueSequence)
model: ChatOpenAI = ChatOpenAI(
model="gpt-4o-mini", api_key=SecretStr(os.getenv("OPENAI_API_KEY") or ""), base_url=os.getenv("OPENAI_BASE_URL"), temperature=0.7
)
chain: RunnableSerializable[Any, Any] = self.augmentation_prompt | model | parser

# Get augmented messages
result = self.chain.invoke({"topic": topic, "dialogue": dialogue_str})
dialogue_str: str = str(dialogue)
result: List[DialogueMessage] = chain.invoke({"topic": topic, "dialogue": dialogue_str})

# Create new Dialogue object with augmented messages
return Dialogue(messages=result, topic=topic)

async def ainvoke(self, *, dialogue: Dialogue, topic: str = "") -> Dialogue:
Expand All @@ -94,8 +96,13 @@ async def ainvoke(self, *, dialogue: Dialogue, topic: str = "") -> Dialogue:
Returns:
Dialogue: Augmented dialogue object
"""
dialogue_str = str(dialogue)

result = await self.chain.ainvoke({"topic": topic, "dialogue": dialogue_str})
parser: JsonOutputParser = JsonOutputParser(pydantic_object=DialogueSequence)
model: ChatOpenAI = ChatOpenAI(
model="gpt-4o-mini", api_key=SecretStr(os.getenv("OPENAI_API_KEY") or ""), base_url=os.getenv("OPENAI_BASE_URL"), temperature=0.7
)
chain: RunnableSerializable[Any, Any] = self.augmentation_prompt | model | parser

dialogue_str: str = str(dialogue)
result: List[DialogueMessage] = await chain.ainvoke({"topic": topic, "dialogue": dialogue_str})

return Dialogue(messages=result, topic=topic)
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def run_all_algorithms():
elif algorithms[class_]["input_type"] is str and algorithms[class_]["output_type"] is BaseGraph:
metrics = {"is_theme_valid": [], "are_triplets_valid": []}
for case in topic_to_graph:
test_topic = case['topic']
test_topic = case["topic"]
result = class_instance.invoke(test_topic)

metrics["are_triplets_valid"].append(are_triplets_valid(result, model, topic=test_topic)["value"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
This module contains functions that checks Graphs and Dialogues for various metrics using LLM calls.
"""

from langchain.chat_models.base import BaseChatModel
from langchain.output_parsers import PydanticOutputParser
from typing import TypedDict, List
from chatsky_llm_autoconfig.graph import BaseGraph
from typing import List, Tuple
from typing import List
from langchain_core.language_models.chat_models import BaseChatModel
from langchain.prompts import PromptTemplate
from pydantic import BaseModel, Field
Expand All @@ -18,7 +21,12 @@
logging.basicConfig(level=logging.INFO)


def are_triplets_valid(G: BaseGraph, model: BaseChatModel, topic: str) -> dict[str]:
class ValidationResult(TypedDict):
value: bool
description: str


def are_triplets_valid(G: BaseGraph, model: BaseChatModel, topic: str) -> ValidationResult:
"""
Validates the dialog graph structure and logical transitions between nodes.

Expand All @@ -28,7 +36,7 @@ def are_triplets_valid(G: BaseGraph, model: BaseChatModel, topic: str) -> dict[s
topic (str): The topic of the dialog

Returns:
dict: {'value': bool, 'description': str}
ValidationResult: Dictionary with structure {'value': bool, 'description': str}
"""
# Define prompt template and parser inside the function since they're only used here
triplet_validate_prompt_template = """
Expand Down Expand Up @@ -74,7 +82,7 @@ class TransitionValidationResult(BaseModel):
# Create a mapping from node IDs to node data for quick access
node_map = {node["id"]: node for node in graph["nodes"]}
overall_valid = True
descriptions = []
descriptions: List[str] = []

for edge in graph["edges"]:
source_id = edge["source"]
Expand All @@ -83,13 +91,15 @@ class TransitionValidationResult(BaseModel):

# Check if source and target nodes exist
if source_id not in node_map:
description = f"Invalid edge: source node {source_id} does not exist."
description = f"Invalid edge: source node {
source_id} does not exist."
logging.info(description)
overall_valid = False
descriptions.append(description)
continue
if target_id not in node_map:
description = f"Invalid edge: target node {target_id} does not exist."
description = f"Invalid edge: target node {
target_id} does not exist."
logging.info(description)
overall_valid = False
descriptions.append(description)
Expand All @@ -115,11 +125,12 @@ class TransitionValidationResult(BaseModel):

if not response.isValid:
overall_valid = False
description = f"Invalid transition from {source_utterances} to {target_utterances} via edge '{edge_utterances}': {response.description}"
description = f"Invalid transition from {source_utterances} to {
target_utterances} via edge '{edge_utterances}': {response.description}"
logging.info(description)
descriptions.append(description)

result = {"value": overall_valid, "description": " ".join(descriptions) if descriptions else "All transitions are valid."}
result: ValidationResult = {"value": overall_valid, "description": " ".join(descriptions) if descriptions else "All transitions are valid."}
return result


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "dff-llm-integration-VcuUrJCU-py3.12",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
Expand All @@ -381,7 +381,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.12.7"
}
},
"nbformat": 4,
Expand Down
Loading