1. import asyncio
    
  2. import os
    
  3. from unittest import mock
    
  4. 
    
  5. from asgiref.sync import async_to_sync
    
  6. 
    
  7. from django.core.cache import DEFAULT_CACHE_ALIAS, caches
    
  8. from django.core.exceptions import ImproperlyConfigured, SynchronousOnlyOperation
    
  9. from django.http import HttpResponse, HttpResponseNotAllowed
    
  10. from django.test import RequestFactory, SimpleTestCase
    
  11. from django.utils.asyncio import async_unsafe
    
  12. from django.views.generic.base import View
    
  13. 
    
  14. from .models import SimpleModel
    
  15. 
    
  16. 
    
  17. class CacheTest(SimpleTestCase):
    
  18.     def test_caches_local(self):
    
  19.         @async_to_sync
    
  20.         async def async_cache():
    
  21.             return caches[DEFAULT_CACHE_ALIAS]
    
  22. 
    
  23.         cache_1 = async_cache()
    
  24.         cache_2 = async_cache()
    
  25.         self.assertIs(cache_1, cache_2)
    
  26. 
    
  27. 
    
  28. class DatabaseConnectionTest(SimpleTestCase):
    
  29.     """A database connection cannot be used in an async context."""
    
  30. 
    
  31.     async def test_get_async_connection(self):
    
  32.         with self.assertRaises(SynchronousOnlyOperation):
    
  33.             list(SimpleModel.objects.all())
    
  34. 
    
  35. 
    
  36. class AsyncUnsafeTest(SimpleTestCase):
    
  37.     """
    
  38.     async_unsafe decorator should work correctly and returns the correct
    
  39.     message.
    
  40.     """
    
  41. 
    
  42.     @async_unsafe
    
  43.     def dangerous_method(self):
    
  44.         return True
    
  45. 
    
  46.     async def test_async_unsafe(self):
    
  47.         # async_unsafe decorator catches bad access and returns the right
    
  48.         # message.
    
  49.         msg = (
    
  50.             "You cannot call this from an async context - use a thread or "
    
  51.             "sync_to_async."
    
  52.         )
    
  53.         with self.assertRaisesMessage(SynchronousOnlyOperation, msg):
    
  54.             self.dangerous_method()
    
  55. 
    
  56.     @mock.patch.dict(os.environ, {"DJANGO_ALLOW_ASYNC_UNSAFE": "true"})
    
  57.     @async_to_sync  # mock.patch() is not async-aware.
    
  58.     async def test_async_unsafe_suppressed(self):
    
  59.         # Decorator doesn't trigger check when the environment variable to
    
  60.         # suppress it is set.
    
  61.         try:
    
  62.             self.dangerous_method()
    
  63.         except SynchronousOnlyOperation:
    
  64.             self.fail("SynchronousOnlyOperation should not be raised.")
    
  65. 
    
  66. 
    
  67. class SyncView(View):
    
  68.     def get(self, request, *args, **kwargs):
    
  69.         return HttpResponse("Hello (sync) world!")
    
  70. 
    
  71. 
    
  72. class AsyncView(View):
    
  73.     async def get(self, request, *args, **kwargs):
    
  74.         return HttpResponse("Hello (async) world!")
    
  75. 
    
  76. 
    
  77. class ViewTests(SimpleTestCase):
    
  78.     def test_views_are_correctly_marked(self):
    
  79.         tests = [
    
  80.             (SyncView, False),
    
  81.             (AsyncView, True),
    
  82.         ]
    
  83.         for view_cls, is_async in tests:
    
  84.             with self.subTest(view_cls=view_cls, is_async=is_async):
    
  85.                 self.assertIs(view_cls.view_is_async, is_async)
    
  86.                 callback = view_cls.as_view()
    
  87.                 self.assertIs(asyncio.iscoroutinefunction(callback), is_async)
    
  88. 
    
  89.     def test_mixed_views_raise_error(self):
    
  90.         class MixedView(View):
    
  91.             def get(self, request, *args, **kwargs):
    
  92.                 return HttpResponse("Hello (mixed) world!")
    
  93. 
    
  94.             async def post(self, request, *args, **kwargs):
    
  95.                 return HttpResponse("Hello (mixed) world!")
    
  96. 
    
  97.         msg = (
    
  98.             f"{MixedView.__qualname__} HTTP handlers must either be all sync or all "
    
  99.             "async."
    
  100.         )
    
  101.         with self.assertRaisesMessage(ImproperlyConfigured, msg):
    
  102.             MixedView.as_view()
    
  103. 
    
  104.     def test_options_handler_responds_correctly(self):
    
  105.         tests = [
    
  106.             (SyncView, False),
    
  107.             (AsyncView, True),
    
  108.         ]
    
  109.         for view_cls, is_coroutine in tests:
    
  110.             with self.subTest(view_cls=view_cls, is_coroutine=is_coroutine):
    
  111.                 instance = view_cls()
    
  112.                 response = instance.options(None)
    
  113.                 self.assertIs(
    
  114.                     asyncio.iscoroutine(response),
    
  115.                     is_coroutine,
    
  116.                 )
    
  117.                 if is_coroutine:
    
  118.                     response = asyncio.run(response)
    
  119. 
    
  120.                 self.assertIsInstance(response, HttpResponse)
    
  121. 
    
  122.     def test_http_method_not_allowed_responds_correctly(self):
    
  123.         request_factory = RequestFactory()
    
  124.         tests = [
    
  125.             (SyncView, False),
    
  126.             (AsyncView, True),
    
  127.         ]
    
  128.         for view_cls, is_coroutine in tests:
    
  129.             with self.subTest(view_cls=view_cls, is_coroutine=is_coroutine):
    
  130.                 instance = view_cls()
    
  131.                 response = instance.http_method_not_allowed(request_factory.post("/"))
    
  132.                 self.assertIs(
    
  133.                     asyncio.iscoroutine(response),
    
  134.                     is_coroutine,
    
  135.                 )
    
  136.                 if is_coroutine:
    
  137.                     response = asyncio.run(response)
    
  138. 
    
  139.                 self.assertIsInstance(response, HttpResponseNotAllowed)
    
  140. 
    
  141.     def test_base_view_class_is_sync(self):
    
  142.         """
    
  143.         View and by extension any subclasses that don't define handlers are
    
  144.         sync.
    
  145.         """
    
  146.         self.assertIs(View.view_is_async, False)