1. from django.core.exceptions import ImproperlyConfigured
    
  2. from django.core.handlers.wsgi import WSGIHandler, WSGIRequest, get_script_name
    
  3. from django.core.signals import request_finished, request_started
    
  4. from django.db import close_old_connections, connection
    
  5. from django.test import (
    
  6.     RequestFactory,
    
  7.     SimpleTestCase,
    
  8.     TransactionTestCase,
    
  9.     override_settings,
    
  10. )
    
  11. 
    
  12. 
    
  13. class HandlerTests(SimpleTestCase):
    
  14.     request_factory = RequestFactory()
    
  15. 
    
  16.     def setUp(self):
    
  17.         request_started.disconnect(close_old_connections)
    
  18. 
    
  19.     def tearDown(self):
    
  20.         request_started.connect(close_old_connections)
    
  21. 
    
  22.     def test_middleware_initialized(self):
    
  23.         handler = WSGIHandler()
    
  24.         self.assertIsNotNone(handler._middleware_chain)
    
  25. 
    
  26.     def test_bad_path_info(self):
    
  27.         """
    
  28.         A non-UTF-8 path populates PATH_INFO with an URL-encoded path and
    
  29.         produces a 404.
    
  30.         """
    
  31.         environ = self.request_factory.get("/").environ
    
  32.         environ["PATH_INFO"] = "\xed"
    
  33.         handler = WSGIHandler()
    
  34.         response = handler(environ, lambda *a, **k: None)
    
  35.         # The path of the request will be encoded to '/%ED'.
    
  36.         self.assertEqual(response.status_code, 404)
    
  37. 
    
  38.     def test_non_ascii_query_string(self):
    
  39.         """
    
  40.         Non-ASCII query strings are properly decoded (#20530, #22996).
    
  41.         """
    
  42.         environ = self.request_factory.get("/").environ
    
  43.         raw_query_strings = [
    
  44.             b"want=caf%C3%A9",  # This is the proper way to encode 'café'
    
  45.             b"want=caf\xc3\xa9",  # UA forgot to quote bytes
    
  46.             b"want=caf%E9",  # UA quoted, but not in UTF-8
    
  47.             # UA forgot to convert Latin-1 to UTF-8 and to quote (typical of
    
  48.             # MSIE).
    
  49.             b"want=caf\xe9",
    
  50.         ]
    
  51.         got = []
    
  52.         for raw_query_string in raw_query_strings:
    
  53.             # Simulate http.server.BaseHTTPRequestHandler.parse_request
    
  54.             # handling of raw request.
    
  55.             environ["QUERY_STRING"] = str(raw_query_string, "iso-8859-1")
    
  56.             request = WSGIRequest(environ)
    
  57.             got.append(request.GET["want"])
    
  58.         # %E9 is converted to the Unicode replacement character by parse_qsl
    
  59.         self.assertEqual(got, ["café", "café", "caf\ufffd", "café"])
    
  60. 
    
  61.     def test_non_ascii_cookie(self):
    
  62.         """Non-ASCII cookies set in JavaScript are properly decoded (#20557)."""
    
  63.         environ = self.request_factory.get("/").environ
    
  64.         raw_cookie = 'want="café"'.encode("utf-8").decode("iso-8859-1")
    
  65.         environ["HTTP_COOKIE"] = raw_cookie
    
  66.         request = WSGIRequest(environ)
    
  67.         self.assertEqual(request.COOKIES["want"], "café")
    
  68. 
    
  69.     def test_invalid_unicode_cookie(self):
    
  70.         """
    
  71.         Invalid cookie content should result in an absent cookie, but not in a
    
  72.         crash while trying to decode it (#23638).
    
  73.         """
    
  74.         environ = self.request_factory.get("/").environ
    
  75.         environ["HTTP_COOKIE"] = "x=W\x03c(h]\x8e"
    
  76.         request = WSGIRequest(environ)
    
  77.         # We don't test COOKIES content, as the result might differ between
    
  78.         # Python version because parsing invalid content became stricter in
    
  79.         # latest versions.
    
  80.         self.assertIsInstance(request.COOKIES, dict)
    
  81. 
    
  82.     @override_settings(ROOT_URLCONF="handlers.urls")
    
  83.     def test_invalid_multipart_boundary(self):
    
  84.         """
    
  85.         Invalid boundary string should produce a "Bad Request" response, not a
    
  86.         server error (#23887).
    
  87.         """
    
  88.         environ = self.request_factory.post("/malformed_post/").environ
    
  89.         environ["CONTENT_TYPE"] = "multipart/form-data; boundary=WRONG\x07"
    
  90.         handler = WSGIHandler()
    
  91.         response = handler(environ, lambda *a, **k: None)
    
  92.         # Expect "bad request" response
    
  93.         self.assertEqual(response.status_code, 400)
    
  94. 
    
  95. 
    
  96. @override_settings(ROOT_URLCONF="handlers.urls", MIDDLEWARE=[])
    
  97. class TransactionsPerRequestTests(TransactionTestCase):
    
  98.     available_apps = []
    
  99. 
    
  100.     def test_no_transaction(self):
    
  101.         response = self.client.get("/in_transaction/")
    
  102.         self.assertContains(response, "False")
    
  103. 
    
  104.     def test_auto_transaction(self):
    
  105.         old_atomic_requests = connection.settings_dict["ATOMIC_REQUESTS"]
    
  106.         try:
    
  107.             connection.settings_dict["ATOMIC_REQUESTS"] = True
    
  108.             response = self.client.get("/in_transaction/")
    
  109.         finally:
    
  110.             connection.settings_dict["ATOMIC_REQUESTS"] = old_atomic_requests
    
  111.         self.assertContains(response, "True")
    
  112. 
    
  113.     async def test_auto_transaction_async_view(self):
    
  114.         old_atomic_requests = connection.settings_dict["ATOMIC_REQUESTS"]
    
  115.         try:
    
  116.             connection.settings_dict["ATOMIC_REQUESTS"] = True
    
  117.             msg = "You cannot use ATOMIC_REQUESTS with async views."
    
  118.             with self.assertRaisesMessage(RuntimeError, msg):
    
  119.                 await self.async_client.get("/async_regular/")
    
  120.         finally:
    
  121.             connection.settings_dict["ATOMIC_REQUESTS"] = old_atomic_requests
    
  122. 
    
  123.     def test_no_auto_transaction(self):
    
  124.         old_atomic_requests = connection.settings_dict["ATOMIC_REQUESTS"]
    
  125.         try:
    
  126.             connection.settings_dict["ATOMIC_REQUESTS"] = True
    
  127.             response = self.client.get("/not_in_transaction/")
    
  128.         finally:
    
  129.             connection.settings_dict["ATOMIC_REQUESTS"] = old_atomic_requests
    
  130.         self.assertContains(response, "False")
    
  131. 
    
  132. 
    
  133. @override_settings(ROOT_URLCONF="handlers.urls")
    
  134. class SignalsTests(SimpleTestCase):
    
  135.     def setUp(self):
    
  136.         self.signals = []
    
  137.         self.signaled_environ = None
    
  138.         request_started.connect(self.register_started)
    
  139.         request_finished.connect(self.register_finished)
    
  140. 
    
  141.     def tearDown(self):
    
  142.         request_started.disconnect(self.register_started)
    
  143.         request_finished.disconnect(self.register_finished)
    
  144. 
    
  145.     def register_started(self, **kwargs):
    
  146.         self.signals.append("started")
    
  147.         self.signaled_environ = kwargs.get("environ")
    
  148. 
    
  149.     def register_finished(self, **kwargs):
    
  150.         self.signals.append("finished")
    
  151. 
    
  152.     def test_request_signals(self):
    
  153.         response = self.client.get("/regular/")
    
  154.         self.assertEqual(self.signals, ["started", "finished"])
    
  155.         self.assertEqual(response.content, b"regular content")
    
  156.         self.assertEqual(self.signaled_environ, response.wsgi_request.environ)
    
  157. 
    
  158.     def test_request_signals_streaming_response(self):
    
  159.         response = self.client.get("/streaming/")
    
  160.         self.assertEqual(self.signals, ["started"])
    
  161.         self.assertEqual(b"".join(response.streaming_content), b"streaming content")
    
  162.         self.assertEqual(self.signals, ["started", "finished"])
    
  163. 
    
  164. 
    
  165. def empty_middleware(get_response):
    
  166.     pass
    
  167. 
    
  168. 
    
  169. @override_settings(ROOT_URLCONF="handlers.urls")
    
  170. class HandlerRequestTests(SimpleTestCase):
    
  171.     request_factory = RequestFactory()
    
  172. 
    
  173.     def test_async_view(self):
    
  174.         """Calling an async view down the normal synchronous path."""
    
  175.         response = self.client.get("/async_regular/")
    
  176.         self.assertEqual(response.status_code, 200)
    
  177. 
    
  178.     def test_suspiciousop_in_view_returns_400(self):
    
  179.         response = self.client.get("/suspicious/")
    
  180.         self.assertEqual(response.status_code, 400)
    
  181. 
    
  182.     def test_bad_request_in_view_returns_400(self):
    
  183.         response = self.client.get("/bad_request/")
    
  184.         self.assertEqual(response.status_code, 400)
    
  185. 
    
  186.     def test_invalid_urls(self):
    
  187.         response = self.client.get("~%A9helloworld")
    
  188.         self.assertEqual(response.status_code, 404)
    
  189.         self.assertEqual(response.context["request_path"], "/~%25A9helloworld")
    
  190. 
    
  191.         response = self.client.get("d%aao%aaw%aan%aal%aao%aaa%aad%aa/")
    
  192.         self.assertEqual(
    
  193.             response.context["request_path"],
    
  194.             "/d%25AAo%25AAw%25AAn%25AAl%25AAo%25AAa%25AAd%25AA",
    
  195.         )
    
  196. 
    
  197.         response = self.client.get("/%E2%99%E2%99%A5/")
    
  198.         self.assertEqual(response.context["request_path"], "/%25E2%2599%E2%99%A5/")
    
  199. 
    
  200.         response = self.client.get("/%E2%98%8E%E2%A9%E2%99%A5/")
    
  201.         self.assertEqual(
    
  202.             response.context["request_path"], "/%E2%98%8E%25E2%25A9%E2%99%A5/"
    
  203.         )
    
  204. 
    
  205.     def test_environ_path_info_type(self):
    
  206.         environ = self.request_factory.get("/%E2%A8%87%87%A5%E2%A8%A0").environ
    
  207.         self.assertIsInstance(environ["PATH_INFO"], str)
    
  208. 
    
  209.     def test_handle_accepts_httpstatus_enum_value(self):
    
  210.         def start_response(status, headers):
    
  211.             start_response.status = status
    
  212. 
    
  213.         environ = self.request_factory.get("/httpstatus_enum/").environ
    
  214.         WSGIHandler()(environ, start_response)
    
  215.         self.assertEqual(start_response.status, "200 OK")
    
  216. 
    
  217.     @override_settings(MIDDLEWARE=["handlers.tests.empty_middleware"])
    
  218.     def test_middleware_returns_none(self):
    
  219.         msg = "Middleware factory handlers.tests.empty_middleware returned None."
    
  220.         with self.assertRaisesMessage(ImproperlyConfigured, msg):
    
  221.             self.client.get("/")
    
  222. 
    
  223.     def test_no_response(self):
    
  224.         msg = (
    
  225.             "The view %s didn't return an HttpResponse object. It returned None "
    
  226.             "instead."
    
  227.         )
    
  228.         tests = (
    
  229.             ("/no_response_fbv/", "handlers.views.no_response"),
    
  230.             ("/no_response_cbv/", "handlers.views.NoResponse.__call__"),
    
  231.         )
    
  232.         for url, view in tests:
    
  233.             with self.subTest(url=url), self.assertRaisesMessage(
    
  234.                 ValueError, msg % view
    
  235.             ):
    
  236.                 self.client.get(url)
    
  237. 
    
  238. 
    
  239. class ScriptNameTests(SimpleTestCase):
    
  240.     def test_get_script_name(self):
    
  241.         # Regression test for #23173
    
  242.         # Test first without PATH_INFO
    
  243.         script_name = get_script_name({"SCRIPT_URL": "/foobar/"})
    
  244.         self.assertEqual(script_name, "/foobar/")
    
  245. 
    
  246.         script_name = get_script_name({"SCRIPT_URL": "/foobar/", "PATH_INFO": "/"})
    
  247.         self.assertEqual(script_name, "/foobar")
    
  248. 
    
  249.     def test_get_script_name_double_slashes(self):
    
  250.         """
    
  251.         WSGI squashes multiple successive slashes in PATH_INFO, get_script_name
    
  252.         should take that into account when forming SCRIPT_NAME (#17133).
    
  253.         """
    
  254.         script_name = get_script_name(
    
  255.             {
    
  256.                 "SCRIPT_URL": "/mst/milestones//accounts/login//help",
    
  257.                 "PATH_INFO": "/milestones/accounts/login/help",
    
  258.             }
    
  259.         )
    
  260.         self.assertEqual(script_name, "/mst")
    
  261. 
    
  262. 
    
  263. @override_settings(ROOT_URLCONF="handlers.urls")
    
  264. class AsyncHandlerRequestTests(SimpleTestCase):
    
  265.     """Async variants of the normal handler request tests."""
    
  266. 
    
  267.     async def test_sync_view(self):
    
  268.         """Calling a sync view down the asynchronous path."""
    
  269.         response = await self.async_client.get("/regular/")
    
  270.         self.assertEqual(response.status_code, 200)
    
  271. 
    
  272.     async def test_async_view(self):
    
  273.         """Calling an async view down the asynchronous path."""
    
  274.         response = await self.async_client.get("/async_regular/")
    
  275.         self.assertEqual(response.status_code, 200)
    
  276. 
    
  277.     async def test_suspiciousop_in_view_returns_400(self):
    
  278.         response = await self.async_client.get("/suspicious/")
    
  279.         self.assertEqual(response.status_code, 400)
    
  280. 
    
  281.     async def test_bad_request_in_view_returns_400(self):
    
  282.         response = await self.async_client.get("/bad_request/")
    
  283.         self.assertEqual(response.status_code, 400)
    
  284. 
    
  285.     async def test_no_response(self):
    
  286.         msg = (
    
  287.             "The view handlers.views.no_response didn't return an "
    
  288.             "HttpResponse object. It returned None instead."
    
  289.         )
    
  290.         with self.assertRaisesMessage(ValueError, msg):
    
  291.             await self.async_client.get("/no_response_fbv/")
    
  292. 
    
  293.     async def test_unawaited_response(self):
    
  294.         msg = (
    
  295.             "The view handlers.views.CoroutineClearingView.__call__ didn't"
    
  296.             " return an HttpResponse object. It returned an unawaited"
    
  297.             " coroutine instead. You may need to add an 'await'"
    
  298.             " into your view."
    
  299.         )
    
  300.         with self.assertRaisesMessage(ValueError, msg):
    
  301.             await self.async_client.get("/unawaited/")