124 lines
4.1 KiB
Python
124 lines
4.1 KiB
Python
"""
|
|
Rate limiting middleware for Flask API
|
|
Implements token bucket algorithm for request throttling
|
|
"""
|
|
|
|
from functools import wraps
|
|
from flask import request, jsonify
|
|
import time
|
|
from collections import defaultdict
|
|
import threading
|
|
|
|
class RateLimiter:
|
|
"""
|
|
Thread-safe rate limiter using token bucket algorithm
|
|
"""
|
|
def __init__(self):
|
|
self.clients = defaultdict(lambda: {'tokens': 0, 'last_update': time.time(), 'initialized': False})
|
|
self.lock = threading.Lock()
|
|
|
|
def is_allowed(self, client_id, max_tokens=60, refill_rate=1.0):
|
|
"""
|
|
Check if request is allowed for client
|
|
|
|
Args:
|
|
client_id: Unique identifier for client (IP address)
|
|
max_tokens: Maximum tokens in bucket (requests per period)
|
|
refill_rate: Tokens added per second
|
|
|
|
Returns:
|
|
tuple: (is_allowed: bool, retry_after: int)
|
|
"""
|
|
with self.lock:
|
|
now = time.time()
|
|
client = self.clients[client_id]
|
|
|
|
# Initialize new clients with full bucket
|
|
if not client.get('initialized', False):
|
|
client['tokens'] = max_tokens
|
|
client['initialized'] = True
|
|
|
|
# Calculate tokens to add based on time elapsed
|
|
time_passed = now - client['last_update']
|
|
client['tokens'] = min(
|
|
max_tokens,
|
|
client['tokens'] + time_passed * refill_rate
|
|
)
|
|
client['last_update'] = now
|
|
|
|
# Check if request is allowed
|
|
if client['tokens'] >= 1:
|
|
client['tokens'] -= 1
|
|
return True, 0
|
|
else:
|
|
# Calculate retry-after time
|
|
retry_after = int((1 - client['tokens']) / refill_rate) + 1
|
|
return False, retry_after
|
|
|
|
def clear_client(self, client_id):
|
|
"""Remove client from rate limiter (for testing/reset)"""
|
|
with self.lock:
|
|
if client_id in self.clients:
|
|
del self.clients[client_id]
|
|
|
|
# Global rate limiter instance
|
|
rate_limiter = RateLimiter()
|
|
|
|
def rate_limit(max_per_minute=60):
|
|
"""
|
|
Decorator to apply rate limiting to Flask routes
|
|
|
|
Usage:
|
|
@app.route('/api/endpoint')
|
|
@rate_limit(max_per_minute=30)
|
|
def my_endpoint():
|
|
...
|
|
"""
|
|
def decorator(f):
|
|
@wraps(f)
|
|
def wrapped(*args, **kwargs):
|
|
# Get client identifier (IP address)
|
|
client_id = request.remote_addr or 'unknown'
|
|
|
|
# Convert per-minute limit to per-second refill rate
|
|
refill_rate = max_per_minute / 60.0
|
|
|
|
# Check if request is allowed
|
|
is_allowed, retry_after = rate_limiter.is_allowed(
|
|
client_id,
|
|
max_tokens=max_per_minute,
|
|
refill_rate=refill_rate
|
|
)
|
|
|
|
if not is_allowed:
|
|
response = jsonify({
|
|
'error': 'rate_limit_exceeded',
|
|
'message': f'Too many requests. Please try again in {retry_after} seconds.',
|
|
'retry_after': retry_after
|
|
})
|
|
response.status_code = 429
|
|
response.headers['Retry-After'] = str(retry_after)
|
|
response.headers['X-RateLimit-Limit'] = str(max_per_minute)
|
|
response.headers['X-RateLimit-Remaining'] = '0'
|
|
return response
|
|
|
|
return f(*args, **kwargs)
|
|
return wrapped
|
|
return decorator
|
|
|
|
def get_rate_limit_headers(client_id, max_per_minute=60):
|
|
"""
|
|
Get rate limit headers for response
|
|
|
|
Returns dict of headers to add to response
|
|
"""
|
|
with rate_limiter.lock:
|
|
client = rate_limiter.clients.get(client_id, {'tokens': max_per_minute})
|
|
remaining = int(client.get('tokens', max_per_minute))
|
|
|
|
return {
|
|
'X-RateLimit-Limit': str(max_per_minute),
|
|
'X-RateLimit-Remaining': str(max(0, remaining)),
|
|
'X-RateLimit-Reset': str(int(time.time() + 60))
|
|
}
|