From 470e66ef4638016f1f752d7c789a4fd09fc77e0b Mon Sep 17 00:00:00 2001 From: TheEdgeOfRage Date: Mon, 18 Dec 2023 15:21:55 +0100 Subject: [PATCH] Handle rate limit responses and automatically retry the requests --- dune_client/api/base.py | 81 +++++++++++++++++++++++++------------ dune_client/client_async.py | 64 ++++++++++++++++++++--------- 2 files changed, 100 insertions(+), 45 deletions(-) diff --git a/dune_client/api/base.py b/dune_client/api/base.py index 6605252..9bf0fd8 100644 --- a/dune_client/api/base.py +++ b/dune_client/api/base.py @@ -8,7 +8,8 @@ import logging.config import os from json import JSONDecodeError -from typing import Dict, Optional, Any +from time import time +from typing import Callable, Dict, Optional, Any import requests from requests import Response @@ -16,6 +17,10 @@ from dune_client.util import get_package_version +class RateLimitedError(Exception): + """Special Error for Rate Limited Requests""" + + # pylint: disable=too-few-public-methods class BaseDuneClient: """ @@ -71,6 +76,8 @@ class BaseRouter(BaseDuneClient): def _handle_response(self, response: Response) -> Any: """Generic response handler utilized by all Dune API routes""" + if response.status_code == 429: + raise RateLimitedError try: # Some responses can be decoded and converted to DuneErrors response_json = response.json() @@ -84,6 +91,18 @@ def _handle_response(self, response: Response) -> Any: def _route_url(self, route: str) -> str: return f"{self.base_url}{self.api_version}{route}" + def _handle_ratelimit(self, call: Callable[..., Any]) -> Any: + """Generic wrapper around request callables. If the request fails due to rate limiting, + it will retry it up to five times, sleeping i * 5s in between""" + for i in range(5): + try: + return call() + except RateLimitedError: + self.logger.warning(f"Rate limited. Retrying in {i * 5} seconds.") + time.sleep(i * 5) + + raise RateLimitedError + def _get( self, route: str, @@ -93,37 +112,49 @@ def _get( """Generic interface for the GET method of a Dune API request""" url = self._route_url(route) self.logger.debug(f"GET received input url={url}") - response = requests.get( - url=url, - headers=self.default_headers(), - timeout=self.request_timeout, - params=params, - ) - if raw: - return response - return self._handle_response(response) + + def _get() -> Any: + response = requests.get( + url=url, + headers=self.default_headers(), + timeout=self.request_timeout, + params=params, + ) + if raw: + return response + return self._handle_response(response) + + return self._handle_ratelimit(_get) def _post(self, route: str, params: Optional[Any] = None) -> Any: """Generic interface for the POST method of a Dune API request""" url = self._route_url(route) self.logger.debug(f"POST received input url={url}, params={params}") - response = requests.post( - url=url, - json=params, - headers=self.default_headers(), - timeout=self.request_timeout, - ) - return self._handle_response(response) + + def _post() -> Any: + response = requests.post( + url=url, + json=params, + headers=self.default_headers(), + timeout=self.request_timeout, + ) + return self._handle_response(response) + + return self._handle_ratelimit(_post) def _patch(self, route: str, params: Any) -> Any: """Generic interface for the PATCH method of a Dune API request""" url = self._route_url(route) self.logger.debug(f"PATCH received input url={url}, params={params}") - response = requests.request( - method="PATCH", - url=url, - json=params, - headers=self.default_headers(), - timeout=self.request_timeout, - ) - return self._handle_response(response) + + def _patch() -> Any: + response = requests.request( + method="PATCH", + url=url, + json=params, + headers=self.default_headers(), + timeout=self.request_timeout, + ) + return self._handle_response(response) + + return self._handle_ratelimit(_patch) diff --git a/dune_client/client_async.py b/dune_client/client_async.py index 3e264e1..547c9c6 100644 --- a/dune_client/client_async.py +++ b/dune_client/client_async.py @@ -7,7 +7,8 @@ import asyncio from io import BytesIO -from typing import Any, Optional, Union +from time import time +from typing import Any, Callable, Optional, Union from aiohttp import ( ClientSession, @@ -17,7 +18,7 @@ ClientTimeout, ) -from dune_client.api.base import BaseDuneClient +from dune_client.api.base import BaseDuneClient, RateLimitedError from dune_client.models import ( ExecutionResponse, ExecutionResultCSV, @@ -77,6 +78,9 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self.disconnect() async def _handle_response(self, response: ClientResponse) -> Any: + if response.status == 429: + raise RateLimitedError + try: # Some responses can be decoded and converted to DuneErrors response_json = await response.json() @@ -90,6 +94,18 @@ async def _handle_response(self, response: ClientResponse) -> Any: def _route_url(self, route: str) -> str: return f"{self.api_version}{route}" + async def _handle_ratelimit(self, call: Callable[..., Any]) -> Any: + """Generic wrapper around request callables. If the request fails due to rate limiting, + it will retry it up to five times, sleeping i * 5s in between""" + for i in range(5): + try: + return await call() + except RateLimitedError: + self.logger.warning(f"Rate limited. Retrying in {i * 5} seconds.") + time.sleep(i * 5) + + raise RateLimitedError + async def _get( self, route: str, @@ -97,29 +113,37 @@ async def _get( raw: bool = False, ) -> Any: url = self._route_url(route) - if self._session is None: - raise ValueError("Client is not connected; call `await cl.connect()`") self.logger.debug(f"GET received input url={url}") - response = await self._session.get( - url=url, - headers=self.default_headers(), - params=params, - ) - if raw: - return response - return await self._handle_response(response) + + async def _get() -> Any: + if self._session is None: + raise ValueError("Client is not connected; call `await cl.connect()`") + response = await self._session.get( + url=url, + headers=self.default_headers(), + params=params, + ) + if raw: + return response + return await self._handle_response(response) + + return await self._handle_ratelimit(_get) async def _post(self, route: str, params: Any) -> Any: url = self._route_url(route) - if self._session is None: - raise ValueError("Client is not connected; call `await cl.connect()`") self.logger.debug(f"POST received input url={url}, params={params}") - response = await self._session.post( - url=url, - json=params, - headers=self.default_headers(), - ) - return await self._handle_response(response) + + async def _post() -> Any: + if self._session is None: + raise ValueError("Client is not connected; call `await cl.connect()`") + response = await self._session.post( + url=url, + json=params, + headers=self.default_headers(), + ) + return await self._handle_response(response) + + return await self._handle_ratelimit(_post) async def execute( self, query: QueryBase, performance: Optional[str] = None