Skip to content

Commit

Permalink
Improves the relevances of codestral completion
Browse files Browse the repository at this point in the history
  • Loading branch information
brichet committed Nov 8, 2024
1 parent 8043bf9 commit 2d30689
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 17 deletions.
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"@jupyterlab/settingregistry": "^4.2.0",
"@langchain/core": "^0.3.13",
"@langchain/mistralai": "^0.1.1",
"@lumino/commands": "^2.1.2",
"@lumino/coreutils": "^2.1.2",
"@lumino/polling": "^2.1.2",
"@lumino/signaling": "^2.1.2"
Expand Down
6 changes: 6 additions & 0 deletions src/completion-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export class CompletionProvider implements IInlineCompletionProvider {

constructor(options: CompletionProvider.IOptions) {
const { name, settings } = options;
this._requestCompletion = options.requestCompletion;
this.setCompleter(name, settings);
}

Expand All @@ -28,6 +29,9 @@ export class CompletionProvider implements IInlineCompletionProvider {
setCompleter(name: string, settings: ReadonlyPartialJSONObject) {
try {
this._completer = getCompleter(name, settings);
if (this._completer) {
this._completer.requestCompletion = this._requestCompletion;
}
this._name = this._completer === null ? 'None' : name;
} catch (e: any) {
this._completer = null;
Expand Down Expand Up @@ -65,11 +69,13 @@ export class CompletionProvider implements IInlineCompletionProvider {
}

private _name: string = 'None';
private _requestCompletion: () => void;
private _completer: IBaseCompleter | null = null;
}

export namespace CompletionProvider {
export interface IOptions extends BaseCompleter.IOptions {
name: string;
requestCompletion: () => void;
}
}
5 changes: 4 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ const aiProviderPlugin: JupyterFrontEndPlugin<IAIProvider> = {
manager: ICompletionProviderManager,
settingRegistry: ISettingRegistry
): IAIProvider => {
const aiProvider = new AIProvider({ completionProviderManager: manager });
const aiProvider = new AIProvider({
completionProviderManager: manager,
requestCompletion: () => app.commands.execute('inline-completer:invoke')
});

settingRegistry
.load(aiProviderPlugin.id)
Expand Down
5 changes: 5 additions & 0 deletions src/llm-models/base-completer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ export interface IBaseCompleter {
*/
provider: LLM;

/**
* The function to fetch a new completion.
*/
requestCompletion?: () => void;

/**
* The fetch request for the LLM completer.
*/
Expand Down
81 changes: 67 additions & 14 deletions src/llm-models/codestral-completer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,70 @@ const INTERVAL = 1000;

export class CodestralCompleter implements IBaseCompleter {
constructor(options: BaseCompleter.IOptions) {
// this._requestCompletion = options.requestCompletion;
this._mistralProvider = new MistralAI({ ...options.settings });
this._throttler = new Throttler(async (data: CompletionRequest) => {
const response = await this._mistralProvider.completionWithRetry(
data,
{},
false
);
const items = response.choices.map((choice: any) => {
return { insertText: choice.message.content as string };
});
this._throttler = new Throttler(
async (data: CompletionRequest) => {
this._invokedData = data;
let fetchAgain = false;

return {
items
};
}, INTERVAL);
// Request completion.
const response = await this._mistralProvider.completionWithRetry(
data,
{},
false
);

// Extract results of completion request.
let items = response.choices.map((choice: any) => {
return { insertText: choice.message.content as string };
});

// Check if the prompt has changed during the request.
if (this._invokedData.prompt !== this._currentData?.prompt) {
// The current prompt does not include the invoked one, the result is
// cancelled and a new completion will be requested.
if (!this._currentData?.prompt.startsWith(this._invokedData.prompt)) {
fetchAgain = true;
items = [];
} else {
// Check if some results contain the current prompt, and return them if so,
// otherwise request completion again.
const newItems: { insertText: string }[] = [];
items.forEach(item => {
const result = this._invokedData!.prompt + item.insertText;
if (result.startsWith(this._currentData!.prompt)) {
const insertText = result.slice(
this._currentData!.prompt.length
);
newItems.push({ insertText });
}
});
if (newItems.length) {
items = newItems;
} else {
fetchAgain = true;
items = [];
}
}
}
return {
items,
fetchAgain
};
},
{ limit: INTERVAL }
);
}

get provider(): LLM {
return this._mistralProvider;
}

set requestCompletion(value: () => void) {
this._requestCompletion = value;
}

async fetch(
request: CompletionHandler.IRequest,
context: IInlineCompletionContext
Expand All @@ -59,13 +102,23 @@ export class CodestralCompleter implements IBaseCompleter {
};

try {
return this._throttler.invoke(data);
this._currentData = data;
const completionResult = await this._throttler.invoke(data);
if (completionResult.fetchAgain) {
if (this._requestCompletion) {
this._requestCompletion();
}
}
return { items: completionResult.items };
} catch (error) {
console.error('Error fetching completions', error);
return { items: [] };
}
}

private _requestCompletion?: () => void;
private _throttler: Throttler;
private _mistralProvider: MistralAI;
private _invokedData: CompletionRequest | null = null;
private _currentData: CompletionRequest | null = null;
}
7 changes: 6 additions & 1 deletion src/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ export class AIProvider implements IAIProvider {
constructor(options: AIProvider.IOptions) {
this._completionProvider = new CompletionProvider({
name: 'None',
settings: {}
settings: {},
requestCompletion: options.requestCompletion
});
options.completionProviderManager.registerInlineProvider(
this._completionProvider
Expand Down Expand Up @@ -103,6 +104,10 @@ export namespace AIProvider {
* The completion provider manager in which register the LLM completer.
*/
completionProviderManager: ICompletionProviderManager;
/**
* The application commands registry.
*/
requestCompletion: () => void;
}

/**
Expand Down
3 changes: 2 additions & 1 deletion yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1740,7 +1740,7 @@ __metadata:
languageName: node
linkType: hard

"@lumino/commands@npm:^2.3.0":
"@lumino/commands@npm:^2.1.2, @lumino/commands@npm:^2.3.0":
version: 2.3.0
resolution: "@lumino/commands@npm:2.3.0"
dependencies:
Expand Down Expand Up @@ -4885,6 +4885,7 @@ __metadata:
"@jupyterlab/settingregistry": ^4.2.0
"@langchain/core": ^0.3.13
"@langchain/mistralai": ^0.1.1
"@lumino/commands": ^2.1.2
"@lumino/coreutils": ^2.1.2
"@lumino/polling": ^2.1.2
"@lumino/signaling": ^2.1.2
Expand Down

0 comments on commit 2d30689

Please sign in to comment.