From 1d5a99fb3aef45fbf145736dd02c738e0e6f93d3 Mon Sep 17 00:00:00 2001 From: Astrea <25420078+AstreaTSS@users.noreply.github.com> Date: Tue, 4 Jul 2023 23:22:57 -0400 Subject: [PATCH] feat: add in hybrid slash commands (#1399) * feat: add in hybrid slash commands Still relatively untested. Needs a lot of polish. * feat: simulate specific slash restrains * fix: oops, scope is weird * fix: handle no options correctly * feat: add use_slash_command_msg argument * docs: add docs for hybrid command manager * docs: add snippet for hybrid cmds in guide * docs: finialize docs for hybrid commands * fix: properly set x_id properties * docs: adjust title of hybrid command section * fix: handle dummy base command functions * fix: parse subcommands correctly * feat: add silence_autocomplete_errors * fix: make options not keyword-only This threw the prefixed command parser in for a loop. * fix: use more logic to determine right kind of kind * fix: properly handle keyword only * feat: add support for aliases * fix: add aliases to base commands too * refactor: black learns how to not be dumb * feat: remove hybrid command dm app permission handling This wasn't matching slash command behavior --------- Co-authored-by: Astrea49 <25420078+Astrea49@users.noreply.github.com> --- .../ext/hybrid_commands/context.md | 1 + .../ext/hybrid_commands/hybrid_slash.md | 1 + .../ext/hybrid_commands/index.md | 5 + .../ext/hybrid_commands/manager.md | 1 + .../API Reference/API Reference/ext/index.md | 3 + docs/src/Guides/03 Creating Commands.md | 35 ++ interactions/ext/hybrid_commands/__init__.py | 12 + interactions/ext/hybrid_commands/context.py | 383 +++++++++++ .../ext/hybrid_commands/hybrid_slash.py | 593 ++++++++++++++++++ interactions/ext/hybrid_commands/manager.py | 149 +++++ 10 files changed, 1183 insertions(+) create mode 100644 docs/src/API Reference/API Reference/ext/hybrid_commands/context.md create mode 100644 docs/src/API Reference/API Reference/ext/hybrid_commands/hybrid_slash.md create mode 100644 docs/src/API Reference/API Reference/ext/hybrid_commands/index.md create mode 100644 docs/src/API Reference/API Reference/ext/hybrid_commands/manager.md create mode 100644 interactions/ext/hybrid_commands/__init__.py create mode 100644 interactions/ext/hybrid_commands/context.py create mode 100644 interactions/ext/hybrid_commands/hybrid_slash.py create mode 100644 interactions/ext/hybrid_commands/manager.py diff --git a/docs/src/API Reference/API Reference/ext/hybrid_commands/context.md b/docs/src/API Reference/API Reference/ext/hybrid_commands/context.md new file mode 100644 index 000000000..769cc00fb --- /dev/null +++ b/docs/src/API Reference/API Reference/ext/hybrid_commands/context.md @@ -0,0 +1 @@ +::: interactions.ext.hybrid_commands.context diff --git a/docs/src/API Reference/API Reference/ext/hybrid_commands/hybrid_slash.md b/docs/src/API Reference/API Reference/ext/hybrid_commands/hybrid_slash.md new file mode 100644 index 000000000..3be2c46c1 --- /dev/null +++ b/docs/src/API Reference/API Reference/ext/hybrid_commands/hybrid_slash.md @@ -0,0 +1 @@ +::: interactions.ext.hybrid_commands.hybrid_slash diff --git a/docs/src/API Reference/API Reference/ext/hybrid_commands/index.md b/docs/src/API Reference/API Reference/ext/hybrid_commands/index.md new file mode 100644 index 000000000..c2721707d --- /dev/null +++ b/docs/src/API Reference/API Reference/ext/hybrid_commands/index.md @@ -0,0 +1,5 @@ +# Hybrid Commands Index + +- [Context](context) +- [Hybrid Slash](hybrid_slash) +- [Manager](manager) diff --git a/docs/src/API Reference/API Reference/ext/hybrid_commands/manager.md b/docs/src/API Reference/API Reference/ext/hybrid_commands/manager.md new file mode 100644 index 000000000..d8181b9f3 --- /dev/null +++ b/docs/src/API Reference/API Reference/ext/hybrid_commands/manager.md @@ -0,0 +1 @@ +::: interactions.ext.hybrid_commands.manager diff --git a/docs/src/API Reference/API Reference/ext/index.md b/docs/src/API Reference/API Reference/ext/index.md index e83a65ca5..9e0753164 100644 --- a/docs/src/API Reference/API Reference/ext/index.md +++ b/docs/src/API Reference/API Reference/ext/index.md @@ -13,3 +13,6 @@ These files contain useful features that help you develop a bot - [Prefixed Commands](prefixed_commands) - An extension to allow prefixed/text commands + +- [Hybrid Commands](hybrid_commands) + - An extension that makes hybrid slash/prefixed commands diff --git a/docs/src/Guides/03 Creating Commands.md b/docs/src/Guides/03 Creating Commands.md index 2352bbf55..74afec8a3 100644 --- a/docs/src/Guides/03 Creating Commands.md +++ b/docs/src/Guides/03 Creating Commands.md @@ -520,3 +520,38 @@ There also is `on_command` which you can overwrite too. That fires on every inte If your bot is complex enough, you might find yourself wanting to use custom models in your commands. To do this, you'll want to use a string option, and define a converter. Information on how to use converters can be found [on the converter page](/Guides/08 Converters). + +## I Want To Make A Prefixed/Text Command Too + +You're in luck! You can use a hybrid command, which is a slash command that also gets converted to an equivalent prefixed command under the hood. + +Hybrid commands are their own extension, and require [prefixed commands to set up beforehand](/interactions.py/Guides/26 Prefixed Commands). After that, use the `setup` function in the `hybrid_commands` extension in your main bot file. + +Your setup can (but doesn't necessarily have to) look like this: + +```python +import interactions +from interactions.ext import prefixed_commands as prefixed +from interactions.ext import hybrid_commands as hybrid + +bot = interactions.Client(...) # may want to enable the message content intent +prefixed.setup(bot) # normal step for prefixed commands +hybrid.setup(bot) # note its usage AFTER prefixed commands have been set up +``` + +To actually make slash commands, simply replace `@slash_command` with `@hybrid_slash_command`, and `SlashContext` with `HybridContext`, like so: + +```python +from interactions.ext.hybrid_commands import hybrid_slash_command, HybridContext + +@hybrid_slash_command(name="my_command", description="My hybrid command!") +async def my_command_function(ctx: HybridContext): + await ctx.send("Hello World") +``` + +Suggesting you are using the default mention settings for your bot, you should be able to run this command by `@BotPing my_command`. + +As you can see, the only difference between hybrid commands and slash commands, from a developer perspective, is that they use `HybridContext`, which attempts +to seamlessly allow using the same context for slash and prefixed commands. You can always get the underlying context via `inner_context`, though. + +Of course, keep in mind that support two different types of commands is hard - some features may not get represented well in prefixed commands, and autocomplete is not possible at all. diff --git a/interactions/ext/hybrid_commands/__init__.py b/interactions/ext/hybrid_commands/__init__.py new file mode 100644 index 000000000..70d33bc59 --- /dev/null +++ b/interactions/ext/hybrid_commands/__init__.py @@ -0,0 +1,12 @@ +from .context import HybridContext +from .hybrid_slash import HybridSlashCommand, hybrid_slash_command, hybrid_slash_subcommand +from .manager import HybridManager, setup + +__all__ = ( + "HybridContext", + "HybridManager", + "HybridSlashCommand", + "hybrid_slash_command", + "hybrid_slash_subcommand", + "setup", +) diff --git a/interactions/ext/hybrid_commands/context.py b/interactions/ext/hybrid_commands/context.py new file mode 100644 index 000000000..95c444d73 --- /dev/null +++ b/interactions/ext/hybrid_commands/context.py @@ -0,0 +1,383 @@ +import datetime + +from typing import TYPE_CHECKING, Any, Optional, Union, Iterable, Sequence +from typing_extensions import Self + + +from interactions import ( + BaseContext, + Permissions, + Message, + SlashContext, + Client, + Typing, + Embed, + BaseComponent, + UPLOADABLE_TYPE, + Snowflake_Type, + Sticker, + AllowedMentions, + MessageReference, + MessageFlags, + to_snowflake, + Attachment, + process_message_payload, +) +from interactions.client.mixins.send import SendMixin +from interactions.ext import prefixed_commands as prefixed + +if TYPE_CHECKING: + from .hybrid_slash import HybridSlashCommand + +__all__ = ("HybridContext",) + + +class DeferTyping: + def __init__(self, ctx: "HybridContext", ephermal: bool) -> None: + self.ctx = ctx + self.ephermal = ephermal + + async def __aenter__(self) -> None: + await self.ctx.defer(ephemeral=self.ephermal) + + async def __aexit__(self, *_) -> None: + pass + + +class HybridContext(BaseContext, SendMixin): + prefix: str + "The prefix used to invoke this command." + + app_permissions: Permissions + """The permissions available to this context""" + + deferred: bool + """Whether the context has been deferred.""" + responded: bool + """Whether the context has been responded to.""" + ephemeral: bool + """Whether the context response is ephemeral.""" + + _command_name: str + """The command name.""" + _message: Message | None + + args: list[Any] + """The arguments passed to the command.""" + kwargs: dict[str, Any] + """The keyword arguments passed to the command.""" + + __attachment_index__: int + + _slash_ctx: SlashContext | None + _prefixed_ctx: prefixed.PrefixedContext | None + + def __init__(self, client: Client): + super().__init__(client) + self.prefix = "" + self.app_permissions = Permissions(0) + self.deferred = False + self.responded = False + self.ephemeral = False + self._command_name = "" + self.args = [] + self.kwargs = {} + self._message = None + self.__attachment_index__ = 0 + self._slash_ctx = None + self._prefixed_ctx = None + + @classmethod + def from_dict(cls, client: Client, payload: dict) -> None: + # this doesn't mean anything, so just implement it to make abc happy + raise NotImplementedError + + @classmethod + def from_slash_context(cls, ctx: SlashContext) -> Self: + self = cls(ctx.client) + self.guild_id = ctx.guild_id + self.channel_id = ctx.channel_id + self.author_id = ctx.author_id + self.message_id = ctx.message_id + self.prefix = "/" + self.app_permissions = ctx.app_permissions + self.deferred = ctx.deferred + self.responded = ctx.responded + self.ephemeral = ctx.ephemeral + self._command_name = ctx._command_name + self.args = ctx.args + self.kwargs = ctx.kwargs + self._slash_ctx = ctx + return self + + @classmethod + def from_prefixed_context(cls, ctx: prefixed.PrefixedContext) -> Self: + # this is a "best guess" on what the permissions are + # this may or may not be totally accurate + if hasattr(ctx.channel, "permissions_for"): + app_permissions = ctx.channel.permissions_for(ctx.guild.me) # type: ignore + elif ctx.channel.type in {10, 11, 12}: # it's a thread + app_permissions = ctx.channel.parent_channel.permissions_for(ctx.guild.me) # type: ignore + + self = cls(ctx.client) + self.guild_id = ctx.guild_id + self.channel_id = ctx.channel_id + self.author_id = ctx.author_id + self.message_id = ctx.message_id + self._message = ctx.message + self.prefix = ctx.prefix + self.app_permissions = app_permissions + self._command_name = ctx.command.qualified_name + self.args = ctx.args + self._prefixed_ctx = ctx + return self + + @property + def inner_context(self) -> SlashContext | prefixed.PrefixedContext: + """The inner context that this hybrid context is wrapping.""" + return self._slash_ctx or self._prefixed_ctx # type: ignore + + @property + def command(self) -> "HybridSlashCommand": + return self.client._interaction_lookup[self._command_name] + + @property + def expires_at(self) -> Optional[datetime.datetime]: + """The time at which the interaction expires.""" + if not self._slash_ctx: + return None + + if self.responded: + return self._slash_ctx.id.created_at + datetime.timedelta(minutes=15) + return self._slash_ctx.id.created_at + datetime.timedelta(seconds=3) + + @property + def expired(self) -> bool: + """Whether the interaction has expired.""" + return datetime.datetime.utcnow() > self.expires_at if self._slash_ctx else False + + @property + def deferred_ephemeral(self) -> bool: + """Whether the interaction has been deferred ephemerally.""" + return self.deferred and self.ephemeral + + @property + def message(self) -> Message | None: + """The message that invoked this context.""" + return self._message or self.client.cache.get_message(self.channel_id, self.message_id) + + @property + def typing(self) -> Typing | DeferTyping: + """A context manager to send a _typing/defer state to a given channel as long as long as the wrapped operation takes.""" + if self._slash_ctx: + return DeferTyping(self._slash_ctx, self.ephemeral) + return self.channel.typing + + async def defer(self, ephemeral: bool = False) -> None: + """ + Either defers the response (if used in an interaction) or triggers a typing indicator for 10 seconds (if used for messages). + + Args: + ephemeral: Should the response be ephemeral? Only applies to responses for interactions. + + """ + if self._slash_ctx: + await self._slash_ctx.defer(ephemeral=ephemeral) + else: + await self.channel.trigger_typing() + + self.deferred = True + + async def reply( + self, + content: Optional[str] = None, + embeds: Optional[ + Union[ + Iterable[Union[Embed, dict]], + Union[Embed, dict], + ] + ] = None, + embed: Optional[Union[Embed, dict]] = None, + **kwargs, + ) -> "Message": + """ + Reply to this message, takes all the same attributes as `send`. + + For interactions, this functions the same as `send`. + """ + kwargs = locals() + kwargs.pop("self") + extra_kwargs = kwargs.pop("kwargs") + kwargs |= extra_kwargs + + if self._slash_ctx: + result = await self.send(**kwargs) + else: + kwargs.pop("ephemeral", None) + result = await self._prefixed_ctx.reply(**kwargs) + + self.responded = True + return result + + async def _send_http_request( + self, + message_payload: dict, + files: Iterable["UPLOADABLE_TYPE"] | None = None, + ) -> dict: + if self._slash_ctx: + return await self._slash_ctx._send_http_request(message_payload, files) + return await self._prefixed_ctx._send_http_request(message_payload, files) + + async def send( + self, + content: Optional[str] = None, + *, + embeds: Optional[ + Union[ + Iterable[Union["Embed", dict]], + Union["Embed", dict], + ] + ] = None, + embed: Optional[Union["Embed", dict]] = None, + components: Optional[ + Union[ + Iterable[Iterable[Union["BaseComponent", dict]]], + Iterable[Union["BaseComponent", dict]], + "BaseComponent", + dict, + ] + ] = None, + stickers: Optional[ + Union[ + Iterable[Union["Sticker", "Snowflake_Type"]], + "Sticker", + "Snowflake_Type", + ] + ] = None, + allowed_mentions: Optional[Union["AllowedMentions", dict]] = None, + reply_to: Optional[Union["MessageReference", "Message", dict, "Snowflake_Type"]] = None, + files: Optional[Union["UPLOADABLE_TYPE", Iterable["UPLOADABLE_TYPE"]]] = None, + file: Optional["UPLOADABLE_TYPE"] = None, + tts: bool = False, + suppress_embeds: bool = False, + silent: bool = False, + flags: Optional[Union[int, "MessageFlags"]] = None, + delete_after: Optional[float] = None, + ephemeral: bool = False, + **kwargs: Any, + ) -> "Message": + """ + Send a message. + + Args: + content: Message text content. + embeds: Embedded rich content (up to 6000 characters). + embed: Embedded rich content (up to 6000 characters). + components: The components to include with the message. + stickers: IDs of up to 3 stickers in the server to send in the message. + allowed_mentions: Allowed mentions for the message. + reply_to: Message to reference, must be from the same channel. + files: Files to send, the path, bytes or File() instance, defaults to None. You may have up to 10 files. + file: Files to send, the path, bytes or File() instance, defaults to None. You may have up to 10 files. + tts: Should this message use Text To Speech. + suppress_embeds: Should embeds be suppressed on this send + silent: Should this message be sent without triggering a notification. + flags: Message flags to apply. + delete_after: Delete message after this many seconds. + ephemeral: Should this message be sent as ephemeral (hidden) - only works with interactions + + Returns: + New message object that was sent. + """ + flags = MessageFlags(flags or 0) + if ephemeral and not self._slash_ctx: + flags |= MessageFlags.EPHEMERAL + self.ephemeral = True + if suppress_embeds: + flags |= MessageFlags.SUPPRESS_EMBEDS + if silent: + flags |= MessageFlags.SILENT + + return await super().send( + content=content, + embeds=embeds, + embed=embed, + components=components, + stickers=stickers, + allowed_mentions=allowed_mentions, + reply_to=reply_to, + files=files, + file=file, + tts=tts, + flags=flags, + delete_after=delete_after, + **kwargs, + ) + + async def delete(self, message: "Snowflake_Type") -> None: + """ + Delete a message sent in response to this context. Must be in the same channel as the context. + + Args: + message: The message to delete + """ + if self._slash_ctx: + return await self._slash_ctx.delete(message) + await self.client.http.delete_message(self.channel_id, to_snowflake(message)) + + async def edit( + self, + message: "Snowflake_Type", + *, + content: Optional[str] = None, + embeds: Optional[ + Union[ + Iterable[Union["Embed", dict]], + Union["Embed", dict], + ] + ] = None, + embed: Optional[Union["Embed", dict]] = None, + components: Optional[ + Union[ + Iterable[Iterable[Union["BaseComponent", dict]]], + Iterable[Union["BaseComponent", dict]], + "BaseComponent", + dict, + ] + ] = None, + attachments: Optional[Sequence[Attachment | dict]] = None, + allowed_mentions: Optional[Union["AllowedMentions", dict]] = None, + files: Optional[Union["UPLOADABLE_TYPE", Iterable["UPLOADABLE_TYPE"]]] = None, + file: Optional["UPLOADABLE_TYPE"] = None, + tts: bool = False, + ) -> "Message": + if self._slash_ctx: + return await self._slash_ctx.edit( + message, + content=content, + embeds=embeds, + embed=embed, + components=components, + attachments=attachments, + allowed_mentions=allowed_mentions, + files=files, + file=file, + tts=tts, + ) + + message_payload = process_message_payload( + content=content, + embeds=embeds or embed, + components=components, + allowed_mentions=allowed_mentions, + attachments=attachments, + tts=tts, + ) + if file: + files = [file, *files] if files else [file] + + message_data = await self.client.http.edit_message( + message_payload, self.channel_id, to_snowflake(message), files=files + ) + if message_data: + return self.client.cache.place_message_data(message_data) diff --git a/interactions/ext/hybrid_commands/hybrid_slash.py b/interactions/ext/hybrid_commands/hybrid_slash.py new file mode 100644 index 000000000..dc848c8a5 --- /dev/null +++ b/interactions/ext/hybrid_commands/hybrid_slash.py @@ -0,0 +1,593 @@ +import asyncio +import inspect +from typing import Any, Callable, List, Optional, Union, TYPE_CHECKING, Awaitable + +import attrs +from interactions import ( + BaseContext, + Converter, + NoArgumentConverter, + Attachment, + SlashCommandChoice, + OptionType, + BaseChannelConverter, + ChannelType, + BaseChannel, + BaseCommand, + MemberConverter, + UserConverter, + RoleConverter, + SlashCommand, + SlashContext, + Absent, + LocalisedName, + LocalisedDesc, + MISSING, + SlashCommandOption, + Snowflake_Type, + Permissions, +) +from interactions.client.const import AsyncCallable, GLOBAL_SCOPE +from interactions.client.utils.serializer import no_export_meta +from interactions.client.utils.misc_utils import maybe_coroutine, get_object_name +from interactions.client.errors import BadArgument +from interactions.ext.prefixed_commands import PrefixedCommand, PrefixedContext +from interactions.models.internal.converters import _LiteralConverter +from interactions.models.internal.checks import guild_only + +if TYPE_CHECKING: + from .context import HybridContext + +__all__ = ("HybridSlashCommand", "hybrid_slash_command", "hybrid_slash_subcommand") + + +def _values_wrapper(a_dict: dict | None) -> list: + return list(a_dict.values()) if a_dict else [] + + +def generate_permission_check(permissions: "Permissions") -> Callable[["HybridContext"], Awaitable[bool]]: + async def _permission_check(ctx: "HybridContext") -> bool: + return ctx.author.has_permission(*permissions) if ctx.guild_id else True # type: ignore + + return _permission_check # type: ignore + + +def generate_scope_check(_scopes: list["Snowflake_Type"]) -> Callable[["HybridContext"], Awaitable[bool]]: + scopes = frozenset(int(s) for s in _scopes) + + async def _scope_check(ctx: "HybridContext") -> bool: + return int(ctx.guild_id) in scopes + + return _scope_check # type: ignore + + +class BasicConverter(Converter): + def __init__(self, type_to_convert: Any) -> None: + self.type_to_convert = type_to_convert + + async def convert(self, ctx: BaseContext, arg: str) -> Any: + return self.type_to_convert(arg) + + +class BoolConverter(Converter): + async def convert(self, ctx: BaseContext, argument: str) -> bool: + lowered = argument.lower() + if lowered in {"yes", "y", "true", "t", "1", "enable", "on"}: + return True + elif lowered in {"no", "n", "false", "f", "0", "disable", "off"}: # noqa: RET505 + return False + raise BadArgument(f"{argument} is not a recognised boolean option.") + + +class AttachmentConverter(NoArgumentConverter): + async def convert(self, ctx: "HybridContext", _: Any) -> Attachment: + try: + attachment = ctx.message.attachments[ctx.__attachment_index__] + ctx.__attachment_index__ += 1 + return attachment + except IndexError: + raise BadArgument("No attachment found.") from None + + +class ChoicesConverter(_LiteralConverter): + def __init__(self, choices: list[SlashCommandChoice | dict]) -> None: + standardized_choices = tuple((SlashCommandChoice(**o) if isinstance(o, dict) else o) for o in choices) + + names = tuple(c.name for c in standardized_choices) + self.values = {str(arg): str for arg in names} + self.choice_values = {str(c.name): c.value for c in standardized_choices} + + async def convert(self, ctx: BaseContext, argument: str) -> Any: + val = await super().convert(ctx, argument) + return self.choice_values[val] + + +class RangeConverter(Converter[float | int]): + def __init__( + self, + number_type: int, + min_value: Optional[float | int], + max_value: Optional[float | int], + ) -> None: + self.number_type = number_type + self.min_value = min_value + self.max_value = max_value + + self.number_convert = int if number_type == OptionType.INTEGER else float + + async def convert(self, ctx: BaseContext, argument: str) -> float | int: + try: + converted: float | int = await maybe_coroutine(self.number_convert, ctx, argument) + + if self.min_value and converted < self.min_value: + raise BadArgument(f'Value "{argument}" is less than {self.min_value}.') + if self.max_value and converted > self.max_value: + raise BadArgument(f'Value "{argument}" is greater than {self.max_value}.') + + return converted + except ValueError: + type_name = "number" if self.number_type == OptionType.NUMBER else "integer" + + if type_name.startswith("i"): + raise BadArgument(f'Argument "{argument}" is not an {type_name}.') from None + raise BadArgument(f'Argument "{argument}" is not a {type_name}.') from None + except BadArgument: + raise + + +class StringLengthConverter(Converter[str]): + def __init__(self, min_length: Optional[int], max_length: Optional[int]) -> None: + self.min_length = min_length + self.max_length = max_length + + async def convert(self, ctx: BaseContext, argument: str) -> str: + if self.min_length and len(argument) < self.min_length: + raise BadArgument(f'The string "{argument}" is shorter than {self.min_length} character(s).') + elif self.max_length and len(argument) > self.max_length: # noqa: RET506 + raise BadArgument(f'The string "{argument}" is longer than {self.max_length} character(s).') + + return argument + + +class NarrowedChannelConverter(BaseChannelConverter): + def __init__(self, channel_types: list[ChannelType | int]) -> None: + self.channel_types = channel_types + + def _check(self, result: BaseChannel) -> bool: + return result.type in self.channel_types + + +class HackyUnionConverter(Converter): + def __init__(self, *converters: type[Converter]) -> None: + self.converters = converters + + async def convert(self, ctx: BaseContext, arg: str) -> Any: + for converter in self.converters: + try: + return await converter().convert(ctx, arg) + except Exception: + continue + + union_names = tuple(get_object_name(t).removesuffix("Converter") for t in self.converters) + union_types_str = ", ".join(union_names[:-1]) + f", or {union_names[-1]}" + raise BadArgument(f'Could not convert "{arg}" into {union_types_str}.') + + +class ChainConverter(Converter): + def __init__( + self, + first_converter: Converter, + second_converter: type[Converter] | Converter, + name_of_cmd: str, + ) -> None: + self.first_converter = first_converter + self.second_converter = second_converter + self.name_of_cmd = name_of_cmd + + async def convert(self, ctx: BaseContext, arg: str) -> Any: + first = await self.first_converter.convert(ctx, arg) + return await maybe_coroutine( + BaseCommand._get_converter_function(self.second_converter, self.name_of_cmd)(ctx, first) + ) + + +class ChainNoArgConverter(NoArgumentConverter): + def __init__( + self, + first_converter: NoArgumentConverter, + second_converter: type[Converter] | Converter, + name_of_cmd: str, + ) -> None: + self.first_converter = first_converter + self.second_converter = second_converter + self.name_of_cmd = name_of_cmd + + async def convert(self, ctx: "HybridContext", _: Any) -> Any: + first = await self.first_converter.convert(ctx, _) + return await maybe_coroutine( + BaseCommand._get_converter_function(self.second_converter, self.name_of_cmd)(ctx, first) + ) + + +def type_from_option(option_type: OptionType | int) -> Converter: + if option_type == OptionType.STRING: + return BasicConverter(str) + elif option_type == OptionType.INTEGER: # noqa: RET505 + return BasicConverter(int) + elif option_type == OptionType.NUMBER: + return BasicConverter(float) + elif option_type == OptionType.BOOLEAN: + return BoolConverter() + elif option_type == OptionType.USER: + return HackyUnionConverter(MemberConverter, UserConverter) + elif option_type == OptionType.CHANNEL: + return BaseChannelConverter() + elif option_type == OptionType.ROLE: + return RoleConverter() + elif option_type == OptionType.MENTIONABLE: + return HackyUnionConverter(MemberConverter, UserConverter, RoleConverter) + elif option_type == OptionType.ATTACHMENT: + return AttachmentConverter() + raise NotImplementedError(f"Unknown option type: {option_type}") + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class HybridSlashCommand(SlashCommand): + aliases: list[str] = attrs.field(repr=False, factory=list, metadata=no_export_meta) + _dummy_base: bool = attrs.field(repr=False, default=False, metadata=no_export_meta) + _silence_autocomplete_errors: bool = attrs.field(repr=False, default=False, metadata=no_export_meta) + + async def __call__(self, context: SlashContext, *args, **kwargs) -> None: + new_ctx = context.client.hybrid.hybrid_context.from_slash_context(context) + await super().__call__(new_ctx, *args, **kwargs) + + def group( + self, + name: str = None, + description: str = "No Description Set", + inherit_checks: bool = True, + aliases: list[str] | None = None, + ) -> "HybridSlashCommand": + self._dummy_base = True + return HybridSlashCommand( + name=self.name, + description=self.description, + group_name=name, + group_description=description, + scopes=self.scopes, + default_member_permissions=self.default_member_permissions, + dm_permission=self.dm_permission, + checks=self.checks.copy() if inherit_checks else [], + aliases=aliases or [], + ) + + def subcommand( + self, + sub_cmd_name: Absent[LocalisedName | str] = MISSING, + group_name: LocalisedName | str = None, + sub_cmd_description: Absent[LocalisedDesc | str] = MISSING, + group_description: Absent[LocalisedDesc | str] = MISSING, + options: List[Union[SlashCommandOption, dict]] = None, + nsfw: bool = False, + inherit_checks: bool = True, + aliases: list[str] | None = None, + silence_autocomplete_errors: bool = True, + ) -> Callable[..., "HybridSlashCommand"]: + def wrapper(call: AsyncCallable) -> "HybridSlashCommand": + nonlocal sub_cmd_name, sub_cmd_description + + if not asyncio.iscoroutinefunction(call): + raise TypeError("Subcommand must be coroutine") + + if sub_cmd_description is MISSING: + sub_cmd_description = call.__doc__ or "No Description Set" + if sub_cmd_name is MISSING: + sub_cmd_name = call.__name__ + + self._dummy_base = True + return HybridSlashCommand( + name=self.name, + description=self.description, + group_name=group_name or self.group_name, + group_description=group_description or self.group_description, + sub_cmd_name=sub_cmd_name, + sub_cmd_description=sub_cmd_description, + default_member_permissions=self.default_member_permissions, + dm_permission=self.dm_permission, + options=options, + callback=call, + scopes=self.scopes, + nsfw=nsfw, + checks=self.checks.copy() if inherit_checks else [], + aliases=aliases or [], + silence_autocomplete_errors=silence_autocomplete_errors, + ) + + return wrapper + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class _HybridToPrefixedCommand(PrefixedCommand): + async def __call__(self, context: PrefixedContext, *args, **kwargs) -> None: + new_ctx = context.client.hybrid.hybrid_context.from_prefixed_context(context) + await super().__call__(new_ctx, *args, **kwargs) + + +def slash_to_prefixed(cmd: HybridSlashCommand) -> _HybridToPrefixedCommand: # noqa: C901 there's nothing i can do + prefixed_cmd = _HybridToPrefixedCommand( + name=str(cmd.sub_cmd_name) if cmd.is_subcommand else str(cmd.name), + aliases=list(_values_wrapper(cmd.sub_cmd_name.to_locale_dict())) + if cmd.is_subcommand + else list(_values_wrapper(cmd.name.to_locale_dict())), + help=str(cmd.description), + callback=cmd.callback, + checks=cmd.checks, + cooldown=cmd.cooldown, + max_concurrency=cmd.max_concurrency, + pre_run_callback=cmd.pre_run_callback, + post_run_callback=cmd.post_run_callback, + error_callback=cmd.error_callback, + ) + if cmd.aliases: + prefixed_cmd.aliases.extend(cmd.aliases) + + if not cmd.dm_permission: + prefixed_cmd.add_check(guild_only()) + + if cmd.scopes != [GLOBAL_SCOPE]: + prefixed_cmd.add_check(generate_scope_check(cmd.scopes)) + + if cmd.default_member_permissions: + prefixed_cmd.add_check(generate_permission_check(cmd.default_member_permissions)) + + if not cmd.options: + prefixed_cmd._inspect_signature = inspect.Signature() + return prefixed_cmd + + fake_sig_parameters: list[inspect.Parameter] = [] + + for option in cmd.options: + if isinstance(option, dict): + # makes my life easier + option = SlashCommandOption(**option) + + if option.autocomplete and not cmd._silence_autocomplete_errors: + # there isn't much we can do here + raise ValueError("Autocomplete is unsupported in hybrid commands.") + + name = str(option.name) + annotation = inspect.Parameter.empty + default = inspect.Parameter.empty + kind = inspect.Parameter.POSITIONAL_ONLY if cmd._uses_arg else inspect.Parameter.POSITIONAL_OR_KEYWORD + + if slash_param := cmd.parameters.get(name): + kind = slash_param.kind + + if kind == inspect.Parameter.KEYWORD_ONLY: # work around prefixed cmd parsing + kind = inspect.Parameter.POSITIONAL_OR_KEYWORD + + if slash_param.converter: + annotation = slash_param.converter + if slash_param.default is not MISSING: + default = slash_param.default + + if option.choices: + option_anno = ChoicesConverter(option.choices) + elif option.min_value is not None or option.max_value is not None: + option_anno = RangeConverter(option.type, option.min_value, option.max_value) + elif option.min_length is not None or option.max_length is not None: + option_anno = StringLengthConverter(option.min_length, option.max_length) + elif option.type == OptionType.CHANNEL and option.channel_types: + option_anno = NarrowedChannelConverter(option.channel_types) + else: + option_anno = type_from_option(option.type) + + if annotation is inspect.Parameter.empty: + annotation = option_anno + elif isinstance(option_anno, NoArgumentConverter): + annotation = ChainNoArgConverter(option_anno, annotation, name) + else: + annotation = ChainConverter(option_anno, annotation, name) + + if not option.required and default == inspect.Parameter.empty: + default = None + + actual_param = inspect.Parameter( + name=name, + kind=kind, + default=default, + annotation=annotation, + ) + fake_sig_parameters.append(actual_param) + + prefixed_cmd._inspect_signature = inspect.Signature(parameters=fake_sig_parameters) + return prefixed_cmd + + +def create_subcmd_func(group: bool = False) -> Callable: + async def _subcommand_base(*args, **kwargs) -> None: + if group: + raise BadArgument("Cannot run this subcommand group without a valid subcommand.") + raise BadArgument("Cannot run this command without a valid subcommand.") + + return _subcommand_base + + +def base_subcommand_generator( + name: str, aliases: list[str], description: str, group: bool = False +) -> _HybridToPrefixedCommand: + return _HybridToPrefixedCommand( + callback=create_subcmd_func(group=group), + name=name, + aliases=aliases, + help=description, + ignore_extra=False, + inspect_signature=inspect.Signature(None), # type: ignore + ) + + +def hybrid_slash_command( + name: Absent[str | LocalisedName] = MISSING, + *, + aliases: Optional[list[str]] = None, + description: Absent[str | LocalisedDesc] = MISSING, + scopes: Absent[list["Snowflake_Type"]] = MISSING, + options: Optional[list[Union[SlashCommandOption, dict]]] = None, + default_member_permissions: Optional["Permissions"] = None, + dm_permission: bool = True, + sub_cmd_name: str | LocalisedName = None, + group_name: str | LocalisedName = None, + sub_cmd_description: str | LocalisedDesc = "No Description Set", + group_description: str | LocalisedDesc = "No Description Set", + nsfw: bool = False, + silence_autocomplete_errors: bool = False, +) -> Callable[[AsyncCallable], HybridSlashCommand]: + """ + A decorator to declare a coroutine as a hybrid slash command. + + Hybrid commands are a slash command that can also function as a prefixed command. + These use a HybridContext instead of an SlashContext, but otherwise are mostly identical to normal slash commands. + + Note that hybrid commands do not support autocompletes. + They also only partially support attachments, allowing one attachment option for a command. + + !!! note + While the base and group descriptions arent visible in the discord client, currently. + We strongly advise defining them anyway, if you're using subcommands, as Discord has said they will be visible in + one of the future ui updates. + + Args: + name: 1-32 character name of the command, defaults to the name of the coroutine. + aliases: Aliases for the prefixed command varient of the command. Has no effect on the slash command. + description: 1-100 character description of the command + scopes: The scope this command exists within + options: The parameters for the command, max 25 + default_member_permissions: What permissions members need to have by default to use this command. + dm_permission: Should this command be available in DMs. + sub_cmd_name: 1-32 character name of the subcommand + sub_cmd_description: 1-100 character description of the subcommand + group_name: 1-32 character name of the group + group_description: 1-100 character description of the group + nsfw: This command should only work in NSFW channels + silence_autocomplete_errors: Should autocomplete errors be silenced. Don't use this unless you know what you're doing. + + Returns: + HybridSlashCommand Object + + """ + + def wrapper(func: AsyncCallable) -> HybridSlashCommand: + if not asyncio.iscoroutinefunction(func): + raise ValueError("Commands must be coroutines") + + perm = default_member_permissions + if hasattr(func, "default_member_permissions"): + if perm: + perm = perm | func.default_member_permissions + else: + perm = func.default_member_permissions + + _name = name + if _name is MISSING: + _name = func.__name__ + + _description = description + if _description is MISSING: + _description = func.__doc__ or "No Description Set" + + cmd = HybridSlashCommand( + name=_name, + group_name=group_name, + group_description=group_description, + sub_cmd_name=sub_cmd_name, + sub_cmd_description=sub_cmd_description, + description=_description, + scopes=scopes or [GLOBAL_SCOPE], + default_member_permissions=perm, + dm_permission=dm_permission, + callback=func, + options=options, + nsfw=nsfw, + aliases=aliases or [], + silence_autocomplete_errors=silence_autocomplete_errors, + ) + + return cmd + + return wrapper + + +def hybrid_slash_subcommand( + base: str | LocalisedName, + *, + subcommand_group: Optional[str | LocalisedName] = None, + name: Absent[str | LocalisedName] = MISSING, + aliases: Optional[list[str]] = None, + description: Absent[str | LocalisedDesc] = MISSING, + base_description: Optional[str | LocalisedDesc] = None, + base_desc: Optional[str | LocalisedDesc] = None, + base_default_member_permissions: Optional["Permissions"] = None, + base_dm_permission: bool = True, + subcommand_group_description: Optional[str | LocalisedDesc] = None, + sub_group_desc: Optional[str | LocalisedDesc] = None, + scopes: List["Snowflake_Type"] = None, + options: List[dict] = None, + nsfw: bool = False, + silence_autocomplete_errors: bool = False, +) -> Callable[[AsyncCallable], HybridSlashCommand]: + """ + A decorator specifically tailored for creating hybrid slash subcommands. + + Args: + base: The name of the base command + subcommand_group: The name of the subcommand group, if any. + name: The name of the subcommand, defaults to the name of the coroutine. + aliases: Aliases for the prefixed command varient of the subcommand. Has no effect on the slash command. + description: The description of the subcommand + base_description: The description of the base command + base_desc: An alias of `base_description` + base_default_member_permissions: What permissions members need to have by default to use this command. + base_dm_permission: Should this command be available in DMs. + subcommand_group_description: Description of the subcommand group + sub_group_desc: An alias for `subcommand_group_description` + scopes: The scopes of which this command is available, defaults to GLOBAL_SCOPE + options: The options for this command + nsfw: This command should only work in NSFW channels + silence_autocomplete_errors: Should autocomplete errors be silenced. Don't use this unless you know what you're doing. + + Returns: + A HybridSlashCommand object + + """ + + def wrapper(func: AsyncCallable) -> HybridSlashCommand: + if not asyncio.iscoroutinefunction(func): + raise ValueError("Commands must be coroutines") + + _name = name + if _name is MISSING: + _name = func.__name__ + + _description = description + if _description is MISSING: + _description = func.__doc__ or "No Description Set" + + cmd = HybridSlashCommand( + name=base, + description=(base_description or base_desc) or "No Description Set", + group_name=subcommand_group, + group_description=(subcommand_group_description or sub_group_desc) or "No Description Set", + sub_cmd_name=_name, + sub_cmd_description=_description, + default_member_permissions=base_default_member_permissions, + dm_permission=base_dm_permission, + scopes=scopes or [GLOBAL_SCOPE], + callback=func, + options=options, + nsfw=nsfw, + aliases=aliases or [], + silence_autocomplete_errors=silence_autocomplete_errors, + ) + return cmd + + return wrapper diff --git a/interactions/ext/hybrid_commands/manager.py b/interactions/ext/hybrid_commands/manager.py new file mode 100644 index 000000000..de0127d59 --- /dev/null +++ b/interactions/ext/hybrid_commands/manager.py @@ -0,0 +1,149 @@ +from typing import cast, Callable, Any + +from interactions import Client, BaseContext, listen +from interactions.api.events import CallbackAdded, ExtensionUnload +from interactions.ext import prefixed_commands as prefixed + +from .context import HybridContext +from .hybrid_slash import ( + _values_wrapper, + base_subcommand_generator, + HybridSlashCommand, + _HybridToPrefixedCommand, + slash_to_prefixed, +) + +__all__ = ("HybridManager", "setup") + + +def add_use_slash_command_message( + prefixed_cmd: _HybridToPrefixedCommand, slash_cmd: HybridSlashCommand +) -> _HybridToPrefixedCommand: + if prefixed_cmd.has_binding: + + def wrap_old_callback(func: Callable) -> Any: + async def _msg_callback(self, ctx: prefixed.PrefixedContext, *args, **kwargs): + await ctx.reply(f"This command has been updated. Please use {slash_cmd.mention(ctx.guild_id)} instead.") + await func(ctx, *args, **kwargs) + + return _msg_callback + + else: + + def wrap_old_callback(func: Callable) -> Any: + async def _msg_callback(ctx: prefixed.PrefixedContext, *args, **kwargs): + await ctx.reply(f"This command has been updated. Please use {slash_cmd.mention(ctx.guild_id)} instead.") + await func(ctx, *args, **kwargs) + + return _msg_callback + + prefixed_cmd.callback = wrap_old_callback(prefixed_cmd.callback) + return prefixed_cmd + + +class HybridManager: + """ + The main part of the extension. Deals with injecting itself in the first place. + + Parameters: + client: The client instance. + hybrid_context: The object to instantiate for Hybrid Context + use_slash_command_msg: If enabled, will send out a message encouraging users to use the slash command \ + equivalent whenever they use the prefixed command version. + """ + + def __init__( + self, client: Client, *, hybrid_context: type[BaseContext] = HybridContext, use_slash_command_msg: bool = False + ) -> None: + if not hasattr(client, "prefixed") or not isinstance(client.prefixed, prefixed.PrefixedManager): + raise TypeError("Prefixed commands are not set up for this bot.") + + self.hybrid_context = hybrid_context + self.use_slash_command_msg = use_slash_command_msg + + self.client = cast(prefixed.PrefixedInjectedClient, client) + self.ext_command_list: dict[str, list[str]] = {} + + self.client.add_listener(self.add_hybrid_command.copy_with_binding(self)) + self.client.add_listener(self.handle_ext_unload.copy_with_binding(self)) + + self.client.hybrid = self + + @listen("on_callback_added") + async def add_hybrid_command(self, event: CallbackAdded): + if ( + not isinstance(event.callback, HybridSlashCommand) + or not event.callback.callback + or event.callback._dummy_base + ): + return + + cmd = event.callback + prefixed_transform = slash_to_prefixed(cmd) + + if self.use_slash_command_msg: + prefixed_transform = add_use_slash_command_message(prefixed_transform, cmd) + + if cmd.is_subcommand: + base = None + if not (base := self.client.prefixed.commands.get(str(cmd.name))): + base = base_subcommand_generator( + str(cmd.name), + list(_values_wrapper(cmd.name.to_locale_dict())) + cmd.aliases, + str(cmd.name), + group=False, + ) + self.client.prefixed.add_command(base) + + if cmd.group_name: # group command + group = None + if not (group := base.subcommands.get(str(cmd.group_name))): + group = base_subcommand_generator( + str(cmd.group_name), + list(_values_wrapper(cmd.group_name.to_locale_dict())) + cmd.aliases, + str(cmd.group_name), + group=True, + ) + base.add_command(group) + base = group + + # since this is added *after* the base command has been added to the bot, we need to run + # this function ourselves + prefixed_transform._parse_parameters() + base.add_command(prefixed_transform) + else: + self.client.prefixed.add_command(prefixed_transform) + + if cmd.extension: + self.ext_command_list.setdefault(cmd.extension.extension_name, []).append(cmd.resolved_name) + + @listen("extension_unload") + async def handle_ext_unload(self, event: ExtensionUnload) -> None: + if not self.ext_command_list.get(event.extension.extension_name): + return + + for cmd in self.ext_command_list[event.extension.extension_name]: + self.client.prefixed.remove_command(cmd, delete_parent_if_empty=True) + + del self.ext_command_list[event.extension.extension_name] + + +def setup( + client: Client, *, hybrid_context: type[BaseContext] = HybridContext, use_slash_command_msg: bool = False +) -> HybridManager: + """ + Sets up hybrid commands. It is recommended to use this function directly to do so. + + !!! warning + Prefixed commands need to be set up prior to using this. + + Args: + client: The client instance. + hybrid_context: The object to instantiate for Hybrid Context + use_slash_command_msg: If enabled, will send out a message encouraging users to use the slash command \ + equivalent whenever they use the prefixed command version. + + Returns: + HybridManager: The class that deals with all things hybrid commands. + """ + return HybridManager(client, hybrid_context=hybrid_context, use_slash_command_msg=use_slash_command_msg)