diff --git a/packages/backend/src/assets/webui.db b/packages/backend/src/assets/webui.db new file mode 100644 index 000000000..0f335a153 Binary files /dev/null and b/packages/backend/src/assets/webui.db differ diff --git a/packages/backend/src/managers/playgroundV2Manager.ts b/packages/backend/src/managers/playgroundV2Manager.ts index 4a030e6fa..99dc13a4d 100644 --- a/packages/backend/src/managers/playgroundV2Manager.ts +++ b/packages/backend/src/managers/playgroundV2Manager.ts @@ -36,6 +36,8 @@ import { getRandomString } from '../utils/randomUtils'; import type { TaskRegistry } from '../registries/TaskRegistry'; import type { CancellationTokenRegistry } from '../registries/CancellationTokenRegistry'; import { getHash } from '../utils/sha'; +import type { ConfigurationRegistry } from '../registries/ConfigurationRegistry'; +import type { PodmanConnection } from './podmanConnection'; export class PlaygroundV2Manager implements Disposable { #conversationRegistry: ConversationRegistry; @@ -46,17 +48,24 @@ export class PlaygroundV2Manager implements Disposable { private taskRegistry: TaskRegistry, private telemetry: TelemetryLogger, private cancellationTokenRegistry: CancellationTokenRegistry, + configurationRegistry: ConfigurationRegistry, + podmanConnection: PodmanConnection, ) { - this.#conversationRegistry = new ConversationRegistry(webview); + this.#conversationRegistry = new ConversationRegistry( + webview, + configurationRegistry, + taskRegistry, + podmanConnection, + ); } - deleteConversation(conversationId: string): void { + async deleteConversation(conversationId: string): Promise { const conversation = this.#conversationRegistry.get(conversationId); this.telemetry.logUsage('playground.delete', { totalMessages: conversation.messages.length, modelId: getHash(conversation.modelId), }); - this.#conversationRegistry.deleteConversation(conversationId); + await this.#conversationRegistry.deleteConversation(conversationId); } async requestCreatePlayground(name: string, model: ModelInfo): Promise { @@ -117,11 +126,11 @@ export class PlaygroundV2Manager implements Disposable { } // Create conversation - const conversationId = this.#conversationRegistry.createConversation(name, model.id); + const conversationId = await this.#conversationRegistry.createConversation(name, model.id); // create/start inference server if necessary const servers = this.inferenceManager.getServers(); - const server = servers.find(s => s.models.map(mi => mi.id).includes(model.id)); + let server = servers.find(s => s.models.map(mi => mi.id).includes(model.id)); if (!server) { await this.inferenceManager.createInferenceServer( await withDefaultConfiguration({ @@ -131,10 +140,15 @@ export class PlaygroundV2Manager implements Disposable { }, }), ); + server = this.inferenceManager.findServerByModel(model); } else if (server.status === 'stopped') { await this.inferenceManager.startInferenceServer(server.container.containerId); } + if (server && server.status === 'running') { + await this.#conversationRegistry.startConversationContainer(server, trackingId, conversationId); + } + return conversationId; } diff --git a/packages/backend/src/registries/ConfigurationRegistry.ts b/packages/backend/src/registries/ConfigurationRegistry.ts index 5cb66237b..bcc2eb67c 100644 --- a/packages/backend/src/registries/ConfigurationRegistry.ts +++ b/packages/backend/src/registries/ConfigurationRegistry.ts @@ -62,6 +62,10 @@ export class ConfigurationRegistry extends Publisher imp return path.join(this.appUserDirectory, 'models'); } + public getConversationsPath(): string { + return path.join(this.appUserDirectory, 'conversations'); + } + dispose(): void { this.#configurationDisposable?.dispose(); } diff --git a/packages/backend/src/registries/ConversationRegistry.ts b/packages/backend/src/registries/ConversationRegistry.ts index eab300242..c1f7af748 100644 --- a/packages/backend/src/registries/ConversationRegistry.ts +++ b/packages/backend/src/registries/ConversationRegistry.ts @@ -25,14 +25,36 @@ import type { Message, PendingChat, } from '@shared/src/models/IPlaygroundMessage'; -import type { Disposable, Webview } from '@podman-desktop/api'; +import { + type Disposable, + type Webview, + type ContainerCreateOptions, + containerEngine, + type ContainerProviderConnection, + type ImageInfo, + type PullEvent, +} from '@podman-desktop/api'; import { Messages } from '@shared/Messages'; +import type { ConfigurationRegistry } from './ConfigurationRegistry'; +import path from 'node:path'; +import fs from 'node:fs'; +import type { InferenceServer } from '@shared/src/models/IInference'; +import { getFreeRandomPort } from '../utils/ports'; +import { DISABLE_SELINUX_LABEL_SECURITY_OPTION } from '../utils/utils'; +import { getImageInfo } from '../utils/inferenceUtils'; +import type { TaskRegistry } from './TaskRegistry'; +import type { PodmanConnection } from '../managers/podmanConnection'; export class ConversationRegistry extends Publisher implements Disposable { #conversations: Map; #counter: number; - constructor(webview: Webview) { + constructor( + webview: Webview, + private configurationRegistry: ConfigurationRegistry, + private taskRegistry: TaskRegistry, + private podmanConnection: PodmanConnection, + ) { super(webview, Messages.MSG_CONVERSATIONS_UPDATE, () => this.getAll()); this.#conversations = new Map(); this.#counter = 0; @@ -76,13 +98,32 @@ export class ConversationRegistry extends Publisher implements D this.notify(); } - deleteConversation(id: string): void { + async deleteConversation(id: string): Promise { + const conversation = this.get(id); + if (conversation.container) { + await containerEngine.stopContainer(conversation.container?.engineId, conversation.container?.containerId); + } + await fs.promises.rm(path.join(this.configurationRegistry.getConversationsPath(), id), { + recursive: true, + force: true, + }); this.#conversations.delete(id); this.notify(); } - createConversation(name: string, modelId: string): string { + async createConversation(name: string, modelId: string): Promise { const conversationId = this.getUniqueId(); + const conversationFolder = path.join(this.configurationRegistry.getConversationsPath(), conversationId); + await fs.promises.mkdir(conversationFolder, { + recursive: true, + }); + //WARNING: this will not work in production mode but didn't find how to embed binary assets + //this code get an initialized database so that default user is not admin thus did not get the initial + //welcome modal dialog + await fs.promises.copyFile( + path.join(__dirname, '..', 'src', 'assets', 'webui.db'), + path.join(conversationFolder, 'webui.db'), + ); this.#conversations.set(conversationId, { name: name, modelId: modelId, @@ -93,6 +134,77 @@ export class ConversationRegistry extends Publisher implements D return conversationId; } + async startConversationContainer(server: InferenceServer, trackingId: string, conversationId: string): Promise { + const conversation = this.get(conversationId); + const port = await getFreeRandomPort('127.0.0.1'); + const connection = await this.podmanConnection.getConnectionByEngineId(server.container.engineId); + await this.pullImage(connection, 'ghcr.io/open-webui/open-webui:main', { + trackingId: trackingId, + }); + const inferenceServerContainer = await containerEngine.inspectContainer( + server.container.engineId, + server.container.containerId, + ); + const options: ContainerCreateOptions = { + Env: [ + 'DEFAULT_LOCALE=en-US', + 'WEBUI_AUTH=false', + 'ENABLE_OLLAMA_API=false', + `OPENAI_API_BASE_URL=http://${inferenceServerContainer.NetworkSettings.IPAddress}:8000/v1`, + 'OPENAI_API_KEY=sk_dummy', + `WEBUI_URL=http://localhost:${port}`, + `DEFAULT_MODELS=/models/${server.models[0].file?.file}`, + ], + Image: 'ghcr.io/open-webui/open-webui:main', + HostConfig: { + AutoRemove: true, + Mounts: [ + { + Source: path.join(this.configurationRegistry.getConversationsPath(), conversationId), + Target: '/app/backend/data', + Type: 'bind', + }, + ], + PortBindings: { + '8080/tcp': [ + { + HostPort: `${port}`, + }, + ], + }, + SecurityOpt: [DISABLE_SELINUX_LABEL_SECURITY_OPTION], + }, + }; + const c = await containerEngine.createContainer(server.container.engineId, options); + conversation.container = { engineId: c.engineId, containerId: c.id, port }; + } + + protected pullImage( + connection: ContainerProviderConnection, + image: string, + labels: { [id: string]: string }, + ): Promise { + // Creating a task to follow pulling progress + const pullingTask = this.taskRegistry.createTask(`Pulling ${image}.`, 'loading', labels); + + // get the default image info for this provider + return getImageInfo(connection, image, (_event: PullEvent) => {}) + .catch((err: unknown) => { + pullingTask.state = 'error'; + pullingTask.progress = undefined; + pullingTask.error = `Something went wrong while pulling ${image}: ${String(err)}`; + throw err; + }) + .then(imageInfo => { + pullingTask.state = 'success'; + pullingTask.progress = undefined; + return imageInfo; + }) + .finally(() => { + this.taskRegistry.updateTask(pullingTask); + }); + } + /** * This method will be responsible for finalizing the message by concatenating all the choices * @param conversationId diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index 7d25a1108..d41951e23 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -306,6 +306,8 @@ export class Studio { this.#taskRegistry, this.#telemetry, this.#cancellationTokenRegistry, + this.#configurationRegistry, + this.#podmanConnection, ); this.#extensionContext.subscriptions.push(this.#playgroundManager); diff --git a/packages/frontend/src/pages/Playground.svelte b/packages/frontend/src/pages/Playground.svelte index 62d1a6200..7900f81b1 100644 --- a/packages/frontend/src/pages/Playground.svelte +++ b/packages/frontend/src/pages/Playground.svelte @@ -1,113 +1,26 @@ {#if conversation} @@ -188,110 +69,12 @@ function handleOnClick(): void { -
-
- - -
-
- {#if conversation} - - {#key conversation.messages.length} - - {/key} - -
    - {#each messages as message} -
  • - -
  • - {/each} -
- {/if} -
-
-
- -
Next prompt will use these settings
-
-
Model Parameters
-
-
-
- -
- - - -
- What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output - more random, while lower values like 0.2 will make it more focused and deterministic. -
-
-
-
-
-
- -
- - - -
- The maximum number of tokens that can be generated in the chat completion. -
-
-
-
-
-
- -
- - - -
- An alternative to sampling with temperature, where the model considers the results of the - tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% - probability mass are considered. -
-
-
-
-
-
-
-
-
- {#if errorMsg} -
{errorMsg}
- {/if} -
- - -
- {#if !sendEnabled && cancellationTokenId !== undefined} - - {/if} -
+
+
+
diff --git a/packages/shared/src/models/IPlaygroundMessage.ts b/packages/shared/src/models/IPlaygroundMessage.ts index cdebc2046..4333305ac 100644 --- a/packages/shared/src/models/IPlaygroundMessage.ts +++ b/packages/shared/src/models/IPlaygroundMessage.ts @@ -57,6 +57,11 @@ export interface Conversation { messages: Message[]; modelId: string; name: string; + container?: { + engineId: string; + containerId: string; + port: number; + }; } export interface Choice {