This commit is contained in:
rzmk 2021-05-18 18:48:42 -04:00
parent f9a2e68c8b
commit 95d0201681
1552 changed files with 367539 additions and 2 deletions

View file

@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
"""
discord.ext.commands
~~~~~~~~~~~~~~~~~~~~~
An extension module to facilitate creation of bot commands.
:copyright: (c) 2015-present Rapptz
:license: MIT, see LICENSE for more details.
"""
from .bot import Bot, AutoShardedBot, when_mentioned, when_mentioned_or
from .context import Context
from .core import *
from .errors import *
from .help import *
from .converter import *
from .cooldowns import *
from .cog import *

View file

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
# This is merely a tag type to avoid circular import issues.
# Yes, this is a terrible solution but ultimately it is the only solution.
class _BaseCommand:
__slots__ = ()

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,451 @@
# -*- coding: utf-8 -*-
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
import inspect
import copy
from ._types import _BaseCommand
__all__ = (
'CogMeta',
'Cog',
)
class CogMeta(type):
"""A metaclass for defining a cog.
Note that you should probably not use this directly. It is exposed
purely for documentation purposes along with making custom metaclasses to intermix
with other metaclasses such as the :class:`abc.ABCMeta` metaclass.
For example, to create an abstract cog mixin class, the following would be done.
.. code-block:: python3
import abc
class CogABCMeta(commands.CogMeta, abc.ABCMeta):
pass
class SomeMixin(metaclass=abc.ABCMeta):
pass
class SomeCogMixin(SomeMixin, commands.Cog, metaclass=CogABCMeta):
pass
.. note::
When passing an attribute of a metaclass that is documented below, note
that you must pass it as a keyword-only argument to the class creation
like the following example:
.. code-block:: python3
class MyCog(commands.Cog, name='My Cog'):
pass
Attributes
-----------
name: :class:`str`
The cog name. By default, it is the name of the class with no modification.
description: :class:`str`
The cog description. By default, it is the cleaned docstring of the class.
.. versionadded:: 1.6
command_attrs: :class:`dict`
A list of attributes to apply to every command inside this cog. The dictionary
is passed into the :class:`Command` options at ``__init__``.
If you specify attributes inside the command attribute in the class, it will
override the one specified inside this attribute. For example:
.. code-block:: python3
class MyCog(commands.Cog, command_attrs=dict(hidden=True)):
@commands.command()
async def foo(self, ctx):
pass # hidden -> True
@commands.command(hidden=False)
async def bar(self, ctx):
pass # hidden -> False
"""
def __new__(cls, *args, **kwargs):
name, bases, attrs = args
attrs['__cog_name__'] = kwargs.pop('name', name)
attrs['__cog_settings__'] = kwargs.pop('command_attrs', {})
description = kwargs.pop('description', None)
if description is None:
description = inspect.cleandoc(attrs.get('__doc__', ''))
attrs['__cog_description__'] = description
commands = {}
listeners = {}
no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})'
new_cls = super().__new__(cls, name, bases, attrs, **kwargs)
for base in reversed(new_cls.__mro__):
for elem, value in base.__dict__.items():
if elem in commands:
del commands[elem]
if elem in listeners:
del listeners[elem]
is_static_method = isinstance(value, staticmethod)
if is_static_method:
value = value.__func__
if isinstance(value, _BaseCommand):
if is_static_method:
raise TypeError('Command in method {0}.{1!r} must not be staticmethod.'.format(base, elem))
if elem.startswith(('cog_', 'bot_')):
raise TypeError(no_bot_cog.format(base, elem))
commands[elem] = value
elif inspect.iscoroutinefunction(value):
try:
getattr(value, '__cog_listener__')
except AttributeError:
continue
else:
if elem.startswith(('cog_', 'bot_')):
raise TypeError(no_bot_cog.format(base, elem))
listeners[elem] = value
new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__
listeners_as_list = []
for listener in listeners.values():
for listener_name in listener.__cog_listener_names__:
# I use __name__ instead of just storing the value so I can inject
# the self attribute when the time comes to add them to the bot
listeners_as_list.append((listener_name, listener.__name__))
new_cls.__cog_listeners__ = listeners_as_list
return new_cls
def __init__(self, *args, **kwargs):
super().__init__(*args)
@classmethod
def qualified_name(cls):
return cls.__cog_name__
def _cog_special_method(func):
func.__cog_special_method__ = None
return func
class Cog(metaclass=CogMeta):
"""The base class that all cogs must inherit from.
A cog is a collection of commands, listeners, and optional state to
help group commands together. More information on them can be found on
the :ref:`ext_commands_cogs` page.
When inheriting from this class, the options shown in :class:`CogMeta`
are equally valid here.
"""
def __new__(cls, *args, **kwargs):
# For issue 426, we need to store a copy of the command objects
# since we modify them to inject `self` to them.
# To do this, we need to interfere with the Cog creation process.
self = super().__new__(cls)
cmd_attrs = cls.__cog_settings__
# Either update the command with the cog provided defaults or copy it.
self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__)
lookup = {
cmd.qualified_name: cmd
for cmd in self.__cog_commands__
}
# Update the Command instances dynamically as well
for command in self.__cog_commands__:
setattr(self, command.callback.__name__, command)
parent = command.parent
if parent is not None:
# Get the latest parent reference
parent = lookup[parent.qualified_name]
# Update our parent's reference to our self
parent.remove_command(command.name)
parent.add_command(command)
return self
def get_commands(self):
r"""
Returns
--------
List[:class:`.Command`]
A :class:`list` of :class:`.Command`\s that are
defined inside this cog.
.. note::
This does not include subcommands.
"""
return [c for c in self.__cog_commands__ if c.parent is None]
@property
def qualified_name(self):
""":class:`str`: Returns the cog's specified name, not the class name."""
return self.__cog_name__
@property
def description(self):
""":class:`str`: Returns the cog's description, typically the cleaned docstring."""
return self.__cog_description__
@description.setter
def description(self, description):
self.__cog_description__ = description
def walk_commands(self):
"""An iterator that recursively walks through this cog's commands and subcommands.
Yields
------
Union[:class:`.Command`, :class:`.Group`]
A command or group from the cog.
"""
from .core import GroupMixin
for command in self.__cog_commands__:
if command.parent is None:
yield command
if isinstance(command, GroupMixin):
yield from command.walk_commands()
def get_listeners(self):
"""Returns a :class:`list` of (name, function) listener pairs that are defined in this cog.
Returns
--------
List[Tuple[:class:`str`, :ref:`coroutine <coroutine>`]]
The listeners defined in this cog.
"""
return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__]
@classmethod
def _get_overridden_method(cls, method):
"""Return None if the method is not overridden. Otherwise returns the overridden method."""
return getattr(method.__func__, '__cog_special_method__', method)
@classmethod
def listener(cls, name=None):
"""A decorator that marks a function as a listener.
This is the cog equivalent of :meth:`.Bot.listen`.
Parameters
------------
name: :class:`str`
The name of the event being listened to. If not provided, it
defaults to the function's name.
Raises
--------
TypeError
The function is not a coroutine function or a string was not passed as
the name.
"""
if name is not None and not isinstance(name, str):
raise TypeError('Cog.listener expected str but received {0.__class__.__name__!r} instead.'.format(name))
def decorator(func):
actual = func
if isinstance(actual, staticmethod):
actual = actual.__func__
if not inspect.iscoroutinefunction(actual):
raise TypeError('Listener function must be a coroutine function.')
actual.__cog_listener__ = True
to_assign = name or actual.__name__
try:
actual.__cog_listener_names__.append(to_assign)
except AttributeError:
actual.__cog_listener_names__ = [to_assign]
# we have to return `func` instead of `actual` because
# we need the type to be `staticmethod` for the metaclass
# to pick it up but the metaclass unfurls the function and
# thus the assignments need to be on the actual function
return func
return decorator
def has_error_handler(self):
""":class:`bool`: Checks whether the cog has an error handler.
.. versionadded:: 1.7
"""
return not hasattr(self.cog_command_error.__func__, '__cog_special_method__')
@_cog_special_method
def cog_unload(self):
"""A special method that is called when the cog gets removed.
This function **cannot** be a coroutine. It must be a regular
function.
Subclasses must replace this if they want special unloading behaviour.
"""
pass
@_cog_special_method
def bot_check_once(self, ctx):
"""A special method that registers as a :meth:`.Bot.check_once`
check.
This function **can** be a coroutine and must take a sole parameter,
``ctx``, to represent the :class:`.Context`.
"""
return True
@_cog_special_method
def bot_check(self, ctx):
"""A special method that registers as a :meth:`.Bot.check`
check.
This function **can** be a coroutine and must take a sole parameter,
``ctx``, to represent the :class:`.Context`.
"""
return True
@_cog_special_method
def cog_check(self, ctx):
"""A special method that registers as a :func:`commands.check`
for every command and subcommand in this cog.
This function **can** be a coroutine and must take a sole parameter,
``ctx``, to represent the :class:`.Context`.
"""
return True
@_cog_special_method
async def cog_command_error(self, ctx, error):
"""A special method that is called whenever an error
is dispatched inside this cog.
This is similar to :func:`.on_command_error` except only applying
to the commands inside this cog.
This **must** be a coroutine.
Parameters
-----------
ctx: :class:`.Context`
The invocation context where the error happened.
error: :class:`CommandError`
The error that happened.
"""
pass
@_cog_special_method
async def cog_before_invoke(self, ctx):
"""A special method that acts as a cog local pre-invoke hook.
This is similar to :meth:`.Command.before_invoke`.
This **must** be a coroutine.
Parameters
-----------
ctx: :class:`.Context`
The invocation context.
"""
pass
@_cog_special_method
async def cog_after_invoke(self, ctx):
"""A special method that acts as a cog local post-invoke hook.
This is similar to :meth:`.Command.after_invoke`.
This **must** be a coroutine.
Parameters
-----------
ctx: :class:`.Context`
The invocation context.
"""
pass
def _inject(self, bot):
cls = self.__class__
# realistically, the only thing that can cause loading errors
# is essentially just the command loading, which raises if there are
# duplicates. When this condition is met, we want to undo all what
# we've added so far for some form of atomic loading.
for index, command in enumerate(self.__cog_commands__):
command.cog = self
if command.parent is None:
try:
bot.add_command(command)
except Exception as e:
# undo our additions
for to_undo in self.__cog_commands__[:index]:
if to_undo.parent is None:
bot.remove_command(to_undo.name)
raise e
# check if we're overriding the default
if cls.bot_check is not Cog.bot_check:
bot.add_check(self.bot_check)
if cls.bot_check_once is not Cog.bot_check_once:
bot.add_check(self.bot_check_once, call_once=True)
# while Bot.add_listener can raise if it's not a coroutine,
# this precondition is already met by the listener decorator
# already, thus this should never raise.
# Outside of, memory errors and the like...
for name, method_name in self.__cog_listeners__:
bot.add_listener(getattr(self, method_name), name)
return self
def _eject(self, bot):
cls = self.__class__
try:
for command in self.__cog_commands__:
if command.parent is None:
bot.remove_command(command.name)
for _, method_name in self.__cog_listeners__:
bot.remove_listener(getattr(self, method_name))
if cls.bot_check is not Cog.bot_check:
bot.remove_check(self.bot_check)
if cls.bot_check_once is not Cog.bot_check_once:
bot.remove_check(self.bot_check_once, call_once=True)
finally:
try:
self.cog_unload()
except Exception:
pass

