diff --git a/src/Chat.ts b/src/Chat.ts index 59fbda1..be0a360 100644 --- a/src/Chat.ts +++ b/src/Chat.ts @@ -4,16 +4,23 @@ import { PromptBuilder } from "./PromptBuilder"; import { ExtractArgs, ExtractChatArgs, ReplaceChatArgs } from "./types"; export class Chat< + ToolNames extends string, TMessages extends | [] | [...OpenAI.Chat.CreateChatCompletionRequestMessage[], OpenAI.Chat.CreateChatCompletionRequestMessage], - TSuppliedInputArgs extends ExtractChatArgs + TSuppliedInputArgs extends ExtractChatArgs, > { constructor( public messages: F.Narrow, - public args: F.Narrow + public args: F.Narrow, + public tools = {} as Record, + /// + public mustUseTool: boolean = false ) {} + toJSONSchema() { + } + toArray() { return (this.messages as TMessages).map((m) => ({ role: m.role, diff --git a/src/ToolBuilder.ts b/src/ToolBuilder.ts new file mode 100644 index 0000000..48fef2b --- /dev/null +++ b/src/ToolBuilder.ts @@ -0,0 +1,51 @@ +interface Tool { + name: string; + type: "query" | "mutation" + build: (input: I) => O; +} + +export class ToolBuilder { + private name: string; + private implementation?: (input: I) => O; + private type: TType; + + constructor(name: string, type: TType = "query" as TType) { + this.name = name; + this.type = type; + } + + addInputValidation(): ToolBuilder { + // Implementation here + return this as unknown as ToolBuilder; + } + + addOutputValidation(): ToolBuilder { + // Implementation here + return this as unknown as ToolBuilder; + } + + query(queryFunction: (input: I) => O): ToolBuilder<"query", I, O> { + + return { + ...this, + implementation: queryFunction, + type: "query" + }; + } + + mutation(mutationFunction: (input: I) => O): ToolBuilder<"mutation", I, O> { + return { + ...this, + implementation: mutationFunction, + type: "mutation" + }; + } + + build(): Tool { + return { + name: this.name, + build: this.implementation!, + type: this.type + }; + } + } \ No newline at end of file diff --git a/src/__tests__/Chat.test.ts b/src/__tests__/Chat.test.ts index dfb6afc..859fccc 100644 --- a/src/__tests__/Chat.test.ts +++ b/src/__tests__/Chat.test.ts @@ -2,10 +2,11 @@ import { strict as assert } from "node:assert"; import { Chat } from "../Chat"; import { system, user, assistant } from "../ChatHelpers"; import { Equal, Expect } from "./types.test"; +import { ToolBuilder } from "../ToolBuilder"; describe("Chat", () => { it("should allow empty array", () => { - const chat = new Chat([], {}).toArray(); + const chat = new Chat([], {}, {}).toArray(); type test = Expect>; assert.deepEqual(chat, []); }); @@ -43,7 +44,6 @@ describe("Chat", () => { it("should allow chat of all diffent types", () => { const chat = new Chat( [ - // ^? user(`Tell me a {{jokeType1}} joke`), assistant(`{{var2}} joke?`), system(`joke? {{var3}}`), @@ -64,21 +64,97 @@ describe("Chat", () => { }); it("should allow chat of all diffent types with no args", () => { - const chat = new Chat( - [ - // ^? - user(`Tell me a joke`), - assistant(`joke?`), - system(`joke?`), - ], - {} - ).toArray(); const usrMsg = user("Tell me a joke"); const astMsg = assistant("joke?"); const sysMsg = system("joke?"); + + const chat = new Chat([usrMsg, astMsg, sysMsg], {}).toArray(); type test = Expect< Equal >; assert.deepEqual(chat, [usrMsg, astMsg, sysMsg]); }); + + it("should allow me to pass in tools", () => { + const usrMsg = user("Tell me a joke"); + const astMsg = assistant("joke?"); + const sysMsg = system("joke?"); + const tools = { + google: new ToolBuilder("google") + .addInputValidation<{ query: string }>() + .addOutputValidation<{ results: string[] }>() + .query(({ query }) => { + return { + results: ["foo", "bar"], + }; + }), + wikipedia: new ToolBuilder("wikipedia") + .addInputValidation<{ page: string }>() + .addOutputValidation<{ results: string[] }>() + .query(({ page }) => { + return { + results: ["foo", "bar"], + }; + }), + sendEmail: new ToolBuilder("sendEmail") + .addInputValidation<{ to: string; subject: string; body: string }>() + .addOutputValidation<{ success: boolean }>() + .mutation(({ to, subject, body }) => { + return { + success: true, + }; + }), + } as const; + + const chat = new Chat([usrMsg, astMsg, sysMsg], {}, tools); + + type tests = [ + Expect< + Equal< + typeof chat, + Chat< + keyof typeof tools, + [typeof usrMsg, typeof astMsg, typeof sysMsg], + {} + > + > + >, + Expect< + Equal< + typeof tools, + { + readonly google: ToolBuilder< + "query", + { + query: string; + }, + { + results: string[]; + } + >; + readonly wikipedia: ToolBuilder< + "query", + { + page: string; + }, + { + results: string[]; + } + >; + readonly sendEmail: ToolBuilder< + "mutation", + { + to: string; + subject: string; + body: string; + }, + { + success: boolean; + } + >; + } + > + > + ]; + }); });