658 lines
23 KiB
Python
658 lines
23 KiB
Python
import functools
|
|
import hashlib
|
|
import logging
|
|
from collections import defaultdict
|
|
from textwrap import dedent
|
|
from typing import Any, Callable, Coroutine, Dict, List, Optional, Set, Tuple
|
|
|
|
from django.core.exceptions import ImproperlyConfigured
|
|
from typing_extensions import Protocol
|
|
|
|
from .redis import use_redis
|
|
from .schema_version import SchemaVersion
|
|
from .utils import split_element_id, str_dict_to_bytes
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if use_redis:
|
|
from .redis import get_connection, aioredis
|
|
|
|
|
|
class CacheReset(Exception):
|
|
pass
|
|
|
|
|
|
class ElementCacheProvider(Protocol):
|
|
"""
|
|
Base class for cache provider.
|
|
|
|
See RedisCacheProvider as reverence implementation.
|
|
"""
|
|
|
|
def __init__(self, ensure_cache: Callable[[], Coroutine[Any, Any, None]]) -> None:
|
|
...
|
|
|
|
async def ensure_cache(self) -> None:
|
|
...
|
|
|
|
async def clear_cache(self) -> None:
|
|
...
|
|
|
|
async def reset_full_cache(self, data: Dict[str, str]) -> None:
|
|
...
|
|
|
|
async def data_exists(self) -> bool:
|
|
...
|
|
|
|
async def get_all_data(self) -> Dict[bytes, bytes]:
|
|
...
|
|
|
|
async def get_collection_data(self, collection: str) -> Dict[int, bytes]:
|
|
...
|
|
|
|
async def get_element_data(self, element_id: str) -> Optional[bytes]:
|
|
...
|
|
|
|
async def add_changed_elements(
|
|
self,
|
|
changed_elements: List[str],
|
|
deleted_element_ids: List[str],
|
|
default_change_id: int,
|
|
) -> int:
|
|
...
|
|
|
|
async def get_data_since(
|
|
self, change_id: int, max_change_id: int = -1
|
|
) -> Tuple[Dict[str, List[bytes]], List[str]]:
|
|
...
|
|
|
|
async def set_lock(self, lock_name: str) -> bool:
|
|
...
|
|
|
|
async def get_lock(self, lock_name: str) -> bool:
|
|
...
|
|
|
|
async def del_lock(self, lock_name: str) -> None:
|
|
...
|
|
|
|
async def get_current_change_id(self) -> Optional[int]:
|
|
...
|
|
|
|
async def get_lowest_change_id(self) -> Optional[int]:
|
|
...
|
|
|
|
async def get_schema_version(self) -> Optional[SchemaVersion]:
|
|
...
|
|
|
|
async def set_schema_version(self, schema_version: SchemaVersion) -> None:
|
|
...
|
|
|
|
|
|
def ensure_cache_wrapper() -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
"""
|
|
Wraps a cache function to ensure, that the cache is filled.
|
|
When the function raises a CacheReset-Error the cache will be ensured (call
|
|
to `ensure_cache`) and the method will be recalled. This is done, until the
|
|
operation was successful.
|
|
"""
|
|
|
|
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
@functools.wraps(func)
|
|
async def wrapped(
|
|
cache_provider: ElementCacheProvider, *args: Any, **kwargs: Any
|
|
) -> Any:
|
|
success = False
|
|
while not success:
|
|
try:
|
|
result = await func(cache_provider, *args, **kwargs)
|
|
success = True
|
|
except CacheReset:
|
|
logger.warn(
|
|
f"Redis was flushed before method '{func.__name__}'. Ensures cache now."
|
|
)
|
|
await cache_provider.ensure_cache()
|
|
return result
|
|
|
|
return wrapped
|
|
|
|
return wrapper
|
|
|
|
|
|
class RedisCacheProvider:
|
|
"""
|
|
Cache provider that loads and saves the data to redis.
|
|
"""
|
|
|
|
full_data_cache_key: str = "full_data"
|
|
change_id_cache_key: str = "change_id"
|
|
schema_cache_key: str = "schema"
|
|
prefix: str = "element_cache_"
|
|
|
|
# All lua-scripts used by this provider. Every entry is a Tuple (str, bool) with the
|
|
# script and an ensure_cache-indicator. If the indicator is True, a short ensure_cache-script
|
|
# will be prepended to the script which raises a CacheReset, if the full data cache is empty.
|
|
# This requires the full_data_cache_key to be the first key given in `keys`!
|
|
# All scripts are dedented and hashed for faster execution. Convention: The keys of this
|
|
# member are the methods that needs these scripts.
|
|
scripts = {
|
|
"clear_cache": (
|
|
"return redis.call('del', 'fake_key', unpack(redis.call('keys', ARGV[1])))",
|
|
False,
|
|
),
|
|
"get_all_data": ("return redis.call('hgetall', KEYS[1])", True),
|
|
"get_collection_data": (
|
|
"""
|
|
local cursor = 0
|
|
local collection = {}
|
|
repeat
|
|
local result = redis.call('HSCAN', KEYS[1], cursor, 'MATCH', ARGV[1])
|
|
cursor = tonumber(result[1])
|
|
for _, v in pairs(result[2]) do
|
|
table.insert(collection, v)
|
|
end
|
|
until cursor == 0
|
|
return collection
|
|
""",
|
|
True,
|
|
),
|
|
"get_element_data": ("return redis.call('hget', KEYS[1], ARGV[1])", True),
|
|
"add_changed_elements": (
|
|
"""
|
|
-- Generate a new change_id
|
|
local tmp = redis.call('zrevrangebyscore', KEYS[2], '+inf', '-inf', 'WITHSCORES', 'LIMIT', 0, 1)
|
|
local change_id
|
|
if next(tmp) == nil then
|
|
-- The key does not exist
|
|
change_id = ARGV[1]
|
|
else
|
|
change_id = tmp[2] + 1
|
|
end
|
|
|
|
local nc = tonumber(ARGV[2])
|
|
local nd = tonumber(ARGV[3])
|
|
|
|
local i, max
|
|
-- Add changed_elements to the cache and sorted set (the first of the pairs)
|
|
if (nc > 0) then
|
|
max = 2 + nc
|
|
redis.call('hmset', KEYS[1], unpack(ARGV, 4, max + 1))
|
|
for i = 4, max, 2 do
|
|
redis.call('zadd', KEYS[2], change_id, ARGV[i])
|
|
end
|
|
end
|
|
|
|
-- Delete deleted_element_ids and add them to sorted set
|
|
if (nd > 0) then
|
|
max = 3 + nc + nd
|
|
redis.call('hdel', KEYS[1], unpack(ARGV, 4 + nc, max))
|
|
for i = 4 + nc, max, 2 do
|
|
redis.call('zadd', KEYS[2], change_id, ARGV[i])
|
|
end
|
|
end
|
|
|
|
-- Set lowest_change_id if it does not exist
|
|
redis.call('zadd', KEYS[2], 'NX', change_id, '_config:lowest_change_id')
|
|
|
|
return change_id
|
|
""",
|
|
True,
|
|
),
|
|
"get_data_since": (
|
|
"""
|
|
-- Get change ids of changed elements
|
|
local element_ids = redis.call('zrangebyscore', KEYS[2], ARGV[1], ARGV[2])
|
|
|
|
-- Save elements in array. Rotate element_id and element_json
|
|
local elements = {}
|
|
for _, element_id in pairs(element_ids) do
|
|
table.insert(elements, element_id)
|
|
table.insert(elements, redis.call('hget', KEYS[1], element_id))
|
|
end
|
|
return elements
|
|
""",
|
|
True,
|
|
),
|
|
}
|
|
|
|
def __init__(self, ensure_cache: Callable[[], Coroutine[Any, Any, None]]) -> None:
|
|
self._ensure_cache = ensure_cache
|
|
|
|
# hash all scripts and remove indentation.
|
|
for key in self.scripts.keys():
|
|
script, add_ensure_cache = self.scripts[key]
|
|
script = dedent(script)
|
|
if add_ensure_cache:
|
|
script = (
|
|
dedent(
|
|
"""
|
|
local exist = redis.call('exists', KEYS[1])
|
|
if (exist == 0) then
|
|
redis.log(redis.LOG_WARNING, "empty: "..KEYS[1])
|
|
return redis.error_reply("cache_reset")
|
|
end
|
|
"""
|
|
)
|
|
+ script
|
|
)
|
|
self.scripts[key] = (script, add_ensure_cache)
|
|
self._script_hashes = {
|
|
key: hashlib.sha1(script.encode()).hexdigest()
|
|
for key, (script, _) in self.scripts.items()
|
|
}
|
|
|
|
async def ensure_cache(self) -> None:
|
|
await self._ensure_cache()
|
|
|
|
def get_full_data_cache_key(self) -> str:
|
|
return "".join((self.prefix, self.full_data_cache_key))
|
|
|
|
def get_change_id_cache_key(self) -> str:
|
|
return "".join((self.prefix, self.change_id_cache_key))
|
|
|
|
def get_schema_cache_key(self) -> str:
|
|
return "".join((self.prefix, self.schema_cache_key))
|
|
|
|
async def clear_cache(self) -> None:
|
|
"""
|
|
Deleted all cache entries created with this element cache.
|
|
"""
|
|
await self.eval("clear_cache", keys=[], args=[f"{self.prefix}*"])
|
|
|
|
async def reset_full_cache(self, data: Dict[str, str]) -> None:
|
|
"""
|
|
Deletes the full_data_cache and write new data in it. Clears the change id key.
|
|
Does not clear locks.
|
|
"""
|
|
async with get_connection() as redis:
|
|
tr = redis.multi_exec()
|
|
tr.delete(self.get_change_id_cache_key())
|
|
tr.delete(self.get_full_data_cache_key())
|
|
tr.hmset_dict(self.get_full_data_cache_key(), data)
|
|
await tr.execute()
|
|
|
|
async def data_exists(self) -> bool:
|
|
"""
|
|
Returns True, when there is data in the cache.
|
|
"""
|
|
async with get_connection() as redis:
|
|
return await redis.exists(self.get_full_data_cache_key())
|
|
|
|
@ensure_cache_wrapper()
|
|
async def get_all_data(self) -> Dict[bytes, bytes]:
|
|
"""
|
|
Returns all data from the full_data_cache in a mapping from element_id to the element.
|
|
"""
|
|
return await aioredis.util.wait_make_dict(
|
|
self.eval("get_all_data", [self.get_full_data_cache_key()])
|
|
)
|
|
|
|
@ensure_cache_wrapper()
|
|
async def get_collection_data(self, collection: str) -> Dict[int, bytes]:
|
|
"""
|
|
Returns all elements for a collection from the cache. The data is mapped
|
|
from element_id to the element.
|
|
"""
|
|
response = await self.eval(
|
|
"get_collection_data", [self.get_full_data_cache_key()], [f"{collection}:*"]
|
|
)
|
|
|
|
collection_data = {}
|
|
for i in range(0, len(response), 2):
|
|
_, id = split_element_id(response[i])
|
|
collection_data[id] = response[i + 1]
|
|
|
|
return collection_data
|
|
|
|
@ensure_cache_wrapper()
|
|
async def get_element_data(self, element_id: str) -> Optional[bytes]:
|
|
"""
|
|
Returns one element from the cache. Returns None, when the element does not exist.
|
|
"""
|
|
return await self.eval(
|
|
"get_element_data", [self.get_full_data_cache_key()], [element_id]
|
|
)
|
|
|
|
@ensure_cache_wrapper()
|
|
async def add_changed_elements(
|
|
self,
|
|
changed_elements: List[str],
|
|
deleted_element_ids: List[str],
|
|
default_change_id: int,
|
|
) -> int:
|
|
"""
|
|
Modified the full_data_cache to insert the changed_elements and removes the
|
|
deleted_element_ids (in this order). Generates a new change_id and inserts all
|
|
element_ids (changed and deleted) with the change_id into the change_id_cache.
|
|
The newly generated change_id is returned.
|
|
"""
|
|
return int(
|
|
await self.eval(
|
|
"add_changed_elements",
|
|
keys=[self.get_full_data_cache_key(), self.get_change_id_cache_key()],
|
|
args=[
|
|
default_change_id,
|
|
len(changed_elements),
|
|
len(deleted_element_ids),
|
|
*(changed_elements + deleted_element_ids),
|
|
],
|
|
)
|
|
)
|
|
|
|
@ensure_cache_wrapper()
|
|
async def get_data_since(
|
|
self, change_id: int, max_change_id: int = -1
|
|
) -> Tuple[Dict[str, List[bytes]], List[str]]:
|
|
"""
|
|
Returns all elements since a change_id (included) and until the max_change_id (included).
|
|
|
|
The returend value is a two element tuple. The first value is a dict the elements where
|
|
the key is the collection_string and the value a list of (json-) encoded elements. The
|
|
second element is a list of element_ids, that have been deleted since the change_id.
|
|
"""
|
|
changed_elements: Dict[str, List[bytes]] = defaultdict(list)
|
|
deleted_elements: List[str] = []
|
|
|
|
# Convert max_change_id to a string. If its negative, use the string '+inf'
|
|
redis_max_change_id = "+inf" if max_change_id < 0 else str(max_change_id)
|
|
# lua script that returns gets all element_ids from change_id_cache_key
|
|
# and then uses each element_id on full_data or restricted_data.
|
|
# It returns a list where the odd values are the change_id and the
|
|
# even values the element as json. The function wait_make_dict creates
|
|
# a python dict from the returned list.
|
|
elements: Dict[bytes, Optional[bytes]] = await aioredis.util.wait_make_dict(
|
|
self.eval(
|
|
"get_data_since",
|
|
keys=[self.get_full_data_cache_key(), self.get_change_id_cache_key()],
|
|
args=[change_id, redis_max_change_id],
|
|
)
|
|
)
|
|
|
|
for element_id, element_json in elements.items():
|
|
if element_id.startswith(b"_config"):
|
|
# Ignore config values from the change_id cache key
|
|
continue
|
|
if element_json is None:
|
|
# The element is not in the cache. It has to be deleted.
|
|
deleted_elements.append(element_id.decode())
|
|
else:
|
|
collection_string, id = split_element_id(element_id)
|
|
changed_elements[collection_string].append(element_json)
|
|
return changed_elements, deleted_elements
|
|
|
|
async def set_lock(self, lock_name: str) -> bool:
|
|
"""
|
|
Tries to sets a lock.
|
|
|
|
Returns True when the lock could be set and False, if it was already set.
|
|
"""
|
|
# TODO: Improve lock. See: https://redis.io/topics/distlock
|
|
async with get_connection() as redis:
|
|
return await redis.setnx(f"{self.prefix}lock_{lock_name}", 1)
|
|
|
|
async def get_lock(self, lock_name: str) -> bool:
|
|
"""
|
|
Returns True, when the lock is set. Else False.
|
|
"""
|
|
async with get_connection() as redis:
|
|
return await redis.get(f"{self.prefix}lock_{lock_name}")
|
|
|
|
async def del_lock(self, lock_name: str) -> None:
|
|
"""
|
|
Deletes the lock. Does nothing when the lock is not set.
|
|
"""
|
|
async with get_connection() as redis:
|
|
await redis.delete(f"{self.prefix}lock_{lock_name}")
|
|
|
|
async def get_current_change_id(self) -> Optional[int]:
|
|
"""
|
|
Get the highest change_id from redis.
|
|
"""
|
|
async with get_connection() as redis:
|
|
value = await redis.zrevrangebyscore(
|
|
self.get_change_id_cache_key(), withscores=True, count=1, offset=0
|
|
)
|
|
# Return the score (second element) of the first (and only) element, if exists.
|
|
return value[0][1] if value else None
|
|
|
|
async def get_lowest_change_id(self) -> Optional[int]:
|
|
"""
|
|
Get the lowest change_id from redis.
|
|
|
|
Returns None if lowest score does not exist.
|
|
"""
|
|
async with get_connection() as redis:
|
|
return await redis.zscore(
|
|
self.get_change_id_cache_key(), "_config:lowest_change_id"
|
|
)
|
|
|
|
async def get_schema_version(self) -> Optional[SchemaVersion]:
|
|
""" Retrieves the schema version of the cache or None, if not existent """
|
|
async with get_connection() as redis:
|
|
schema_version = await redis.hgetall(self.get_schema_cache_key())
|
|
if not schema_version:
|
|
return None
|
|
|
|
return {
|
|
"migration": int(schema_version[b"migration"].decode()),
|
|
"config": int(schema_version[b"config"].decode()),
|
|
"db": schema_version[b"db"].decode(),
|
|
}
|
|
|
|
async def set_schema_version(self, schema_version: SchemaVersion) -> None:
|
|
""" Sets the schema version for this cache. """
|
|
async with get_connection() as redis:
|
|
await redis.hmset_dict(self.get_schema_cache_key(), schema_version)
|
|
|
|
async def eval(
|
|
self, script_name: str, keys: List[str] = [], args: List[Any] = []
|
|
) -> Any:
|
|
"""
|
|
Runs a lua script in redis. This wrapper around redis.eval tries to make
|
|
usage of redis script cache. First the hash is send to the server and if
|
|
the script is not present there (NOSCRIPT error) the actual script will be
|
|
send.
|
|
If the script uses the ensure_cache-prefix, the first key must be the full_data
|
|
cache key. This is checked here.
|
|
Also this method incudes the custom "CacheReset" error, which will be raised in
|
|
python, if the lua-script returns a "cache_reset" string as an error response.
|
|
"""
|
|
hash = self._script_hashes[script_name]
|
|
if (
|
|
self.scripts[script_name][1]
|
|
and not keys[0] == self.get_full_data_cache_key()
|
|
):
|
|
raise ImproperlyConfigured(
|
|
"A script with a ensure_cache prefix must have the full_data cache key as its first key"
|
|
)
|
|
|
|
async with get_connection() as redis:
|
|
try:
|
|
return await redis.evalsha(hash, keys, args)
|
|
except aioredis.errors.ReplyError as e:
|
|
if str(e).startswith("NOSCRIPT"):
|
|
return await self._eval(redis, script_name, keys=keys, args=args)
|
|
elif str(e) == "cache_reset":
|
|
raise CacheReset()
|
|
else:
|
|
raise e
|
|
|
|
async def _eval(
|
|
self, redis: Any, script_name: str, keys: List[str] = [], args: List[Any] = []
|
|
) -> Any:
|
|
""" Do a real eval of the script (no hash used here). Catches "cache_reset". """
|
|
try:
|
|
return await redis.eval(self.scripts[script_name][0], keys, args)
|
|
except aioredis.errors.ReplyError as e:
|
|
if str(e) == "cache_reset":
|
|
raise CacheReset()
|
|
else:
|
|
raise e
|
|
|
|
|
|
class MemmoryCacheProvider:
|
|
"""
|
|
CacheProvider for the ElementCache that uses only the memory.
|
|
|
|
See the RedisCacheProvider for a description of the methods.
|
|
|
|
This provider supports only one process. It saves the data into the memory.
|
|
When you use different processes they will use diffrent data.
|
|
|
|
For this reason, the ensure_cache is not used and the schema version always
|
|
returns an invalid schema to always buold the cache.
|
|
"""
|
|
|
|
def __init__(self, ensure_cache: Callable[[], Coroutine[Any, Any, None]]) -> None:
|
|
self.set_data_dicts()
|
|
|
|
def set_data_dicts(self) -> None:
|
|
self.full_data: Dict[str, str] = {}
|
|
self.change_id_data: Dict[int, Set[str]] = {}
|
|
self.locks: Dict[str, str] = {}
|
|
|
|
async def ensure_cache(self) -> None:
|
|
pass
|
|
|
|
async def clear_cache(self) -> None:
|
|
self.set_data_dicts()
|
|
|
|
async def reset_full_cache(self, data: Dict[str, str]) -> None:
|
|
self.change_id_data = {}
|
|
self.full_data = data
|
|
|
|
async def data_exists(self) -> bool:
|
|
return bool(self.full_data)
|
|
|
|
async def get_all_data(self) -> Dict[bytes, bytes]:
|
|
return str_dict_to_bytes(self.full_data)
|
|
|
|
async def get_collection_data(self, collection: str) -> Dict[int, bytes]:
|
|
out = {}
|
|
query = f"{collection}:"
|
|
for element_id, value in self.full_data.items():
|
|
if element_id.startswith(query):
|
|
_, id = split_element_id(element_id)
|
|
out[id] = value.encode()
|
|
return out
|
|
|
|
async def get_element_data(self, element_id: str) -> Optional[bytes]:
|
|
value = self.full_data.get(element_id, None)
|
|
return value.encode() if value is not None else None
|
|
|
|
async def add_changed_elements(
|
|
self,
|
|
changed_elements: List[str],
|
|
deleted_element_ids: List[str],
|
|
default_change_id: int,
|
|
) -> int:
|
|
current_change_id = await self.get_current_change_id()
|
|
if current_change_id is None:
|
|
change_id = default_change_id
|
|
else:
|
|
change_id = current_change_id + 1
|
|
|
|
for i in range(0, len(changed_elements), 2):
|
|
element_id = changed_elements[i]
|
|
self.full_data[element_id] = changed_elements[i + 1]
|
|
|
|
if change_id in self.change_id_data:
|
|
self.change_id_data[change_id].add(element_id)
|
|
else:
|
|
self.change_id_data[change_id] = {element_id}
|
|
|
|
for element_id in deleted_element_ids:
|
|
try:
|
|
del self.full_data[element_id]
|
|
except KeyError:
|
|
pass
|
|
if change_id in self.change_id_data:
|
|
self.change_id_data[change_id].add(element_id)
|
|
else:
|
|
self.change_id_data[change_id] = {element_id}
|
|
|
|
return change_id
|
|
|
|
async def get_data_since(
|
|
self, change_id: int, max_change_id: int = -1
|
|
) -> Tuple[Dict[str, List[bytes]], List[str]]:
|
|
changed_elements: Dict[str, List[bytes]] = defaultdict(list)
|
|
deleted_elements: List[str] = []
|
|
|
|
all_element_ids: Set[str] = set()
|
|
for data_change_id, element_ids in self.change_id_data.items():
|
|
if data_change_id >= change_id and (
|
|
max_change_id == -1 or data_change_id <= max_change_id
|
|
):
|
|
all_element_ids.update(element_ids)
|
|
|
|
for element_id in all_element_ids:
|
|
element_json = self.full_data.get(element_id, None)
|
|
if element_json is None:
|
|
deleted_elements.append(element_id)
|
|
else:
|
|
collection_string, id = split_element_id(element_id)
|
|
changed_elements[collection_string].append(element_json.encode())
|
|
return changed_elements, deleted_elements
|
|
|
|
async def set_lock(self, lock_name: str) -> bool:
|
|
if lock_name in self.locks:
|
|
return False
|
|
self.locks[lock_name] = "1"
|
|
return True
|
|
|
|
async def get_lock(self, lock_name: str) -> bool:
|
|
return lock_name in self.locks
|
|
|
|
async def del_lock(self, lock_name: str) -> None:
|
|
try:
|
|
del self.locks[lock_name]
|
|
except KeyError:
|
|
pass
|
|
|
|
async def get_current_change_id(self) -> Optional[int]:
|
|
change_data = self.change_id_data
|
|
if change_data:
|
|
return max(change_data.keys())
|
|
return None
|
|
|
|
async def get_lowest_change_id(self) -> Optional[int]:
|
|
change_data = self.change_id_data
|
|
if change_data:
|
|
return min(change_data.keys())
|
|
return None
|
|
|
|
async def get_schema_version(self) -> Optional[SchemaVersion]:
|
|
return None
|
|
|
|
async def set_schema_version(self, schema_version: SchemaVersion) -> None:
|
|
pass
|
|
|
|
|
|
class Cachable(Protocol):
|
|
"""
|
|
A Cachable is an object that returns elements that can be cached.
|
|
|
|
It needs at least the methods defined here.
|
|
"""
|
|
|
|
def get_collection_string(self) -> str:
|
|
"""
|
|
Returns the string representing the name of the cachable.
|
|
"""
|
|
|
|
def get_elements(self) -> List[Dict[str, Any]]:
|
|
"""
|
|
Returns all elements of the cachable.
|
|
"""
|
|
|
|
async def restrict_elements(
|
|
self, user_id: int, elements: List[Dict[str, Any]]
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Converts full_data to restricted_data.
|
|
|
|
elements can be an empty list, a list with some elements of the cachable or with all
|
|
elements of the cachable.
|
|
"""
|