198 lines
5.5 KiB
Python
198 lines
5.5 KiB
Python
# type: ignore
|
|
import asyncio
|
|
import sys
|
|
import types
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import aioredis
|
|
from django.conf import settings
|
|
|
|
from . import logging
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
connection_pool_limit = getattr(settings, "CONNECTION_POOL_LIMIT", 10)
|
|
logger.info(f"CONNECTION_POOL_LIMIT={connection_pool_limit}")
|
|
|
|
|
|
# Copied from https://github.com/django/channels_redis/blob/master/channels_redis/core.py
|
|
# and renamed..
|
|
|
|
AIOREDIS_VERSION = tuple(map(int, aioredis.__version__.split(".")))
|
|
|
|
|
|
def _wrap_close(loop, pool):
|
|
"""
|
|
Decorate an event loop's close method with our own.
|
|
"""
|
|
original_impl = loop.close
|
|
|
|
def _wrapper(self, *args, **kwargs):
|
|
# If the event loop was closed, there's nothing we can do anymore.
|
|
if not self.is_closed():
|
|
self.run_until_complete(pool.close_loop(self))
|
|
# Restore the original close() implementation after we're done.
|
|
self.close = original_impl
|
|
return self.close(*args, **kwargs)
|
|
|
|
loop.close = types.MethodType(_wrapper, loop)
|
|
|
|
|
|
class ChannelRedisConnectionPool:
|
|
"""
|
|
Connection pool manager for the channel layer.
|
|
It manages a set of connections for the given host specification and
|
|
taking into account asyncio event loops.
|
|
"""
|
|
|
|
def __init__(self, host):
|
|
self.host = host
|
|
self.conn_map = {}
|
|
self.in_use = {}
|
|
|
|
def _ensure_loop(self, loop):
|
|
"""
|
|
Get connection list for the specified loop.
|
|
"""
|
|
if loop is None:
|
|
loop = asyncio.get_event_loop()
|
|
|
|
if loop not in self.conn_map:
|
|
# Swap the loop's close method with our own so we get
|
|
# a chance to do some cleanup.
|
|
_wrap_close(loop, self)
|
|
self.conn_map[loop] = []
|
|
|
|
return self.conn_map[loop], loop
|
|
|
|
async def pop(self, loop=None):
|
|
"""
|
|
Get a connection for the given identifier and loop.
|
|
"""
|
|
conns, loop = self._ensure_loop(loop)
|
|
if not conns:
|
|
if sys.version_info >= (3, 8, 0) and AIOREDIS_VERSION >= (1, 3, 1):
|
|
conn = await aioredis.create_redis(**self.host)
|
|
else:
|
|
conn = await aioredis.create_redis(**self.host, loop=loop)
|
|
conns.append(conn)
|
|
conn = conns.pop()
|
|
if conn.closed:
|
|
conn = await self.pop(loop=loop)
|
|
return conn
|
|
self.in_use[conn] = loop
|
|
return conn
|
|
|
|
def push(self, conn):
|
|
"""
|
|
Return a connection to the pool.
|
|
"""
|
|
loop = self.in_use[conn]
|
|
del self.in_use[conn]
|
|
if loop is not None:
|
|
conns, _ = self._ensure_loop(loop)
|
|
conns.append(conn)
|
|
|
|
def conn_error(self, conn):
|
|
"""
|
|
Handle a connection that produced an error.
|
|
"""
|
|
conn.close()
|
|
del self.in_use[conn]
|
|
|
|
def reset(self):
|
|
"""
|
|
Clear all connections from the pool.
|
|
"""
|
|
self.conn_map = {}
|
|
self.in_use = {}
|
|
|
|
async def close_loop(self, loop):
|
|
"""
|
|
Close all connections owned by the pool on the given loop.
|
|
"""
|
|
if loop in self.conn_map:
|
|
for conn in self.conn_map[loop]:
|
|
conn.close()
|
|
await conn.wait_closed()
|
|
del self.conn_map[loop]
|
|
|
|
for k, v in self.in_use.items():
|
|
if v is loop:
|
|
self.in_use[k] = None
|
|
|
|
async def close(self):
|
|
"""
|
|
Close all connections owned by the pool.
|
|
"""
|
|
conn_map = self.conn_map
|
|
in_use = self.in_use
|
|
self.reset()
|
|
for conns in conn_map.values():
|
|
for conn in conns:
|
|
conn.close()
|
|
await conn.wait_closed()
|
|
for conn in in_use:
|
|
conn.close()
|
|
await conn.wait_closed()
|
|
|
|
|
|
class InvalidConnection(Exception):
|
|
pass
|
|
|
|
|
|
class ConnectionPool(ChannelRedisConnectionPool):
|
|
"""Adds a trivial, soft limit for the pool"""
|
|
|
|
def __init__(self, host: Any) -> None:
|
|
self.counter = 0
|
|
super().__init__(host)
|
|
|
|
async def pop(
|
|
self, *args: List[Any], **kwargs: Dict[str, Any]
|
|
) -> aioredis.commands.Redis:
|
|
while self.counter > connection_pool_limit:
|
|
await asyncio.sleep(0.1)
|
|
self.counter += 1
|
|
|
|
return await self.pop_ensured_connection(*args, **kwargs)
|
|
|
|
async def pop_ensured_connection(
|
|
self, *args: List[Any], **kwargs: Dict[str, Any]
|
|
) -> aioredis.commands.Redis:
|
|
redis: Optional[aioredis.commands.Redis] = None
|
|
|
|
while redis is None:
|
|
redis = await super().pop(*args, **kwargs)
|
|
|
|
try:
|
|
await self.try_ping(redis)
|
|
except InvalidConnection:
|
|
if redis is not None:
|
|
super().conn_error(redis)
|
|
redis = None
|
|
|
|
return redis
|
|
|
|
async def try_ping(self, redis: aioredis.commands.Redis) -> None:
|
|
try:
|
|
pong = await redis.ping()
|
|
if pong != b"PONG":
|
|
logger.info("Redis connection invalid, did not recieve PONG")
|
|
raise InvalidConnection()
|
|
except (ConnectionRefusedError, ConnectionResetError):
|
|
logger.info("Redis connection invalid, connection is bad")
|
|
raise InvalidConnection()
|
|
|
|
def push(self, conn: aioredis.commands.Redis) -> None:
|
|
super().push(conn)
|
|
self.counter -= 1
|
|
|
|
def conn_error(self, conn: aioredis.commands.Redis) -> None:
|
|
super().conn_error(conn)
|
|
self.counter -= 1
|
|
|
|
def reset(self) -> None:
|
|
super().reset()
|
|
self.counter = 0
|