Skip to content

Commit

Permalink
[add] prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
begla committed Aug 18, 2024
1 parent 4edcd52 commit 94c8539
Show file tree
Hide file tree
Showing 2 changed files with 339 additions and 0 deletions.
7 changes: 7 additions & 0 deletions iolite_plugins/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ add_library(IoliteBenchmarkPlugin SHARED
)
list(APPEND PLUGINS IoliteBenchmarkPlugin)

# ChatGPT plugin
add_library(IoliteChatGPTPlugin SHARED
chatgpt/chatgpt_plugin.cpp
${IMGUI_SOURCES}
)
list(APPEND PLUGINS IoliteChatGPTPlugin)

# OIDN denoiser plugin
add_library(IoliteDenoiserOIDNPlugin SHARED
denoiser_oidn_plugin/denoiser_oidn_plugin.cpp
Expand Down
332 changes: 332 additions & 0 deletions iolite_plugins/chatgpt/chatgpt_plugin.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,332 @@
// MIT License
//
// Copyright (c) 2024 Missing Deadlines (Benjamin Wrensch)
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

// API
#include "iolite_api.h"
#include "lua_plugin.h"
#include "lua_plugin_api.h"
#define _IO_PLUGIN_NAME "ChatGPT"
#include "iolite_plugins_common.h"
#include "iolite_plugins_libraries.h"

// Dependencies
#include "rapidjson/document.h"
#include "rapidjson/stringbuffer.h"
#include "rapidjson/writer.h"
#include "imgui.h"
#include "IconsFontAwesome6.h"

// Interfaces we use
//----------------------------------------------------------------------------//
static const io_api_manager_i* io_api_manager = nullptr;
static const io_logging_i* io_logging = nullptr;
static const io_plugin_lua_i* io_plugin_lua = nullptr;
static const io_low_level_imgui_i* io_low_level_imgui = nullptr;
static const io_settings_i* io_settings = nullptr;
static io_user_editor_i io_user_editor = {};

//----------------------------------------------------------------------------//
static const char* tune = R"(
Create a voxel asset in my voxel game engine using Lua 5.1 with the following rules:
- The grid is exactly 64^3 voxels large
- The following API function can be used to set voxels: set_voxel(x, y, z, color_palette_index)
- The parameter 'color_palette_index' is an index into a color palette; 0 marks air/empty voxels and all values > 0 are solid voxels
- Only use palette indexes 0 to 255
- Please make sure to only use valid values for 'color_palette_index'
- Assume that the 'set_voxel' function is there; don't create or implement the function
- Coordinate y = 0 is at the bottom
- Only output the raw Lua source code
The asset itself should look like the following:
%s
Output only plain text. Do not output markdown.
)";

//----------------------------------------------------------------------------//
static const char* script_before = R"(
Log.load()
Entity.load()
VoxelShape.load()
local shapeEntity = Entity.find_first_entity_with_name("chatgpt")
local shape = VoxelShape.get_component_for_entity(shapeEntity)
VoxelShape.fill(shape, U8Vec3(0, 0, 0), U8Vec3(63, 63, 63), 0)
function set_voxel(x, y, z, palette_index)
VoxelShape.set(shape, U8Vec3(x, y, z), palette_index)
end
)";

//----------------------------------------------------------------------------//
static const char* script_after = R"(
VoxelShape.voxelize(shape)
)";

//----------------------------------------------------------------------------//
IO_API_EXPORT io_uint32_t IO_API_CALL get_api_version()
{
return IO_API_VERSION;
}

//----------------------------------------------------------------------------//
static std::string replace_all(std::string string, const std::string& from,
const std::string& to)
{
size_t start_pos = 0;
while ((start_pos = string.find(from, start_pos)) != std::string::npos)
{
string.replace(start_pos, from.length(), to);
start_pos +=
to.length(); // Handles case where 'to' is a substring of 'from'
}
return string;
}

//----------------------------------------------------------------------------//
static void execute_result(const char* chatgpt_script)
{
std::string chatgpt_script_clean = replace_all(chatgpt_script, "```lua", "");
chatgpt_script_clean = replace_all(chatgpt_script_clean, "```", "");

std::vector<char> code(128u * 1024u);
stbsp_snprintf(code.data(), code.size(), "%s\n%s\n%s", script_before,
chatgpt_script_clean.c_str(), script_after);

common::log_message(io_logging, "Generated script:\n%s\n\n", code.data());
io_plugin_lua->execute_script(code.data());
}

//----------------------------------------------------------------------------//
static std::vector<char> request_result;
static constexpr size_t buffer_size = 128u * 1024u;
static std::vector<char> buffer(buffer_size);

//----------------------------------------------------------------------------//
static auto write_callback(char* ptr, size_t size, size_t nmemb,
void* userdata) -> size_t
{
const size_t current_size_in_bytes = request_result.size();
request_result.resize(current_size_in_bytes + size * nmemb);

const auto new_size_in_bytes = size * nmemb;
memcpy(&request_result[current_size_in_bytes], ptr, new_size_in_bytes);

return size * nmemb;
}

