OpenSlides/server/openslides/utils/redis_connection_pool.py
Finn Stutzenstein f9d19db9e2
Update black
2021-04-26 09:00:04 +02:00

198 lines
5.5 KiB
Python

# type: ignore
import asyncio
import sys
import types
from typing import Any, Dict, List, Optional
import aioredis
from django.conf import settings
from . import logging
logger = logging.getLogger(__name__)
connection_pool_limit = getattr(settings, "CONNECTION_POOL_LIMIT", 10)
logger.info(f"CONNECTION_POOL_LIMIT={connection_pool_limit}")
# Copied from https://github.com/django/channels_redis/blob/master/channels_redis/core.py
# and renamed..
AIOREDIS_VERSION = tuple(map(int, aioredis.__version__.split(".")))
def _wrap_close(loop, pool):
"""
Decorate an event loop's close method with our own.
"""
original_impl = loop.close
def _wrapper(self, *args, **kwargs):
# If the event loop was closed, there's nothing we can do anymore.
if not self.is_closed():
self.run_until_complete(pool.close_loop(self))
# Restore the original close() implementation after we're done.
self.close = original_impl
return self.close(*args, **kwargs)
loop.close = types.MethodType(_wrapper, loop)
class ChannelRedisConnectionPool:
"""
Connection pool manager for the channel layer.
It manages a set of connections for the given host specification and
taking into account asyncio event loops.
"""
def __init__(self, host):
self.host = host
self.conn_map = {}
self.in_use = {}
def _ensure_loop(self, loop):
"""
Get connection list for the specified loop.
"""
if loop is None:
loop = asyncio.get_event_loop()
if loop not in self.conn_map:
# Swap the loop's close method with our own so we get
# a chance to do some cleanup.
_wrap_close(loop, self)
self.conn_map[loop] = []
return self.conn_map[loop], loop
async def pop(self, loop=None):
"""
Get a connection for the given identifier and loop.
"""
conns, loop = self._ensure_loop(loop)
if not conns:
if sys.version_info >= (3, 8, 0) and AIOREDIS_VERSION >= (1, 3, 1):
conn = await aioredis.create_redis(**self.host)
else:
conn = await aioredis.create_redis(**self.host, loop=loop)
conns.append(conn)
conn = conns.pop()
if conn.closed:
conn = await self.pop(loop=loop)
return conn
self.in_use[conn] = loop
return conn
def push(self, conn):
"""
Return a connection to the pool.
"""
loop = self.in_use[conn]
del self.in_use[conn]
if loop is not None:
conns, _ = self._ensure_loop(loop)
conns.append(conn)
def conn_error(self, conn):
"""
Handle a connection that produced an error.
"""
conn.close()
del self.in_use[conn]
def reset(self):
"""
Clear all connections from the pool.
"""
self.conn_map = {}
self.in_use = {}
async def close_loop(self, loop):
"""
Close all connections owned by the pool on the given loop.
"""
if loop in self.conn_map:
for conn in self.conn_map[loop]:
conn.close()
await conn.wait_closed()
del self.conn_map[loop]
for k, v in self.in_use.items():
if v is loop:
self.in_use[k] = None
async def close(self):
"""
Close all connections owned by the pool.
"""
conn_map = self.conn_map
in_use = self.in_use
self.reset()
for conns in conn_map.values():
for conn in conns:
conn.close()
await conn.wait_closed()
for conn in in_use:
conn.close()
await conn.wait_closed()
class InvalidConnection(Exception):
pass
class ConnectionPool(ChannelRedisConnectionPool):
"""Adds a trivial, soft limit for the pool"""
def __init__(self, host: Any) -> None:
self.counter = 0
super().__init__(host)
async def pop(
self, *args: List[Any], **kwargs: Dict[str, Any]
) -> aioredis.commands.Redis:
while self.counter > connection_pool_limit:
await asyncio.sleep(0.1)
self.counter += 1
return await self.pop_ensured_connection(*args, **kwargs)
async def pop_ensured_connection(
self, *args: List[Any], **kwargs: Dict[str, Any]
) -> aioredis.commands.Redis:
redis: Optional[aioredis.commands.Redis] = None
while redis is None:
redis = await super().pop(*args, **kwargs)
try:
await self.try_ping(redis)
except InvalidConnection:
if redis is not None:
super().conn_error(redis)
redis = None
return redis
async def try_ping(self, redis: aioredis.commands.Redis) -> None:
try:
pong = await redis.ping()
if pong != b"PONG":
logger.info("Redis connection invalid, did not recieve PONG")
raise InvalidConnection()
except (ConnectionRefusedError, ConnectionResetError):
logger.info("Redis connection invalid, connection is bad")
raise InvalidConnection()
def push(self, conn: aioredis.commands.Redis) -> None:
super().push(conn)
self.counter -= 1
def conn_error(self, conn: aioredis.commands.Redis) -> None:
super().conn_error(conn)
self.counter -= 1
def reset(self) -> None:
super().reset()
self.counter = 0