2019-11-04 14:56:01 +01:00
|
|
|
from typing import Callable
|
|
|
|
|
2019-11-01 09:11:12 +01:00
|
|
|
from django.db import DEFAULT_DB_ALIAS, connections
|
|
|
|
from django.test.utils import CaptureQueriesContext
|
|
|
|
|
|
|
|
|
2019-11-04 14:56:01 +01:00
|
|
|
def count_queries(func, verbose=False) -> Callable[..., int]:
|
|
|
|
def wrapper(*args, **kwargs) -> int:
|
|
|
|
context = CaptureQueriesContext(connections[DEFAULT_DB_ALIAS])
|
|
|
|
with context:
|
|
|
|
func(*args, **kwargs)
|
|
|
|
|
|
|
|
if verbose:
|
|
|
|
print(get_verbose_queries(context))
|
2019-11-01 09:11:12 +01:00
|
|
|
|
2019-11-04 14:56:01 +01:00
|
|
|
return len(context)
|
2019-11-01 09:11:12 +01:00
|
|
|
|
2019-11-04 14:56:01 +01:00
|
|
|
return wrapper
|
2019-11-01 09:11:12 +01:00
|
|
|
|
|
|
|
|
2019-11-04 14:56:01 +01:00
|
|
|
class AssertNumQueriesContext(CaptureQueriesContext):
|
|
|
|
def __init__(self, test_case, num, verbose):
|
|
|
|
self.test_case = test_case
|
|
|
|
self.num = num
|
|
|
|
self.verbose = verbose
|
|
|
|
super().__init__(connections[DEFAULT_DB_ALIAS])
|
2019-11-01 09:11:12 +01:00
|
|
|
|
2019-11-04 14:56:01 +01:00
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
|
|
super().__exit__(exc_type, exc_value, traceback)
|
|
|
|
if exc_type is not None:
|
|
|
|
return
|
|
|
|
executed = len(self)
|
|
|
|
verbose_queries = get_verbose_queries(self)
|
|
|
|
if self.verbose:
|
|
|
|
print(verbose_queries)
|
|
|
|
self.test_case.assertEqual(executed, self.num)
|
|
|
|
else:
|
|
|
|
self.test_case.assertEqual(executed, self.num, verbose_queries)
|
2019-11-01 09:11:12 +01:00
|
|
|
|
|
|
|
|
2019-11-04 14:56:01 +01:00
|
|
|
def get_verbose_queries(context):
|
|
|
|
queries = "\n".join(
|
|
|
|
f"{i}. {query['sql']}"
|
|
|
|
for i, query in enumerate(context.captured_queries, start=1)
|
|
|
|
)
|
|
|
|
return f"{len(context)} queries executed\nCaptured queries were:\n{queries}"
|