View file

@ -0,0 +1,340 @@
# -*- coding: utf-8 -*-
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
import discord.abc
import discord.utils
class Context(discord.abc.Messageable):
r"""Represents the context in which a command is being invoked under.
This class contains a lot of meta data to help you understand more about
the invocation context. This class is not created manually and is instead
passed around to commands as the first parameter.
This class implements the :class:`~discord.abc.Messageable` ABC.
Attributes
-----------
message: :class:`.Message`
The message that triggered the command being executed.
bot: :class:`.Bot`
The bot that contains the command being executed.
args: :class:`list`
The list of transformed arguments that were passed into the command.
If this is accessed during the :func:`on_command_error` event
then this list could be incomplete.
kwargs: :class:`dict`
A dictionary of transformed arguments that were passed into the command.
Similar to :attr:`args`\, if this is accessed in the
:func:`on_command_error` event then this dict could be incomplete.
prefix: :class:`str`
The prefix that was used to invoke the command.
command: :class:`Command`
The command that is being invoked currently.
invoked_with: :class:`str`
The command name that triggered this invocation. Useful for finding out
which alias called the command.
invoked_parents: List[:class:`str`]
The command names of the parents that triggered this invocation. Useful for
finding out which aliases called the command.
For example in commands ``?a b c test``, the invoked parents are ``['a', 'b', 'c']``.
.. versionadded:: 1.7
invoked_subcommand: :class:`Command`
The subcommand that was invoked.
If no valid subcommand was invoked then this is equal to ``None``.
subcommand_passed: Optional[:class:`str`]
The string that was attempted to call a subcommand. This does not have
to point to a valid registered subcommand and could just point to a
nonsense string. If nothing was passed to attempt a call to a
subcommand then this is set to ``None``.
command_failed: :class:`bool`
A boolean that indicates if the command failed to be parsed, checked,
or invoked.
"""
def __init__(self, **attrs):
self.message = attrs.pop('message', None)
self.bot = attrs.pop('bot', None)
self.args = attrs.pop('args', [])
self.kwargs = attrs.pop('kwargs', {})
self.prefix = attrs.pop('prefix')
self.command = attrs.pop('command', None)
self.view = attrs.pop('view', None)
self.invoked_with = attrs.pop('invoked_with', None)
self.invoked_parents = attrs.pop('invoked_parents', [])
self.invoked_subcommand = attrs.pop('invoked_subcommand', None)
self.subcommand_passed = attrs.pop('subcommand_passed', None)
self.command_failed = attrs.pop('command_failed', False)
self._state = self.message._state
async def invoke(self, *args, **kwargs):
r"""|coro|
Calls a command with the arguments given.
This is useful if you want to just call the callback that a
:class:`.Command` holds internally.
.. note::
This does not handle converters, checks, cooldowns, pre-invoke,
or after-invoke hooks in any matter. It calls the internal callback
directly as-if it was a regular function.
You must take care in passing the proper arguments when
using this function.
.. warning::
The first parameter passed **must** be the command being invoked.
Parameters
-----------
command: :class:`.Command`
The command that is going to be called.
\*args
The arguments to to use.
\*\*kwargs
The keyword arguments to use.
Raises
-------
TypeError
The command argument to invoke is missing.
"""
try:
command = args[0]
except IndexError:
raise TypeError('Missing command to invoke.') from None
arguments = []
if command.cog is not None:
arguments.append(command.cog)
arguments.append(self)
arguments.extend(args[1:])
ret = await command.callback(*arguments, **kwargs)
return ret
async def reinvoke(self, *, call_hooks=False, restart=True):
"""|coro|
Calls the command again.
This is similar to :meth:`~.Context.invoke` except that it bypasses
checks, cooldowns, and error handlers.
.. note::
If you want to bypass :exc:`.UserInputError` derived exceptions,
it is recommended to use the regular :meth:`~.Context.invoke`
as it will work more naturally. After all, this will end up
using the old arguments the user has used and will thus just
fail again.
Parameters
------------
call_hooks: :class:`bool`
Whether to call the before and after invoke hooks.
restart: :class:`bool`
Whether to start the call chain from the very beginning
or where we left off (i.e. the command that caused the error).
The default is to start where we left off.
Raises
-------
ValueError
The context to reinvoke is not valid.
"""
cmd = self.command
view = self.view
if cmd is None:
raise ValueError('This context is not valid.')
# some state to revert to when we're done
index, previous = view.index, view.previous
invoked_with = self.invoked_with
invoked_subcommand = self.invoked_subcommand
invoked_parents = self.invoked_parents
subcommand_passed = self.subcommand_passed
if restart:
to_call = cmd.root_parent or cmd
view.index = len(self.prefix)
view.previous = 0
self.invoked_parents = []
self.invoked_with = view.get_word() # advance to get the root command
else:
to_call = cmd
try:
await to_call.reinvoke(self, call_hooks=call_hooks)
finally:
self.command = cmd
view.index = index
view.previous = previous
self.invoked_with = invoked_with
self.invoked_subcommand = invoked_subcommand
self.invoked_parents = invoked_parents
self.subcommand_passed = subcommand_passed
@property
def valid(self):
""":class:`bool`: Checks if the invocation context is valid to be invoked with."""
return self.prefix is not None and self.command is not None
async def _get_channel(self):
return self.channel
@property
def cog(self):
"""Optional[:class:`.Cog`]: Returns the cog associated with this context's command. None if it does not exist."""
if self.command is None:
return None
return self.command.cog
@discord.utils.cached_property
def guild(self):
"""Optional[:class:`.Guild`]: Returns the guild associated with this context's command. None if not available."""
return self.message.guild
@discord.utils.cached_property
def channel(self):
"""Union[:class:`.abc.Messageable`]: Returns the channel associated with this context's command.
Shorthand for :attr:`.Message.channel`.
"""
return self.message.channel
@discord.utils.cached_property
def author(self):
"""Union[:class:`~discord.User`, :class:`.Member`]:
Returns the author associated with this context's command. Shorthand for :attr:`.Message.author`
"""
return self.message.author
@discord.utils.cached_property
def me(self):
"""Union[:class:`.Member`, :class:`.ClientUser`]:
Similar to :attr:`.Guild.me` except it may return the :class:`.ClientUser` in private message contexts.
"""
return self.guild.me if self.guild is not None else self.bot.user
@property
def voice_client(self):
r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable."""
g = self.guild
return g.voice_client if g else None
async def send_help(self, *args):
"""send_help(entity=<bot>)
|coro|
Shows the help command for the specified entity if given.
The entity can be a command or a cog.
If no entity is given, then it'll show help for the
entire bot.
If the entity is a string, then it looks up whether it's a
:class:`Cog` or a :class:`Command`.
.. note::
Due to the way this function works, instead of returning
something similar to :meth:`~.commands.HelpCommand.command_not_found`
this returns :class:`None` on bad input or no help command.
Parameters
------------
entity: Optional[Union[:class:`Command`, :class:`Cog`, :class:`str`]]
The entity to show help for.
Returns
--------
Any
The result of the help command, if any.
"""
from .core import Group, Command, wrap_callback
from .errors import CommandError
bot = self.bot
cmd = bot.help_command
if cmd is None:
return None
cmd = cmd.copy()
cmd.context = self
if len(args) == 0:
await cmd.prepare_help_command(self, None)
mapping = cmd.get_bot_mapping()
injected = wrap_callback(cmd.send_bot_help)
try:
return await injected(mapping)
except CommandError as e:
await cmd.on_help_command_error(self, e)
return None
entity = args[0]
if entity is None:
return None
if isinstance(entity, str):
entity = bot.get_cog(entity) or bot.get_command(entity)
try:
entity.qualified_name
except AttributeError:
# if we're here then it's not a cog, group, or command.
return None
await cmd.prepare_help_command(self, entity.qualified_name)
try:
if hasattr(entity, '__cog_commands__'):
injected = wrap_callback(cmd.send_cog_help)
return await injected(entity)
elif isinstance(entity, Group):
injected = wrap_callback(cmd.send_group_help)
return await injected(entity)
elif isinstance(entity, Command):
injected = wrap_callback(cmd.send_command_help)
return await injected(entity)
else:
return None
except CommandError as e:
await cmd.on_help_command_error(self, e)
@discord.utils.copy_doc(discord.Message.reply)
async def reply(self, content=None, **kwargs):
return await self.message.reply(content, **kwargs)

View file

@ -0,0 +1,852 @@
# -*- coding: utf-8 -*-
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
import re
import inspect
import typing
import discord
from .errors import *
__all__ = (
'Converter',
'MemberConverter',
'UserConverter',
'MessageConverter',
'PartialMessageConverter',
'TextChannelConverter',
'InviteConverter',
'GuildConverter',
'RoleConverter',
'GameConverter',
'ColourConverter',
'ColorConverter',
'VoiceChannelConverter',
'StageChannelConverter',
'EmojiConverter',
'PartialEmojiConverter',
'CategoryChannelConverter',
'IDConverter',
'StoreChannelConverter',
'clean_content',
'Greedy',
)
def _get_from_guilds(bot, getter, argument):
result = None
for guild in bot.guilds:
result = getattr(guild, getter)(argument)
if result:
return result
return result
_utils_get = discord.utils.get
class Converter:
"""The base class of custom converters that require the :class:`.Context`
to be passed to be useful.
This allows you to implement converters that function similar to the
special cased ``discord`` classes.
Classes that derive from this should override the :meth:`~.Converter.convert`
method to do its conversion logic. This method must be a :ref:`coroutine <coroutine>`.
"""
async def convert(self, ctx, argument):
"""|coro|
The method to override to do conversion logic.
If an error is found while converting, it is recommended to
raise a :exc:`.CommandError` derived exception as it will
properly propagate to the error handlers.
Parameters
-----------
ctx: :class:`.Context`
The invocation context that the argument is being used in.
argument: :class:`str`
The argument that is being converted.
Raises
-------
:exc:`.CommandError`
A generic exception occurred when converting the argument.
:exc:`.BadArgument`
The converter failed to convert the argument.
"""
raise NotImplementedError('Derived classes need to implement this.')
class IDConverter(Converter):
def __init__(self):
self._id_regex = re.compile(r'([0-9]{15,20})$')
super().__init__()
def _get_id_match(self, argument):
return self._id_regex.match(argument)
class MemberConverter(IDConverter):
"""Converts to a :class:`~discord.Member`.
All lookups are via the local guild. If in a DM context, then the lookup
is done by the global cache.
The lookup strategy is as follows (in order):
1. Lookup by ID.
2. Lookup by mention.
3. Lookup by name#discrim
4. Lookup by name
5. Lookup by nickname
.. versionchanged:: 1.5
Raise :exc:`.MemberNotFound` instead of generic :exc:`.BadArgument`
.. versionchanged:: 1.5.1
This converter now lazily fetches members from the gateway and HTTP APIs,
optionally caching the result if :attr:`.MemberCacheFlags.joined` is enabled.
"""
async def query_member_named(self, guild, argument):
cache = guild._state.member_cache_flags.joined
if len(argument) > 5 and argument[-5] == '#':
username, _, discriminator = argument.rpartition('#')
members = await guild.query_members(username, limit=100, cache=cache)
return discord.utils.get(members, name=username, discriminator=discriminator)
else:
members = await guild.query_members(argument, limit=100, cache=cache)
return discord.utils.find(lambda m: m.name == argument or m.nick == argument, members)
async def query_member_by_id(self, bot, guild, user_id):
ws = bot._get_websocket(shard_id=guild.shard_id)
cache = guild._state.member_cache_flags.joined
if ws.is_ratelimited():
# If we're being rate limited on the WS, then fall back to using the HTTP API
# So we don't have to wait ~60 seconds for the query to finish
try:
member = await guild.fetch_member(user_id)
except discord.HTTPException:
return None
if cache:
guild._add_member(member)
return member
# If we're not being rate limited then we can use the websocket to actually query
members = await guild.query_members(limit=1, user_ids=[user_id], cache=cache)
if not members:
return None
return members[0]
async def convert(self, ctx, argument):
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument)
guild = ctx.guild
result = None
user_id = None
if match is None:
# not a mention...
if guild:
result = guild.get_member_named(argument)
else:
result = _get_from_guilds(bot, 'get_member_named', argument)
else:
user_id = int(match.group(1))
if guild:
result = guild.get_member(user_id) or _utils_get(ctx.message.mentions, id=user_id)
else:
result = _get_from_guilds(bot, 'get_member', user_id)
if result is None:
if guild is None:
raise MemberNotFound(argument)
if user_id is not None:
result = await self.query_member_by_id(bot, guild, user_id)
else:
result = await self.query_member_named(guild, argument)
if not result:
raise MemberNotFound(argument)
return result
class UserConverter(IDConverter):
"""Converts to a :class:`~discord.User`.
All lookups are via the global user cache.
The lookup strategy is as follows (in order):
1. Lookup by ID.
2. Lookup by mention.
3. Lookup by name#discrim
4. Lookup by name
.. versionchanged:: 1.5
Raise :exc:`.UserNotFound` instead of generic :exc:`.BadArgument`
.. versionchanged:: 1.6
This converter now lazily fetches users from the HTTP APIs if an ID is passed
and it's not available in cache.
"""
async def convert(self, ctx, argument):
match = self._get_id_match(argument) or re.match(r'<@!?([0-9]+)>$', argument)
result = None
state = ctx._state
if match is not None:
user_id = int(match.group(1))
result = ctx.bot.get_user(user_id) or _utils_get(ctx.message.mentions, id=user_id)
if result is None:
try:
result = await ctx.bot.fetch_user(user_id)
except discord.HTTPException:
raise UserNotFound(argument) from None
return result
arg = argument
# Remove the '@' character if this is the first character from the argument
if arg[0] == '@':
# Remove first character
arg = arg[1:]
# check for discriminator if it exists,
if len(arg) > 5 and arg[-5] == '#':
discrim = arg[-4:]
name = arg[:-5]
predicate = lambda u: u.name == name and u.discriminator == discrim
result = discord.utils.find(predicate, state._users.values())
if result is not None:
return result
predicate = lambda u: u.name == arg
result = discord.utils.find(predicate, state._users.values())
if result is None:
raise UserNotFound(argument)
return result
class PartialMessageConverter(Converter):
"""Converts to a :class:`discord.PartialMessage`.
.. versionadded:: 1.7
The creation strategy is as follows (in order):
1. By "{channel ID}-{message ID}" (retrieved by shift-clicking on "Copy ID")
2. By message ID (The message is assumed to be in the context channel.)
3. By message URL
"""
def _get_id_matches(self, argument):
id_regex = re.compile(r'(?:(?P<channel_id>[0-9]{15,20})-)?(?P<message_id>[0-9]{15,20})$')
link_regex = re.compile(
r'https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/'
r'(?:[0-9]{15,20}|@me)'
r'/(?P<channel_id>[0-9]{15,20})/(?P<message_id>[0-9]{15,20})/?$'
)
match = id_regex.match(argument) or link_regex.match(argument)
if not match:
raise MessageNotFound(argument)
channel_id = match.group("channel_id")
return int(match.group("message_id")), int(channel_id) if channel_id else None
async def convert(self, ctx, argument):
message_id, channel_id = self._get_id_matches(argument)
channel = ctx.bot.get_channel(channel_id) if channel_id else ctx.channel
if not channel:
raise ChannelNotFound(channel_id)
return discord.PartialMessage(channel=channel, id=message_id)
class MessageConverter(PartialMessageConverter):
"""Converts to a :class:`discord.Message`.
.. versionadded:: 1.1
The lookup strategy is as follows (in order):
1. Lookup by "{channel ID}-{message ID}" (retrieved by shift-clicking on "Copy ID")
2. Lookup by message ID (the message **must** be in the context channel)
3. Lookup by message URL
.. versionchanged:: 1.5
Raise :exc:`.ChannelNotFound`, :exc:`.MessageNotFound` or :exc:`.ChannelNotReadable` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx, argument):
message_id, channel_id = self._get_id_matches(argument)
message = ctx.bot._connection._get_message(message_id)
if message:
return message
channel = ctx.bot.get_channel(channel_id) if channel_id else ctx.channel
if not channel:
raise ChannelNotFound(channel_id)
try:
return await channel.fetch_message(message_id)
except discord.NotFound:
raise MessageNotFound(argument)
except discord.Forbidden:
raise ChannelNotReadable(channel)
class TextChannelConverter(IDConverter):
"""Converts to a :class:`~discord.TextChannel`.
All lookups are via the local guild. If in a DM context, then the lookup
is done by the global cache.
The lookup strategy is as follows (in order):
1. Lookup by ID.
2. Lookup by mention.
3. Lookup by name
.. versionchanged:: 1.5
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx, argument):
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
result = None
guild = ctx.guild
if match is None:
# not a mention
if guild:
result = discord.utils.get(guild.text_channels, name=argument)
else:
def check(c):
return isinstance(c, discord.TextChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels())
else:
channel_id = int(match.group(1))
if guild:
result = guild.get_channel(channel_id)
else:
result = _get_from_guilds(bot, 'get_channel', channel_id)
if not isinstance(result, discord.TextChannel):
raise ChannelNotFound(argument)
return result
class VoiceChannelConverter(IDConverter):
"""Converts to a :class:`~discord.VoiceChannel`.
All lookups are via the local guild. If in a DM context, then the lookup
is done by the global cache.
The lookup strategy is as follows (in order):
1. Lookup by ID.
2. Lookup by mention.
3. Lookup by name
.. versionchanged:: 1.5
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx, argument):
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
result = None
guild = ctx.guild
if match is None:
# not a mention
if guild:
result = discord.utils.get(guild.voice_channels, name=argument)
else:
def check(c):
return isinstance(c, discord.VoiceChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels())
else:
channel_id = int(match.group(1))
if guild:
result = guild.get_channel(channel_id)
else:
result = _get_from_guilds(bot, 'get_channel', channel_id)
if not isinstance(result, discord.VoiceChannel):
raise ChannelNotFound(argument)
return result
class StageChannelConverter(IDConverter):
"""Converts to a :class:`~discord.StageChannel`.
.. versionadded:: 1.7
All lookups are via the local guild. If in a DM context, then the lookup
is done by the global cache.
The lookup strategy is as follows (in order):
1. Lookup by ID.
2. Lookup by mention.
3. Lookup by name
"""
async def convert(self, ctx, argument):
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
result = None
guild = ctx.guild
if match is None:
# not a mention
if guild:
result = discord.utils.get(guild.stage_channels, name=argument)
else:
def check(c):
return isinstance(c, discord.StageChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels())
else:
channel_id = int(match.group(1))
if guild:
result = guild.get_channel(channel_id)
else:
result = _get_from_guilds(bot, 'get_channel', channel_id)
if not isinstance(result, discord.StageChannel):
raise ChannelNotFound(argument)
return result
class CategoryChannelConverter(IDConverter):
"""Converts to a :class:`~discord.CategoryChannel`.
All lookups are via the local guild. If in a DM context, then the lookup
is done by the global cache.
The lookup strategy is as follows (in order):
1. Lookup by ID.
2. Lookup by mention.
3. Lookup by name
.. versionchanged:: 1.5
Raise :exc:`.ChannelNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx, argument):
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
result = None
guild = ctx.guild
if match is None:
# not a mention
if guild:
result = discord.utils.get(guild.categories, name=argument)
else:
def check(c):
return isinstance(c, discord.CategoryChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels())
else:
channel_id = int(match.group(1))
if guild:
result = guild.get_channel(channel_id)
else:
result = _get_from_guilds(bot, 'get_channel', channel_id)
if not isinstance(result, discord.CategoryChannel):
raise ChannelNotFound(argument)
return result
class StoreChannelConverter(IDConverter):
"""Converts to a :class:`~discord.StoreChannel`.
All lookups are via the local guild. If in a DM context, then the lookup
is done by the global cache.
The lookup strategy is as follows (in order):
1. Lookup by ID.
2. Lookup by mention.
3. Lookup by name.
.. versionadded:: 1.7
"""
async def convert(self, ctx, argument):
bot = ctx.bot
match = self._get_id_match(argument) or re.match(r'<#([0-9]+)>$', argument)
result = None
guild = ctx.guild
if match is None:
# not a mention
if guild:
result = discord.utils.get(guild.channels, name=argument)
else:
def check(c):
return isinstance(c, discord.StoreChannel) and c.name == argument
result = discord.utils.find(check, bot.get_all_channels())
else:
channel_id = int(match.group(1))
if guild:
result = guild.get_channel(channel_id)
else:
result = _get_from_guilds(bot, 'get_channel', channel_id)
if not isinstance(result, discord.StoreChannel):
raise ChannelNotFound(argument)
return result
class ColourConverter(Converter):
"""Converts to a :class:`~discord.Colour`.
.. versionchanged:: 1.5
Add an alias named ColorConverter
The following formats are accepted:
- ``0x<hex>``
- ``#<hex>``
- ``0x#<hex>``
- ``rgb(<number>, <number>, <number>)``
- Any of the ``classmethod`` in :class:`Colour`
- The ``_`` in the name can be optionally replaced with spaces.
Like CSS, ``<number>`` can be either 0-255 or 0-100% and ``<hex>`` can be
either a 6 digit hex number or a 3 digit hex shortcut (e.g. #fff).
.. versionchanged:: 1.5
Raise :exc:`.BadColourArgument` instead of generic :exc:`.BadArgument`
.. versionchanged:: 1.7
Added support for ``rgb`` function and 3-digit hex shortcuts
"""
RGB_REGEX = re.compile(r'rgb\s*\((?P<r>[0-9]{1,3}%?)\s*,\s*(?P<g>[0-9]{1,3}%?)\s*,\s*(?P<b>[0-9]{1,3}%?)\s*\)')
def parse_hex_number(self, argument):
arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument
try:
value = int(arg, base=16)
if not (0 <= value <= 0xFFFFFF):
raise BadColourArgument(argument)
except ValueError:
raise BadColourArgument(argument)
else:
return discord.Color(value=value)
def parse_rgb_number(self, argument, number):
if number[-1] == '%':
value = int(number[:-1])
if not (0 <= value <= 100):
raise BadColourArgument(argument)
return round(255 * (value / 100))
value = int(number)
if not (0 <= value <= 255):
raise BadColourArgument(argument)
return value
def parse_rgb(self, argument, *, regex=RGB_REGEX):
match = regex.match(argument)
if match is None:
raise BadColourArgument(argument)
red = self.parse_rgb_number(argument, match.group('r'))
green = self.parse_rgb_number(argument, match.group('g'))
blue = self.parse_rgb_number(argument, match.group('b'))
return discord.Color.from_rgb(red, green, blue)
async def convert(self, ctx, argument):
if argument[0] == '#':
return self.parse_hex_number(argument[1:])
if argument[0:2] == '0x':
rest = argument[2:]
# Legacy backwards compatible syntax
if rest.startswith('#'):
return self.parse_hex_number(rest[1:])
return self.parse_hex_number(rest)
arg = argument.lower()
if arg[0:3] == 'rgb':
return self.parse_rgb(arg)
arg = arg.replace(' ', '_')
method = getattr(discord.Colour, arg, None)
if arg.startswith('from_') or method is None or not inspect.ismethod(method):
raise BadColourArgument(arg)
return method()
ColorConverter = ColourConverter
class RoleConverter(IDConverter):
"""Converts to a :class:`~discord.Role`.
All lookups are via the local guild. If in a DM context, the converter raises
:exc:`.NoPrivateMessage` exception.
The lookup strategy is as follows (in order):
1. Lookup by ID.
2. Lookup by mention.
3. Lookup by name
.. versionchanged:: 1.5
Raise :exc:`.RoleNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx, argument):
guild = ctx.guild
if not guild:
raise NoPrivateMessage()
match = self._get_id_match(argument) or re.match(r'<@&([0-9]+)>$', argument)
if match:
result = guild.get_role(int(match.group(1)))
else:
result = discord.utils.get(guild._roles.values(), name=argument)
if result is None:
raise RoleNotFound(argument)
return result
class GameConverter(Converter):
"""Converts to :class:`~discord.Game`."""
async def convert(self, ctx, argument):
return discord.Game(name=argument)
class InviteConverter(Converter):
"""Converts to a :class:`~discord.Invite`.
This is done via an HTTP request using :meth:`.Bot.fetch_invite`.
.. versionchanged:: 1.5
Raise :exc:`.BadInviteArgument` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx, argument):
try:
invite = await ctx.bot.fetch_invite(argument)
return invite
except Exception as exc:
raise BadInviteArgument() from exc
class GuildConverter(IDConverter):
"""Converts to a :class:`~discord.Guild`.
The lookup strategy is as follows (in order):
1. Lookup by ID.
2. Lookup by name. (There is no disambiguation for Guilds with multiple matching names).
.. versionadded:: 1.7
"""
async def convert(self, ctx, argument):
match = self._get_id_match(argument)
result = None
if match is not None:
guild_id = int(match.group(1))
result = ctx.bot.get_guild(guild_id)
if result is None:
result = discord.utils.get(ctx.bot.guilds, name=argument)
if result is None:
raise GuildNotFound(argument)
return result
class EmojiConverter(IDConverter):
"""Converts to a :class:`~discord.Emoji`.
All lookups are done for the local guild first, if available. If that lookup
fails, then it checks the client's global cache.
The lookup strategy is as follows (in order):
1. Lookup by ID.
2. Lookup by extracting ID from the emoji.
3. Lookup by name
.. versionchanged:: 1.5
Raise :exc:`.EmojiNotFound` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx, argument):
match = self._get_id_match(argument) or re.match(r'<a?:[a-zA-Z0-9\_]+:([0-9]+)>$', argument)
result = None
bot = ctx.bot
guild = ctx.guild
if match is None:
# Try to get the emoji by name. Try local guild first.
if guild:
result = discord.utils.get(guild.emojis, name=argument)
if result is None:
result = discord.utils.get(bot.emojis, name=argument)
else:
emoji_id = int(match.group(1))
# Try to look up emoji by id.
if guild:
result = discord.utils.get(guild.emojis, id=emoji_id)
if result is None:
result = discord.utils.get(bot.emojis, id=emoji_id)
if result is None:
raise EmojiNotFound(argument)
return result
class PartialEmojiConverter(Converter):
"""Converts to a :class:`~discord.PartialEmoji`.
This is done by extracting the animated flag, name and ID from the emoji.
.. versionchanged:: 1.5
Raise :exc:`.PartialEmojiConversionFailure` instead of generic :exc:`.BadArgument`
"""
async def convert(self, ctx, argument):
match = re.match(r'<(a?):([a-zA-Z0-9\_]+):([0-9]+)>$', argument)
if match:
emoji_animated = bool(match.group(1))
emoji_name = match.group(2)
emoji_id = int(match.group(3))
return discord.PartialEmoji.with_state(ctx.bot._connection, animated=emoji_animated, name=emoji_name,
id=emoji_id)
raise PartialEmojiConversionFailure(argument)
class clean_content(Converter):
"""Converts the argument to mention scrubbed version of
said content.
This behaves similarly to :attr:`~discord.Message.clean_content`.
Attributes
------------
fix_channel_mentions: :class:`bool`
Whether to clean channel mentions.
use_nicknames: :class:`bool`
Whether to use nicknames when transforming mentions.
escape_markdown: :class:`bool`
Whether to also escape special markdown characters.
remove_markdown: :class:`bool`
Whether to also remove special markdown characters. This option is not supported with ``escape_markdown``
.. versionadded:: 1.7
"""
def __init__(self, *, fix_channel_mentions=False, use_nicknames=True, escape_markdown=False, remove_markdown=False):
self.fix_channel_mentions = fix_channel_mentions
self.use_nicknames = use_nicknames
self.escape_markdown = escape_markdown
self.remove_markdown = remove_markdown
async def convert(self, ctx, argument):
message = ctx.message
transformations = {}
if self.fix_channel_mentions and ctx.guild:
def resolve_channel(id, *, _get=ctx.guild.get_channel):
ch = _get(id)
return ('<#%s>' % id), ('#' + ch.name if ch else '#deleted-channel')
transformations.update(resolve_channel(channel) for channel in message.raw_channel_mentions)
if self.use_nicknames and ctx.guild:
def resolve_member(id, *, _get=ctx.guild.get_member):
m = _get(id)
return '@' + m.display_name if m else '@deleted-user'
else:
def resolve_member(id, *, _get=ctx.bot.get_user):
m = _get(id)
return '@' + m.name if m else '@deleted-user'
transformations.update(
('<@%s>' % member_id, resolve_member(member_id))
for member_id in message.raw_mentions
)
transformations.update(
('<@!%s>' % member_id, resolve_member(member_id))
for member_id in message.raw_mentions
)
if ctx.guild:
def resolve_role(_id, *, _find=ctx.guild.get_role):
r = _find(_id)
return '@' + r.name if r else '@deleted-role'
transformations.update(
('<@&%s>' % role_id, resolve_role(role_id))
for role_id in message.raw_role_mentions
)
def repl(obj):
return transformations.get(obj.group(0), '')
pattern = re.compile('|'.join(transformations.keys()))
result = pattern.sub(repl, argument)
if self.escape_markdown:
result = discord.utils.escape_markdown(result)
elif self.remove_markdown:
result = discord.utils.remove_markdown(result)
# Completely ensure no mentions escape:
return discord.utils.escape_mentions(result)
class _Greedy:
__slots__ = ('converter',)
def __init__(self, *, converter=None):
self.converter = converter
def __getitem__(self, params):
if not isinstance(params, tuple):
params = (params,)
if len(params) != 1:
raise TypeError('Greedy[...] only takes a single argument')
converter = params[0]
if not (callable(converter) or isinstance(converter, Converter) or hasattr(converter, '__origin__')):
raise TypeError('Greedy[...] expects a type or a Converter instance.')
if converter is str or converter is type(None) or converter is _Greedy:
raise TypeError('Greedy[%s] is invalid.' % converter.__name__)
if getattr(converter, '__origin__', None) is typing.Union and type(None) in converter.__args__:
raise TypeError('Greedy[%r] is invalid.' % converter)
return self.__class__(converter=converter)
Greedy = _Greedy()

View file

@ -0,0 +1,295 @@
# -*- coding: utf-8 -*-
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from discord.enums import Enum
import time
import asyncio
from collections import deque
from ...abc import PrivateChannel
from .errors import MaxConcurrencyReached
__all__ = (
'BucketType',
'Cooldown',
'CooldownMapping',
'MaxConcurrency',
)
class BucketType(Enum):
default = 0
user = 1
guild = 2
channel = 3
member = 4
category = 5
role = 6
def get_key(self, msg):
if self is BucketType.user:
return msg.author.id
elif self is BucketType.guild:
return (msg.guild or msg.author).id
elif self is BucketType.channel:
return msg.channel.id
elif self is BucketType.member:
return ((msg.guild and msg.guild.id), msg.author.id)
elif self is BucketType.category:
return (msg.channel.category or msg.channel).id
elif self is BucketType.role:
# we return the channel id of a private-channel as there are only roles in guilds
# and that yields the same result as for a guild with only the @everyone role
# NOTE: PrivateChannel doesn't actually have an id attribute but we assume we are
# recieving a DMChannel or GroupChannel which inherit from PrivateChannel and do
return (msg.channel if isinstance(msg.channel, PrivateChannel) else msg.author.top_role).id
def __call__(self, msg):
return self.get_key(msg)
class Cooldown:
__slots__ = ('rate', 'per', 'type', '_window', '_tokens', '_last')
def __init__(self, rate, per, type):
self.rate = int(rate)
self.per = float(per)
self.type = type
self._window = 0.0
self._tokens = self.rate
self._last = 0.0
if not callable(self.type):
raise TypeError('Cooldown type must be a BucketType or callable')
def get_tokens(self, current=None):
if not current:
current = time.time()
tokens = self._tokens
if current > self._window + self.per:
tokens = self.rate
return tokens
def get_retry_after(self, current=None):
current = current or time.time()
tokens = self.get_tokens(current)
if tokens == 0:
return self.per - (current - self._window)
return 0.0
def update_rate_limit(self, current=None):
current = current or time.time()
self._last = current
self._tokens = self.get_tokens(current)
# first token used means that we start a new rate limit window
if self._tokens == self.rate:
self._window = current
# check if we are rate limited
if self._tokens == 0:
return self.per - (current - self._window)
# we're not so decrement our tokens
self._tokens -= 1
# see if we got rate limited due to this token change, and if
# so update the window to point to our current time frame
if self._tokens == 0:
self._window = current
def reset(self):
self._tokens = self.rate
self._last = 0.0
def copy(self):
return Cooldown(self.rate, self.per, self.type)
def __repr__(self):
return '<Cooldown rate: {0.rate} per: {0.per} window: {0._window} tokens: {0._tokens}>'.format(self)
class CooldownMapping:
def __init__(self, original):
self._cache = {}
self._cooldown = original
def copy(self):
ret = CooldownMapping(self._cooldown)
ret._cache = self._cache.copy()
return ret
@property
def valid(self):
return self._cooldown is not None
@classmethod
def from_cooldown(cls, rate, per, type):
return cls(Cooldown(rate, per, type))
def _bucket_key(self, msg):
return self._cooldown.type(msg)
def _verify_cache_integrity(self, current=None):
# we want to delete all cache objects that haven't been used
# in a cooldown window. e.g. if we have a command that has a
# cooldown of 60s and it has not been used in 60s then that key should be deleted
current = current or time.time()
dead_keys = [k for k, v in self._cache.items() if current > v._last + v.per]
for k in dead_keys:
del self._cache[k]
def get_bucket(self, message, current=None):
if self._cooldown.type is BucketType.default:
return self._cooldown
self._verify_cache_integrity(current)
key = self._bucket_key(message)
if key not in self._cache:
bucket = self._cooldown.copy()
self._cache[key] = bucket
else:
bucket = self._cache[key]
return bucket
def update_rate_limit(self, message, current=None):
bucket = self.get_bucket(message, current)
return bucket.update_rate_limit(current)
class _Semaphore:
"""This class is a version of a semaphore.
If you're wondering why asyncio.Semaphore isn't being used,
it's because it doesn't expose the internal value. This internal
value is necessary because I need to support both `wait=True` and
`wait=False`.
An asyncio.Queue could have been used to do this as well -- but it is
not as inefficient since internally that uses two queues and is a bit
overkill for what is basically a counter.
"""
__slots__ = ('value', 'loop', '_waiters')
def __init__(self, number):
self.value = number
self.loop = asyncio.get_event_loop()
self._waiters = deque()
def __repr__(self):
return '<_Semaphore value={0.value} waiters={1}>'.format(self, len(self._waiters))
def locked(self):
return self.value == 0
def is_active(self):
return len(self._waiters) > 0
def wake_up(self):
while self._waiters:
future = self._waiters.popleft()
if not future.done():
future.set_result(None)
return
async def acquire(self, *, wait=False):
if not wait and self.value <= 0:
# signal that we're not acquiring
return False
while self.value <= 0:
future = self.loop.create_future()
self._waiters.append(future)
try:
await future
except:
future.cancel()
if self.value > 0 and not future.cancelled():
self.wake_up()
raise
self.value -= 1
return True
def release(self):
self.value += 1
self.wake_up()
class MaxConcurrency:
__slots__ = ('number', 'per', 'wait', '_mapping')
def __init__(self, number, *, per, wait):
self._mapping = {}
self.per = per
self.number = number
self.wait = wait
if number <= 0:
raise ValueError('max_concurrency \'number\' cannot be less than 1')
if not isinstance(per, BucketType):
raise TypeError('max_concurrency \'per\' must be of type BucketType not %r' % type(per))
def copy(self):
return self.__class__(self.number, per=self.per, wait=self.wait)
def __repr__(self):
return '<MaxConcurrency per={0.per!r} number={0.number} wait={0.wait}>'.format(self)
def get_key(self, message):
return self.per.get_key(message)
async def acquire(self, message):
key = self.get_key(message)
try:
sem = self._mapping[key]
except KeyError:
self._mapping[key] = sem = _Semaphore(self.number)
acquired = await sem.acquire(wait=self.wait)
if not acquired:
raise MaxConcurrencyReached(self.number, self.per)
async def release(self, message):
# Technically there's no reason for this function to be async
# But it might be more useful in the future
key = self.get_key(message)
try:
sem = self._mapping[key]
except KeyError:
# ...? peculiar
return
else:
sem.release()
if sem.value >= self.number and not sem.is_active():
del self._mapping[key]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,811 @@
# -*- coding: utf-8 -*-
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from discord.errors import ClientException, DiscordException
__all__ = (
'CommandError',
'MissingRequiredArgument',
'BadArgument',
'PrivateMessageOnly',
'NoPrivateMessage',
'CheckFailure',
'CheckAnyFailure',
'CommandNotFound',
'DisabledCommand',
'CommandInvokeError',
'TooManyArguments',
'UserInputError',
'CommandOnCooldown',
'MaxConcurrencyReached',
'NotOwner',
'MessageNotFound',
'MemberNotFound',
'GuildNotFound',
'UserNotFound',
'ChannelNotFound',
'ChannelNotReadable',
'BadColourArgument',
'RoleNotFound',
'BadInviteArgument',
'EmojiNotFound',
'PartialEmojiConversionFailure',
'BadBoolArgument',
'MissingRole',
'BotMissingRole',
'MissingAnyRole',
'BotMissingAnyRole',
'MissingPermissions',
'BotMissingPermissions',
'NSFWChannelRequired',
'ConversionError',
'BadUnionArgument',
'ArgumentParsingError',
'UnexpectedQuoteError',
'InvalidEndOfQuotedStringError',
'ExpectedClosingQuoteError',
'ExtensionError',
'ExtensionAlreadyLoaded',
'ExtensionNotLoaded',
'NoEntryPointError',
'ExtensionFailed',
'ExtensionNotFound',
'CommandRegistrationError',
)
class CommandError(DiscordException):
r"""The base exception type for all command related errors.
This inherits from :exc:`discord.DiscordException`.
This exception and exceptions inherited from it are handled
in a special way as they are caught and passed into a special event
from :class:`.Bot`\, :func:`on_command_error`.
"""
def __init__(self, message=None, *args):
if message is not None:
# clean-up @everyone and @here mentions
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
super().__init__(m, *args)
else:
super().__init__(*args)
class ConversionError(CommandError):
"""Exception raised when a Converter class raises non-CommandError.
This inherits from :exc:`CommandError`.
Attributes
----------
converter: :class:`discord.ext.commands.Converter`
The converter that failed.
original: :exc:`Exception`
The original exception that was raised. You can also get this via
the ``__cause__`` attribute.
"""
def __init__(self, converter, original):
self.converter = converter
self.original = original
class UserInputError(CommandError):
"""The base exception type for errors that involve errors
regarding user input.
This inherits from :exc:`CommandError`.
"""
pass
class CommandNotFound(CommandError):
"""Exception raised when a command is attempted to be invoked
but no command under that name is found.
This is not raised for invalid subcommands, rather just the
initial main command that is attempted to be invoked.
This inherits from :exc:`CommandError`.
"""
pass
class MissingRequiredArgument(UserInputError):
"""Exception raised when parsing a command and a parameter
that is required is not encountered.
This inherits from :exc:`UserInputError`
Attributes
-----------
param: :class:`inspect.Parameter`
The argument that is missing.
"""
def __init__(self, param):
self.param = param
super().__init__('{0.name} is a required argument that is missing.'.format(param))
class TooManyArguments(UserInputError):
"""Exception raised when the command was passed too many arguments and its
:attr:`.Command.ignore_extra` attribute was not set to ``True``.
This inherits from :exc:`UserInputError`
"""
pass
class BadArgument(UserInputError):
"""Exception raised when a parsing or conversion failure is encountered
on an argument to pass into a command.
This inherits from :exc:`UserInputError`
"""
pass
class CheckFailure(CommandError):
"""Exception raised when the predicates in :attr:`.Command.checks` have failed.
This inherits from :exc:`CommandError`
"""
pass
class CheckAnyFailure(CheckFailure):
"""Exception raised when all predicates in :func:`check_any` fail.
This inherits from :exc:`CheckFailure`.
.. versionadded:: 1.3
Attributes
------------
errors: List[:class:`CheckFailure`]
A list of errors that were caught during execution.
checks: List[Callable[[:class:`Context`], :class:`bool`]]
A list of check predicates that failed.
"""
def __init__(self, checks, errors):
self.checks = checks
self.errors = errors
super().__init__('You do not have permission to run this command.')
class PrivateMessageOnly(CheckFailure):
"""Exception raised when an operation does not work outside of private
message contexts.
This inherits from :exc:`CheckFailure`
"""
def __init__(self, message=None):
super().__init__(message or 'This command can only be used in private messages.')
class NoPrivateMessage(CheckFailure):
"""Exception raised when an operation does not work in private message
contexts.
This inherits from :exc:`CheckFailure`
"""
def __init__(self, message=None):
super().__init__(message or 'This command cannot be used in private messages.')
class NotOwner(CheckFailure):
"""Exception raised when the message author is not the owner of the bot.
This inherits from :exc:`CheckFailure`
"""
pass
class MemberNotFound(BadArgument):
"""Exception raised when the member provided was not found in the bot's
cache.
This inherits from :exc:`BadArgument`
.. versionadded:: 1.5
Attributes
-----------
argument: :class:`str`
The member supplied by the caller that was not found
"""
def __init__(self, argument):
self.argument = argument
super().__init__('Member "{}" not found.'.format(argument))
class GuildNotFound(BadArgument):
"""Exception raised when the guild provided was not found in the bot's cache.
This inherits from :exc:`BadArgument`
.. versionadded:: 1.7
Attributes
-----------
argument: :class:`str`
The guild supplied by the called that was not found
"""
def __init__(self, argument):
self.argument = argument
super().__init__('Guild "{}" not found.'.format(argument))
class UserNotFound(BadArgument):
"""Exception raised when the user provided was not found in the bot's
cache.
This inherits from :exc:`BadArgument`
.. versionadded:: 1.5
Attributes
-----------
argument: :class:`str`
The user supplied by the caller that was not found
"""
def __init__(self, argument):
self.argument = argument
super().__init__('User "{}" not found.'.format(argument))
class MessageNotFound(BadArgument):
"""Exception raised when the message provided was not found in the channel.
This inherits from :exc:`BadArgument`
.. versionadded:: 1.5
Attributes
-----------
argument: :class:`str`
The message supplied by the caller that was not found
"""
def __init__(self, argument):
self.argument = argument
super().__init__('Message "{}" not found.'.format(argument))
class ChannelNotReadable(BadArgument):
"""Exception raised when the bot does not have permission to read messages
in the channel.
This inherits from :exc:`BadArgument`
.. versionadded:: 1.5
Attributes
-----------
argument: :class:`.abc.GuildChannel`
The channel supplied by the caller that was not readable
"""
def __init__(self, argument):
self.argument = argument
super().__init__("Can't read messages in {}.".format(argument.mention))
class ChannelNotFound(BadArgument):
"""Exception raised when the bot can not find the channel.
This inherits from :exc:`BadArgument`
.. versionadded:: 1.5
Attributes
-----------
argument: :class:`str`
The channel supplied by the caller that was not found
"""
def __init__(self, argument):
self.argument = argument
super().__init__('Channel "{}" not found.'.format(argument))
class BadColourArgument(BadArgument):
"""Exception raised when the colour is not valid.
This inherits from :exc:`BadArgument`
.. versionadded:: 1.5
Attributes
-----------
argument: :class:`str`
The colour supplied by the caller that was not valid
"""
def __init__(self, argument):
self.argument = argument
super().__init__('Colour "{}" is invalid.'.format(argument))
BadColorArgument = BadColourArgument
class RoleNotFound(BadArgument):
"""Exception raised when the bot can not find the role.
This inherits from :exc:`BadArgument`
.. versionadded:: 1.5
Attributes
-----------
argument: :class:`str`
The role supplied by the caller that was not found
"""
def __init__(self, argument):
self.argument = argument
super().__init__('Role "{}" not found.'.format(argument))
class BadInviteArgument(BadArgument):
"""Exception raised when the invite is invalid or expired.
This inherits from :exc:`BadArgument`
.. versionadded:: 1.5
"""
def __init__(self):
super().__init__('Invite is invalid or expired.')
class EmojiNotFound(BadArgument):
"""Exception raised when the bot can not find the emoji.
This inherits from :exc:`BadArgument`
.. versionadded:: 1.5
Attributes
-----------
argument: :class:`str`
The emoji supplied by the caller that was not found
"""
def __init__(self, argument):
self.argument = argument
super().__init__('Emoji "{}" not found.'.format(argument))
class PartialEmojiConversionFailure(BadArgument):
"""Exception raised when the emoji provided does not match the correct
format.
This inherits from :exc:`BadArgument`
.. versionadded:: 1.5
Attributes
-----------
argument: :class:`str`
The emoji supplied by the caller that did not match the regex
"""
def __init__(self, argument):
self.argument = argument
super().__init__('Couldn\'t convert "{}" to PartialEmoji.'.format(argument))
class BadBoolArgument(BadArgument):
"""Exception raised when a boolean argument was not convertable.
This inherits from :exc:`BadArgument`
.. versionadded:: 1.5
Attributes
-----------
argument: :class:`str`
The boolean argument supplied by the caller that is not in the predefined list
"""
def __init__(self, argument):
self.argument = argument
super().__init__('{} is not a recognised boolean option'.format(argument))
class DisabledCommand(CommandError):
"""Exception raised when the command being invoked is disabled.
This inherits from :exc:`CommandError`
"""
pass
class CommandInvokeError(CommandError):
"""Exception raised when the command being invoked raised an exception.
This inherits from :exc:`CommandError`
Attributes
-----------
original: :exc:`Exception`
The original exception that was raised. You can also get this via
the ``__cause__`` attribute.
"""
def __init__(self, e):
self.original = e
super().__init__('Command raised an exception: {0.__class__.__name__}: {0}'.format(e))
class CommandOnCooldown(CommandError):
"""Exception raised when the command being invoked is on cooldown.
This inherits from :exc:`CommandError`
Attributes
-----------
cooldown: Cooldown
A class with attributes ``rate``, ``per``, and ``type`` similar to
the :func:`.cooldown` decorator.
retry_after: :class:`float`
The amount of seconds to wait before you can retry again.
"""
def __init__(self, cooldown, retry_after):
self.cooldown = cooldown
self.retry_after = retry_after
super().__init__('You are on cooldown. Try again in {:.2f}s'.format(retry_after))
class MaxConcurrencyReached(CommandError):
"""Exception raised when the command being invoked has reached its maximum concurrency.
This inherits from :exc:`CommandError`.
Attributes
------------
number: :class:`int`
The maximum number of concurrent invokers allowed.
per: :class:`.BucketType`
The bucket type passed to the :func:`.max_concurrency` decorator.
"""
def __init__(self, number, per):
self.number = number
self.per = per
name = per.name
suffix = 'per %s' % name if per.name != 'default' else 'globally'
plural = '%s times %s' if number > 1 else '%s time %s'
fmt = plural % (number, suffix)
super().__init__('Too many people using this command. It can only be used {} concurrently.'.format(fmt))
class MissingRole(CheckFailure):
"""Exception raised when the command invoker lacks a role to run a command.
This inherits from :exc:`CheckFailure`
.. versionadded:: 1.1
Attributes
-----------
missing_role: Union[:class:`str`, :class:`int`]
The required role that is missing.
This is the parameter passed to :func:`~.commands.has_role`.
"""
def __init__(self, missing_role):
self.missing_role = missing_role
message = 'Role {0!r} is required to run this command.'.format(missing_role)
super().__init__(message)
class BotMissingRole(CheckFailure):
"""Exception raised when the bot's member lacks a role to run a command.
This inherits from :exc:`CheckFailure`
.. versionadded:: 1.1
Attributes
-----------
missing_role: Union[:class:`str`, :class:`int`]
The required role that is missing.
This is the parameter passed to :func:`~.commands.has_role`.
"""
def __init__(self, missing_role):
self.missing_role = missing_role
message = 'Bot requires the role {0!r} to run this command'.format(missing_role)
super().__init__(message)
class MissingAnyRole(CheckFailure):
"""Exception raised when the command invoker lacks any of
the roles specified to run a command.
This inherits from :exc:`CheckFailure`
.. versionadded:: 1.1
Attributes
-----------
missing_roles: List[Union[:class:`str`, :class:`int`]]
The roles that the invoker is missing.
These are the parameters passed to :func:`~.commands.has_any_role`.
"""
def __init__(self, missing_roles):
self.missing_roles = missing_roles
missing = ["'{}'".format(role) for role in missing_roles]
if len(missing) > 2:
fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' or '.join(missing)
message = "You are missing at least one of the required roles: {}".format(fmt)
super().__init__(message)
class BotMissingAnyRole(CheckFailure):
"""Exception raised when the bot's member lacks any of
the roles specified to run a command.
This inherits from :exc:`CheckFailure`
.. versionadded:: 1.1
Attributes
-----------
missing_roles: List[Union[:class:`str`, :class:`int`]]
The roles that the bot's member is missing.
These are the parameters passed to :func:`~.commands.has_any_role`.
"""
def __init__(self, missing_roles):
self.missing_roles = missing_roles
missing = ["'{}'".format(role) for role in missing_roles]
if len(missing) > 2:
fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' or '.join(missing)
message = "Bot is missing at least one of the required roles: {}".format(fmt)
super().__init__(message)
class NSFWChannelRequired(CheckFailure):
"""Exception raised when a channel does not have the required NSFW setting.
This inherits from :exc:`CheckFailure`.
.. versionadded:: 1.1
Parameters
-----------
channel: :class:`discord.abc.GuildChannel`
The channel that does not have NSFW enabled.
"""
def __init__(self, channel):
self.channel = channel
super().__init__("Channel '{}' needs to be NSFW for this command to work.".format(channel))
class MissingPermissions(CheckFailure):
"""Exception raised when the command invoker lacks permissions to run a
command.
This inherits from :exc:`CheckFailure`
Attributes
-----------
missing_perms: :class:`list`
The required permissions that are missing.
"""
def __init__(self, missing_perms, *args):
self.missing_perms = missing_perms
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_perms]
if len(missing) > 2:
fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' and '.join(missing)
message = 'You are missing {} permission(s) to run this command.'.format(fmt)
super().__init__(message, *args)
class BotMissingPermissions(CheckFailure):
"""Exception raised when the bot's member lacks permissions to run a
command.
This inherits from :exc:`CheckFailure`
Attributes
-----------
missing_perms: :class:`list`
The required permissions that are missing.
"""
def __init__(self, missing_perms, *args):
self.missing_perms = missing_perms
missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_perms]
if len(missing) > 2:
fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1])
else:
fmt = ' and '.join(missing)
message = 'Bot requires {} permission(s) to run this command.'.format(fmt)
super().__init__(message, *args)
class BadUnionArgument(UserInputError):
"""Exception raised when a :data:`typing.Union` converter fails for all
its associated types.
This inherits from :exc:`UserInputError`
Attributes
-----------
param: :class:`inspect.Parameter`
The parameter that failed being converted.
converters: Tuple[Type, ...]
A tuple of converters attempted in conversion, in order of failure.
errors: List[:class:`CommandError`]
A list of errors that were caught from failing the conversion.
"""
def __init__(self, param, converters, errors):
self.param = param
self.converters = converters
self.errors = errors
def _get_name(x):
try:
return x.__name__
except AttributeError:
return x.__class__.__name__
to_string = [_get_name(x) for x in converters]
if len(to_string) > 2:
fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1])
else:
fmt = ' or '.join(to_string)
super().__init__('Could not convert "{0.name}" into {1}.'.format(param, fmt))
class ArgumentParsingError(UserInputError):
"""An exception raised when the parser fails to parse a user's input.
This inherits from :exc:`UserInputError`.
There are child classes that implement more granular parsing errors for
i18n purposes.
"""
pass
class UnexpectedQuoteError(ArgumentParsingError):
"""An exception raised when the parser encounters a quote mark inside a non-quoted string.
This inherits from :exc:`ArgumentParsingError`.
Attributes
------------
quote: :class:`str`
The quote mark that was found inside the non-quoted string.
"""
def __init__(self, quote):
self.quote = quote
super().__init__('Unexpected quote mark, {0!r}, in non-quoted string'.format(quote))
class InvalidEndOfQuotedStringError(ArgumentParsingError):
"""An exception raised when a space is expected after the closing quote in a string
but a different character is found.
This inherits from :exc:`ArgumentParsingError`.
Attributes
-----------
char: :class:`str`
The character found instead of the expected string.
"""
def __init__(self, char):
self.char = char
super().__init__('Expected space after closing quotation but received {0!r}'.format(char))
class ExpectedClosingQuoteError(ArgumentParsingError):
"""An exception raised when a quote character is expected but not found.
This inherits from :exc:`ArgumentParsingError`.
Attributes
-----------
close_quote: :class:`str`
The quote character expected.
"""
def __init__(self, close_quote):
self.close_quote = close_quote
super().__init__('Expected closing {}.'.format(close_quote))
class ExtensionError(DiscordException):
"""Base exception for extension related errors.
This inherits from :exc:`~discord.DiscordException`.
Attributes
------------
name: :class:`str`
The extension that had an error.
"""
def __init__(self, message=None, *args, name):
self.name = name
message = message or 'Extension {!r} had an error.'.format(name)
# clean-up @everyone and @here mentions
m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere')
super().__init__(m, *args)
class ExtensionAlreadyLoaded(ExtensionError):
"""An exception raised when an extension has already been loaded.
This inherits from :exc:`ExtensionError`
"""
def __init__(self, name):
super().__init__('Extension {!r} is already loaded.'.format(name), name=name)
class ExtensionNotLoaded(ExtensionError):
"""An exception raised when an extension was not loaded.
This inherits from :exc:`ExtensionError`
"""
def __init__(self, name):
super().__init__('Extension {!r} has not been loaded.'.format(name), name=name)
class NoEntryPointError(ExtensionError):
"""An exception raised when an extension does not have a ``setup`` entry point function.
This inherits from :exc:`ExtensionError`
"""
def __init__(self, name):
super().__init__("Extension {!r} has no 'setup' function.".format(name), name=name)
class ExtensionFailed(ExtensionError):
"""An exception raised when an extension failed to load during execution of the module or ``setup`` entry point.
This inherits from :exc:`ExtensionError`
Attributes
-----------
name: :class:`str`
The extension that had the error.
original: :exc:`Exception`
The original exception that was raised. You can also get this via
the ``__cause__`` attribute.
"""
def __init__(self, name, original):
self.original = original
fmt = 'Extension {0!r} raised an error: {1.__class__.__name__}: {1}'
super().__init__(fmt.format(name, original), name=name)
class ExtensionNotFound(ExtensionError):
"""An exception raised when an extension is not found.
This inherits from :exc:`ExtensionError`
.. versionchanged:: 1.3
Made the ``original`` attribute always None.
Attributes
-----------
name: :class:`str`
The extension that had the error.
original: :class:`NoneType`
Always ``None`` for backwards compatibility.
"""
def __init__(self, name, original=None):
self.original = None
fmt = 'Extension {0!r} could not be loaded.'
super().__init__(fmt.format(name), name=name)
class CommandRegistrationError(ClientException):
"""An exception raised when the command can't be added
because the name is already taken by a different command.
This inherits from :exc:`discord.ClientException`
.. versionadded:: 1.4
Attributes
----------
name: :class:`str`
The command name that had the error.
alias_conflict: :class:`bool`
Whether the name that conflicts is an alias of the command we try to add.
"""
def __init__(self, name, *, alias_conflict=False):
self.name = name
self.alias_conflict = alias_conflict
type_ = 'alias' if alias_conflict else 'command'
super().__init__('The {} {} is already an existing command or alias.'.format(type_, name))

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,194 @@
# -*- coding: utf-8 -*-
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError
# map from opening quotes to closing quotes
_quotes = {
'"': '"',
"": "",
"": "",
"": "",
"": "",
"": "",
"": "",
"": "",
"": "",
"": "",
"": "",
"": "",
"": "",
"«": "»",
"": "",
"": "",
"": "",
}
_all_quotes = set(_quotes.keys()) | set(_quotes.values())
class StringView:
def __init__(self, buffer):
self.index = 0
self.buffer = buffer
self.end = len(buffer)
self.previous = 0
@property
def current(self):
return None if self.eof else self.buffer[self.index]
@property
def eof(self):
return self.index >= self.end
def undo(self):
self.index = self.previous
def skip_ws(self):
pos = 0
while not self.eof:
try:
current = self.buffer[self.index + pos]
if not current.isspace():
break
pos += 1
except IndexError:
break
self.previous = self.index
self.index += pos
return self.previous != self.index
def skip_string(self, string):
strlen = len(string)
if self.buffer[self.index:self.index + strlen] == string:
self.previous = self.index
self.index += strlen
return True
return False
def read_rest(self):
result = self.buffer[self.index:]
self.previous = self.index
self.index = self.end
return result
def read(self, n):
result = self.buffer[self.index:self.index + n]
self.previous = self.index
self.index += n
return result
def get(self):
try:
result = self.buffer[self.index + 1]
except IndexError:
result = None
self.previous = self.index
self.index += 1
return result
def get_word(self):
pos = 0
while not self.eof:
try:
current = self.buffer[self.index + pos]
if current.isspace():
break
pos += 1
except IndexError:
break
self.previous = self.index
result = self.buffer[self.index:self.index + pos]
self.index += pos
return result
def get_quoted_word(self):
current = self.current
if current is None:
return None
close_quote = _quotes.get(current)
is_quoted = bool(close_quote)
if is_quoted:
result = []
_escaped_quotes = (current, close_quote)
else:
result = [current]
_escaped_quotes = _all_quotes
while not self.eof:
current = self.get()
if not current:
if is_quoted:
# unexpected EOF
raise ExpectedClosingQuoteError(close_quote)
return ''.join(result)
# currently we accept strings in the format of "hello world"
# to embed a quote inside the string you must escape it: "a \"world\""
if current == '\\':
next_char = self.get()
if not next_char:
# string ends with \ and no character after it
if is_quoted:
# if we're quoted then we're expecting a closing quote
raise ExpectedClosingQuoteError(close_quote)
# if we aren't then we just let it through
return ''.join(result)
if next_char in _escaped_quotes:
# escaped quote
result.append(next_char)
else:
# different escape character, ignore it
self.undo()
result.append(current)
continue
if not is_quoted and current in _all_quotes:
# we aren't quoted
raise UnexpectedQuoteError(current)
# closing quote
if is_quoted and current == close_quote:
next_char = self.get()
valid_eof = not next_char or next_char.isspace()
if not valid_eof:
raise InvalidEndOfQuotedStringError(next_char)
# we're quoted so it's okay
return ''.join(result)
if current.isspace() and not is_quoted:
# end of word found
return ''.join(result)
result.append(current)
def __repr__(self):
return '<StringView pos: {0.index} prev: {0.previous} end: {0.end} eof: {0.eof}>'.format(self)

View file

@ -0,0 +1,507 @@
# -*- coding: utf-8 -*-
"""
The MIT License (MIT)
Copyright (c) 2015-present Rapptz
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
import asyncio
import datetime
import aiohttp
import discord
import inspect
import logging
import sys
import traceback
from discord.backoff import ExponentialBackoff
log = logging.getLogger(__name__)
class Loop:
"""A background task helper that abstracts the loop and reconnection logic for you.
The main interface to create this is through :func:`loop`.
"""
def __init__(self, coro, seconds, hours, minutes, count, reconnect, loop):
self.coro = coro
self.reconnect = reconnect
self.loop = loop
self.count = count
self._current_loop = 0
self._task = None
self._injected = None
self._valid_exception = (
OSError,
discord.GatewayNotFound,
discord.ConnectionClosed,
aiohttp.ClientError,
asyncio.TimeoutError,
)
self._before_loop = None
self._after_loop = None
self._is_being_cancelled = False
self._has_failed = False
self._stop_next_iteration = False
if self.count is not None and self.count <= 0:
raise ValueError('count must be greater than 0 or None.')
self.change_interval(seconds=seconds, minutes=minutes, hours=hours)
self._last_iteration_failed = False
self._last_iteration = None
self._next_iteration = None
if not inspect.iscoroutinefunction(self.coro):
raise TypeError('Expected coroutine function, not {0.__name__!r}.'.format(type(self.coro)))
async def _call_loop_function(self, name, *args, **kwargs):
coro = getattr(self, '_' + name)
if coro is None:
return
if self._injected is not None:
await coro(self._injected, *args, **kwargs)
else:
await coro(*args, **kwargs)
async def _loop(self, *args, **kwargs):
backoff = ExponentialBackoff()
await self._call_loop_function('before_loop')
sleep_until = discord.utils.sleep_until
self._last_iteration_failed = False
self._next_iteration = datetime.datetime.now(datetime.timezone.utc)
try:
await asyncio.sleep(0) # allows canceling in before_loop
while True:
if not self._last_iteration_failed:
self._last_iteration = self._next_iteration
self._next_iteration = self._get_next_sleep_time()
try:
await self.coro(*args, **kwargs)
self._last_iteration_failed = False
now = datetime.datetime.now(datetime.timezone.utc)
if now > self._next_iteration:
self._next_iteration = now
except self._valid_exception:
self._last_iteration_failed = True
if not self.reconnect:
raise
await asyncio.sleep(backoff.delay())
else:
await sleep_until(self._next_iteration)
if self._stop_next_iteration:
return
self._current_loop += 1
if self._current_loop == self.count:
break
except asyncio.CancelledError:
self._is_being_cancelled = True
raise
except Exception as exc:
self._has_failed = True
await self._call_loop_function('error', exc)
raise exc
finally:
await self._call_loop_function('after_loop')
self._is_being_cancelled = False
self._current_loop = 0
self._stop_next_iteration = False
self._has_failed = False
def __get__(self, obj, objtype):
if obj is None:
return self
copy = Loop(self.coro, seconds=self.seconds, hours=self.hours, minutes=self.minutes,
count=self.count, reconnect=self.reconnect, loop=self.loop)
copy._injected = obj
copy._before_loop = self._before_loop
copy._after_loop = self._after_loop
copy._error = self._error
setattr(obj, self.coro.__name__, copy)
return copy
@property
def current_loop(self):
""":class:`int`: The current iteration of the loop."""
return self._current_loop
@property
def next_iteration(self):
"""Optional[:class:`datetime.datetime`]: When the next iteration of the loop will occur.
.. versionadded:: 1.3
"""
if self._task is None:
return None
elif self._task and self._task.done() or self._stop_next_iteration:
return None
return self._next_iteration
async def __call__(self, *args, **kwargs):
r"""|coro|
Calls the internal callback that the task holds.
.. versionadded:: 1.6
Parameters
------------
\*args
The arguments to use.
\*\*kwargs
The keyword arguments to use.
"""
if self._injected is not None:
args = (self._injected, *args)
return await self.coro(*args, **kwargs)
def start(self, *args, **kwargs):
r"""Starts the internal task in the event loop.
Parameters
------------
\*args
The arguments to use.
\*\*kwargs
The keyword arguments to use.
Raises
--------
RuntimeError
A task has already been launched and is running.
Returns
---------
:class:`asyncio.Task`
The task that has been created.
"""
if self._task is not None and not self._task.done():
raise RuntimeError('Task is already launched and is not completed.')
if self._injected is not None:
args = (self._injected, *args)
if self.loop is None:
self.loop = asyncio.get_event_loop()
self._task = self.loop.create_task(self._loop(*args, **kwargs))
return self._task
def stop(self):
r"""Gracefully stops the task from running.
Unlike :meth:`cancel`\, this allows the task to finish its
current iteration before gracefully exiting.
.. note::
If the internal function raises an error that can be
handled before finishing then it will retry until
it succeeds.
If this is undesirable, either remove the error handling
before stopping via :meth:`clear_exception_types` or
use :meth:`cancel` instead.
.. versionadded:: 1.2
"""
if self._task and not self._task.done():
self._stop_next_iteration = True
def _can_be_cancelled(self):
return not self._is_being_cancelled and self._task and not self._task.done()
def cancel(self):
"""Cancels the internal task, if it is running."""
if self._can_be_cancelled():
self._task.cancel()
def restart(self, *args, **kwargs):
r"""A convenience method to restart the internal task.
.. note::
Due to the way this function works, the task is not
returned like :meth:`start`.
Parameters
------------
\*args
The arguments to to use.
\*\*kwargs
The keyword arguments to use.
"""
def restart_when_over(fut, *, args=args, kwargs=kwargs):
self._task.remove_done_callback(restart_when_over)
self.start(*args, **kwargs)
if self._can_be_cancelled():
self._task.add_done_callback(restart_when_over)
self._task.cancel()
def add_exception_type(self, *exceptions):
r"""Adds exception types to be handled during the reconnect logic.
By default the exception types handled are those handled by
:meth:`discord.Client.connect`\, which includes a lot of internet disconnection
errors.
This function is useful if you're interacting with a 3rd party library that
raises its own set of exceptions.
Parameters
------------
\*exceptions: Type[:class:`BaseException`]
An argument list of exception classes to handle.
Raises
--------
TypeError
An exception passed is either not a class or not inherited from :class:`BaseException`.
"""
for exc in exceptions:
if not inspect.isclass(exc):
raise TypeError('{0!r} must be a class.'.format(exc))
if not issubclass(exc, BaseException):
raise TypeError('{0!r} must inherit from BaseException.'.format(exc))
self._valid_exception = (*self._valid_exception, *exceptions)
def clear_exception_types(self):
"""Removes all exception types that are handled.
.. note::
This operation obviously cannot be undone!
"""
self._valid_exception = tuple()
def remove_exception_type(self, *exceptions):
r"""Removes exception types from being handled during the reconnect logic.
Parameters
------------
\*exceptions: Type[:class:`BaseException`]
An argument list of exception classes to handle.
Returns
---------
:class:`bool`
Whether all exceptions were successfully removed.
"""
old_length = len(self._valid_exception)
self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions)
return len(self._valid_exception) == old_length - len(exceptions)
def get_task(self):
"""Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running."""
return self._task
def is_being_cancelled(self):
"""Whether the task is being cancelled."""
return self._is_being_cancelled
def failed(self):
""":class:`bool`: Whether the internal task has failed.
.. versionadded:: 1.2
"""
return self._has_failed
def is_running(self):
""":class:`bool`: Check if the task is currently running.
.. versionadded:: 1.4
"""
return not bool(self._task.done()) if self._task else False
async def _error(self, *args):
exception = args[-1]
print('Unhandled exception in internal background task {0.__name__!r}.'.format(self.coro), file=sys.stderr)
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
def before_loop(self, coro):
"""A decorator that registers a coroutine to be called before the loop starts running.
This is useful if you want to wait for some bot state before the loop starts,
such as :meth:`discord.Client.wait_until_ready`.
The coroutine must take no arguments (except ``self`` in a class context).
Parameters
------------
coro: :ref:`coroutine <coroutine>`
The coroutine to register before the loop runs.
Raises
-------
TypeError
The function was not a coroutine.
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError('Expected coroutine function, received {0.__name__!r}.'.format(type(coro)))
self._before_loop = coro
return coro
def after_loop(self, coro):
"""A decorator that register a coroutine to be called after the loop finished running.
The coroutine must take no arguments (except ``self`` in a class context).
.. note::
This coroutine is called even during cancellation. If it is desirable
to tell apart whether something was cancelled or not, check to see
whether :meth:`is_being_cancelled` is ``True`` or not.
Parameters
------------
coro: :ref:`coroutine <coroutine>`
The coroutine to register after the loop finishes.
Raises
-------
TypeError
The function was not a coroutine.
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError('Expected coroutine function, received {0.__name__!r}.'.format(type(coro)))
self._after_loop = coro
return coro
def error(self, coro):
"""A decorator that registers a coroutine to be called if the task encounters an unhandled exception.
The coroutine must take only one argument the exception raised (except ``self`` in a class context).
By default this prints to :data:`sys.stderr` however it could be
overridden to have a different implementation.
.. versionadded:: 1.4
Parameters
------------
coro: :ref:`coroutine <coroutine>`
The coroutine to register in the event of an unhandled exception.
Raises
-------
TypeError
The function was not a coroutine.
"""
if not inspect.iscoroutinefunction(coro):
raise TypeError('Expected coroutine function, received {0.__name__!r}.'.format(type(coro)))
self._error = coro
return coro
def _get_next_sleep_time(self):
return self._last_iteration + datetime.timedelta(seconds=self._sleep)
def change_interval(self, *, seconds=0, minutes=0, hours=0):
"""Changes the interval for the sleep time.
.. note::
This only applies on the next loop iteration. If it is desirable for the change of interval
to be applied right away, cancel the task with :meth:`cancel`.
.. versionadded:: 1.2
Parameters
------------
seconds: :class:`float`
The number of seconds between every iteration.
minutes: :class:`float`
The number of minutes between every iteration.
hours: :class:`float`
The number of hours between every iteration.
Raises
-------
ValueError
An invalid value was given.
"""
sleep = seconds + (minutes * 60.0) + (hours * 3600.0)
if sleep < 0:
raise ValueError('Total number of seconds cannot be less than zero.')
self._sleep = sleep
self.seconds = seconds
self.hours = hours
self.minutes = minutes
def loop(*, seconds=0, minutes=0, hours=0, count=None, reconnect=True, loop=None):
"""A decorator that schedules a task in the background for you with
optional reconnect logic. The decorator returns a :class:`Loop`.
Parameters
------------
seconds: :class:`float`
The number of seconds between every iteration.
minutes: :class:`float`
The number of minutes between every iteration.
hours: :class:`float`
The number of hours between every iteration.
count: Optional[:class:`int`]
The number of loops to do, ``None`` if it should be an
infinite loop.
reconnect: :class:`bool`
Whether to handle errors and restart the task
using an exponential back-off algorithm similar to the
one used in :meth:`discord.Client.connect`.
loop: :class:`asyncio.AbstractEventLoop`
The loop to use to register the task, if not given
defaults to :func:`asyncio.get_event_loop`.
Raises
--------
ValueError
An invalid value was given.
TypeError
The function was not a coroutine.
"""
def decorator(func):
kwargs = {
'seconds': seconds,
'minutes': minutes,
'hours': hours,
'count': count,
'reconnect': reconnect,
'loop': loop
}
return Loop(func, **kwargs)
return decorator