1. import asyncio
    
  2. import sys
    
  3. import threading
    
  4. from pathlib import Path
    
  5. 
    
  6. from asgiref.testing import ApplicationCommunicator
    
  7. 
    
  8. from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler
    
  9. from django.core.asgi import get_asgi_application
    
  10. from django.core.signals import request_finished, request_started
    
  11. from django.db import close_old_connections
    
  12. from django.test import (
    
  13.     AsyncRequestFactory,
    
  14.     SimpleTestCase,
    
  15.     modify_settings,
    
  16.     override_settings,
    
  17. )
    
  18. from django.utils.http import http_date
    
  19. 
    
  20. from .urls import sync_waiter, test_filename
    
  21. 
    
  22. TEST_STATIC_ROOT = Path(__file__).parent / "project" / "static"
    
  23. 
    
  24. 
    
  25. @override_settings(ROOT_URLCONF="asgi.urls")
    
  26. class ASGITest(SimpleTestCase):
    
  27.     async_request_factory = AsyncRequestFactory()
    
  28. 
    
  29.     def setUp(self):
    
  30.         request_started.disconnect(close_old_connections)
    
  31. 
    
  32.     def tearDown(self):
    
  33.         request_started.connect(close_old_connections)
    
  34. 
    
  35.     async def test_get_asgi_application(self):
    
  36.         """
    
  37.         get_asgi_application() returns a functioning ASGI callable.
    
  38.         """
    
  39.         application = get_asgi_application()
    
  40.         # Construct HTTP request.
    
  41.         scope = self.async_request_factory._base_scope(path="/")
    
  42.         communicator = ApplicationCommunicator(application, scope)
    
  43.         await communicator.send_input({"type": "http.request"})
    
  44.         # Read the response.
    
  45.         response_start = await communicator.receive_output()
    
  46.         self.assertEqual(response_start["type"], "http.response.start")
    
  47.         self.assertEqual(response_start["status"], 200)
    
  48.         self.assertEqual(
    
  49.             set(response_start["headers"]),
    
  50.             {
    
  51.                 (b"Content-Length", b"12"),
    
  52.                 (b"Content-Type", b"text/html; charset=utf-8"),
    
  53.             },
    
  54.         )
    
  55.         response_body = await communicator.receive_output()
    
  56.         self.assertEqual(response_body["type"], "http.response.body")
    
  57.         self.assertEqual(response_body["body"], b"Hello World!")
    
  58.         # Allow response.close() to finish.
    
  59.         await communicator.wait()
    
  60. 
    
  61.     async def test_file_response(self):
    
  62.         """
    
  63.         Makes sure that FileResponse works over ASGI.
    
  64.         """
    
  65.         application = get_asgi_application()
    
  66.         # Construct HTTP request.
    
  67.         scope = self.async_request_factory._base_scope(path="/file/")
    
  68.         communicator = ApplicationCommunicator(application, scope)
    
  69.         await communicator.send_input({"type": "http.request"})
    
  70.         # Get the file content.
    
  71.         with open(test_filename, "rb") as test_file:
    
  72.             test_file_contents = test_file.read()
    
  73.         # Read the response.
    
  74.         response_start = await communicator.receive_output()
    
  75.         self.assertEqual(response_start["type"], "http.response.start")
    
  76.         self.assertEqual(response_start["status"], 200)
    
  77.         headers = response_start["headers"]
    
  78.         self.assertEqual(len(headers), 3)
    
  79.         expected_headers = {
    
  80.             b"Content-Length": str(len(test_file_contents)).encode("ascii"),
    
  81.             b"Content-Type": b"text/x-python",
    
  82.             b"Content-Disposition": b'inline; filename="urls.py"',
    
  83.         }
    
  84.         for key, value in headers:
    
  85.             try:
    
  86.                 self.assertEqual(value, expected_headers[key])
    
  87.             except AssertionError:
    
  88.                 # Windows registry may not be configured with correct
    
  89.                 # mimetypes.
    
  90.                 if sys.platform == "win32" and key == b"Content-Type":
    
  91.                     self.assertEqual(value, b"text/plain")
    
  92.                 else:
    
  93.                     raise
    
  94.         response_body = await communicator.receive_output()
    
  95.         self.assertEqual(response_body["type"], "http.response.body")
    
  96.         self.assertEqual(response_body["body"], test_file_contents)
    
  97.         # Allow response.close() to finish.
    
  98.         await communicator.wait()
    
  99. 
    
  100.     @modify_settings(INSTALLED_APPS={"append": "django.contrib.staticfiles"})
    
  101.     @override_settings(
    
  102.         STATIC_URL="static/",
    
  103.         STATIC_ROOT=TEST_STATIC_ROOT,
    
  104.         STATICFILES_DIRS=[TEST_STATIC_ROOT],
    
  105.         STATICFILES_FINDERS=[
    
  106.             "django.contrib.staticfiles.finders.FileSystemFinder",
    
  107.         ],
    
  108.     )
    
  109.     async def test_static_file_response(self):
    
  110.         application = ASGIStaticFilesHandler(get_asgi_application())
    
  111.         # Construct HTTP request.
    
  112.         scope = self.async_request_factory._base_scope(path="/static/file.txt")
    
  113.         communicator = ApplicationCommunicator(application, scope)
    
  114.         await communicator.send_input({"type": "http.request"})
    
  115.         # Get the file content.
    
  116.         file_path = TEST_STATIC_ROOT / "file.txt"
    
  117.         with open(file_path, "rb") as test_file:
    
  118.             test_file_contents = test_file.read()
    
  119.         # Read the response.
    
  120.         stat = file_path.stat()
    
  121.         response_start = await communicator.receive_output()
    
  122.         self.assertEqual(response_start["type"], "http.response.start")
    
  123.         self.assertEqual(response_start["status"], 200)
    
  124.         self.assertEqual(
    
  125.             set(response_start["headers"]),
    
  126.             {
    
  127.                 (b"Content-Length", str(len(test_file_contents)).encode("ascii")),
    
  128.                 (b"Content-Type", b"text/plain"),
    
  129.                 (b"Content-Disposition", b'inline; filename="file.txt"'),
    
  130.                 (b"Last-Modified", http_date(stat.st_mtime).encode("ascii")),
    
  131.             },
    
  132.         )
    
  133.         response_body = await communicator.receive_output()
    
  134.         self.assertEqual(response_body["type"], "http.response.body")
    
  135.         self.assertEqual(response_body["body"], test_file_contents)
    
  136.         # Allow response.close() to finish.
    
  137.         await communicator.wait()
    
  138. 
    
  139.     async def test_headers(self):
    
  140.         application = get_asgi_application()
    
  141.         communicator = ApplicationCommunicator(
    
  142.             application,
    
  143.             self.async_request_factory._base_scope(
    
  144.                 path="/meta/",
    
  145.                 headers=[
    
  146.                     [b"content-type", b"text/plain; charset=utf-8"],
    
  147.                     [b"content-length", b"77"],
    
  148.                     [b"referer", b"Scotland"],
    
  149.                     [b"referer", b"Wales"],
    
  150.                 ],
    
  151.             ),
    
  152.         )
    
  153.         await communicator.send_input({"type": "http.request"})
    
  154.         response_start = await communicator.receive_output()
    
  155.         self.assertEqual(response_start["type"], "http.response.start")
    
  156.         self.assertEqual(response_start["status"], 200)
    
  157.         self.assertEqual(
    
  158.             set(response_start["headers"]),
    
  159.             {
    
  160.                 (b"Content-Length", b"19"),
    
  161.                 (b"Content-Type", b"text/plain; charset=utf-8"),
    
  162.             },
    
  163.         )
    
  164.         response_body = await communicator.receive_output()
    
  165.         self.assertEqual(response_body["type"], "http.response.body")
    
  166.         self.assertEqual(response_body["body"], b"From Scotland,Wales")
    
  167.         # Allow response.close() to finish
    
  168.         await communicator.wait()
    
  169. 
    
  170.     async def test_post_body(self):
    
  171.         application = get_asgi_application()
    
  172.         scope = self.async_request_factory._base_scope(method="POST", path="/post/")
    
  173.         communicator = ApplicationCommunicator(application, scope)
    
  174.         await communicator.send_input({"type": "http.request", "body": b"Echo!"})
    
  175.         response_start = await communicator.receive_output()
    
  176.         self.assertEqual(response_start["type"], "http.response.start")
    
  177.         self.assertEqual(response_start["status"], 200)
    
  178.         response_body = await communicator.receive_output()
    
  179.         self.assertEqual(response_body["type"], "http.response.body")
    
  180.         self.assertEqual(response_body["body"], b"Echo!")
    
  181. 
    
  182.     async def test_get_query_string(self):
    
  183.         application = get_asgi_application()
    
  184.         for query_string in (b"name=Andrew", "name=Andrew"):
    
  185.             with self.subTest(query_string=query_string):
    
  186.                 scope = self.async_request_factory._base_scope(
    
  187.                     path="/",
    
  188.                     query_string=query_string,
    
  189.                 )
    
  190.                 communicator = ApplicationCommunicator(application, scope)
    
  191.                 await communicator.send_input({"type": "http.request"})
    
  192.                 response_start = await communicator.receive_output()
    
  193.                 self.assertEqual(response_start["type"], "http.response.start")
    
  194.                 self.assertEqual(response_start["status"], 200)
    
  195.                 response_body = await communicator.receive_output()
    
  196.                 self.assertEqual(response_body["type"], "http.response.body")
    
  197.                 self.assertEqual(response_body["body"], b"Hello Andrew!")
    
  198.                 # Allow response.close() to finish
    
  199.                 await communicator.wait()
    
  200. 
    
  201.     async def test_disconnect(self):
    
  202.         application = get_asgi_application()
    
  203.         scope = self.async_request_factory._base_scope(path="/")
    
  204.         communicator = ApplicationCommunicator(application, scope)
    
  205.         await communicator.send_input({"type": "http.disconnect"})
    
  206.         with self.assertRaises(asyncio.TimeoutError):
    
  207.             await communicator.receive_output()
    
  208. 
    
  209.     async def test_wrong_connection_type(self):
    
  210.         application = get_asgi_application()
    
  211.         scope = self.async_request_factory._base_scope(path="/", type="other")
    
  212.         communicator = ApplicationCommunicator(application, scope)
    
  213.         await communicator.send_input({"type": "http.request"})
    
  214.         msg = "Django can only handle ASGI/HTTP connections, not other."
    
  215.         with self.assertRaisesMessage(ValueError, msg):
    
  216.             await communicator.receive_output()
    
  217. 
    
  218.     async def test_non_unicode_query_string(self):
    
  219.         application = get_asgi_application()
    
  220.         scope = self.async_request_factory._base_scope(path="/", query_string=b"\xff")
    
  221.         communicator = ApplicationCommunicator(application, scope)
    
  222.         await communicator.send_input({"type": "http.request"})
    
  223.         response_start = await communicator.receive_output()
    
  224.         self.assertEqual(response_start["type"], "http.response.start")
    
  225.         self.assertEqual(response_start["status"], 400)
    
  226.         response_body = await communicator.receive_output()
    
  227.         self.assertEqual(response_body["type"], "http.response.body")
    
  228.         self.assertEqual(response_body["body"], b"")
    
  229. 
    
  230.     async def test_request_lifecycle_signals_dispatched_with_thread_sensitive(self):
    
  231.         class SignalHandler:
    
  232.             """Track threads handler is dispatched on."""
    
  233. 
    
  234.             threads = []
    
  235. 
    
  236.             def __call__(self, **kwargs):
    
  237.                 self.threads.append(threading.current_thread())
    
  238. 
    
  239.         signal_handler = SignalHandler()
    
  240.         request_started.connect(signal_handler)
    
  241.         request_finished.connect(signal_handler)
    
  242. 
    
  243.         # Perform a basic request.
    
  244.         application = get_asgi_application()
    
  245.         scope = self.async_request_factory._base_scope(path="/")
    
  246.         communicator = ApplicationCommunicator(application, scope)
    
  247.         await communicator.send_input({"type": "http.request"})
    
  248.         response_start = await communicator.receive_output()
    
  249.         self.assertEqual(response_start["type"], "http.response.start")
    
  250.         self.assertEqual(response_start["status"], 200)
    
  251.         response_body = await communicator.receive_output()
    
  252.         self.assertEqual(response_body["type"], "http.response.body")
    
  253.         self.assertEqual(response_body["body"], b"Hello World!")
    
  254.         # Give response.close() time to finish.
    
  255.         await communicator.wait()
    
  256. 
    
  257.         # AsyncToSync should have executed the signals in the same thread.
    
  258.         request_started_thread, request_finished_thread = signal_handler.threads
    
  259.         self.assertEqual(request_started_thread, request_finished_thread)
    
  260.         request_started.disconnect(signal_handler)
    
  261.         request_finished.disconnect(signal_handler)
    
  262. 
    
  263.     async def test_concurrent_async_uses_multiple_thread_pools(self):
    
  264.         sync_waiter.active_threads.clear()
    
  265. 
    
  266.         # Send 2 requests concurrently
    
  267.         application = get_asgi_application()
    
  268.         scope = self.async_request_factory._base_scope(path="/wait/")
    
  269.         communicators = []
    
  270.         for _ in range(2):
    
  271.             communicators.append(ApplicationCommunicator(application, scope))
    
  272.             await communicators[-1].send_input({"type": "http.request"})
    
  273. 
    
  274.         # Each request must complete with a status code of 200
    
  275.         # If requests aren't scheduled concurrently, the barrier in the
    
  276.         # sync_wait view will time out, resulting in a 500 status code.
    
  277.         for communicator in communicators:
    
  278.             response_start = await communicator.receive_output()
    
  279.             self.assertEqual(response_start["type"], "http.response.start")
    
  280.             self.assertEqual(response_start["status"], 200)
    
  281.             response_body = await communicator.receive_output()
    
  282.             self.assertEqual(response_body["type"], "http.response.body")
    
  283.             self.assertEqual(response_body["body"], b"Hello World!")
    
  284.             # Give response.close() time to finish.
    
  285.             await communicator.wait()
    
  286. 
    
  287.         # The requests should have scheduled on different threads. Note
    
  288.         # active_threads is a set (a thread can only appear once), therefore
    
  289.         # length is a sufficient check.
    
  290.         self.assertEqual(len(sync_waiter.active_threads), 2)
    
  291. 
    
  292.         sync_waiter.active_threads.clear()