from __future__ import annotations
from asyncio import iscoroutinefunction, sleep
from collections.abc import Coroutine
from functools import wraps
from time import sleep as block
from typing import TYPE_CHECKING, Any, Callable, Protocol, TypeVar, Union, cast
from ._sync import SyncRateLimit, SyncStore
from ._utils import maybe_awaitable
from .errors import RateLimitError
from .ratelimit import RateLimit
from .store import BaseStore
if TYPE_CHECKING:
from .rate import Rate, RateGroup
__all__ = (
"on_retry_sleep",
"on_retry_block",
"ratelimit"
)
Key = TypeVar("Key")
"""An unbound and unconstrained TypeVar"""
R = TypeVar("R", covariant=True)
AnyRateLimit = Union[SyncRateLimit, RateLimit]
class LimitedCallable(Protocol[R]):
limit: AnyRateLimit
def __call__(self, *args, **kwds) -> R:
...
def _apply_attrs(func: Callable[..., R], **attrs) -> LimitedCallable[R]:
func = cast(LimitedCallable[R], func)
for k, v in attrs.items():
setattr(func, k, v)
return func
[docs]async def on_retry_sleep(error: RateLimitError) -> None:
"""Make the current task yield to the event_loop
till a usage token is available for the rate limit
which raised :exc:`uprate.errors.RateLimitError`
Parameters
----------
error : :exc:`.RateLimitError`
The rate limit error to sleep for.
"""
await sleep(error.retry_after)
[docs]def on_retry_block(error: RateLimitError) -> None:
"""Block the current thread till a usage token is
available for the rate limit which raised
:exc:`uprate.errors.RateLimitError`
Parameters
----------
error : :exc:`.RateLimitError`
The rate limit error to block for.
"""
block(error.retry_after)
# TODO: Complete Docs
[docs]def ratelimit(
rate: Rate | RateGroup, *,
key: Callable[..., Key] | Callable[..., Coroutine[Any, Any, Key]] | None = None,
on_retry: Callable[[RateLimitError], Any] | Callable[[RateLimitError], Coroutine] | None = None,
store: BaseStore | SyncStore | None = None
):
"""Limit a coroutine function or a callable to be called within
provided rate. :ref:`Example here <index-example>`
Parameters
----------
rate: :class:`~uprate.rate.Rate`, :class:`~uprate.rate.RateGroup`, (``Rate | RateGroup``)
The rate that the decorated function must follow.
key : Callable[..., :data:`.Key`] | Callable[..., Coroutine[Any, Any, :data:`.Key`]] | :data:`None`
The callback for generating a bucket for ratelimit from the arguments provided
to the decorated function, this can be a coroutine function only when the decorated
is a coroutine function as well. If :data:`None`, a default callback returning a string based on the
decorated function's name is used, by default :data:`None`.
on_retry : Callable[[:exc:`.RateLimitError`], Any] | Callable[[:exc:`.RateLimitError`], Coroutine] | :data:`None`
If provided then this function will be called when the function gets ratelimited
and then the decorated function will be called again, if :data:`None` then function call isn't retried
and :exc:`.RateLimitError` is raised, by default :data:`None`.
store : :class:`.BaseStore` | :class:`.SyncStore` | :data:`None`
The store to use for the rate limit, must be of type :class:`.BaseStore` if
decorated function is a coroutine function else :class:`.SyncStore`.
If :data:`None` then a suitable derived memory store is used, by default :data:`None`.
Raises
------
:exc:`.RateLimitError`
``on_retry`` parameter is not provided and the decorated function got ratelimited.
"""
def decorator(func: Callable[..., R]) -> LimitedCallable[R]:
nonlocal on_retry, key
key = key or cast(Callable[..., Key], lambda *a, **k: "DEFAULT_BUCKET_" + func.__name__)
if iscoroutinefunction(func):
if isinstance(store, BaseStore) or store is None:
limit: AnyRateLimit = RateLimit(rate, store)
else:
raise TypeError("Cannot use a uprate._sync.SyncStore instance with a coroutine function.")
@wraps(func)
async def rated(*args, **kwargs):
while True:
try:
bucket = await maybe_awaitable(key(*args, **kwargs))
await limit.acquire(bucket)
except RateLimitError as err:
if on_retry is None:
raise err from None
else:
await maybe_awaitable(on_retry(err))
else:
return await func(*args, **kwargs)
else:
if isinstance(store, SyncStore) or store is None:
limit = SyncRateLimit(rate, store)
else:
raise TypeError("Cannot use a uprate.BaseStore instance with a subroutine.")
@wraps(func)
def rated(*args, **kwargs):
while True:
try:
bucket = key(*args, **kwargs)
limit.acquire(bucket)
except RateLimitError as err:
if on_retry is None:
raise err from None
else:
on_retry(err)
else:
return func()
return _apply_attrs(rated, limit=limit)
return decorator