1. from unittest.mock import MagicMock, patch
    
  2. 
    
  3. from django.db import DEFAULT_DB_ALIAS, connection, connections
    
  4. from django.db.backends.base.base import BaseDatabaseWrapper
    
  5. from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
    
  6. 
    
  7. from ..models import Square
    
  8. 
    
  9. 
    
  10. class DatabaseWrapperTests(SimpleTestCase):
    
  11.     def test_repr(self):
    
  12.         conn = connections[DEFAULT_DB_ALIAS]
    
  13.         self.assertEqual(
    
  14.             repr(conn),
    
  15.             f"<DatabaseWrapper vendor={connection.vendor!r} alias='default'>",
    
  16.         )
    
  17. 
    
  18.     def test_initialization_class_attributes(self):
    
  19.         """
    
  20.         The "initialization" class attributes like client_class and
    
  21.         creation_class should be set on the class and reflected in the
    
  22.         corresponding instance attributes of the instantiated backend.
    
  23.         """
    
  24.         conn = connections[DEFAULT_DB_ALIAS]
    
  25.         conn_class = type(conn)
    
  26.         attr_names = [
    
  27.             ("client_class", "client"),
    
  28.             ("creation_class", "creation"),
    
  29.             ("features_class", "features"),
    
  30.             ("introspection_class", "introspection"),
    
  31.             ("ops_class", "ops"),
    
  32.             ("validation_class", "validation"),
    
  33.         ]
    
  34.         for class_attr_name, instance_attr_name in attr_names:
    
  35.             class_attr_value = getattr(conn_class, class_attr_name)
    
  36.             self.assertIsNotNone(class_attr_value)
    
  37.             instance_attr_value = getattr(conn, instance_attr_name)
    
  38.             self.assertIsInstance(instance_attr_value, class_attr_value)
    
  39. 
    
  40.     def test_initialization_display_name(self):
    
  41.         self.assertEqual(BaseDatabaseWrapper.display_name, "unknown")
    
  42.         self.assertNotEqual(connection.display_name, "unknown")
    
  43. 
    
  44.     def test_get_database_version(self):
    
  45.         with patch.object(BaseDatabaseWrapper, "__init__", return_value=None):
    
  46.             msg = (
    
  47.                 "subclasses of BaseDatabaseWrapper may require a "
    
  48.                 "get_database_version() method."
    
  49.             )
    
  50.             with self.assertRaisesMessage(NotImplementedError, msg):
    
  51.                 BaseDatabaseWrapper().get_database_version()
    
  52. 
    
  53.     def test_check_database_version_supported_with_none_as_database_version(self):
    
  54.         with patch.object(connection.features, "minimum_database_version", None):
    
  55.             connection.check_database_version_supported()
    
  56. 
    
  57. 
    
  58. class ExecuteWrapperTests(TestCase):
    
  59.     @staticmethod
    
  60.     def call_execute(connection, params=None):
    
  61.         ret_val = "1" if params is None else "%s"
    
  62.         sql = "SELECT " + ret_val + connection.features.bare_select_suffix
    
  63.         with connection.cursor() as cursor:
    
  64.             cursor.execute(sql, params)
    
  65. 
    
  66.     def call_executemany(self, connection, params=None):
    
  67.         # executemany() must use an update query. Make sure it does nothing
    
  68.         # by putting a false condition in the WHERE clause.
    
  69.         sql = "DELETE FROM {} WHERE 0=1 AND 0=%s".format(Square._meta.db_table)
    
  70.         if params is None:
    
  71.             params = [(i,) for i in range(3)]
    
  72.         with connection.cursor() as cursor:
    
  73.             cursor.executemany(sql, params)
    
  74. 
    
  75.     @staticmethod
    
  76.     def mock_wrapper():
    
  77.         return MagicMock(side_effect=lambda execute, *args: execute(*args))
    
  78. 
    
  79.     def test_wrapper_invoked(self):
    
  80.         wrapper = self.mock_wrapper()
    
  81.         with connection.execute_wrapper(wrapper):
    
  82.             self.call_execute(connection)
    
  83.         self.assertTrue(wrapper.called)
    
  84.         (_, sql, params, many, context), _ = wrapper.call_args
    
  85.         self.assertIn("SELECT", sql)
    
  86.         self.assertIsNone(params)
    
  87.         self.assertIs(many, False)
    
  88.         self.assertEqual(context["connection"], connection)
    
  89. 
    
  90.     def test_wrapper_invoked_many(self):
    
  91.         wrapper = self.mock_wrapper()
    
  92.         with connection.execute_wrapper(wrapper):
    
  93.             self.call_executemany(connection)
    
  94.         self.assertTrue(wrapper.called)
    
  95.         (_, sql, param_list, many, context), _ = wrapper.call_args
    
  96.         self.assertIn("DELETE", sql)
    
  97.         self.assertIsInstance(param_list, (list, tuple))
    
  98.         self.assertIs(many, True)
    
  99.         self.assertEqual(context["connection"], connection)
    
  100. 
    
  101.     def test_database_queried(self):
    
  102.         wrapper = self.mock_wrapper()
    
  103.         with connection.execute_wrapper(wrapper):
    
  104.             with connection.cursor() as cursor:
    
  105.                 sql = "SELECT 17" + connection.features.bare_select_suffix
    
  106.                 cursor.execute(sql)
    
  107.                 seventeen = cursor.fetchall()
    
  108.                 self.assertEqual(list(seventeen), [(17,)])
    
  109.             self.call_executemany(connection)
    
  110. 
    
  111.     def test_nested_wrapper_invoked(self):
    
  112.         outer_wrapper = self.mock_wrapper()
    
  113.         inner_wrapper = self.mock_wrapper()
    
  114.         with connection.execute_wrapper(outer_wrapper), connection.execute_wrapper(
    
  115.             inner_wrapper
    
  116.         ):
    
  117.             self.call_execute(connection)
    
  118.             self.assertEqual(inner_wrapper.call_count, 1)
    
  119.             self.call_executemany(connection)
    
  120.             self.assertEqual(inner_wrapper.call_count, 2)
    
  121. 
    
  122.     def test_outer_wrapper_blocks(self):
    
  123.         def blocker(*args):
    
  124.             pass
    
  125. 
    
  126.         wrapper = self.mock_wrapper()
    
  127.         c = connection  # This alias shortens the next line.
    
  128.         with c.execute_wrapper(wrapper), c.execute_wrapper(blocker), c.execute_wrapper(
    
  129.             wrapper
    
  130.         ):
    
  131.             with c.cursor() as cursor:
    
  132.                 cursor.execute("The database never sees this")
    
  133.                 self.assertEqual(wrapper.call_count, 1)
    
  134.                 cursor.executemany("The database never sees this %s", [("either",)])
    
  135.                 self.assertEqual(wrapper.call_count, 2)
    
  136. 
    
  137.     def test_wrapper_gets_sql(self):
    
  138.         wrapper = self.mock_wrapper()
    
  139.         sql = "SELECT 'aloha'" + connection.features.bare_select_suffix
    
  140.         with connection.execute_wrapper(wrapper), connection.cursor() as cursor:
    
  141.             cursor.execute(sql)
    
  142.         (_, reported_sql, _, _, _), _ = wrapper.call_args
    
  143.         self.assertEqual(reported_sql, sql)
    
  144. 
    
  145.     def test_wrapper_connection_specific(self):
    
  146.         wrapper = self.mock_wrapper()
    
  147.         with connections["other"].execute_wrapper(wrapper):
    
  148.             self.assertEqual(connections["other"].execute_wrappers, [wrapper])
    
  149.             self.call_execute(connection)
    
  150.         self.assertFalse(wrapper.called)
    
  151.         self.assertEqual(connection.execute_wrappers, [])
    
  152.         self.assertEqual(connections["other"].execute_wrappers, [])
    
  153. 
    
  154. 
    
  155. class ConnectionHealthChecksTests(SimpleTestCase):
    
  156.     databases = {"default"}
    
  157. 
    
  158.     def setUp(self):
    
  159.         # All test cases here need newly configured and created connections.
    
  160.         # Use the default db connection for convenience.
    
  161.         connection.close()
    
  162.         self.addCleanup(connection.close)
    
  163. 
    
  164.     def patch_settings_dict(self, conn_health_checks):
    
  165.         self.settings_dict_patcher = patch.dict(
    
  166.             connection.settings_dict,
    
  167.             {
    
  168.                 **connection.settings_dict,
    
  169.                 "CONN_MAX_AGE": None,
    
  170.                 "CONN_HEALTH_CHECKS": conn_health_checks,
    
  171.             },
    
  172.         )
    
  173.         self.settings_dict_patcher.start()
    
  174.         self.addCleanup(self.settings_dict_patcher.stop)
    
  175. 
    
  176.     def run_query(self):
    
  177.         with connection.cursor() as cursor:
    
  178.             cursor.execute("SELECT 42" + connection.features.bare_select_suffix)
    
  179. 
    
  180.     @skipUnlessDBFeature("test_db_allows_multiple_connections")
    
  181.     def test_health_checks_enabled(self):
    
  182.         self.patch_settings_dict(conn_health_checks=True)
    
  183.         self.assertIsNone(connection.connection)
    
  184.         # Newly created connections are considered healthy without performing
    
  185.         # the health check.
    
  186.         with patch.object(connection, "is_usable", side_effect=AssertionError):
    
  187.             self.run_query()
    
  188. 
    
  189.         old_connection = connection.connection
    
  190.         # Simulate request_finished.
    
  191.         connection.close_if_unusable_or_obsolete()
    
  192.         self.assertIs(old_connection, connection.connection)
    
  193. 
    
  194.         # Simulate connection health check failing.
    
  195.         with patch.object(
    
  196.             connection, "is_usable", return_value=False
    
  197.         ) as mocked_is_usable:
    
  198.             self.run_query()
    
  199.             new_connection = connection.connection
    
  200.             # A new connection is established.
    
  201.             self.assertIsNot(new_connection, old_connection)
    
  202.             # Only one health check per "request" is performed, so the next
    
  203.             # query will carry on even if the health check fails. Next query
    
  204.             # succeeds because the real connection is healthy and only the
    
  205.             # health check failure is mocked.
    
  206.             self.run_query()
    
  207.             self.assertIs(new_connection, connection.connection)
    
  208.         self.assertEqual(mocked_is_usable.call_count, 1)
    
  209. 
    
  210.         # Simulate request_finished.
    
  211.         connection.close_if_unusable_or_obsolete()
    
  212.         # The underlying connection is being reused further with health checks
    
  213.         # succeeding.
    
  214.         self.run_query()
    
  215.         self.run_query()
    
  216.         self.assertIs(new_connection, connection.connection)
    
  217. 
    
  218.     @skipUnlessDBFeature("test_db_allows_multiple_connections")
    
  219.     def test_health_checks_enabled_errors_occurred(self):
    
  220.         self.patch_settings_dict(conn_health_checks=True)
    
  221.         self.assertIsNone(connection.connection)
    
  222.         # Newly created connections are considered healthy without performing
    
  223.         # the health check.
    
  224.         with patch.object(connection, "is_usable", side_effect=AssertionError):
    
  225.             self.run_query()
    
  226. 
    
  227.         old_connection = connection.connection
    
  228.         # Simulate errors_occurred.
    
  229.         connection.errors_occurred = True
    
  230.         # Simulate request_started (the connection is healthy).
    
  231.         connection.close_if_unusable_or_obsolete()
    
  232.         # Persistent connections are enabled.
    
  233.         self.assertIs(old_connection, connection.connection)
    
  234.         # No additional health checks after the one in
    
  235.         # close_if_unusable_or_obsolete() are executed during this "request"
    
  236.         # when running queries.
    
  237.         with patch.object(connection, "is_usable", side_effect=AssertionError):
    
  238.             self.run_query()
    
  239. 
    
  240.     @skipUnlessDBFeature("test_db_allows_multiple_connections")
    
  241.     def test_health_checks_disabled(self):
    
  242.         self.patch_settings_dict(conn_health_checks=False)
    
  243.         self.assertIsNone(connection.connection)
    
  244.         # Newly created connections are considered healthy without performing
    
  245.         # the health check.
    
  246.         with patch.object(connection, "is_usable", side_effect=AssertionError):
    
  247.             self.run_query()
    
  248. 
    
  249.         old_connection = connection.connection
    
  250.         # Simulate request_finished.
    
  251.         connection.close_if_unusable_or_obsolete()
    
  252.         # Persistent connections are enabled (connection is not).
    
  253.         self.assertIs(old_connection, connection.connection)
    
  254.         # Health checks are not performed.
    
  255.         with patch.object(connection, "is_usable", side_effect=AssertionError):
    
  256.             self.run_query()
    
  257.             # Health check wasn't performed and the connection is unchanged.
    
  258.             self.assertIs(old_connection, connection.connection)
    
  259.             self.run_query()
    
  260.             # The connection is unchanged after the next query either during
    
  261.             # the current "request".
    
  262.             self.assertIs(old_connection, connection.connection)
    
  263. 
    
  264.     @skipUnlessDBFeature("test_db_allows_multiple_connections")
    
  265.     def test_set_autocommit_health_checks_enabled(self):
    
  266.         self.patch_settings_dict(conn_health_checks=True)
    
  267.         self.assertIsNone(connection.connection)
    
  268.         # Newly created connections are considered healthy without performing
    
  269.         # the health check.
    
  270.         with patch.object(connection, "is_usable", side_effect=AssertionError):
    
  271.             # Simulate outermost atomic block: changing autocommit for
    
  272.             # a connection.
    
  273.             connection.set_autocommit(False)
    
  274.             self.run_query()
    
  275.             connection.commit()
    
  276.             connection.set_autocommit(True)
    
  277. 
    
  278.         old_connection = connection.connection
    
  279.         # Simulate request_finished.
    
  280.         connection.close_if_unusable_or_obsolete()
    
  281.         # Persistent connections are enabled.
    
  282.         self.assertIs(old_connection, connection.connection)
    
  283. 
    
  284.         # Simulate connection health check failing.
    
  285.         with patch.object(
    
  286.             connection, "is_usable", return_value=False
    
  287.         ) as mocked_is_usable:
    
  288.             # Simulate outermost atomic block: changing autocommit for
    
  289.             # a connection.
    
  290.             connection.set_autocommit(False)
    
  291.             new_connection = connection.connection
    
  292.             self.assertIsNot(new_connection, old_connection)
    
  293.             # Only one health check per "request" is performed, so a query will
    
  294.             # carry on even if the health check fails. This query succeeds
    
  295.             # because the real connection is healthy and only the health check
    
  296.             # failure is mocked.
    
  297.             self.run_query()
    
  298.             connection.commit()
    
  299.             connection.set_autocommit(True)
    
  300.             # The connection is unchanged.
    
  301.             self.assertIs(new_connection, connection.connection)
    
  302.         self.assertEqual(mocked_is_usable.call_count, 1)
    
  303. 
    
  304.         # Simulate request_finished.
    
  305.         connection.close_if_unusable_or_obsolete()
    
  306.         # The underlying connection is being reused further with health checks
    
  307.         # succeeding.
    
  308.         connection.set_autocommit(False)
    
  309.         self.run_query()
    
  310.         connection.commit()
    
  311.         connection.set_autocommit(True)
    
  312.         self.assertIs(new_connection, connection.connection)
    
  313. 
    
  314. 
    
  315. class MultiDatabaseTests(TestCase):
    
  316.     databases = {"default", "other"}
    
  317. 
    
  318.     def test_multi_database_init_connection_state_called_once(self):
    
  319.         for db in self.databases:
    
  320.             with self.subTest(database=db):
    
  321.                 with patch.object(connections[db], "commit", return_value=None):
    
  322.                     with patch.object(
    
  323.                         connections[db],
    
  324.                         "check_database_version_supported",
    
  325.                     ) as mocked_check_database_version_supported:
    
  326.                         connections[db].init_connection_state()
    
  327.                         after_first_calls = len(
    
  328.                             mocked_check_database_version_supported.mock_calls
    
  329.                         )
    
  330.                         connections[db].init_connection_state()
    
  331.                         self.assertEqual(
    
  332.                             len(mocked_check_database_version_supported.mock_calls),
    
  333.                             after_first_calls,
    
  334.                         )