diff --git a/django_multitenant/settings.py b/django_multitenant/settings.py index 77aa45a..da182ae 100644 --- a/django_multitenant/settings.py +++ b/django_multitenant/settings.py @@ -6,3 +6,4 @@ TENANT_MODEL_NAME = getattr(settings, "TENANT_MODEL_NAME", None) CITUS_EXTENSION_INSTALLED = getattr(settings, "CITUS_EXTENSION_INSTALLED", False) TENANT_STRICT_MODE = getattr(settings, "TENANT_STRICT_MODE", False) +TENANT_USE_ASGIREF = getattr(settings, "TENANT_USE_ASGIREF", False) diff --git a/django_multitenant/tests/settings.py b/django_multitenant/tests/settings.py index 3e10723..ec81c39 100644 --- a/django_multitenant/tests/settings.py +++ b/django_multitenant/tests/settings.py @@ -78,3 +78,6 @@ DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" USE_TZ = True + +TENANT_USE_ASGIREF = False + diff --git a/django_multitenant/tests/test_utils.py b/django_multitenant/tests/test_utils.py index 66ddfb0..6ace95f 100644 --- a/django_multitenant/tests/test_utils.py +++ b/django_multitenant/tests/test_utils.py @@ -1,3 +1,8 @@ +import asyncio +import sys, importlib +from asgiref.sync import async_to_sync, sync_to_async + + from django_multitenant.utils import ( set_current_tenant, get_current_tenant, @@ -11,6 +16,12 @@ class UtilsTest(BaseTestCase): + async def async_get_current_tenant(self): + return get_current_tenant() + + async def async_set_current_tenant(self, tenant): + return set_current_tenant(tenant) + def test_set_current_tenant(self): projects = self.projects account = projects[0].account @@ -19,6 +30,50 @@ def test_set_current_tenant(self): self.assertEqual(get_current_tenant(), account) unset_current_tenant() + def test_tenant_persists_from_thread_to_async_task(self): + projects = self.projects + account = projects[0].account + + # Set the tenant in main thread + set_current_tenant(account) + + with self.settings(TENANT_USE_ASGIREF=True): + importlib.reload(sys.modules['django_multitenant.utils']) + from django_multitenant.utils import get_current_tenant + # Check the tenant within an async task when asgiref enabled + tenant = async_to_sync(self.async_get_current_tenant)() + self.assertEqual(get_current_tenant(), tenant) + unset_current_tenant() + + with self.settings(TENANT_USE_ASGIREF=False): + importlib.reload(sys.modules['django_multitenant.utils']) + from django_multitenant.utils import get_current_tenant + # Check the tenant within an async task when asgiref is disabled + tenant = async_to_sync(self.async_get_current_tenant)() + self.assertIsNone(get_current_tenant()) + unset_current_tenant() + + def test_tenant_persists_from_async_task_to_thread(self): + projects = self.projects + account = projects[0].account + + with self.settings(TENANT_USE_ASGIREF=True): + importlib.reload(sys.modules['django_multitenant.utils']) + from django_multitenant.utils import get_current_tenant + # Set the tenant in task + async_to_sync(self.async_set_current_tenant)(account) + self.assertEqual(get_current_tenant(), account) + unset_current_tenant() + + with self.settings(TENANT_USE_ASGIREF=False): + importlib.reload(sys.modules['django_multitenant.utils']) + from django_multitenant.utils import get_current_tenant + # Set the tenant in task + async_to_sync(self.async_set_current_tenant)(account) + self.assertIsNone(get_current_tenant()) + unset_current_tenant() + + def test_get_tenant_column(self): from .models import Project diff --git a/django_multitenant/utils.py b/django_multitenant/utils.py index 0053168..c3a9d2d 100644 --- a/django_multitenant/utils.py +++ b/django_multitenant/utils.py @@ -1,14 +1,20 @@ import inspect from django.apps import apps +from django.conf import settings -try: - from threading import local -except ImportError: - from django.utils._threading_local import local + +if settings.TENANT_USE_ASGIREF: + # asgiref must be installed, its included with Django >= 3.0 + from asgiref.local import Local as local +else: + try: + from threading import local + except ImportError: + from django.utils._threading_local import local -_thread_locals = local() +_thread_locals = _context = local() def get_model_by_db_table(db_table): @@ -26,14 +32,14 @@ def get_model_by_db_table(db_table): def get_current_tenant(): """ - Utils to get the tenant that hass been set in the current thread using `set_current_tenant`. + Utils to get the tenant that hass been set in the current thread/context using `set_current_tenant`. Can be used by doing: ``` my_class_object = get_current_tenant() ``` Will return None if the tenant is not set """ - return getattr(_thread_locals, "tenant", None) + return getattr(_context, "tenant", None) def get_tenant_column(model_class_or_instance): @@ -125,7 +131,7 @@ def get_tenant_filters(table, filters=None): def set_current_tenant(tenant): """ - Utils to set a tenant in the current thread. + Utils to set a tenant in the current thread/context. Often used in a middleware once a user is logged in to make sure all db calls are sharded to the current tenant. Can be used by doing: @@ -133,11 +139,11 @@ def set_current_tenant(tenant): get_current_tenant(my_class_object) ``` """ - setattr(_thread_locals, "tenant", tenant) + setattr(_context, "tenant", tenant) def unset_current_tenant(): - setattr(_thread_locals, "tenant", None) + setattr(_context, "tenant", None) def is_distributed_model(model): diff --git a/requirements/test-requirements.txt b/requirements/test-requirements.txt index cab14b6..26af9c8 100644 --- a/requirements/test-requirements.txt +++ b/requirements/test-requirements.txt @@ -1,15 +1,15 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.11 # by the following command: # # pip-compile --output-file=requirements/test-requirements.txt --resolver=backtracking requirements/test.in # +asgiref==3.7.2 + # via -r requirements/test.in coverage[toml]==7.2.7 # via pytest-cov exam==0.10.6 # via -r requirements/test.in -exceptiongroup==1.1.2 - # via pytest iniconfig==2.0.0 # via pytest mock==5.0.2 @@ -29,7 +29,3 @@ pytest-cov==4.1.0 # via -r requirements/test.in pytest-django==4.5.2 # via -r requirements/test.in -tomli==2.0.1 - # via - # coverage - # pytest diff --git a/requirements/test.in b/requirements/test.in index a75009f..6e67ff8 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -3,3 +3,4 @@ pytest pytest-cov pytest-django exam +asgiref>= 3.5.2