Skip to content

Commit

Permalink
Improve Context typing for custom contexts (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
null-domain committed Nov 10, 2023
1 parent a42444f commit 3ec5e47
Show file tree
Hide file tree
Showing 10 changed files with 40 additions and 24 deletions.
4 changes: 3 additions & 1 deletion miru/abc/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from ..view import View
from .item_handler import ItemHandler

ViewContextT = t.TypeVar("ViewContextT", bound="ViewContext")


__all__ = ("Item", "DecoratedItem", "ViewItem", "ModalItem")

Expand Down Expand Up @@ -187,7 +189,7 @@ def _from_component(cls, component: hikari.PartialComponent, row: t.Optional[int
"""
...

async def callback(self, context: ViewContext) -> None:
async def callback(self, context: ViewContextT) -> None:
"""
The component's callback, gets called when the component receives an interaction.
"""
Expand Down
3 changes: 2 additions & 1 deletion miru/button.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .view import View

ViewT = t.TypeVar("ViewT", bound="View")
ViewContextT = t.TypeVar("ViewContextT", bound="ViewContext")

__all__ = ("Button", "button")

Expand Down Expand Up @@ -173,7 +174,7 @@ def button(
emoji: t.Optional[t.Union[str, hikari.Emoji]] = None,
row: t.Optional[int] = None,
disabled: bool = False,
) -> t.Callable[[t.Callable[[ViewT, Button, ViewContext], t.Any]], Button]:
) -> t.Callable[[t.Callable[[ViewT, Button, ViewContextT], t.Any]], Button]:
"""A decorator to transform a coroutine function into a Discord UI Button's callback.
This must be inside a subclass of View.
Expand Down
14 changes: 8 additions & 6 deletions miru/ext/nav/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
if t.TYPE_CHECKING:
from .navigator import NavigatorView

ViewContextT = t.TypeVar("ViewContextT", bound=ViewContext)

__all__ = (
"NavItem",
"NavButton",
Expand Down Expand Up @@ -104,7 +106,7 @@ def __init__(
):
super().__init__(style=style, label=label, custom_id=custom_id, emoji=emoji, row=row, position=position)

async def callback(self, context: ViewContext) -> None:
async def callback(self, context: ViewContextT) -> None:
self.view.current_page += 1
await self.view.send_page(context)

Expand Down Expand Up @@ -132,7 +134,7 @@ def __init__(
):
super().__init__(style=style, label=label, custom_id=custom_id, emoji=emoji, row=row, position=position)

async def callback(self, context: ViewContext) -> None:
async def callback(self, context: ViewContextT) -> None:
self.view.current_page -= 1
await self.view.send_page(context)

Expand Down Expand Up @@ -160,7 +162,7 @@ def __init__(
):
super().__init__(style=style, label=label, custom_id=custom_id, emoji=emoji, row=row, position=position)

async def callback(self, context: ViewContext) -> None:
async def callback(self, context: ViewContextT) -> None:
self.view.current_page = 0
await self.view.send_page(context)

Expand Down Expand Up @@ -188,7 +190,7 @@ def __init__(
):
super().__init__(style=style, label=label, custom_id=custom_id, emoji=emoji, row=row, position=position)

async def callback(self, context: ViewContext) -> None:
async def callback(self, context: ViewContextT) -> None:
self.view.current_page = len(self.view.pages) - 1
await self.view.send_page(context)

Expand Down Expand Up @@ -223,7 +225,7 @@ async def before_page_change(self) -> None:
self.label = f"{self.view.current_page+1}/{len(self.view.pages)}"
self.disabled = self.disabled if len(self.view.pages) != 1 else True

async def callback(self, context: ViewContext) -> None:
async def callback(self, context: ViewContextT) -> None:
modal = Modal(title="Jump to page")
modal.add_item(
TextInput(label="Page Number", placeholder="Enter a page number to jump to it...", custom_id="pgnum")
Expand Down Expand Up @@ -262,7 +264,7 @@ def __init__(
):
super().__init__(style=style, label=label, custom_id=custom_id, emoji=emoji, row=row, position=position)

async def callback(self, context: ViewContext) -> None:
async def callback(self, context: ViewContextT) -> None:
if not self.view.message and not self.view._inter:
return

Expand Down
15 changes: 9 additions & 6 deletions miru/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from .abc.item_handler import ItemHandler
from .context.modal import ModalContext

if t.TYPE_CHECKING:
ModalContextT = t.TypeVar("ModalContextT", bound=ModalContext)

__all__ = ("Modal",)


Expand Down Expand Up @@ -122,11 +125,11 @@ def values(self) -> t.Optional[t.Mapping[ModalItem, str]]:
return self._values

@property
def last_context(self) -> t.Optional[ModalContext]:
def last_context(self) -> t.Optional[ModalContextT]:
"""
Context proxying the last interaction that was received by the modal.
"""
return t.cast(ModalContext, self._last_context)
return t.cast(ModalContextT, self._last_context)

@property
def _builder(self) -> t.Type[hikari.impl.ModalActionRowBuilder]:
Expand Down Expand Up @@ -173,7 +176,7 @@ def remove_item(self, item: Item[hikari.impl.ModalActionRowBuilder]) -> Modal:
def clear_items(self) -> Modal:
return t.cast(Modal, super().clear_items())

async def modal_check(self, context: ModalContext) -> bool:
async def modal_check(self, context: ModalContextT) -> bool:
"""Called before any callback in the modal is called. Must evaluate to a truthy value to pass.
Override for custom check logic.
Expand All @@ -192,7 +195,7 @@ async def modal_check(self, context: ModalContext) -> bool:
async def on_error(
self,
error: Exception,
context: t.Optional[ModalContext] = None,
context: t.Optional[ModalContextT] = None,
) -> None:
"""Called when an error occurs in a callback function.
Override for custom error-handling logic.
Expand All @@ -210,7 +213,7 @@ async def on_error(

traceback.print_exception(type(error), error, error.__traceback__, file=sys.stderr)

async def callback(self, context: ModalContext) -> None:
async def callback(self, context: ModalContextT) -> None:
"""Called when the modal is submitted.
Parameters
Expand Down Expand Up @@ -244,7 +247,7 @@ def get_context(
"""
return cls(self, interaction, values)

async def _handle_callback(self, context: ModalContext) -> None:
async def _handle_callback(self, context: ModalContextT) -> None:
"""
Handle the callback of the modal. Separate task in case the modal is stopped in the callback.
"""
Expand Down
3 changes: 2 additions & 1 deletion miru/select/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ..view import View

ViewT = t.TypeVar("ViewT", bound="View")
ViewContextT = t.TypeVar("ViewContextT", bound=ViewContext)

__all__ = ("ChannelSelect", "channel_select")

Expand Down Expand Up @@ -123,7 +124,7 @@ def channel_select(
max_values: int = 1,
disabled: bool = False,
row: t.Optional[int] = None,
) -> t.Callable[[t.Callable[[ViewT, ChannelSelect, ViewContext], t.Any]], ChannelSelect]:
) -> t.Callable[[t.Callable[[ViewT, ChannelSelect, ViewContextT], t.Any]], ChannelSelect]:
"""
A decorator to transform a function into a Discord UI ChannelSelectMenu's callback. This must be inside a subclass of View.
"""
Expand Down
3 changes: 2 additions & 1 deletion miru/select/mentionable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ..view import View

ViewT = t.TypeVar("ViewT", bound="View")
ViewContextT = t.TypeVar("ViewContextT", bound=ViewContext)

__all__ = ("MentionableSelect", "mentionable_select")

Expand Down Expand Up @@ -115,7 +116,7 @@ def mentionable_select(
max_values: int = 1,
disabled: bool = False,
row: t.Optional[int] = None,
) -> t.Callable[[t.Callable[[ViewT, MentionableSelect, ViewContext], t.Any]], MentionableSelect]:
) -> t.Callable[[t.Callable[[ViewT, MentionableSelect, ViewContextT], t.Any]], MentionableSelect]:
"""
A decorator to transform a function into a Discord UI MentionableSelectMenu's callback.
This must be inside a subclass of View.
Expand Down
3 changes: 2 additions & 1 deletion miru/select/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ..view import View

ViewT = t.TypeVar("ViewT", bound="View")
ViewContextT = t.TypeVar("ViewContextT", bound=ViewContext)

__all__ = ("RoleSelect", "role_select")

Expand Down Expand Up @@ -107,7 +108,7 @@ def role_select(
max_values: int = 1,
disabled: bool = False,
row: t.Optional[int] = None,
) -> t.Callable[[t.Callable[[ViewT, RoleSelect, ViewContext], t.Any]], RoleSelect]:
) -> t.Callable[[t.Callable[[ViewT, RoleSelect, ViewContextT], t.Any]], RoleSelect]:
"""
A decorator to transform a function into a Discord UI RoleSelectMenu's callback. This must be inside a subclass of View.
"""
Expand Down
3 changes: 2 additions & 1 deletion miru/select/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ..view import View

ViewT = t.TypeVar("ViewT", bound="View")
ViewContextT = t.TypeVar("ViewContextT", bound=ViewContext)

__all__ = ("SelectOption", "TextSelect", "text_select")

Expand Down Expand Up @@ -203,7 +204,7 @@ def text_select(
max_values: int = 1,
disabled: bool = False,
row: t.Optional[int] = None,
) -> t.Callable[[t.Callable[[ViewT, TextSelect, ViewContext], t.Any]], TextSelect]:
) -> t.Callable[[t.Callable[[ViewT, TextSelect, ViewContextT], t.Any]], TextSelect]:
"""
A decorator to transform a function into a Discord UI TextSelectMenu's callback. This must be inside a subclass of View.
"""
Expand Down
3 changes: 2 additions & 1 deletion miru/select/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ..view import View

ViewT = t.TypeVar("ViewT", bound="View")
ViewContextT = t.TypeVar("ViewContextT", bound=ViewContext)

__all__ = ("UserSelect", "user_select")

Expand Down Expand Up @@ -115,7 +116,7 @@ def user_select(
max_values: int = 1,
disabled: bool = False,
row: t.Optional[int] = None,
) -> t.Callable[[t.Callable[[ViewT, UserSelect, ViewContext], t.Any]], UserSelect]:
) -> t.Callable[[t.Callable[[ViewT, UserSelect, ViewContextT], t.Any]], UserSelect]:
"""
A decorator to transform a function into a Discord UI UserSelectMenu's callback. This must be inside a subclass of View.
"""
Expand Down
13 changes: 8 additions & 5 deletions miru/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@

logger = logging.getLogger(__name__)

if t.TYPE_CHECKING:
ViewContextT = t.TypeVar("ViewContextT", bound=ViewContext)

__all__ = (
"View",
"get_view",
Expand Down Expand Up @@ -130,11 +133,11 @@ def autodefer(self) -> bool:
return self._autodefer

@property
def last_context(self) -> t.Optional[ViewContext]:
def last_context(self) -> t.Optional[ViewContextT]:
"""
The last context that was received by the view.
"""
return t.cast(ViewContext, self._last_context)
return t.cast(ViewContextT, self._last_context)

@property
def _builder(self) -> t.Type[hikari.impl.MessageActionRowBuilder]:
Expand Down Expand Up @@ -221,7 +224,7 @@ def remove_item(self, item: Item[hikari.impl.MessageActionRowBuilder]) -> View:
def clear_items(self) -> View:
return t.cast(View, super().clear_items())

async def view_check(self, context: ViewContext) -> bool:
async def view_check(self, context: ViewContextT) -> bool:
"""Called before any callback in the view is called. Must evaluate to a truthy value to pass.
Override for custom check logic.
Expand All @@ -241,7 +244,7 @@ async def on_error(
self,
error: Exception,
item: t.Optional[ViewItem] = None,
context: t.Optional[ViewContext] = None,
context: t.Optional[ViewContextT] = None,
) -> None:
"""Called when an error occurs in a callback function or the built-in timeout function.
Override for custom error-handling logic.
Expand Down Expand Up @@ -282,7 +285,7 @@ def get_context(
"""
return cls(self, interaction)

async def _handle_callback(self, item: ViewItem, context: ViewContext) -> None:
async def _handle_callback(self, item: ViewItem, context: ViewContextT) -> None:
"""
Handle the callback of a view item. Separate task in case the view is stopped in the callback.
"""
Expand Down

0 comments on commit 3ec5e47

Please sign in to comment.