Skip to content

Commit

Permalink
Handle rate limit responses and automatically retry the requests
Browse files Browse the repository at this point in the history
  • Loading branch information
TheEdgeOfRage committed Dec 18, 2023
1 parent d2195b2 commit 470e66e
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 45 deletions.
81 changes: 56 additions & 25 deletions dune_client/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@
import logging.config
import os
from json import JSONDecodeError
from typing import Dict, Optional, Any
from time import time
from typing import Callable, Dict, Optional, Any

import requests
from requests import Response

from dune_client.util import get_package_version


class RateLimitedError(Exception):
"""Special Error for Rate Limited Requests"""


# pylint: disable=too-few-public-methods
class BaseDuneClient:
"""
Expand Down Expand Up @@ -71,6 +76,8 @@ class BaseRouter(BaseDuneClient):

def _handle_response(self, response: Response) -> Any:
"""Generic response handler utilized by all Dune API routes"""
if response.status_code == 429:
raise RateLimitedError
try:
# Some responses can be decoded and converted to DuneErrors
response_json = response.json()
Expand All @@ -84,6 +91,18 @@ def _handle_response(self, response: Response) -> Any:
def _route_url(self, route: str) -> str:
return f"{self.base_url}{self.api_version}{route}"

def _handle_ratelimit(self, call: Callable[..., Any]) -> Any:
"""Generic wrapper around request callables. If the request fails due to rate limiting,
it will retry it up to five times, sleeping i * 5s in between"""
for i in range(5):
try:
return call()
except RateLimitedError:
self.logger.warning(f"Rate limited. Retrying in {i * 5} seconds.")
time.sleep(i * 5)

raise RateLimitedError

def _get(
self,
route: str,
Expand All @@ -93,37 +112,49 @@ def _get(
"""Generic interface for the GET method of a Dune API request"""
url = self._route_url(route)
self.logger.debug(f"GET received input url={url}")
response = requests.get(
url=url,
headers=self.default_headers(),
timeout=self.request_timeout,
params=params,
)
if raw:
return response
return self._handle_response(response)

def _get() -> Any:
response = requests.get(
url=url,
headers=self.default_headers(),
timeout=self.request_timeout,
params=params,
)
if raw:
return response
return self._handle_response(response)

return self._handle_ratelimit(_get)

def _post(self, route: str, params: Optional[Any] = None) -> Any:
"""Generic interface for the POST method of a Dune API request"""
url = self._route_url(route)
self.logger.debug(f"POST received input url={url}, params={params}")
response = requests.post(
url=url,
json=params,
headers=self.default_headers(),
timeout=self.request_timeout,
)
return self._handle_response(response)

def _post() -> Any:
response = requests.post(
url=url,
json=params,
headers=self.default_headers(),
timeout=self.request_timeout,
)
return self._handle_response(response)

return self._handle_ratelimit(_post)

def _patch(self, route: str, params: Any) -> Any:
"""Generic interface for the PATCH method of a Dune API request"""
url = self._route_url(route)
self.logger.debug(f"PATCH received input url={url}, params={params}")
response = requests.request(
method="PATCH",
url=url,
json=params,
headers=self.default_headers(),
timeout=self.request_timeout,
)
return self._handle_response(response)

def _patch() -> Any:
response = requests.request(
method="PATCH",
url=url,
json=params,
headers=self.default_headers(),
timeout=self.request_timeout,
)
return self._handle_response(response)

return self._handle_ratelimit(_patch)
64 changes: 44 additions & 20 deletions dune_client/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

import asyncio
from io import BytesIO
from typing import Any, Optional, Union
from time import time
from typing import Any, Callable, Optional, Union

from aiohttp import (
ClientSession,
Expand All @@ -17,7 +18,7 @@
ClientTimeout,
)

from dune_client.api.base import BaseDuneClient
from dune_client.api.base import BaseDuneClient, RateLimitedError
from dune_client.models import (
ExecutionResponse,
ExecutionResultCSV,
Expand Down Expand Up @@ -77,6 +78,9 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.disconnect()

async def _handle_response(self, response: ClientResponse) -> Any:
if response.status == 429:
raise RateLimitedError

try:
# Some responses can be decoded and converted to DuneErrors
response_json = await response.json()
Expand All @@ -90,36 +94,56 @@ async def _handle_response(self, response: ClientResponse) -> Any:
def _route_url(self, route: str) -> str:
return f"{self.api_version}{route}"

async def _handle_ratelimit(self, call: Callable[..., Any]) -> Any:
"""Generic wrapper around request callables. If the request fails due to rate limiting,
it will retry it up to five times, sleeping i * 5s in between"""
for i in range(5):
try:
return await call()
except RateLimitedError:
self.logger.warning(f"Rate limited. Retrying in {i * 5} seconds.")
time.sleep(i * 5)

raise RateLimitedError

async def _get(
self,
route: str,
params: Optional[Any] = None,
raw: bool = False,
) -> Any:
url = self._route_url(route)
if self._session is None:
raise ValueError("Client is not connected; call `await cl.connect()`")
self.logger.debug(f"GET received input url={url}")
response = await self._session.get(
url=url,
headers=self.default_headers(),
params=params,
)
if raw:
return response
return await self._handle_response(response)

async def _get() -> Any:
if self._session is None:
raise ValueError("Client is not connected; call `await cl.connect()`")
response = await self._session.get(
url=url,
headers=self.default_headers(),
params=params,
)
if raw:
return response
return await self._handle_response(response)

return await self._handle_ratelimit(_get)

async def _post(self, route: str, params: Any) -> Any:
url = self._route_url(route)
if self._session is None:
raise ValueError("Client is not connected; call `await cl.connect()`")
self.logger.debug(f"POST received input url={url}, params={params}")
response = await self._session.post(
url=url,
json=params,
headers=self.default_headers(),
)
return await self._handle_response(response)

async def _post() -> Any:
if self._session is None:
raise ValueError("Client is not connected; call `await cl.connect()`")
response = await self._session.post(
url=url,
json=params,
headers=self.default_headers(),
)
return await self._handle_response(response)

return await self._handle_ratelimit(_post)

async def execute(
self, query: QueryBase, performance: Optional[str] = None
Expand Down

0 comments on commit 470e66e

Please sign in to comment.