-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
116 lines (92 loc) · 3.56 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from enum import Enum
import databases
from datetime import datetime
import ormar
import sqlalchemy
from dotenv import load_dotenv
import os
# Load environment variables from .env file
load_dotenv()
database = databases.Database(
os.getenv("LOCAL_DB_URL"),
timeout=10, # seconds
)
metadata = sqlalchemy.MetaData()
# Create a base OrmarConfig
base_ormar_config = ormar.OrmarConfig(
metadata=metadata,
database=database,
)
class DefaultModel(ormar.Model):
ormar_config = base_ormar_config.copy(abstract=True)
id: int = ormar.Integer(primary_key=True)
created_at: datetime = ormar.DateTime(nullable=False, default=datetime.utcnow)
updated_at: datetime = ormar.DateTime(nullable=False, default=datetime.utcnow)
model_config = dict(protected_namespaces=())
async def update(self, *args, **kwargs):
self.updated_at = datetime.utcnow()
return await super().update(*args, **kwargs)
class RunTerminalInteropProtocol(str, Enum):
CLOSING_ANGLE_BRACKET = "CLOSING_ANGLE_BRACKET"
MARKDOWN_CODE_BLOCKS = "MARKDOWN_CODE_BLOCKS"
JSON_BASIC = "JSON_BASIC"
XML_TAGS = "XML_TAGS"
class TaskConfigSnapshot(DefaultModel):
ormar_config = base_ormar_config.copy(
tablename="task_config_snapshots",
constraints=[ormar.UniqueColumns("name", "toml_content")],
)
name: str = ormar.String(nullable=False, max_length=1_000_000)
# must be normalized
toml_content: str = ormar.String(nullable=False, max_length=1_000_000)
is_most_recent: bool = ormar.Boolean(nullable=True)
class ElicitationSnapshot(DefaultModel):
ormar_config = base_ormar_config.copy(
tablename="elicitation_snapshots",
constraints=[ormar.UniqueColumns("toml_content")],
)
# must be normalized
toml_content: str = ormar.String(nullable=False, max_length=1_000_000)
is_most_recent: bool = ormar.Boolean(nullable=True)
class EvalRun(DefaultModel):
ormar_config = base_ormar_config.copy(tablename="eval_runs_v2")
task_config_snapshot: TaskConfigSnapshot = ormar.ForeignKey(
TaskConfigSnapshot,
nullable=False,
related_name="runs",
ondelete=ormar.ReferentialAction.CASCADE,
)
elicitation_snapshot: ElicitationSnapshot = ormar.ForeignKey(
ElicitationSnapshot,
nullable=False,
related_name="runs",
ondelete=ormar.ReferentialAction.CASCADE,
)
model: str = ormar.String(nullable=False, max_length=1_000_000)
class RunStatus(str, Enum):
# unfinished, result not available
RUNNING = "RUNNING"
# found the flag within the alloted turns
SUCCESS = "SUCCESS"
# ran out of time or broke the env
FAILURE = "FAILURE"
# system error, shouldn't count
ERROR = "ERROR"
# e.g. o1 refusal
REFUSED = "REFUSED"
status: RunStatus = ormar.Enum(enum_class=RunStatus, nullable=False)
exception_stacktrace: str = ormar.String(nullable=True, max_length=1_000_000)
class ChatMessage(DefaultModel):
ormar_config = base_ormar_config.copy(tablename="chat_messages_v2")
run: EvalRun = ormar.ForeignKey(
EvalRun,
nullable=False,
related_name="messages",
ondelete=ormar.ReferentialAction.CASCADE,
index=True,
)
ordinal: int = ormar.Integer(nullable=False)
role: str = ormar.String(nullable=False, max_length=1_000_000)
content: str = ormar.String(nullable=False, max_length=1_000_000)
is_prefilled: bool = ormar.Boolean(nullable=False)
underlying_communication: str = ormar.String(nullable=True, max_length=1_000_000)