From 94c85390fc2dd2c6cfc03ff969d2d7d4e96ca028 Mon Sep 17 00:00:00 2001 From: Benjamin Wrensch Date: Sun, 18 Aug 2024 14:43:38 +0200 Subject: [PATCH] [add] prototype --- iolite_plugins/CMakeLists.txt | 7 + iolite_plugins/chatgpt/chatgpt_plugin.cpp | 332 ++++++++++++++++++++++ 2 files changed, 339 insertions(+) create mode 100644 iolite_plugins/chatgpt/chatgpt_plugin.cpp diff --git a/iolite_plugins/CMakeLists.txt b/iolite_plugins/CMakeLists.txt index d651457..a2eebf5 100644 --- a/iolite_plugins/CMakeLists.txt +++ b/iolite_plugins/CMakeLists.txt @@ -59,6 +59,13 @@ add_library(IoliteBenchmarkPlugin SHARED ) list(APPEND PLUGINS IoliteBenchmarkPlugin) +# ChatGPT plugin +add_library(IoliteChatGPTPlugin SHARED + chatgpt/chatgpt_plugin.cpp + ${IMGUI_SOURCES} +) +list(APPEND PLUGINS IoliteChatGPTPlugin) + # OIDN denoiser plugin add_library(IoliteDenoiserOIDNPlugin SHARED denoiser_oidn_plugin/denoiser_oidn_plugin.cpp diff --git a/iolite_plugins/chatgpt/chatgpt_plugin.cpp b/iolite_plugins/chatgpt/chatgpt_plugin.cpp new file mode 100644 index 0000000..6ccb381 --- /dev/null +++ b/iolite_plugins/chatgpt/chatgpt_plugin.cpp @@ -0,0 +1,332 @@ +// MIT License +// +// Copyright (c) 2024 Missing Deadlines (Benjamin Wrensch) +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// API +#include "iolite_api.h" +#include "lua_plugin.h" +#include "lua_plugin_api.h" +#define _IO_PLUGIN_NAME "ChatGPT" +#include "iolite_plugins_common.h" +#include "iolite_plugins_libraries.h" + +// Dependencies +#include "rapidjson/document.h" +#include "rapidjson/stringbuffer.h" +#include "rapidjson/writer.h" +#include "imgui.h" +#include "IconsFontAwesome6.h" + +// Interfaces we use +//----------------------------------------------------------------------------// +static const io_api_manager_i* io_api_manager = nullptr; +static const io_logging_i* io_logging = nullptr; +static const io_plugin_lua_i* io_plugin_lua = nullptr; +static const io_low_level_imgui_i* io_low_level_imgui = nullptr; +static const io_settings_i* io_settings = nullptr; +static io_user_editor_i io_user_editor = {}; + +//----------------------------------------------------------------------------// +static const char* tune = R"( +Create a voxel asset in my voxel game engine using Lua 5.1 with the following rules: + + - The grid is exactly 64^3 voxels large + - The following API function can be used to set voxels: set_voxel(x, y, z, color_palette_index) + - The parameter 'color_palette_index' is an index into a color palette; 0 marks air/empty voxels and all values > 0 are solid voxels + - Only use palette indexes 0 to 255 + - Please make sure to only use valid values for 'color_palette_index' + - Assume that the 'set_voxel' function is there; don't create or implement the function + - Coordinate y = 0 is at the bottom + - Only output the raw Lua source code + +The asset itself should look like the following: + + %s + +Output only plain text. Do not output markdown. + )"; + +//----------------------------------------------------------------------------// +static const char* script_before = R"( +Log.load() +Entity.load() +VoxelShape.load() + +local shapeEntity = Entity.find_first_entity_with_name("chatgpt") +local shape = VoxelShape.get_component_for_entity(shapeEntity) +VoxelShape.fill(shape, U8Vec3(0, 0, 0), U8Vec3(63, 63, 63), 0) + +function set_voxel(x, y, z, palette_index) + VoxelShape.set(shape, U8Vec3(x, y, z), palette_index) +end +)"; + +//----------------------------------------------------------------------------// +static const char* script_after = R"( +VoxelShape.voxelize(shape) +)"; + +//----------------------------------------------------------------------------// +IO_API_EXPORT io_uint32_t IO_API_CALL get_api_version() +{ + return IO_API_VERSION; +} + +//----------------------------------------------------------------------------// +static std::string replace_all(std::string string, const std::string& from, + const std::string& to) +{ + size_t start_pos = 0; + while ((start_pos = string.find(from, start_pos)) != std::string::npos) + { + string.replace(start_pos, from.length(), to); + start_pos += + to.length(); // Handles case where 'to' is a substring of 'from' + } + return string; +} + +//----------------------------------------------------------------------------// +static void execute_result(const char* chatgpt_script) +{ + std::string chatgpt_script_clean = replace_all(chatgpt_script, "```lua", ""); + chatgpt_script_clean = replace_all(chatgpt_script_clean, "```", ""); + + std::vector code(128u * 1024u); + stbsp_snprintf(code.data(), code.size(), "%s\n%s\n%s", script_before, + chatgpt_script_clean.c_str(), script_after); + + common::log_message(io_logging, "Generated script:\n%s\n\n", code.data()); + io_plugin_lua->execute_script(code.data()); +} + +//----------------------------------------------------------------------------// +static std::vector request_result; +static constexpr size_t buffer_size = 128u * 1024u; +static std::vector buffer(buffer_size); + +//----------------------------------------------------------------------------// +static auto write_callback(char* ptr, size_t size, size_t nmemb, + void* userdata) -> size_t +{ + const size_t current_size_in_bytes = request_result.size(); + request_result.resize(current_size_in_bytes + size * nmemb); + + const auto new_size_in_bytes = size * nmemb; + memcpy(&request_result[current_size_in_bytes], ptr, new_size_in_bytes); + + return size * nmemb; +} + +//----------------------------------------------------------------------------// +static void send_and_execute_prompt(const char* prompt) +{ + // common::log_message(io_logging, "Prompt:\n%s\n", buffer_a.data()); + + const char* api_key = io_settings->get_string("chatgpt_api_key"); + if (strlen(api_key) == 0) + { + common::log_message(io_logging, "No ChatGPT API key set, aborting..."); + return; + } + + libraries::curl.curl_global_init(CURL_GLOBAL_ALL); + auto curl = libraries::curl.curl_easy_init(); + if (curl) + { + libraries::curl.curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, + write_callback); + libraries::curl.curl_easy_setopt(curl, CURLOPT_SSL_VERIFYPEER, 0); + libraries::curl.curl_easy_setopt( + curl, CURLOPT_URL, "https://api.openai.com/v1/chat/completions"); + + curl_slist* headers = nullptr; + { + char buffer[256]; + + headers = libraries::curl.curl_slist_append( + headers, "Content-Type: application/json"); + + stbsp_snprintf(buffer, sizeof(buffer), "Authorization: Bearer %s", + api_key); + headers = libraries::curl.curl_slist_append(headers, buffer); + + libraries::curl.curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + } + + rapidjson::StringBuffer content; + { + rapidjson::Document doc; + doc.SetObject(); + // doc.AddMember("model", "gpt-4o-mini", doc.GetAllocator()); + doc.AddMember("model", "gpt-4o", doc.GetAllocator()); + + rapidjson::Value messages; + messages.SetArray(); + + rapidjson::Value message; + message.SetObject(); + message.AddMember("role", "user", doc.GetAllocator()); + + stbsp_snprintf(buffer.data(), buffer.size(), tune, prompt); + message.AddMember("content", rapidjson::StringRef(buffer.data()), + doc.GetAllocator()); + + messages.PushBack(message, doc.GetAllocator()); + + doc.AddMember("messages", messages, doc.GetAllocator()); + doc.AddMember("temperature", 0.7f, doc.GetAllocator()); + + rapidjson::Writer writer(content); + doc.Accept(writer); + } + + libraries::curl.curl_easy_setopt(curl, CURLOPT_POSTFIELDS, + content.GetString()); + + request_result.clear(); + auto result = libraries::curl.curl_easy_perform(curl); + if (result != CURLE_OK) + { + common::log_message(io_logging, "Request failed with code '%u'...", + result); + } + else + { + request_result.push_back('\0'); + + rapidjson::Document doc; + doc.Parse(request_result.data()); + + if (doc.HasMember("choices") && !doc["choices"].Empty()) + execute_result(doc["choices"][0u]["message"]["content"].GetString()); + } + libraries::curl.curl_slist_free_all(headers); + + libraries::curl.curl_easy_cleanup(curl); + libraries::curl.curl_global_cleanup(); + } + else + { + common::log_message(io_logging, "Failed to initialize curl..."); + } +} + +//----------------------------------------------------------------------------// +static bool open_prompt = false; + +//----------------------------------------------------------------------------// +static void on_build_plugin_menu() +{ + if (ImGui::MenuItem(ICON_FA_ROBOT " ChatGPT Prompt")) + open_prompt = true; +} + +//----------------------------------------------------------------------------// +static void on_editor_tick(float dt) +{ + if (open_prompt) + { + ImGui::OpenPopup("###chatgpt_prompt"); + open_prompt = false; + } + + if (ImGui::BeginPopupModal(ICON_FA_ROBOT + " ChatGPT Prompt###chatgpt_prompt")) + { + static std::vector buffer; + if (buffer.empty()) + { + buffer.resize(128u * 1024u); + stbsp_snprintf(buffer.data(), buffer.size(), "A small 10x10x10 tree"); + } + + ImGui::InputTextMultiline("###prompt", buffer.data(), buffer.size(), + ImVec2(-1, 0)); + + ImGui::BeginDisabled(strlen(buffer.data()) == 0); + if (ImGui::Button("Execute Prompt")) + send_and_execute_prompt(buffer.data()); + ImGui::SameLine(); + if (ImGui::Button("Close")) + ImGui::CloseCurrentPopup(); + ImGui::EndDisabled(); + + ImGui::EndPopup(); + } +} + +//----------------------------------------------------------------------------// +IO_API_EXPORT int IO_API_CALL load_plugin(void* api_manager) +{ + io_api_manager = (const io_api_manager_i*)api_manager; + + // Retrieve the interfaces we use + { + io_logging = (io_logging_i*)io_api_manager->find_first(IO_LOGGING_API_NAME); + + io_plugin_lua = + (io_plugin_lua_i*)io_api_manager->find_first(IO_PLUGIN_LUA_API_NAME); + if (!io_plugin_lua) + { + common::log_message(io_logging, "Lua plugin is required, aborting..."); + return -1; + } + + io_settings = + (const io_settings_i*)io_api_manager->find_first(IO_SETTINGS_API_NAME); + io_low_level_imgui = + (const io_low_level_imgui_i*)io_api_manager->find_first( + IO_LOW_LEVEL_IMGUI_API_NAME); + } + + // Initialize curl + if (!libraries::load_curl(io_logging)) + { + common::log_message(io_logging, "curl is not available, aborting..."); + return -1; + } + + // Set up Dear ImGui + { + auto ctxt = (ImGuiContext*)io_low_level_imgui->get_imgui_context(); + ImGui::SetCurrentContext(ctxt); + + ImGuiMemAllocFunc alloc_func; + ImGuiMemFreeFunc free_func; + io_low_level_imgui->get_imgui_allocator_functions((void**)&alloc_func, + (void**)&free_func); + ImGui::SetAllocatorFunctions(alloc_func, free_func); + } + + // Register the interfaces we provide + io_user_editor = {}; + { + io_user_editor.on_build_plugin_menu = on_build_plugin_menu; + io_user_editor.on_tick = on_editor_tick; + } + io_api_manager->register_api(IO_USER_EDITOR_API_NAME, &io_user_editor); + + return 0; +} + +//----------------------------------------------------------------------------// +IO_API_EXPORT void IO_API_CALL unload_plugin() {}