From cd6e0aa8b00bf5fb58567f6726834023b68c9f7a Mon Sep 17 00:00:00 2001 From: Faster Speeding Date: Sat, 15 Oct 2022 12:57:59 +0100 Subject: [PATCH 1/2] Add support for sync DI to sync callbacks --- tanjun/checks.py | 124 +++++++++++++++++++------------- tanjun/clients.py | 81 +++++++++------------ tanjun/dependencies/limiters.py | 38 ++++++---- tests/test_clients.py | 12 ---- 4 files changed, 136 insertions(+), 119 deletions(-) diff --git a/tanjun/checks.py b/tanjun/checks.py index c8dd85bfc..be066be22 100644 --- a/tanjun/checks.py +++ b/tanjun/checks.py @@ -70,20 +70,26 @@ if typing.TYPE_CHECKING: from collections import abc as collections - _ContextT_contra = typing.TypeVar("_ContextT_contra", bound=tanjun.Context, contravariant=True) + import typing_extensions - class _AnyCallback(typing.Protocol[_ContextT_contra]): - async def __call__( - self, ctx: _ContextT_contra, /, *, localiser: typing.Optional[dependencies.AbstractLocaliser] = None - ) -> bool: - raise NotImplementedError + _P = typing_extensions.ParamSpec("_P") + + _PermissionErrorSigBase = collections.Callable[typing_extensions.Concatenate[hikari.Permissions, _P], Exception] + _PermissionErrorSig = _PermissionErrorSigBase[...] _CommandT = typing.TypeVar("_CommandT", bound=tanjun.ExecutableCommand[typing.Any]) + _ContextT_contra = typing.TypeVar("_ContextT_contra", bound=tanjun.Context, contravariant=True) _CallbackReturnT = typing.Union[_CommandT, collections.Callable[[_CommandT], _CommandT]] _MenuCommandT = typing.TypeVar("_MenuCommandT", bound=tanjun.MenuCommand[typing.Any, typing.Any]) _MessageCommandT = typing.TypeVar("_MessageCommandT", bound=tanjun.MessageCommand[typing.Any]) _SlashCommandT = typing.TypeVar("_SlashCommandT", bound=tanjun.BaseSlashCommand) + class _AnyCallback(typing.Protocol[_ContextT_contra]): + async def __call__( + self, ctx: _ContextT_contra, /, *, localiser: typing.Optional[dependencies.AbstractLocaliser] = None + ) -> bool: + raise NotImplementedError + _ContextT = typing.TypeVar("_ContextT", bound=tanjun.Context) @@ -132,7 +138,7 @@ def _handle_result( ) -> bool: if not result: if self._error: - raise self._error(*args) from None + raise ctx.call_with_di(self._error, args) from None if self._halt_execution: raise errors.HaltExecution from None if self._error_message: @@ -153,7 +159,7 @@ class OwnerCheck(_Check): def __init__( self, *, - error: typing.Optional[collections.Callable[[], Exception]] = None, + error: typing.Optional[collections.Callable[..., Exception]] = None, error_message: typing.Union[str, collections.Mapping[str, str], None] = "Only bot owners can use this command", halt_execution: bool = False, ) -> None: @@ -164,7 +170,8 @@ def __init__( error Callback used to create a custom error to raise if the check fails. - This takes priority over `error_message`. + This takes no positional arguments, supports sync DI and takes + priority over `error_message`. error_message The error message to send in response as a command error if the check fails. @@ -210,7 +217,7 @@ class NsfwCheck(_Check): def __init__( self, *, - error: typing.Optional[collections.Callable[[], Exception]] = None, + error: typing.Optional[collections.Callable[..., Exception]] = None, error_message: typing.Union[ str, collections.Mapping[str, str], None ] = "Command can only be used in NSFW channels", @@ -223,7 +230,8 @@ def __init__( error Callback used to create a custom error to raise if the check fails. - This takes priority over `error_message`. + This takes no positional arguments, supports sync DI and takes + priority over `error_message`. error_message The error message to send in response as a command error if the check fails. @@ -261,7 +269,7 @@ class SfwCheck(_Check): def __init__( self, *, - error: typing.Optional[collections.Callable[[], Exception]] = None, + error: typing.Optional[collections.Callable[..., Exception]] = None, error_message: typing.Union[ str, collections.Mapping[str, str], None ] = "Command can only be used in SFW channels", @@ -274,7 +282,8 @@ def __init__( error Callback used to create a custom error to raise if the check fails. - This takes priority over `error_message`. + This takes no positonal arguments, supports sync DI and takes + priority over `error_message`. error_message The error message to send in response as a command error if the check fails. @@ -312,7 +321,7 @@ class DmCheck(_Check): def __init__( self, *, - error: typing.Optional[collections.Callable[[], Exception]] = None, + error: typing.Optional[collections.Callable[..., Exception]] = None, error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in DMs", halt_execution: bool = False, ) -> None: @@ -323,7 +332,8 @@ def __init__( error Callback used to create a custom error to raise if the check fails. - This takes priority over `error_message`. + This takes no positonal arguments, supports sync DI and takes + priority over `error_message`. error_message The error message to send in response as a command error if the check fails. @@ -361,7 +371,7 @@ class GuildCheck(_Check): def __init__( self, *, - error: typing.Optional[collections.Callable[[], Exception]] = None, + error: typing.Optional[collections.Callable[..., Exception]] = None, error_message: typing.Union[ str, collections.Mapping[str, str], None ] = "Command can only be used in guild channels", @@ -374,7 +384,8 @@ def __init__( error Callback used to create a custom error to raise if the check fails. - This takes priority over `error_message`. + This takes no positonal arguments, supports sync DI and takes + priority over `error_message`. error_message The error message to send in response as a command error if the check fails. @@ -414,7 +425,7 @@ def __init__( permissions: typing.Union[hikari.Permissions, int], /, *, - error: typing.Optional[collections.Callable[[hikari.Permissions], Exception]] = None, + error: typing.Optional[_PermissionErrorSig] = None, error_message: typing.Union[ str, collections.Mapping[str, str], None ] = "You don't have the permissions required to use this command", @@ -429,8 +440,10 @@ def __init__( error Callback used to create a custom error to raise if the check fails. - This should take 1 positional argument of type [hikari.permissions.Permissions][] - which represents the missing permissions required for this command to run. + This should match the signature `def (hikari.permissions.Permissions, ...) -> Exception` + where the first argument is the missing permissions required for this command + and sync dependency injection is supported. The returned [Exception][] will + be raised. This takes priority over `error_message`. error_message @@ -495,7 +508,7 @@ def __init__( permissions: typing.Union[hikari.Permissions, int], /, *, - error: typing.Optional[collections.Callable[[hikari.Permissions], Exception]] = None, + error: typing.Optional[_PermissionErrorSig] = None, error_message: typing.Union[ str, collections.Mapping[str, str], None ] = "Bot doesn't have the permissions required to run this command", @@ -510,8 +523,10 @@ def __init__( error Callback used to create a custom error to raise if the check fails. - This should take 1 positional argument of type [hikari.permissions.Permissions][] - which represents the missing permissions required for this command to run. + This should match the signature `def (hikari.permissions.Permissions, ...) -> Exception` + where the first argument is the missing permissions required for this command + and sync dependency injection is supported. The returned [Exception][] will + be raised. This takes priority over `error_message`. error_message @@ -567,7 +582,7 @@ def with_dm_check(command: _CommandT, /) -> _CommandT: @typing.overload def with_dm_check( *, - error: typing.Optional[collections.Callable[[], Exception]] = None, + error: typing.Optional[collections.Callable[..., Exception]] = None, error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in DMs", follow_wrapped: bool = False, halt_execution: bool = False, @@ -579,7 +594,7 @@ def with_dm_check( command: typing.Optional[_CommandT] = None, /, *, - error: typing.Optional[collections.Callable[[], Exception]] = None, + error: typing.Optional[collections.Callable[..., Exception]] = None, error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in DMs", follow_wrapped: bool = False, halt_execution: bool = False, @@ -593,7 +608,8 @@ def with_dm_check( error Callback used to create a custom error to raise if the check fails. - This takes priority over `error_message`. + This takes no positonal arguments, supports sync DI and takes priority + over `error_message`. error_message The error message to send in response as a command error if the check fails. @@ -629,7 +645,7 @@ def with_guild_check(command: _CommandT, /) -> _CommandT: @typing.overload def with_guild_check( *, - error: typing.Optional[collections.Callable[[], Exception]] = None, + error: typing.Optional[collections.Callable[..., Exception]] = None, error_message: typing.Union[ str, collections.Mapping[str, str], None ] = "Command can only be used in guild channels", @@ -643,7 +659,7 @@ def with_guild_check( command: typing.Optional[_CommandT] = None, /, *, - error: typing.Optional[collections.Callable[[], Exception]] = None, + error: typing.Optional[collections.Callable[..., Exception]] = None, error_message: typing.Union[ str, collections.Mapping[str, str], None ] = "Command can only be used in guild channels", @@ -659,7 +675,8 @@ def with_guild_check( error Callback used to create a custom error to raise if the check fails. - This takes priority over `error_message`. + This takes no positonal arguments, supports sync DI and takes priority + over `error_message`. error_message The error message to send in response as a command error if the check fails. @@ -695,7 +712,7 @@ def with_nsfw_check(command: _CommandT, /) -> _CommandT: @typing.overload def with_nsfw_check( *, - error: typing.Optional[collections.Callable[[], Exception]] = None, + error: typing.Optional[collections.Callable[..., Exception]] = None, error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in NSFW channels", follow_wrapped: bool = False, halt_execution: bool = False, @@ -707,7 +724,7 @@ def with_nsfw_check( command: typing.Optional[_CommandT] = None, /, *, - error: typing.Optional[collections.Callable[[], Exception]] = None, + error: typing.Optional[collections.Callable[..., Exception]] = None, error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in NSFW channels", follow_wrapped: bool = False, halt_execution: bool = False, @@ -721,7 +738,8 @@ def with_nsfw_check( error Callback used to create a custom error to raise if the check fails. - This takes priority over `error_message`. + This takes no positonal arguments, supports sync DI and takes priority + over `error_message`. error_message The error message to send in response as a command error if the check fails. @@ -757,7 +775,7 @@ def with_sfw_check(command: _CommandT, /) -> _CommandT: @typing.overload def with_sfw_check( *, - error: typing.Optional[collections.Callable[[], Exception]] = None, + error: typing.Optional[collections.Callable[..., Exception]] = None, error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in SFW channels", follow_wrapped: bool = False, halt_execution: bool = False, @@ -769,7 +787,7 @@ def with_sfw_check( command: typing.Optional[_CommandT] = None, /, *, - error: typing.Optional[collections.Callable[[], Exception]] = None, + error: typing.Optional[collections.Callable[..., Exception]] = None, error_message: typing.Union[str, collections.Mapping[str, str], None] = "Command can only be used in SFW channels", follow_wrapped: bool = False, halt_execution: bool = False, @@ -783,7 +801,8 @@ def with_sfw_check( error Callback used to create a custom error to raise if the check fails. - This takes priority over `error_message`. + This takes no positonal arguments, supports sync DI and takes priority + over `error_message`. error_message The error message to send in response as a command error if the check fails. @@ -819,7 +838,7 @@ def with_owner_check(command: _CommandT, /) -> _CommandT: @typing.overload def with_owner_check( *, - error: typing.Optional[collections.Callable[[], Exception]] = None, + error: typing.Optional[collections.Callable[..., Exception]] = None, error_message: typing.Union[str, collections.Mapping[str, str], None] = "Only bot owners can use this command", follow_wrapped: bool = False, halt_execution: bool = False, @@ -831,7 +850,7 @@ def with_owner_check( command: typing.Optional[_CommandT] = None, /, *, - error: typing.Optional[collections.Callable[[], Exception]] = None, + error: typing.Optional[collections.Callable[..., Exception]] = None, error_message: typing.Union[str, collections.Mapping[str, str], None] = "Only bot owners can use this command", follow_wrapped: bool = False, halt_execution: bool = False, @@ -845,7 +864,8 @@ def with_owner_check( error Callback used to create a custom error to raise if the check fails. - This takes priority over `error_message`. + This takes no positonal arguments, supports sync DI and takes priority + over `error_message`. error_message The error message to send in response as a command error if the check fails. @@ -876,7 +896,7 @@ def with_owner_check( def with_author_permission_check( permissions: typing.Union[hikari.Permissions, int], *, - error: typing.Optional[collections.Callable[[hikari.Permissions], Exception]] = None, + error: typing.Optional[_PermissionErrorSig] = None, error_message: typing.Union[ str, collections.Mapping[str, str], None ] = "You don't have the permissions required to use this command", @@ -896,8 +916,10 @@ def with_author_permission_check( error Callback used to create a custom error to raise if the check fails. - This should take 1 positional argument of type [hikari.permissions.Permissions][] - which represents the missing permissions required for this command to run. + This should match the signature `def (hikari.permissions.Permissions, ...) -> Exception` + where the first argument is the missing permissions required for this command + and sync dependency injection is supported. The returned [Exception][] will + be raised. This takes priority over `error_message`. error_message @@ -932,7 +954,7 @@ def with_author_permission_check( def with_own_permission_check( permissions: typing.Union[hikari.Permissions, int], *, - error: typing.Optional[collections.Callable[[hikari.Permissions], Exception]] = None, + error: typing.Optional[_PermissionErrorSig] = None, error_message: typing.Union[ str, collections.Mapping[str, str], None ] = "Bot doesn't have the permissions required to run this command", @@ -952,8 +974,10 @@ def with_own_permission_check( error Callback used to create a custom error to raise if the check fails. - This should take 1 positional argument of type [hikari.permissions.Permissions][] - which represents the missing permissions required for this command to run. + This should match the signature `def (hikari.permissions.Permissions, ...) -> Exception` + where the first argument is the missing permissions required for this command + and sync dependency injection is supported. The returned [Exception][] will + be raised. This takes priority over `error_message`. error_message @@ -1150,7 +1174,7 @@ class _AnyChecks(_Check, typing.Generic[_ContextT]): def __init__( self, checks: list[tanjun.CheckSig[_ContextT]], - error: typing.Optional[collections.Callable[[], Exception]], + error: typing.Optional[collections.Callable[..., Exception]], error_message: typing.Union[str, collections.Mapping[str, str], None], halt_execution: bool, suppress: tuple[type[Exception], ...], @@ -1180,7 +1204,7 @@ def any_checks( check: tanjun.CheckSig[_ContextT], /, *checks: tanjun.CheckSig[_ContextT], - error: typing.Optional[collections.Callable[[], Exception]] = None, + error: typing.Optional[collections.Callable[..., Exception]] = None, error_message: typing.Union[str, collections.Mapping[str, str], None], halt_execution: bool = False, suppress: tuple[type[Exception], ...] = (errors.CommandError, errors.HaltExecution), @@ -1199,7 +1223,8 @@ def any_checks( error Callback used to create a custom error to raise if the check fails. - This takes priority over `error_message`. + This takes no positonal arguments, supports sync DI and takes priority + over `error_message`. error_message The error message to send in response as a command error if the check fails. @@ -1226,7 +1251,7 @@ def with_any_checks( check: tanjun.AnyCheckSig, /, *checks: tanjun.AnyCheckSig, - error: typing.Optional[collections.Callable[[], Exception]] = None, + error: typing.Optional[collections.Callable[..., Exception]] = None, error_message: typing.Union[str, collections.Mapping[str, str], None], follow_wrapped: bool = False, halt_execution: bool = False, @@ -1306,7 +1331,8 @@ def with_any_checks( error Callback used to create a custom error to raise if the check fails. - This takes priority over `error_message`. + This takes no positonal arguments, supports sync DI and takes priority + over `error_message`. error_message The error message to send in response as a command error if the check fails. diff --git a/tanjun/clients.py b/tanjun/clients.py index 14e0dcb54..caaa6c83c 100644 --- a/tanjun/clients.py +++ b/tanjun/clients.py @@ -142,6 +142,11 @@ def __call__( class _GatewayBotProto(hikari.EventManagerAware, hikari.RESTAware, hikari.ShardAware, typing.Protocol): """Protocol of a cacheless Hikari Gateway bot.""" + _LoaderSigBase = collections.Callable[typing_extensions.Concatenate[_T, _P], None] + _LoaderSig = _LoaderSigBase[_T, ...] + _LoaderSigT = typing.TypeVar("_LoaderSigT", bound=_LoaderSig[tanjun.Client]) + _StdLoaderSigT = typing.TypeVar("_StdLoaderSigT", bound=_LoaderSig["Client"]) + # 3.9 and 3.10 just can't handle ending a Paramspec with ... so we lie at runtime about this. if typing.TYPE_CHECKING: @@ -191,10 +196,12 @@ def load(self, client: tanjun.Client, /) -> bool: if not isinstance(client, Client): raise TypeError("This loader requires instances of the standard Client implementation") - self._callback(client) + client.injector.call_with_di(self._callback, client) else: - typing.cast("collections.Callable[[tanjun.Client], None]", self._callback)(client) + client.injector.call_with_di( + typing.cast("collections.Callable[[tanjun.Client], None]", self._callback), client + ) return True @@ -231,54 +238,43 @@ def unload(self, client: tanjun.Client, /) -> bool: if not isinstance(client, Client): raise TypeError("This unloader requires instances of the standard Client implementation") - self._callback(client) + client.injector.call_with_di(self._callback, client) else: - typing.cast("collections.Callable[[tanjun.Client], None]", self._callback)(client) + client.injector.call_with_di( + typing.cast("collections.Callable[[tanjun.Client], None]", self._callback), client + ) return True @typing.overload -def as_loader( - callback: collections.Callable[[Client], None], /, *, standard_impl: typing.Literal[True] = True -) -> collections.Callable[[Client], None]: +def as_loader(callback: _StdLoaderSigT, /, *, standard_impl: typing.Literal[True] = True) -> _StdLoaderSigT: ... @typing.overload -def as_loader( - *, standard_impl: typing.Literal[True] = True -) -> collections.Callable[[collections.Callable[[Client], None]], collections.Callable[[Client], None]]: +def as_loader(*, standard_impl: typing.Literal[True] = True) -> collections.Callable[[_StdLoaderSigT], _StdLoaderSigT]: ... @typing.overload -def as_loader( - callback: collections.Callable[[tanjun.Client], None], /, *, standard_impl: typing.Literal[False] -) -> collections.Callable[[tanjun.Client], None]: +def as_loader(callback: _LoaderSigT, /, *, standard_impl: typing.Literal[False]) -> _LoaderSigT: ... @typing.overload -def as_loader( - *, standard_impl: typing.Literal[False] -) -> collections.Callable[[collections.Callable[[tanjun.Client], None]], collections.Callable[[tanjun.Client], None]]: +def as_loader(*, standard_impl: typing.Literal[False]) -> collections.Callable[[_LoaderSigT], _LoaderSigT]: ... def as_loader( - callback: typing.Union[ - collections.Callable[[tanjun.Client], None], collections.Callable[[Client], None], None - ] = None, - /, - *, - standard_impl: bool = True, + callback: typing.Union[_LoaderSigT, _StdLoaderSigT, None] = None, /, *, standard_impl: bool = True ) -> typing.Union[ - collections.Callable[[tanjun.Client], None], - collections.Callable[[Client], None], - collections.Callable[[collections.Callable[[Client], None]], collections.Callable[[Client], None]], - collections.Callable[[collections.Callable[[tanjun.Client], None]], collections.Callable[[tanjun.Client], None]], + _LoaderSigT, + _StdLoaderSigT, + collections.Callable[[_StdLoaderSigT], _StdLoaderSigT], + collections.Callable[[_LoaderSigT], _LoaderSigT], ]: """Mark a callback as being used to load Tanjun components from a module. @@ -294,6 +290,8 @@ def as_loader( [tanjun.abc.Client][] if `standard_impl` is [False][]), return nothing and will be expected to initiate and add utilities such as components to the provided client. + + This supports sync DI. standard_impl Whether this loader should only allow instances of [tanjun.Client][] as opposed to [tanjun.abc.Client][]. @@ -315,45 +313,34 @@ def decorator( @typing.overload -def as_unloader( - callback: collections.Callable[[Client], None], /, *, standard_impl: typing.Literal[True] = True -) -> collections.Callable[[Client], None]: +def as_unloader(callback: _StdLoaderSigT, /, *, standard_impl: typing.Literal[True] = True) -> _StdLoaderSigT: ... @typing.overload def as_unloader( *, standard_impl: typing.Literal[True] = True -) -> collections.Callable[[collections.Callable[[Client], None]], collections.Callable[[Client], None]]: +) -> collections.Callable[[_StdLoaderSigT], _StdLoaderSigT]: ... @typing.overload -def as_unloader( - callback: collections.Callable[[tanjun.Client], None], /, *, standard_impl: typing.Literal[False] -) -> collections.Callable[[tanjun.Client], None]: +def as_unloader(callback: _LoaderSigT, /, *, standard_impl: typing.Literal[False]) -> _LoaderSigT: ... @typing.overload -def as_unloader( - *, standard_impl: typing.Literal[False] -) -> collections.Callable[[collections.Callable[[tanjun.Client], None]], collections.Callable[[tanjun.Client], None]]: +def as_unloader(*, standard_impl: typing.Literal[False]) -> collections.Callable[[_LoaderSigT], _LoaderSigT]: ... def as_unloader( - callback: typing.Union[ - collections.Callable[[Client], None], collections.Callable[[tanjun.Client], None], None - ] = None, - /, - *, - standard_impl: bool = True, + callback: typing.Union[_StdLoaderSigT, _LoaderSigT, None] = None, /, *, standard_impl: bool = True ) -> typing.Union[ - collections.Callable[[Client], None], - collections.Callable[[tanjun.Client], None], - collections.Callable[[collections.Callable[[Client], None]], collections.Callable[[Client], None]], - collections.Callable[[collections.Callable[[tanjun.Client], None]], collections.Callable[[tanjun.Client], None]], + _StdLoaderSigT, + _LoaderSigT, + collections.Callable[[_StdLoaderSigT], _StdLoaderSigT], + collections.Callable[[_LoaderSigT], _LoaderSigT], ]: """Mark a callback as being used to unload a module's utilities from a client. @@ -371,6 +358,8 @@ def as_unloader( [tanjun.abc.Client][] if `standard_impl` is [False][]), return nothing and will be expected to remove utilities such as components from the provided client. + + This supports sync DI. standard_impl Whether this unloader should only allow instances of [tanjun.Client][] as opposed to [tanjun.abc.Client][]. diff --git a/tanjun/dependencies/limiters.py b/tanjun/dependencies/limiters.py index 31060a76e..15d617322 100644 --- a/tanjun/dependencies/limiters.py +++ b/tanjun/dependencies/limiters.py @@ -67,12 +67,20 @@ if typing.TYPE_CHECKING: from collections import abc as collections + import typing_extensions from typing_extensions import Self + _P = typing_extensions.ParamSpec("_P") + _CommandT = typing.TypeVar("_CommandT", bound=tanjun.ExecutableCommand[typing.Any]) _OtherCommandT = typing.TypeVar("_OtherCommandT", bound=tanjun.ExecutableCommand[typing.Any]) _InnerResourceSig = collections.Callable[[], "_InnerResourceT"] + _ConcurrencyErrorSigBase = collections.Callable[typing_extensions.Concatenate[str, _P], Exception] + _ConcurrencyErrorSig = _ConcurrencyErrorSigBase[...] + _CooldownErrorSigBase = collections.Callable[typing_extensions.Concatenate[str, datetime.datetime, _P], Exception] + _CooldownErrorSig = _CooldownErrorSigBase[...] + _InnerResourceT = typing.TypeVar("_InnerResourceT", bound="_InnerResourceProto") @@ -662,7 +670,7 @@ def __init__( bucket_id: str, /, *, - error: typing.Optional[collections.Callable[[str, datetime.datetime], Exception]] = None, + error: typing.Optional[_CooldownErrorSig] = None, error_message: typing.Union[ str, collections.Mapping[str, str] ] = "This command is currently in cooldown. Try again {cooldown}.", @@ -677,9 +685,10 @@ def __init__( error Callback used to create a custom error to raise if the check fails. - This should two arguments one of type [str][] and [datetime.datetime][] - where the first is the limiting bucket's ID and the second is when said - bucket can be used again. + This should match the signature `def (str, datetime.timedelta ...) -> Exception` + where the arguments are the limiting bucket's ID and when the bucket + can be used again, sync dependency injection is supported and the + returned [Exception][] is raised. This takes priority over `error_message`. error_message @@ -725,7 +734,7 @@ def with_cooldown( bucket_id: str, /, *, - error: typing.Optional[collections.Callable[[str, datetime.datetime], Exception]] = None, + error: typing.Optional[_CooldownErrorSig] = None, error_message: typing.Union[ str, collections.Mapping[str, str] ] = "This command is currently in cooldown. Try again {cooldown}.", @@ -747,9 +756,10 @@ def with_cooldown( error Callback used to create a custom error to raise if the check fails. - This should two arguments one of type [str][] and [datetime.datetime][] - where the first is the limiting bucket's ID and the second is when said - bucket can be used again. + This should match the signature `def (str, datetime.timedelta ...) -> Exception` + where the arguments are the limiting bucket's ID and when the bucket + can be used again, sync dependency injection is supported and the + returned [Exception][] is raised. This takes priority over `error_message`. error_message @@ -1012,7 +1022,7 @@ def __init__( bucket_id: str, /, *, - error: typing.Optional[collections.Callable[[str], Exception]] = None, + error: typing.Optional[_ConcurrencyErrorSig] = None, error_message: typing.Union[ str, collections.Mapping[str, str] ] = "This resource is currently busy; please try again later.", @@ -1026,7 +1036,9 @@ def __init__( error Callback used to create a custom error to raise if the check fails. - This should two one [str][] argument which is the limiting bucket's ID. + This should match the signature `def (str, ...) -> Exception` where + the first argument is the limiting bucket's ID, sync dependency + injection is supported and the returned [Exception][] is raised. This takes priority over `error_message`. error_message @@ -1086,7 +1098,7 @@ def with_concurrency_limit( bucket_id: str, /, *, - error: typing.Optional[collections.Callable[[str], Exception]] = None, + error: typing.Optional[_ConcurrencyErrorSig] = None, error_message: typing.Union[ str, collections.Mapping[str, str] ] = "This resource is currently busy; please try again later.", @@ -1107,7 +1119,9 @@ def with_concurrency_limit( error Callback used to create a custom error to raise if the check fails. - This should two one [str][] argument which is the limiting bucket's ID. + This should match the signature `def (str, ...) -> Exception` where + the first argument is the limiting bucket's ID, sync dependency + injection is supported and the returned [Exception][] is raised. This takes priority over `error_message`. error_message diff --git a/tests/test_clients.py b/tests/test_clients.py index 6159fe8e6..67ce142fa 100644 --- a/tests/test_clients.py +++ b/tests/test_clients.py @@ -92,7 +92,6 @@ def test_load(self): mock_callback = mock.Mock() mock_client = mock.Mock(tanjun.Client) descriptor = tanjun.as_loader(mock_callback) - typing_extensions.assert_type(descriptor, collections.Callable[[tanjun.Client], None]) assert isinstance(descriptor, tanjun.clients._LoaderDescriptor) result = descriptor.load(mock_client) @@ -104,7 +103,6 @@ def test_load_when_called_as_decorator(self): mock_callback = mock.Mock() mock_client = mock.Mock(tanjun.Client) descriptor = tanjun.as_loader()(mock_callback) - typing_extensions.assert_type(descriptor, collections.Callable[[tanjun.Client], None]) assert isinstance(descriptor, tanjun.clients._LoaderDescriptor) result = descriptor.load(mock_client) @@ -116,7 +114,6 @@ def test_load_when_called_as_decorator_and_args_passed(self): mock_callback = mock.Mock() mock_client = mock.Mock(tanjun.Client) descriptor = tanjun.as_loader(standard_impl=True)(mock_callback) - typing_extensions.assert_type(descriptor, collections.Callable[[tanjun.Client], None]) assert isinstance(descriptor, tanjun.clients._LoaderDescriptor) result = descriptor.load(mock_client) @@ -127,7 +124,6 @@ def test_load_when_called_as_decorator_and_args_passed(self): def test_load_when_must_be_std_and_not_std(self): mock_callback = mock.Mock() descriptor = tanjun.as_loader(mock_callback) - typing_extensions.assert_type(descriptor, collections.Callable[[tanjun.Client], None]) assert isinstance(descriptor, tanjun.clients._LoaderDescriptor) with pytest.raises(TypeError, match="This loader requires instances of the standard Client implementation"): @@ -139,7 +135,6 @@ def test_load_when_abc_allowed(self): mock_callback = mock.Mock() mock_client = mock.Mock() descriptor = tanjun.as_loader(mock_callback, standard_impl=False) - typing_extensions.assert_type(descriptor, collections.Callable[[tanjun.abc.Client], None]) assert isinstance(descriptor, tanjun.clients._LoaderDescriptor) result = descriptor.load(mock_client) @@ -151,7 +146,6 @@ def test_load_when_abc_allowed_and_called_as_decorator(self): mock_callback = mock.Mock() mock_client = mock.Mock() descriptor = tanjun.as_loader(standard_impl=False)(mock_callback) - typing_extensions.assert_type(descriptor, collections.Callable[[tanjun.abc.Client], None]) assert isinstance(descriptor, tanjun.clients._LoaderDescriptor) result = descriptor.load(mock_client) @@ -205,7 +199,6 @@ def test_unload(self): mock_callback = mock.Mock() mock_client = mock.Mock(tanjun.Client) descriptor = tanjun.as_unloader(mock_callback) - typing_extensions.assert_type(descriptor, collections.Callable[[tanjun.Client], None]) assert isinstance(descriptor, tanjun.clients._UnloaderDescriptor) result = descriptor.unload(mock_client) @@ -217,7 +210,6 @@ def test_unload_when_called_as_decorator(self): mock_callback = mock.Mock() mock_client = mock.Mock(tanjun.Client) descriptor = tanjun.as_unloader()(mock_callback) - typing_extensions.assert_type(descriptor, collections.Callable[[tanjun.Client], None]) assert isinstance(descriptor, tanjun.clients._UnloaderDescriptor) result = descriptor.unload(mock_client) @@ -229,7 +221,6 @@ def test_unload_called_as_decorator_and_args_passed(self): mock_callback = mock.Mock() mock_client = mock.Mock(tanjun.Client) descriptor = tanjun.as_unloader(standard_impl=True)(mock_callback) - typing_extensions.assert_type(descriptor, collections.Callable[[tanjun.Client], None]) assert isinstance(descriptor, tanjun.clients._UnloaderDescriptor) result = descriptor.unload(mock_client) @@ -240,7 +231,6 @@ def test_unload_called_as_decorator_and_args_passed(self): def test_unload_when_must_be_std_and_not_std(self): mock_callback = mock.Mock() descriptor = tanjun.as_unloader(mock_callback) - typing_extensions.assert_type(descriptor, collections.Callable[[tanjun.Client], None]) assert isinstance(descriptor, tanjun.clients._UnloaderDescriptor) with pytest.raises(TypeError, match="This unloader requires instances of the standard Client implementation"): @@ -252,7 +242,6 @@ def test_unload_when_abc_allowed(self): mock_callback = mock.Mock() mock_client = mock.Mock() descriptor = tanjun.as_unloader(mock_callback, standard_impl=False) - typing_extensions.assert_type(descriptor, collections.Callable[[tanjun.abc.Client], None]) assert isinstance(descriptor, tanjun.clients._UnloaderDescriptor) result = descriptor.unload(mock_client) @@ -264,7 +253,6 @@ def test_unload_when_abc_allowed_and_called_as_decorator(self): mock_callback = mock.Mock() mock_client = mock.Mock() descriptor = tanjun.as_unloader(standard_impl=False)(mock_callback) - typing_extensions.assert_type(descriptor, collections.Callable[[tanjun.abc.Client], None]) assert isinstance(descriptor, tanjun.clients._UnloaderDescriptor) result = descriptor.unload(mock_client) From 89cdee79255d618e6b70ba5086f45cf9d39b33f6 Mon Sep 17 00:00:00 2001 From: Faster Speeding Date: Tue, 31 Jan 2023 12:46:25 +0000 Subject: [PATCH 2/2] Fix as_loader and as_unloader --- tanjun/clients.py | 67 +++++++++++++++++++++++++---------------------- 1 file changed, 35 insertions(+), 32 deletions(-) diff --git a/tanjun/clients.py b/tanjun/clients.py index caaa6c83c..6f48addab 100644 --- a/tanjun/clients.py +++ b/tanjun/clients.py @@ -145,8 +145,7 @@ class _GatewayBotProto(hikari.EventManagerAware, hikari.RESTAware, hikari.ShardA _LoaderSigBase = collections.Callable[typing_extensions.Concatenate[_T, _P], None] _LoaderSig = _LoaderSigBase[_T, ...] _LoaderSigT = typing.TypeVar("_LoaderSigT", bound=_LoaderSig[tanjun.Client]) - _StdLoaderSigT = typing.TypeVar("_StdLoaderSigT", bound=_LoaderSig["Client"]) - + _StdLoaderSigT = typing.TypeVar("_StdLoaderSigT", bound="_LoaderSig[Client]") # 3.9 and 3.10 just can't handle ending a Paramspec with ... so we lie at runtime about this. if typing.TYPE_CHECKING: @@ -170,12 +169,9 @@ class _GatewayBotProto(hikari.EventManagerAware, hikari.RESTAware, hikari.ShardA _MENU_TYPES = frozenset((hikari.CommandType.MESSAGE, hikari.CommandType.USER)) -class _LoaderDescriptor(tanjun.ClientLoader): # Slots mess with functools.update_wrapper - def __init__( - self, - callback: typing.Union[collections.Callable[[Client], None], collections.Callable[[tanjun.Client], None]], - standard_impl: bool, - ) -> None: +# Slots mess with functools.update_wrapper +class _LoaderDescriptor(tanjun.ClientLoader): + def __init__(self, callback: _LoaderSig[Client], standard_impl: bool) -> None: self._callback = callback self._must_be_std = standard_impl functools.update_wrapper(self, callback) @@ -199,9 +195,7 @@ def load(self, client: tanjun.Client, /) -> bool: client.injector.call_with_di(self._callback, client) else: - client.injector.call_with_di( - typing.cast("collections.Callable[[tanjun.Client], None]", self._callback), client - ) + client.injector.call_with_di(self._callback, client) return True @@ -209,12 +203,9 @@ def unload(self, _: tanjun.Client, /) -> bool: return False -class _UnloaderDescriptor(tanjun.ClientLoader): # Slots mess with functools.update_wrapper - def __init__( - self, - callback: typing.Union[collections.Callable[[Client], None], collections.Callable[[tanjun.Client], None]], - standard_impl: bool, - ) -> None: +# Slots mess with functools.update_wrapper +class _UnloaderDescriptor(tanjun.ClientLoader): + def __init__(self, callback: _LoaderSig[Client], standard_impl: bool) -> None: self._callback = callback self._must_be_std = standard_impl functools.update_wrapper(self, callback) @@ -241,9 +232,7 @@ def unload(self, client: tanjun.Client, /) -> bool: client.injector.call_with_di(self._callback, client) else: - client.injector.call_with_di( - typing.cast("collections.Callable[[tanjun.Client], None]", self._callback), client - ) + client.injector.call_with_di(self._callback, client) return True @@ -301,13 +290,20 @@ def as_loader( collections.abc.Callable[[tanjun.abc.Client], None]] The decorated load callback. """ - if callback: - return _LoaderDescriptor(callback, standard_impl) - def decorator( - callback: collections.Callable[[tanjun.Client], None], / - ) -> collections.Callable[[tanjun.Client], None]: - return _LoaderDescriptor(callback, standard_impl) + @typing.overload + def decorator(callback: _LoaderSigT, /) -> _LoaderSigT: + ... + + @typing.overload + def decorator(callback: _StdLoaderSigT, /) -> _StdLoaderSigT: + ... + + def decorator(callback: typing.Union[_LoaderSigT, _StdLoaderSigT], /) -> typing.Union[_LoaderSigT, _StdLoaderSigT]: + return typing.cast("typing.Union[_LoaderSigT, _StdLoaderSigT]", _LoaderDescriptor(callback, standard_impl)) + + if callback: + return decorator(callback) return decorator @@ -369,13 +365,20 @@ def as_unloader( collections.abc.Callable[[tanjun.abc.Client], None]] The decorated unload callback. """ - if callback: - return _UnloaderDescriptor(callback, standard_impl) - def decorator( - callback: collections.Callable[[tanjun.Client], None], / - ) -> collections.Callable[[tanjun.Client], None]: - return _UnloaderDescriptor(callback, standard_impl) + @typing.overload + def decorator(callback: _StdLoaderSigT, /) -> _StdLoaderSigT: + ... + + @typing.overload + def decorator(callback: _LoaderSigT, /) -> _LoaderSigT: + ... + + def decorator(callback: typing.Union[_StdLoaderSigT, _LoaderSigT], /) -> typing.Union[_StdLoaderSigT, _LoaderSigT]: + return typing.cast("typing.Union[_StdLoaderSigT, _LoaderSigT]", _UnloaderDescriptor(callback, standard_impl)) + + if callback: + return decorator(callback) return decorator