diff --git a/cogs/code_interpreter_service_cog.py b/cogs/code_interpreter_service_cog.py index 1a09607a..92f970d8 100644 --- a/cogs/code_interpreter_service_cog.py +++ b/cogs/code_interpreter_service_cog.py @@ -133,13 +133,15 @@ async def on_message(self, message): return prompt = message.content.strip() - + self.converser_cog.full_conversation_history[message.channel.id].append(prompt) # If the message channel is in self.chat_agents, then we delegate the message to the agent. if message.channel.id in self.chat_agents: if prompt.lower() in ["stop", "end", "quit", "exit"]: await message.reply( - "Ending chat session. You can access the sandbox of this session at https://" - + self.sessions[message.channel.id].get_hostname() + content = "Ending chat session. You can access the sandbox of this session at https://" + + self.sessions[message.channel.id].get_hostname(), + view = ShareView(self.converser_cog, message.channel.id) + if isinstance(message.channel, discord.Thread) else None ) self.sessions[message.channel.id].close() self.chat_agents.pop(message.channel.id) @@ -233,7 +235,6 @@ async def on_message(self, message): print(stdout_output) except: pass - except Exception as e: response = f"Error: {e}" traceback.print_exc() @@ -242,7 +243,7 @@ async def on_message(self, message): ) safe_remove_list(self.thread_awaiting_responses, message.channel.id) return - + self.converser_cog.full_conversation_history[message.channel.id].append(response) # Parse the artifact names. After Artifacts: there should be a list in form [] where the artifact names are inside, comma separated inside stdout_output artifact_names = re.findall(r"Artifacts: \[(.*?)\]", stdout_output) # The artifacts list may be formatted like ["'/home/user/artifacts/test2.txt', '/home/user/artifacts/test.txt'"], where its technically 1 element in the list, so we need to split it by comma and then remove the quotes and spaces @@ -568,3 +569,61 @@ async def callback(self, interaction: discord.Interaction): await self.ctx.channel.send( "Failed to download artifact: " + artifact, delete_after=120 ) + +class ShareView(discord.ui.View): + def __init__( + self, + converser_cog, + conversation_id, + ): + super().__init__(timeout=3600) # 1 hour interval to share the conversation. + self.converser_cog = converser_cog + self.conversation_id = conversation_id + self.add_item(ShareButton(converser_cog, conversation_id)) + + async def on_timeout(self): + # Remove the button from the view/message + self.clear_items() + + +class ShareButton(discord.ui.Button["ShareView"]): + def __init__(self, converser_cog, conversation_id): + super().__init__( + style=discord.ButtonStyle.green, + label="Share Conversation", + custom_id="share_conversation", + ) + self.converser_cog = converser_cog + self.conversation_id = conversation_id + + async def callback(self, interaction: discord.Interaction): + # Get the user + try: + id = await self.converser_cog.sharegpt_service.format_and_share( + self.converser_cog.full_conversation_history[self.conversation_id], + self.converser_cog.bot.user.default_avatar.url + if not self.converser_cog.bot.user.avatar + else self.converser_cog.bot.user.avatar.url, + ) + url = f"https://shareg.pt/{id}" + await interaction.response.send_message( + embed=EmbedStatics.get_conversation_shared_embed(url) + ) + except ValueError as e: + traceback.print_exc() + await interaction.response.send_message( + embed=EmbedStatics.get_conversation_share_failed_embed( + "The ShareGPT API returned an error: " + str(e) + ), + ephemeral=True, + delete_after=15, + ) + return + except Exception as e: + traceback.print_exc() + await interaction.response.send_message( + embed=EmbedStatics.get_conversation_share_failed_embed(str(e)), + ephemeral=True, + delete_after=15, + ) + return \ No newline at end of file diff --git a/cogs/index_service_cog.py b/cogs/index_service_cog.py index 7cd3ce52..aeee76fa 100644 --- a/cogs/index_service_cog.py +++ b/cogs/index_service_cog.py @@ -13,7 +13,9 @@ from services.environment_service import EnvService from services.moderations_service import Moderation from services.text_service import TextService +from services.sharegpt_service import ShareGPTService from models.index_model import Index_handler +from collections import defaultdict from utils.safe_ctx_respond import safe_remove_list USER_INPUT_API_KEYS = EnvService.get_user_input_api_keys() @@ -36,8 +38,10 @@ def __init__( super().__init__() self.bot = bot self.index_handler = Index_handler(bot, usage_service) + self.full_conversation_history = defaultdict(list) self.thread_awaiting_responses = [] self.deletion_queue = deletion_queue + self.sharegpt_service = ShareGPTService() async def process_indexing(self, message, index_type, content=None, link=None): """ @@ -82,6 +86,8 @@ async def process_indexing(self, message, index_type, content=None, link=None): await message.reply(embed=failure_embed) safe_remove_list(self.thread_awaiting_responses, message.channel.id) return False + # summary is type casted to string because it might be a Response object + self.full_conversation_history[message.channel.id].append(str(summary)) success_embed = discord.Embed( title=f"{index_type.capitalize()} Interpreted", @@ -129,6 +135,8 @@ async def on_message(self, message): prompt = message.content.strip() + self.full_conversation_history[message.channel.id].append(prompt) + if await self.index_handler.get_is_in_index_chat(message): self.thread_awaiting_responses.append(message.channel.id) @@ -139,6 +147,19 @@ async def on_message(self, message): # Handle file uploads file = message.attachments[0] if len(message.attachments) > 0 else None + + if prompt.lower() in ["stop", "end", "quit", "exit"]: + await message.reply( + view = ShareView(self, message.channel.id) if isinstance(message.channel, discord.Thread) else None + ) + await message.reply("Ending chat session.") + self.index_handler.index_chat_chains.pop(message.channel.id) + + # close the thread + thread = await self.bot.fetch_channel(message.channel.id) + await thread.edit(name="Closed-GPT") + await thread.edit(archived=True) + return "Ended chat session." # File operations, allow for user file upload if file: @@ -186,6 +207,8 @@ async def on_message(self, message): ) return + self.full_conversation_history[message.channel.id].append(chat_result) + if chat_result: if len(chat_result) > 2000: embed_pages = EmbedStatics.paginate_chat_embed(chat_result) @@ -492,3 +515,61 @@ async def compose_command(self, ctx, name): return await self.index_handler.compose(ctx, name, user_api_key) + +class ShareView(discord.ui.View): + def __init__( + self, + converser_cog, + conversation_id, + ): + super().__init__(timeout=3600) # 1 hour interval to share the conversation. + self.converser_cog = converser_cog + self.conversation_id = conversation_id + self.add_item(ShareButton(converser_cog, conversation_id)) + + async def on_timeout(self): + # Remove the button from the view/message + self.clear_items() + + +class ShareButton(discord.ui.Button["ShareView"]): + def __init__(self, indexer_cog, conversation_id): + super().__init__( + style=discord.ButtonStyle.green, + label="Share Conversation", + custom_id="share_conversation", + ) + self.indexer_cog = indexer_cog + self.conversation_id = conversation_id + + async def callback(self, interaction: discord.Interaction): + # Get the user + try: + id = await self.indexer_cog.sharegpt_service.format_and_share( + self.indexer_cog.full_conversation_history[self.conversation_id], + self.indexer_cog.bot.user.default_avatar.url + if not self.indexer_cog.bot.user.avatar + else self.indexer_cog.bot.user.avatar.url, + ) + url = f"https://shareg.pt/{id}" + await interaction.response.send_message( + embed=EmbedStatics.get_conversation_shared_embed(url) + ) + except ValueError as e: + traceback.print_exc() + await interaction.response.send_message( + embed=EmbedStatics.get_conversation_share_failed_embed( + "The ShareGPT API returned an error: " + str(e) + ), + ephemeral=True, + delete_after=15, + ) + return + except Exception as e: + traceback.print_exc() + await interaction.response.send_message( + embed=EmbedStatics.get_conversation_share_failed_embed(str(e)), + ephemeral=True, + delete_after=15, + ) + return diff --git a/models/index_model.py b/models/index_model.py index c2828eb0..273ea801 100644 --- a/models/index_model.py +++ b/models/index_model.py @@ -375,6 +375,7 @@ async def execute_index_chat_message(self, ctx, message): if ctx.channel.id not in self.index_chat_chains: return None + '''Moved this logic to index_service_cog.py if message.lower() in ["stop", "end", "quit", "exit"]: await ctx.reply("Ending chat session.") self.index_chat_chains.pop(ctx.channel.id) @@ -384,6 +385,7 @@ async def execute_index_chat_message(self, ctx, message): await thread.edit(name="Closed-GPT") await thread.edit(archived=True) return "Ended chat session." + ''' self.usage_service.update_usage_memory(ctx.guild.name, "index_chat_message", 1)