Skip to content

Commit

Permalink
LLM activator
Browse files Browse the repository at this point in the history
  • Loading branch information
morfeusys committed Jan 19, 2024
1 parent 0132e08 commit fc22a80
Show file tree
Hide file tree
Showing 31 changed files with 1,024 additions and 58 deletions.
22 changes: 22 additions & 0 deletions activators/llm/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import plugins.publish.POM_DESCRIPTION
import plugins.publish.POM_NAME

ext[POM_NAME] = "JAICF-Kotlin LLM Activator Adapter"
ext[POM_DESCRIPTION] = "JAICF-Kotlin LLM Activator Adapter."

plugins {
`jaicf-kotlin`
`jaicf-kotlin-serialization`
`jaicf-publish`
`jaicf-junit`
}

dependencies {
core()
api(ktor("ktor-client-apache"))
api(ktor("ktor-client-jackson"))
api(ktor("ktor-client-logging-jvm"))
implementation("com.knuddels:jtokkit:0.6.1")
testImplementation("io.mockk:mockk" version { mockk })
testImplementation(ktor("ktor-client-mock"))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package com.justai.jaicf.activator.llm

import com.justai.jaicf.activator.event.EventByNameActivationRule
import com.justai.jaicf.activator.llm.function.LLMFunction
import com.justai.jaicf.activator.llm.function.LLMFunctionParametersBuilder
import com.justai.jaicf.builder.ActivationRulesBuilder

class LLMFunctionActivationRule(val function: LLMFunction) : EventByNameActivationRule(LLMEvent.FUNCTION_CALL)

fun ActivationRulesBuilder.llmMessage() =
event(LLMEvent.MESSAGE)

fun ActivationRulesBuilder.llmFunction(name: String) =
event(LLMEvent.FUNCTION_CALL).onlyIf {
activator.llmFunction?.name == name
}

fun ActivationRulesBuilder.llmFunction(function: LLMFunction) =
rule(LLMFunctionActivationRule(function)).onlyIf {
activator.llmFunction?.name == function.name
}

fun ActivationRulesBuilder.llmFunction(
name: String,
description: String,
parameters: LLMFunctionParametersBuilder
) = llmFunction(LLMFunction.create(name, description, parameters))
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package com.justai.jaicf.activator.llm

import com.justai.jaicf.activator.ActivationRuleMatcher
import com.justai.jaicf.activator.ActivatorFactory
import com.justai.jaicf.activator.BaseActivator
import com.justai.jaicf.activator.event.EventActivator
import com.justai.jaicf.activator.event.EventByNameActivationRule
import com.justai.jaicf.activator.llm.client.LLMClient
import com.justai.jaicf.activator.llm.client.LLMRequest
import com.justai.jaicf.activator.llm.client.openai.OpenAIClient
import com.justai.jaicf.activator.selection.ActivationSelector
import com.justai.jaicf.api.BotRequest
import com.justai.jaicf.api.QueryBotRequest
import com.justai.jaicf.api.hasQuery
import com.justai.jaicf.context.ActivatorContext
import com.justai.jaicf.context.BotContext
import com.justai.jaicf.model.activation.Activation
import com.justai.jaicf.model.scenario.ScenarioModel

class LLMActivator(
private val model: ScenarioModel,
private val setting: LLMSettings = LLMSettings(),
private val client: LLMClient = OpenAIClient(setting),
) : BaseActivator(model), EventActivator {
override val name = "llmActivator"

override fun canHandle(request: BotRequest) = request.hasQuery()

override fun provideRuleMatcher(botContext: BotContext, request: BotRequest): ActivationRuleMatcher {
val messages = botContext.llmChatHistory.toMutableList()
when {
request is LLMFunctionBotRequest -> LLMRequest.Message.function(request.name, request.result)
request.input.isNotEmpty() -> LLMRequest.Message.user(request.input)
else -> null
}?.also(messages::add)

val transitions = model.generateTransitions(botContext)
val functions = transitions.map { it.rule }.filterIsInstance(LLMFunctionActivationRule::class.java).map { it.function }
val llmSettings = setting.assign(botContext.llmSettings)
val chatRequest = llmSettings.createChatRequest(messages).let { it.copy(
functions = (it.functions.orEmpty() + functions).ifEmpty { null }
) }.let {
LLMEncoding.trim(it, llmSettings.model!!)
}

val resp = client.chatCompletion(chatRequest)
val message = resp.choices.first().message
botContext.llmChatHistory = chatRequest.messages + message.toRequestMessage()

val context = if (message.functionCall != null) {
LLMFunctionActivatorContext(botContext, message.functionCall.name, message.functionCall.arguments)
} else {
LLMMessageActivatorContext(botContext, message.content!!)
}

return ruleMatcher<EventByNameActivationRule> {
context.takeIf { ctx -> it.event == ctx.event }
}
}

override fun activate(
botContext: BotContext,
request: BotRequest,
selector: ActivationSelector,
activation: ActivatorContext
): Activation? {
val act = activation as LLMActivatorContext
botContext.llmChatHistory = act.history.toList()
return when {
act is LLMFunctionActivatorContext && act.isComplete() ->
LLMFunctionBotRequest(request, act.name, act.result)
act.activate -> QueryBotRequest(request.clientId, "")
else -> null
}?.let { activate(botContext, it, selector) }
}

class Factory(private val settings: LLMSettings) : ActivatorFactory {
override fun create(model: ScenarioModel) =
LLMActivator(model, settings)
}

companion object : ActivatorFactory {
override fun create(model: ScenarioModel) = LLMActivator(model)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package com.justai.jaicf.activator.llm

import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
import com.justai.jaicf.activator.event.EventActivatorContext
import com.justai.jaicf.activator.llm.client.LLMRequest
import com.justai.jaicf.context.ActivatorContext
import com.justai.jaicf.context.BotContext
import com.justai.jaicf.helpers.context.sessionProperty

var BotContext.llmSettings: LLMSettings by sessionProperty { LLMSettings() }
var BotContext.llmChatHistory by sessionProperty { emptyList<LLMRequest.Message>() }

val llmJsonMapper = jacksonObjectMapper()

val ActivatorContext.llmMessage
get() = this as? LLMMessageActivatorContext

val ActivatorContext.llmFunction
get() = this as? LLMFunctionActivatorContext

sealed class LLMActivatorContext(
override val event: String,
private val botContext: BotContext,
) : EventActivatorContext(event) {
internal var activate = false
val history = botContext.llmChatHistory.toMutableList()

fun setSystemMessage(content: String) = also {
history.removeIf { it.role.isSystem }
history.add(LLMRequest.Message.system(content))
botContext.llmChatHistory = history
}

fun activate() {
activate = true
}

fun activateWithSystemMessage(content: String) = setSystemMessage(content).also { activate() }
}

data class LLMMessageActivatorContext(
private val botContext: BotContext,
val content: String,
) : LLMActivatorContext(LLMEvent.MESSAGE, botContext) {
inline fun <reified T> fromJson(): T =
llmJsonMapper.readValue(content, T::class.java)
}

data class LLMFunctionActivatorContext(
private val botContext: BotContext,
val name: String,
val arguments: String,
) : LLMActivatorContext(LLMEvent.FUNCTION_CALL, botContext) {
lateinit var result: String

fun isComplete() = ::result.isInitialized

inline fun <reified T> parseArguments(): T =
llmJsonMapper.readValue(arguments, T::class.java)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.justai.jaicf.activator.llm

import com.justai.jaicf.api.BotRequest
import com.justai.jaicf.api.EventBotRequest

internal data class LLMFunctionBotRequest(
val origin: BotRequest,
val name: String,
val result: String,
) : EventBotRequest(
clientId = origin.clientId,
input = LLMEvent.FUNCTION_CALL
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package com.justai.jaicf.activator.llm

import com.justai.jaicf.activator.llm.client.LLMRequest
import com.knuddels.jtokkit.Encodings
import com.knuddels.jtokkit.api.Encoding

object LLMEncoding {
private val registry = Encodings.newDefaultEncodingRegistry()

private fun countTokens(request: LLMRequest, encoding: Encoding?) =
(request.messages.joinToString("\n") + "\n" + request.functions?.toString().orEmpty()).let { text ->
encoding?.countTokens(text) ?: text.length
}

fun trim(
request: LLMRequest,
model: LLMSettings.Model,
minTokens: Int = model.maxContextLength / 4
): LLMRequest {
val enc = registry.getEncodingForModel(model.name).orElse(null)
val messages = request.messages.toMutableList()
var tokens = countTokens(request, enc)
while (tokens > model.maxContextLength - minTokens) {
val index = messages.indexOfFirst { !it.role.isSystem }
if (index == -1) {
break
} else {
messages.removeAt(index)
tokens = countTokens(request.copy(messages = messages), enc)
}
}
if (messages.none { !it.role.isSystem }) {
throw IllegalArgumentException("Messages history is too big for this model")
}
return request.copy(messages = messages)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package com.justai.jaicf.activator.llm

object LLMEvent {
const val MESSAGE = "llm_message"
const val FUNCTION_CALL = "llm_function_call"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package com.justai.jaicf.activator.llm

import com.justai.jaicf.activator.llm.client.LLMRequest
import com.justai.jaicf.activator.llm.function.LLMFunction
import com.knuddels.jtokkit.api.ModelType

private const val DEFAULT_BASE_URL = "https://api.openai.com/v1"
private val DEFAULT_MODEL = LLMSettings.Model("gpt-3.5-turbo-1106", 16_000)

data class LLMSettings(
val apiKey: String = System.getenv("OPENAI_API_KEY"),
val baseUrl: String = System.getenv("OPENAI_API_BASE_URL") ?: DEFAULT_BASE_URL,
val model: Model? = null,
val temperature: Float? = null,
val maxTokens: Int? = null,
val topP: Float? = null,
val frequencyPenalty: Float? = null,
val presencePenalty: Float? = null,
val responseFormat: LLMRequest.ResponseFormat? = null,
val functions: List<LLMFunction>? = null,
val functionCall: LLMRequest.FunctionCall? = null,
) {
val isJsonFormat
get() = responseFormat?.type === LLMRequest.ResponseFormat.Type.json_object

fun createChatRequest(messages: List<LLMRequest.Message>) =
LLMRequest(
model = model?.name.orEmpty(),
temperature = temperature,
maxTokens = maxTokens,
topP = topP,
frequencyPenalty = frequencyPenalty,
presencePenalty = presencePenalty,
responseFormat = responseFormat,
functions = functions,
functionCall = functionCall,
messages = messages,
)

fun assign(from: LLMSettings) = copy(
model = from.model ?: model ?: DEFAULT_MODEL,
temperature = from.temperature ?: temperature,
maxTokens = from.maxTokens ?: maxTokens,
topP = from.topP ?: topP,
frequencyPenalty = from.frequencyPenalty ?: frequencyPenalty,
presencePenalty = from.presencePenalty ?: presencePenalty,
responseFormat = from.responseFormat ?: responseFormat,
functions = from.functions ?: functions,
functionCall = from.functionCall ?: functionCall,
)

data class Model(
val name: String,
val maxContextLength: Int,
) {
companion object {
fun fromName(name: String) = ModelType.fromName(name)
.orElse(null)?.let { type ->
Model(name, type.maxContextLength)
} ?: throw IllegalArgumentException("Model $name was not found")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.justai.jaicf.activator.llm

import com.justai.jaicf.generic.ActivatorTypeToken

typealias LLMTypeToken = ActivatorTypeToken<LLMActivatorContext>
typealias LLMMessageTypeToken = ActivatorTypeToken<LLMMessageActivatorContext>
typealias LLMFunctionTypeToken = ActivatorTypeToken<LLMFunctionActivatorContext>

val llm: LLMTypeToken = ActivatorTypeToken()
val llmMessage: LLMMessageTypeToken = ActivatorTypeToken()
val llmFunction: LLMFunctionTypeToken = ActivatorTypeToken()
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package com.justai.jaicf.activator.llm.client

interface LLMClient {
fun chatCompletion(request: LLMRequest): LLMResponse
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.justai.jaicf.activator.llm.client

enum class LLMMessageRole {
system, user, assistant;

val isSystem get() = this === system
val isUser get() = this === user
val isAssistant get() = this === assistant
}
Loading

0 comments on commit fc22a80

Please sign in to comment.