Skip to content

Commit

Permalink
* Fix an issue where llama could hang when the context window overflowed
Browse files Browse the repository at this point in the history
* Fixes for smol so it doesn't output control tokens
* Fixes for token counting so we use the actual count method from the model
* Fixes for the languageModel context window so that we can more effectively drop old messages
* Support multiple calls to the language model apis by adding queuing to the native binary
  • Loading branch information
Thomas101 committed Nov 5, 2024
1 parent 9d47b0a commit bba8203
Show file tree
Hide file tree
Showing 13 changed files with 442 additions and 289 deletions.
2 changes: 1 addition & 1 deletion examples/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
"license": "MPL-2.0",
"description": "",
"dependencies": {
"@popperjs/core": "^2.11.8",
"@aibrow/extension": "file:../out/extlib",
"@popperjs/core": "^2.11.8",
"bootstrap": "^5.3.3",
"camelcase": "^8.0.0",
"fast-deep-equal": "^3.1.3"
Expand Down
94 changes: 31 additions & 63 deletions src/extension/background/AI/AIPrompter.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import { AICapabilityGpuEngine } from '#Shared/API/AI'
import { AICapabilityGpuEngine, AIRootModelProps } from '#Shared/API/AI'
import NativeIPC from '../NativeIPC'
import {
kPrompterGetSupportedGpuEngines,
kPrompterExecPromptSession,
kPrompterDisposePromptSession,

PromptOptions
kPrompterCountPromptTokens,
kPrompterDisposePromptSession
} from '#Shared/NativeAPI/PrompterIPC'
import { AIModelTokenCountMethod } from '#Shared/AIModelManifest'
import { kGpuEngineNotSupported } from '#Shared/Errors'

type SupportedEngines = {
Expand All @@ -16,19 +14,13 @@ type SupportedEngines = {
callbacks: Array<(engines: AICapabilityGpuEngine[]) => void>
}

type PromptStreamOptions = {
type CountTokensRequestOptions = {
signal?: AbortSignal
stream: (chunk: string) => void
}

type PromptQueue = {
inflight: boolean
queue: Array<{
options: PromptOptions
streamOptions: PromptStreamOptions
resolve: (value: unknown) => void
reject: (ex: Error) => void
}>
type PromptStreamOptions = {
signal?: AbortSignal
stream: (chunk: string) => void
}

class AIPrompter {
Expand All @@ -37,7 +29,6 @@ class AIPrompter {
/* **************************************************************************/

#supportedEngines: SupportedEngines
#promptQueue: PromptQueue = { inflight: false, queue: [] }

/* **************************************************************************/
// MARK: Lifecycle
Expand Down Expand Up @@ -83,48 +74,23 @@ class AIPrompter {

/**
* Adds a new language model prompt to the queue
* @param options: the prompt options
* @param sessionId: the id of the session
* @param prompt: the prompt to execute
* @param props: the prompt model props
* @param streamOptions: the options for the return stream
*/
async prompt (options: PromptOptions, streamOptions: PromptStreamOptions) {
if (options.gpuEngine && !(await this.getSupportedGpuEngines()).includes(options.gpuEngine)) {
async prompt (sessionId: string, prompt: string, props: AIRootModelProps, streamOptions: PromptStreamOptions) {
if (props.gpuEngine && !(await this.getSupportedGpuEngines()).includes(props.gpuEngine)) {
throw new Error(kGpuEngineNotSupported)
}

return new Promise((resolve, reject) => {
this.#promptQueue.queue.push({
options,
streamOptions,
resolve,
reject
})
setTimeout(this.#drainPromptQueue, 1)
})
}

/**
* Drains the next item in the prompt queue and executes it
*/
#drainPromptQueue = async () => {
if (this.#promptQueue.inflight) { return }
if (this.#promptQueue.queue.length === 0) { return }

this.#promptQueue.inflight = true
const { options, streamOptions, resolve, reject } = this.#promptQueue.queue.pop()
try {
const res = await NativeIPC.stream(
kPrompterExecPromptSession,
options,
(chunk: string) => streamOptions.stream(chunk),
{ signal: streamOptions.signal }
)
resolve(res)
} catch (ex) {
reject(ex)
} finally {
this.#promptQueue.inflight = false
setTimeout(this.#drainPromptQueue, 1)
}
const res = await NativeIPC.stream(
kPrompterExecPromptSession,
{ props, prompt, sessionId },
(chunk: string) => streamOptions.stream(chunk),
{ signal: streamOptions.signal }
)
return res
}

/**
Expand All @@ -141,20 +107,22 @@ class AIPrompter {

/**
* Counts the tokens in a string
* @param input: the string to count the tokens from
* @param method: the method to use for counting
* @param input: the input string
* @param props: the model props
* @param requestOptions: the request options
* @return the token count
*/
async countTokens (input: string, method: AIModelTokenCountMethod | AIModelTokenCountMethod[]) {
for (const m of Array.isArray(method) ? method : [method]) {
switch (m) {
case AIModelTokenCountMethod.Divide4:
return Math.ceil(input.length / 4)
}
async countTokens (input: string, props: AIRootModelProps, requestOptions: CountTokensRequestOptions) {
if (props.gpuEngine && !(await this.getSupportedGpuEngines()).includes(props.gpuEngine)) {
throw new Error(kGpuEngineNotSupported)
}

console.warn('Unknown token count method, defaulting to Divide4:', method)
return Math.ceil(input.length / 4)
const res = await NativeIPC.request(
kPrompterCountPromptTokens,
{ input, props },
{ signal: requestOptions.signal }
)
return res
}
}

Expand Down
17 changes: 6 additions & 11 deletions src/extension/background/APIHandler/AICoreModelHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,7 @@ class AICoreModelHandler {
) => {
return {
sessionId: nanoid(),
props: {
...props,
grammar: payload.getAny('grammar')
}
props
} as AICoreModelData
})
}
Expand All @@ -78,17 +75,15 @@ class AICoreModelHandler {
return await APIHelper.handleStandardPromptPreflight(channel, async (
manifest,
payload,
options
props
) => {
const sessionId = payload.getNonEmptyString('sessionId')
const prompt = payload.getString('prompt')
const grammar = payload.getAny('props.grammar')

await AIPrompter.prompt(
{
...options,
prompt,
grammar
},
sessionId,
prompt,
props,
{
signal: channel.abortSignal,
stream: (chunk: string) => channel.emit(chunk)
Expand Down
106 changes: 60 additions & 46 deletions src/extension/background/APIHandler/AILanguageModelHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,14 @@ import {
AILanguageModelPrompt,
AILanguageModelPromptRole
} from '#Shared/API/AILanguageModel/AILanguageModelTypes'
import PermissionProvider from '../PermissionProvider'
import { getNonEmptyString } from '#Shared/API/Untrusted/UntrustedParser'
import { kModelPromptAborted, kModelPromptTypeNotSupported } from '#Shared/Errors'
import APIHelper from './APIHelper'
import AIModelFileSystem from '../AI/AIModelFileSystem'
import AIPrompter from '../AI/AIPrompter'
import { AIModelManifest } from '#Shared/AIModelManifest'
import { nanoid } from 'nanoid'
import { Template } from '@huggingface/jinja'
import { AICapabilityPromptType } from '#Shared/API/AI'
import { AICapabilityPromptType, AIRootModelProps } from '#Shared/API/AI'

class AILanguageModelHandler {
/* **************************************************************************/
Expand Down Expand Up @@ -55,13 +53,15 @@ class AILanguageModelHandler {
/**
* Builds the prompt from the users variables
* @param manifest: the manifest object
* @param options: the prompt options
* @param systemPrompt: the system prompt
* @param initialPrompts: the array of initial prompts
* @param messages: the array of messages
* @returns the prompt to pass to the LLM
*/
async #buildPrompt (
manifest: AIModelManifest,
props: AIRootModelProps,
systemPrompt: string | undefined,
initialPrompts: AILanguageModelInitialPrompt[],
messages: AILanguageModelPrompt[]
Expand All @@ -71,35 +71,32 @@ class AILanguageModelHandler {
}
const promptConfig = manifest.prompts[AICapabilityPromptType.LanguageModel]

// Build the messages
let tokenCount = systemPrompt
? await AIPrompter.countTokens(systemPrompt, manifest.tokens.method)
: 0

const history = [...initialPrompts, ...messages]
const countedMessages = []
for (let i = history.length - 1; i >= 0; i--) {
const message = history[i]
tokenCount += await AIPrompter.countTokens(message.content, manifest.tokens.method)
if (tokenCount > manifest.tokens.max) {
break
let droppedMessages = 0
while (true) {
const allMessages = [...initialPrompts, ...messages]
if (droppedMessages >= allMessages.length && droppedMessages > 0) {
throw new Error('Failed to build prompt. Context window overflow.')
}
countedMessages.unshift(message)
const history = [
...systemPrompt
? [{ content: systemPrompt, role: 'system' }]
: [],
...allMessages.slice(droppedMessages)
]
const template = new Template(promptConfig.template)
const prompt = template.render({
messages: history,
bos_token: promptConfig.bosToken,
eos_token: promptConfig.eosToken
})

const tokenCount = await AIPrompter.countTokens(prompt, props, {})
if (tokenCount <= props.contextSize) {
return prompt
}

droppedMessages++
}
const messagesWindow = [
...systemPrompt
? [{ content: systemPrompt, role: 'system' }, ...countedMessages]
: countedMessages
]

// Send to the template
const template = new Template(promptConfig.template)
const prompt = template.render({
messages: messagesWindow,
bos_token: promptConfig.bosToken,
eos_token: promptConfig.eosToken
})
return prompt
}

/* **************************************************************************/
Expand All @@ -123,8 +120,9 @@ class AILanguageModelHandler {
const systemPrompt = payload.getString('systemPrompt')
const initialPrompts = payload.getAILanguageModelInitialPrompts('initialPrompts')
const tokensSoFar = await AIPrompter.countTokens(
await this.#buildPrompt(manifest, systemPrompt, initialPrompts, []),
manifest.tokens.method
await this.#buildPrompt(manifest, props, systemPrompt, initialPrompts, []),
props,
{}
)

return {
Expand Down Expand Up @@ -153,13 +151,29 @@ class AILanguageModelHandler {
/* **************************************************************************/

#handleCountTokens = async (channel: IPCInflightChannel) => {
const modelId = await APIHelper.getModelId(channel.payload?.props?.model)
const input = getNonEmptyString(channel.payload?.input)
return await APIHelper.handleStandardPromptPreflight(channel, async (
manifest,
payload,
props
) => {
const input = payload.getString('input')
const prompt = await this.#buildPrompt(
manifest,
props,
undefined,
[],
[{ role: AILanguageModelPromptRole.User, content: input }]
)
if (channel.abortSignal?.aborted) { throw new Error(kModelPromptAborted) }

await PermissionProvider.ensureModelPermission(channel, modelId)
const count = (await AIPrompter.countTokens(
prompt,
props,
{ signal: channel.abortSignal }
)) as number

const manifest = await AIModelFileSystem.readModelManifest(modelId)
return await AIPrompter.countTokens(input, manifest.tokens.method)
return count
})
}

/* **************************************************************************/
Expand All @@ -175,21 +189,19 @@ class AILanguageModelHandler {
return await APIHelper.handleStandardPromptPreflight(channel, async (
manifest,
payload,
options
props
) => {
const systemPrompt = payload.getString('props.systemPrompt')
const initialPrompts = payload.getAILanguageModelInitialPrompts('props.initialPrompts')
const messages = payload.getAILanguageModelPrompts('messages')
const prompt = await this.#buildPrompt(manifest, systemPrompt, initialPrompts, messages)
const grammar = payload.getAny('props.grammar')
const prompt = await this.#buildPrompt(manifest, props, systemPrompt, initialPrompts, messages)
const sessionId = payload.getNonEmptyString('sessionId')
if (channel.abortSignal?.aborted) { throw new Error(kModelPromptAborted) }

const reply = (await AIPrompter.prompt(
{
...options,
prompt,
grammar
},
sessionId,
prompt,
props,
{
signal: channel.abortSignal,
stream: (chunk: string) => channel.emit(chunk)
Expand All @@ -200,11 +212,13 @@ class AILanguageModelHandler {
tokensSoFar: await AIPrompter.countTokens(
await this.#buildPrompt(
manifest,
props,
systemPrompt,
initialPrompts,
[...messages, { role: AILanguageModelPromptRole.Assistant, content: reply }]
),
manifest.tokens.method
props,
{}
)
}
})
Expand Down
7 changes: 5 additions & 2 deletions src/extension/background/APIHandler/AIRewriterHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class AIRewriterHandler {
return await APIHelper.handleStandardPromptPreflight(channel, async (
manifest,
payload,
options
props
) => {
const sharedContext = payload.getString('props.sharedContext')
const tone = payload.getEnum('props.tone', AIRewriterTone, AIRewriterTone.AsIs)
Expand All @@ -96,9 +96,12 @@ class AIRewriterHandler {
const context = payload.getString('context')
const input = payload.getString('input')
const prompt = this.#getPrompt(manifest, tone, format, length, sharedContext, context, input)
const sessionId = payload.getNonEmptyString('sessionId')

await AIPrompter.prompt(
{ ...options, prompt },
sessionId,
prompt,
props,
{
signal: channel.abortSignal,
stream: (chunk: string) => channel.emit(chunk)
Expand Down
Loading

0 comments on commit bba8203

Please sign in to comment.