From bf5a5f818e422f10ed829ee53a63f29f1b41c0d1 Mon Sep 17 00:00:00 2001 From: Jeff MAURY Date: Fri, 20 Sep 2024 09:52:28 +0200 Subject: [PATCH] feat: open-webui playground prototype Signed-off-by: Jeff MAURY --- packages/backend/src/assets/webui.db | Bin 0 -> 143360 bytes .../src/managers/playgroundV2Manager.ts | 24 +- .../src/registries/ConfigurationRegistry.ts | 4 + .../src/registries/ConversationRegistry.ts | 120 ++++++++- packages/backend/src/studio.ts | 2 + packages/frontend/src/pages/Playground.svelte | 231 +----------------- .../shared/src/models/IPlaygroundMessage.ts | 5 + 7 files changed, 153 insertions(+), 233 deletions(-) create mode 100644 packages/backend/src/assets/webui.db diff --git a/packages/backend/src/assets/webui.db b/packages/backend/src/assets/webui.db new file mode 100644 index 0000000000000000000000000000000000000000..0f335a153778dc4a58a39eb44684ddd5b50962f4 GIT binary patch literal 143360 zcmeI*&vP4R8Nl(~TDD@xw$cP6(l%JA3693GBgv911zKtxwWLk##Ib2Q!0fE0mAyp2 zBCVV{{e_*t3@2{P^wJsr0tSu@<-!S=;ljY=KVV>n84i>K^uYV>YGp}TIWBEI6YDEW zmUj1j-`)4~KD(0kStmDdTs18tRH)QSx)qvJPAICXyc7y4it>zjJTD&Z?Jwem&z%s@ zYTN5xFP~9n3x6M!g#-VUWxnh@^(xMM1Q0*~0R#|0009ILKmY**5cuo`o?9Igjh-BR zRSEnt@P5DuBm)EEzZt(jzBWEJ_Sdmrj{R_Kb!=+%-=n`8eIN>PLjVB;5I_I{1Q0*~ z0R(J;^huwxun>)QrdDsPd)Asu#bfiAPWY6psFkv7ZRy1=qps=Yyj-I$ z*1ctT){W22M^hIAKINjQo9x)ux@U!GDjH8tkNcDs@*-k&dS? z9CJ!Lp-?hPmD;vv@8fgHOf=*7DT|`*w&5@8b*njunr?Ykji!^SM|d0kXZwYp`Q zC8KWXrK%GPxlO%RFpXl~V`~$$nfZ7um52`dl$dD535G^BFL?5vs}(h3iDWu6>+>ly zqJ|RG@PF9Ut&1Nm1Sj2s^a(s;S*JJXA!T`Q%(YKKGP8%uy#M zij|yR)as%)Vw^o!luAZLFsV*gVwzbtEmKy>Z5k!rvqmx|qIE>PAh^m(21@#Y#>uZdU46L(IA; zf6Z3Q8-;&N{M3Dv8^xl0kpDmIfn|LNAb&hiH*Ce+;-OJ6S<1yepm?PE?0L;pE& zOB6%^0R#|0;E5O5u~z-T)2G$@i}u@MwN@!rE$4S&aV5L3nhmWkeDi8H6n18YL+8V0 zJ{-ETw3>Y-yAoPjUJWg+UA-E*zH;T-!pa+=*RpS12!(T%Qb{k1a^GH9S^U<*%K3OQ z8EF+31?2w-q{)VQ6|hXJXoN$n*;`@-?n>Mj_8Q8TQMS5Dhy|8eGU}FIs(NgsTR-yM z8~)(b8FfcB%X#C?d?nW?iN>_DUNYQ?k!JVTmag1b%Zd(P&fW@z-KBQXaAM187tH-@_|EpE)fYECLP|QU{2PTUikq6K-VWHg!v@Z|XHe1go4F z@*J#=V$KN3S(4PP9FAl!4+Wo{P)$21>>5s!gT0%SJBSQ1w=ppt**&{xe~^7JZI7;8 zBp+Yw9M!fOJ;&hF?}K(9I|fLML#xLfD;5j%T5i+aGCYnRVqvak=$4V!q}pY1*jc@D zE$g}CH>&vq6pUoQ(Wb38tW7b1@@c?RyIH`|mJ`jrA+i&G(fQgiWoyet^5N^$ZmZGJ z=~f3s2aS?$7JKh5RlQ!nTd4`04$t-$baCw1634{uV{hb*siEMxN%fspWSMzO-)Kgx z{o3!@YZ)YN(Y-*k=d}8-ebpbFoK)|RwW7^l-l}@6Q=&EBJ)HDBzKadq{}3UY+V?Hu z0P%Z(9e&}&Q1J9A)v)m#Hu3vuPcUb(LnU6fkj4Z4;HgvU?tqOXmWca_PME#oJtJg) zj>q;^>=wEcJ&~P=>JLUD>aH_tC3B-D_L)tyZdGdAol{SBZf55~cM!WHuIt@?{|#?8 z*<;s^kpF+;KSa_21Q0*~0R#|0009ILKmY**9xE{B|AlhWH>m`EHu&7w@5kP?YPanpcu?v_!ji;SFSjlG=2UP|I}_T`1OtE-{& zi)$+@*`-xYCMs^OE?m3bp5Gi9-b+&)sIJ_UwiIgUyqS-TMDCxsZD${TS zqNC$>F4a9#+v{}u{-jy-YJZ#sogfVI~(iD z{ycKBd2hJFd#aD7LE45jT-589=HyH4jM9E>)(y)t%NuoPv&F`kYgj&+B?x*uNt}LMz?pp zWVmU7_PKkZ{3x8h&rZ0EP#=dX0C zxU)Ya8Ed=ReT9J7)}CIUH;h=fS#JiNC$d<&QLL=%A~L#-BiTkUWB!WoTl5t=Q9=UNlPUW=?A!;ApPU*|@vZPJGJL zzdOSn;VXBBf@gwi$!@k($s3Z+em?Elpj*V*PJ7C2YUfVX9}EW7`wRBYYcI2ZeW}?} z^U9z0N(YaZb-iwAc2&DCGMq$seUb>O>NUOOafJI6$Fj(^%206fw0g(kGfE;KD*fyi zL2p2}P*c#JY!NP&?N2AW&bxurj`RBqU5M@VdcK$1v9;f2Q9X}tk2!+xXBNqp+V;lg zCdj8#o*=Dll;J5ST#VJCxC~*Uc^QITN_XDw_jy5%E^@XZP*q$UN^v z-th;gr`6qId)MnaFm+9aIwGliq4>b-9^lH2QY*6N;+>yI7nhfAuC6SIH=a;e3tIJd zu05e2I^Wiw_z?QNHF1;Z= z00IagfB*srAbFww_PS=c36d|Nq|+8Cbf300IagfB*srAb { const conversation = this.#conversationRegistry.get(conversationId); this.telemetry.logUsage('playground.delete', { totalMessages: conversation.messages.length, modelId: getHash(conversation.modelId), }); - this.#conversationRegistry.deleteConversation(conversationId); + await this.#conversationRegistry.deleteConversation(conversationId); } async requestCreatePlayground(name: string, model: ModelInfo): Promise { @@ -117,11 +126,11 @@ export class PlaygroundV2Manager implements Disposable { } // Create conversation - const conversationId = this.#conversationRegistry.createConversation(name, model.id); + const conversationId = await this.#conversationRegistry.createConversation(name, model.id); // create/start inference server if necessary const servers = this.inferenceManager.getServers(); - const server = servers.find(s => s.models.map(mi => mi.id).includes(model.id)); + let server = servers.find(s => s.models.map(mi => mi.id).includes(model.id)); if (!server) { await this.inferenceManager.createInferenceServer( await withDefaultConfiguration({ @@ -131,10 +140,15 @@ export class PlaygroundV2Manager implements Disposable { }, }), ); + server = this.inferenceManager.findServerByModel(model); } else if (server.status === 'stopped') { await this.inferenceManager.startInferenceServer(server.container.containerId); } + if (server && server.status === 'running') { + await this.#conversationRegistry.startConversationContainer(server, trackingId, conversationId); + } + return conversationId; } diff --git a/packages/backend/src/registries/ConfigurationRegistry.ts b/packages/backend/src/registries/ConfigurationRegistry.ts index 5cb66237b..bcc2eb67c 100644 --- a/packages/backend/src/registries/ConfigurationRegistry.ts +++ b/packages/backend/src/registries/ConfigurationRegistry.ts @@ -62,6 +62,10 @@ export class ConfigurationRegistry extends Publisher imp return path.join(this.appUserDirectory, 'models'); } + public getConversationsPath(): string { + return path.join(this.appUserDirectory, 'conversations'); + } + dispose(): void { this.#configurationDisposable?.dispose(); } diff --git a/packages/backend/src/registries/ConversationRegistry.ts b/packages/backend/src/registries/ConversationRegistry.ts index eab300242..c1f7af748 100644 --- a/packages/backend/src/registries/ConversationRegistry.ts +++ b/packages/backend/src/registries/ConversationRegistry.ts @@ -25,14 +25,36 @@ import type { Message, PendingChat, } from '@shared/src/models/IPlaygroundMessage'; -import type { Disposable, Webview } from '@podman-desktop/api'; +import { + type Disposable, + type Webview, + type ContainerCreateOptions, + containerEngine, + type ContainerProviderConnection, + type ImageInfo, + type PullEvent, +} from '@podman-desktop/api'; import { Messages } from '@shared/Messages'; +import type { ConfigurationRegistry } from './ConfigurationRegistry'; +import path from 'node:path'; +import fs from 'node:fs'; +import type { InferenceServer } from '@shared/src/models/IInference'; +import { getFreeRandomPort } from '../utils/ports'; +import { DISABLE_SELINUX_LABEL_SECURITY_OPTION } from '../utils/utils'; +import { getImageInfo } from '../utils/inferenceUtils'; +import type { TaskRegistry } from './TaskRegistry'; +import type { PodmanConnection } from '../managers/podmanConnection'; export class ConversationRegistry extends Publisher implements Disposable { #conversations: Map; #counter: number; - constructor(webview: Webview) { + constructor( + webview: Webview, + private configurationRegistry: ConfigurationRegistry, + private taskRegistry: TaskRegistry, + private podmanConnection: PodmanConnection, + ) { super(webview, Messages.MSG_CONVERSATIONS_UPDATE, () => this.getAll()); this.#conversations = new Map(); this.#counter = 0; @@ -76,13 +98,32 @@ export class ConversationRegistry extends Publisher implements D this.notify(); } - deleteConversation(id: string): void { + async deleteConversation(id: string): Promise { + const conversation = this.get(id); + if (conversation.container) { + await containerEngine.stopContainer(conversation.container?.engineId, conversation.container?.containerId); + } + await fs.promises.rm(path.join(this.configurationRegistry.getConversationsPath(), id), { + recursive: true, + force: true, + }); this.#conversations.delete(id); this.notify(); } - createConversation(name: string, modelId: string): string { + async createConversation(name: string, modelId: string): Promise { const conversationId = this.getUniqueId(); + const conversationFolder = path.join(this.configurationRegistry.getConversationsPath(), conversationId); + await fs.promises.mkdir(conversationFolder, { + recursive: true, + }); + //WARNING: this will not work in production mode but didn't find how to embed binary assets + //this code get an initialized database so that default user is not admin thus did not get the initial + //welcome modal dialog + await fs.promises.copyFile( + path.join(__dirname, '..', 'src', 'assets', 'webui.db'), + path.join(conversationFolder, 'webui.db'), + ); this.#conversations.set(conversationId, { name: name, modelId: modelId, @@ -93,6 +134,77 @@ export class ConversationRegistry extends Publisher implements D return conversationId; } + async startConversationContainer(server: InferenceServer, trackingId: string, conversationId: string): Promise { + const conversation = this.get(conversationId); + const port = await getFreeRandomPort('127.0.0.1'); + const connection = await this.podmanConnection.getConnectionByEngineId(server.container.engineId); + await this.pullImage(connection, 'ghcr.io/open-webui/open-webui:main', { + trackingId: trackingId, + }); + const inferenceServerContainer = await containerEngine.inspectContainer( + server.container.engineId, + server.container.containerId, + ); + const options: ContainerCreateOptions = { + Env: [ + 'DEFAULT_LOCALE=en-US', + 'WEBUI_AUTH=false', + 'ENABLE_OLLAMA_API=false', + `OPENAI_API_BASE_URL=http://${inferenceServerContainer.NetworkSettings.IPAddress}:8000/v1`, + 'OPENAI_API_KEY=sk_dummy', + `WEBUI_URL=http://localhost:${port}`, + `DEFAULT_MODELS=/models/${server.models[0].file?.file}`, + ], + Image: 'ghcr.io/open-webui/open-webui:main', + HostConfig: { + AutoRemove: true, + Mounts: [ + { + Source: path.join(this.configurationRegistry.getConversationsPath(), conversationId), + Target: '/app/backend/data', + Type: 'bind', + }, + ], + PortBindings: { + '8080/tcp': [ + { + HostPort: `${port}`, + }, + ], + }, + SecurityOpt: [DISABLE_SELINUX_LABEL_SECURITY_OPTION], + }, + }; + const c = await containerEngine.createContainer(server.container.engineId, options); + conversation.container = { engineId: c.engineId, containerId: c.id, port }; + } + + protected pullImage( + connection: ContainerProviderConnection, + image: string, + labels: { [id: string]: string }, + ): Promise { + // Creating a task to follow pulling progress + const pullingTask = this.taskRegistry.createTask(`Pulling ${image}.`, 'loading', labels); + + // get the default image info for this provider + return getImageInfo(connection, image, (_event: PullEvent) => {}) + .catch((err: unknown) => { + pullingTask.state = 'error'; + pullingTask.progress = undefined; + pullingTask.error = `Something went wrong while pulling ${image}: ${String(err)}`; + throw err; + }) + .then(imageInfo => { + pullingTask.state = 'success'; + pullingTask.progress = undefined; + return imageInfo; + }) + .finally(() => { + this.taskRegistry.updateTask(pullingTask); + }); + } + /** * This method will be responsible for finalizing the message by concatenating all the choices * @param conversationId diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index 7d25a1108..d41951e23 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -306,6 +306,8 @@ export class Studio { this.#taskRegistry, this.#telemetry, this.#cancellationTokenRegistry, + this.#configurationRegistry, + this.#podmanConnection, ); this.#extensionContext.subscriptions.push(this.#playgroundManager); diff --git a/packages/frontend/src/pages/Playground.svelte b/packages/frontend/src/pages/Playground.svelte index 62d1a6200..7900f81b1 100644 --- a/packages/frontend/src/pages/Playground.svelte +++ b/packages/frontend/src/pages/Playground.svelte @@ -1,113 +1,26 @@ {#if conversation} @@ -188,110 +69,12 @@ function handleOnClick(): void { -
-
- - -
-
- {#if conversation} - - {#key conversation.messages.length} - - {/key} - -
    - {#each messages as message} -
  • - -
  • - {/each} -
- {/if} -
-
-
- -
Next prompt will use these settings
-
-
Model Parameters
-
-
-
- -
- - - -
- What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output - more random, while lower values like 0.2 will make it more focused and deterministic. -
-
-
-
-
-
- -
- - - -
- The maximum number of tokens that can be generated in the chat completion. -
-
-
-
-
-
- -
- - - -
- An alternative to sampling with temperature, where the model considers the results of the - tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% - probability mass are considered. -
-
-
-
-
-
-
-
-
- {#if errorMsg} -
{errorMsg}
- {/if} -
- - -
- {#if !sendEnabled && cancellationTokenId !== undefined} - - {/if} -
+
+
+
diff --git a/packages/shared/src/models/IPlaygroundMessage.ts b/packages/shared/src/models/IPlaygroundMessage.ts index cdebc2046..4333305ac 100644 --- a/packages/shared/src/models/IPlaygroundMessage.ts +++ b/packages/shared/src/models/IPlaygroundMessage.ts @@ -57,6 +57,11 @@ export interface Conversation { messages: Message[]; modelId: string; name: string; + container?: { + engineId: string; + containerId: string; + port: number; + }; } export interface Choice {