//----------------------------------------------------------------------------//
static void send_and_execute_prompt(const char* prompt)
{
// common::log_message(io_logging, "Prompt:\n%s\n", buffer_a.data());

const char* api_key = io_settings->get_string("chatgpt_api_key");
if (strlen(api_key) == 0)
{
common::log_message(io_logging, "No ChatGPT API key set, aborting...");
return;
}

libraries::curl.curl_global_init(CURL_GLOBAL_ALL);
auto curl = libraries::curl.curl_easy_init();
if (curl)
{
libraries::curl.curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION,
write_callback);
libraries::curl.curl_easy_setopt(curl, CURLOPT_SSL_VERIFYPEER, 0);
libraries::curl.curl_easy_setopt(
curl, CURLOPT_URL, "https://api.openai.com/v1/chat/completions");

curl_slist* headers = nullptr;
{
char buffer[256];

headers = libraries::curl.curl_slist_append(
headers, "Content-Type: application/json");

stbsp_snprintf(buffer, sizeof(buffer), "Authorization: Bearer %s",
api_key);
headers = libraries::curl.curl_slist_append(headers, buffer);

libraries::curl.curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
}

rapidjson::StringBuffer content;
{
rapidjson::Document doc;
doc.SetObject();
// doc.AddMember("model", "gpt-4o-mini", doc.GetAllocator());
doc.AddMember("model", "gpt-4o", doc.GetAllocator());

rapidjson::Value messages;
messages.SetArray();

rapidjson::Value message;
message.SetObject();
message.AddMember("role", "user", doc.GetAllocator());

stbsp_snprintf(buffer.data(), buffer.size(), tune, prompt);
message.AddMember("content", rapidjson::StringRef(buffer.data()),
doc.GetAllocator());

messages.PushBack(message, doc.GetAllocator());

doc.AddMember("messages", messages, doc.GetAllocator());
doc.AddMember("temperature", 0.7f, doc.GetAllocator());

rapidjson::Writer<rapidjson::StringBuffer> writer(content);
doc.Accept(writer);
}

libraries::curl.curl_easy_setopt(curl, CURLOPT_POSTFIELDS,
content.GetString());

request_result.clear();
auto result = libraries::curl.curl_easy_perform(curl);
if (result != CURLE_OK)
{
common::log_message(io_logging, "Request failed with code '%u'...",
result);
}
else
{
request_result.push_back('\0');

rapidjson::Document doc;
doc.Parse(request_result.data());

if (doc.HasMember("choices") && !doc["choices"].Empty())
execute_result(doc["choices"][0u]["message"]["content"].GetString());
}
libraries::curl.curl_slist_free_all(headers);

libraries::curl.curl_easy_cleanup(curl);
libraries::curl.curl_global_cleanup();
}
else
{
common::log_message(io_logging, "Failed to initialize curl...");
}
}

//----------------------------------------------------------------------------//
static bool open_prompt = false;

//----------------------------------------------------------------------------//
static void on_build_plugin_menu()
{
if (ImGui::MenuItem(ICON_FA_ROBOT " ChatGPT Prompt"))
open_prompt = true;
}

//----------------------------------------------------------------------------//
static void on_editor_tick(float dt)
{
if (open_prompt)
{
ImGui::OpenPopup("###chatgpt_prompt");
open_prompt = false;
}

if (ImGui::BeginPopupModal(ICON_FA_ROBOT
" ChatGPT Prompt###chatgpt_prompt"))
{
static std::vector<char> buffer;
if (buffer.empty())
{
buffer.resize(128u * 1024u);
stbsp_snprintf(buffer.data(), buffer.size(), "A small 10x10x10 tree");
}

ImGui::InputTextMultiline("###prompt", buffer.data(), buffer.size(),
ImVec2(-1, 0));

ImGui::BeginDisabled(strlen(buffer.data()) == 0);
if (ImGui::Button("Execute Prompt"))
send_and_execute_prompt(buffer.data());
ImGui::SameLine();
if (ImGui::Button("Close"))
ImGui::CloseCurrentPopup();
ImGui::EndDisabled();

ImGui::EndPopup();
}
}

//----------------------------------------------------------------------------//
IO_API_EXPORT int IO_API_CALL load_plugin(void* api_manager)
{
io_api_manager = (const io_api_manager_i*)api_manager;

// Retrieve the interfaces we use
{
io_logging = (io_logging_i*)io_api_manager->find_first(IO_LOGGING_API_NAME);

io_plugin_lua =
(io_plugin_lua_i*)io_api_manager->find_first(IO_PLUGIN_LUA_API_NAME);
if (!io_plugin_lua)
{
common::log_message(io_logging, "Lua plugin is required, aborting...");
return -1;
}

io_settings =
(const io_settings_i*)io_api_manager->find_first(IO_SETTINGS_API_NAME);
io_low_level_imgui =
(const io_low_level_imgui_i*)io_api_manager->find_first(
IO_LOW_LEVEL_IMGUI_API_NAME);
}

// Initialize curl
if (!libraries::load_curl(io_logging))
{
common::log_message(io_logging, "curl is not available, aborting...");
return -1;
}

// Set up Dear ImGui
{
auto ctxt = (ImGuiContext*)io_low_level_imgui->get_imgui_context();
ImGui::SetCurrentContext(ctxt);

ImGuiMemAllocFunc alloc_func;
ImGuiMemFreeFunc free_func;
io_low_level_imgui->get_imgui_allocator_functions((void**)&alloc_func,
(void**)&free_func);
ImGui::SetAllocatorFunctions(alloc_func, free_func);
}

// Register the interfaces we provide
io_user_editor = {};
{
io_user_editor.on_build_plugin_menu = on_build_plugin_menu;
io_user_editor.on_tick = on_editor_tick;
}
io_api_manager->register_api(IO_USER_EDITOR_API_NAME, &io_user_editor);

return 0;
}

//----------------------------------------------------------------------------//
IO_API_EXPORT void IO_API_CALL unload_plugin() {}

0 comments on commit 94c8539

Please sign in to comment.