Skip to content

Commit

Permalink
ToolBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
BLamy committed Sep 24, 2023
1 parent 0ce8367 commit d408dd2
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 13 deletions.
11 changes: 9 additions & 2 deletions src/Chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,23 @@ import { PromptBuilder } from "./PromptBuilder";
import { ExtractArgs, ExtractChatArgs, ReplaceChatArgs } from "./types";

export class Chat<
ToolNames extends string,
TMessages extends
| []
| [...OpenAI.Chat.CreateChatCompletionRequestMessage[], OpenAI.Chat.CreateChatCompletionRequestMessage],
TSuppliedInputArgs extends ExtractChatArgs<TMessages, {}>
TSuppliedInputArgs extends ExtractChatArgs<TMessages, {}>,
> {
constructor(
public messages: F.Narrow<TMessages>,
public args: F.Narrow<TSuppliedInputArgs>
public args: F.Narrow<TSuppliedInputArgs>,
public tools = {} as Record<ToolNames, Tool>,
///
public mustUseTool: boolean = false
) {}

toJSONSchema() {
}

toArray() {
return (this.messages as TMessages).map((m) => ({
role: m.role,
Expand Down
51 changes: 51 additions & 0 deletions src/ToolBuilder.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
interface Tool<I = unknown, O = unknown> {
name: string;
type: "query" | "mutation"
build: (input: I) => O;
}

export class ToolBuilder<TType extends "query" | "mutation" = "query", I = unknown, O = unknown> {
private name: string;
private implementation?: (input: I) => O;
private type: TType;

constructor(name: string, type: TType = "query" as TType) {
this.name = name;
this.type = type;
}

addInputValidation<T = I>(): ToolBuilder<TType, T, O> {
// Implementation here
return this as unknown as ToolBuilder<TType, T, O>;
}

addOutputValidation<T = O>(): ToolBuilder<TType, I, T> {
// Implementation here
return this as unknown as ToolBuilder<TType, I, T>;
}

query(queryFunction: (input: I) => O): ToolBuilder<"query", I, O> {

return {
...this,
implementation: queryFunction,
type: "query"
};
}

mutation(mutationFunction: (input: I) => O): ToolBuilder<"mutation", I, O> {
return {
...this,
implementation: mutationFunction,
type: "mutation"
};
}

build(): Tool<I, O> {
return {
name: this.name,
build: this.implementation!,
type: this.type
};
}
}
98 changes: 87 additions & 11 deletions src/__tests__/Chat.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ import { strict as assert } from "node:assert";
import { Chat } from "../Chat";
import { system, user, assistant } from "../ChatHelpers";
import { Equal, Expect } from "./types.test";
import { ToolBuilder } from "../ToolBuilder";

describe("Chat", () => {
it("should allow empty array", () => {
const chat = new Chat([], {}).toArray();
const chat = new Chat([], {}, {}).toArray();
type test = Expect<Equal<typeof chat, []>>;
assert.deepEqual(chat, []);
});
Expand Down Expand Up @@ -43,7 +44,6 @@ describe("Chat", () => {
it("should allow chat of all diffent types", () => {
const chat = new Chat(
[
// ^?
user(`Tell me a {{jokeType1}} joke`),
assistant(`{{var2}} joke?`),
system(`joke? {{var3}}`),
Expand All @@ -64,21 +64,97 @@ describe("Chat", () => {
});

it("should allow chat of all diffent types with no args", () => {
const chat = new Chat(
[
// ^?
user(`Tell me a joke`),
assistant(`joke?`),
system(`joke?`),
],
{}
).toArray();
const usrMsg = user("Tell me a joke");
const astMsg = assistant("joke?");
const sysMsg = system("joke?");

const chat = new Chat([usrMsg, astMsg, sysMsg], {}).toArray();
type test = Expect<
Equal<typeof chat, [typeof usrMsg, typeof astMsg, typeof sysMsg]>
>;
assert.deepEqual(chat, [usrMsg, astMsg, sysMsg]);
});

it("should allow me to pass in tools", () => {
const usrMsg = user("Tell me a joke");
const astMsg = assistant("joke?");
const sysMsg = system("joke?");
const tools = {
google: new ToolBuilder("google")
.addInputValidation<{ query: string }>()
.addOutputValidation<{ results: string[] }>()
.query(({ query }) => {
return {
results: ["foo", "bar"],
};
}),
wikipedia: new ToolBuilder("wikipedia")
.addInputValidation<{ page: string }>()
.addOutputValidation<{ results: string[] }>()
.query(({ page }) => {
return {
results: ["foo", "bar"],
};
}),
sendEmail: new ToolBuilder("sendEmail")
.addInputValidation<{ to: string; subject: string; body: string }>()
.addOutputValidation<{ success: boolean }>()
.mutation(({ to, subject, body }) => {
return {
success: true,
};
}),
} as const;

const chat = new Chat([usrMsg, astMsg, sysMsg], {}, tools);

type tests = [
Expect<
Equal<
typeof chat,
Chat<
keyof typeof tools,
[typeof usrMsg, typeof astMsg, typeof sysMsg],
{}
>
>
>,
Expect<
Equal<
typeof tools,
{
readonly google: ToolBuilder<
"query",
{
query: string;
},
{
results: string[];
}
>;
readonly wikipedia: ToolBuilder<
"query",
{
page: string;
},
{
results: string[];
}
>;
readonly sendEmail: ToolBuilder<
"mutation",
{
to: string;
subject: string;
body: string;
},
{
success: boolean;
}
>;
}
>
>
];
});
});

0 comments on commit d408dd2

Please sign in to comment.