1. import copy
    
  2. import datetime
    
  3. import os
    
  4. from unittest import mock
    
  5. 
    
  6. from django.db import DEFAULT_DB_ALIAS, connection, connections
    
  7. from django.db.backends.base.creation import TEST_DATABASE_PREFIX, BaseDatabaseCreation
    
  8. from django.test import SimpleTestCase, TransactionTestCase
    
  9. from django.test.utils import override_settings
    
  10. 
    
  11. from ..models import (
    
  12.     CircularA,
    
  13.     CircularB,
    
  14.     Object,
    
  15.     ObjectReference,
    
  16.     ObjectSelfReference,
    
  17.     SchoolClass,
    
  18. )
    
  19. 
    
  20. 
    
  21. def get_connection_copy():
    
  22.     # Get a copy of the default connection. (Can't use django.db.connection
    
  23.     # because it'll modify the default connection itself.)
    
  24.     test_connection = copy.copy(connections[DEFAULT_DB_ALIAS])
    
  25.     test_connection.settings_dict = copy.deepcopy(
    
  26.         connections[DEFAULT_DB_ALIAS].settings_dict
    
  27.     )
    
  28.     return test_connection
    
  29. 
    
  30. 
    
  31. class TestDbSignatureTests(SimpleTestCase):
    
  32.     def test_default_name(self):
    
  33.         # A test db name isn't set.
    
  34.         prod_name = "hodor"
    
  35.         test_connection = get_connection_copy()
    
  36.         test_connection.settings_dict["NAME"] = prod_name
    
  37.         test_connection.settings_dict["TEST"] = {"NAME": None}
    
  38.         signature = BaseDatabaseCreation(test_connection).test_db_signature()
    
  39.         self.assertEqual(signature[3], TEST_DATABASE_PREFIX + prod_name)
    
  40. 
    
  41.     def test_custom_test_name(self):
    
  42.         # A regular test db name is set.
    
  43.         test_name = "hodor"
    
  44.         test_connection = get_connection_copy()
    
  45.         test_connection.settings_dict["TEST"] = {"NAME": test_name}
    
  46.         signature = BaseDatabaseCreation(test_connection).test_db_signature()
    
  47.         self.assertEqual(signature[3], test_name)
    
  48. 
    
  49.     def test_custom_test_name_with_test_prefix(self):
    
  50.         # A test db name prefixed with TEST_DATABASE_PREFIX is set.
    
  51.         test_name = TEST_DATABASE_PREFIX + "hodor"
    
  52.         test_connection = get_connection_copy()
    
  53.         test_connection.settings_dict["TEST"] = {"NAME": test_name}
    
  54.         signature = BaseDatabaseCreation(test_connection).test_db_signature()
    
  55.         self.assertEqual(signature[3], test_name)
    
  56. 
    
  57. 
    
  58. @override_settings(INSTALLED_APPS=["backends.base.app_unmigrated"])
    
  59. @mock.patch.object(connection, "ensure_connection")
    
  60. @mock.patch.object(connection, "prepare_database")
    
  61. @mock.patch(
    
  62.     "django.db.migrations.recorder.MigrationRecorder.has_table", return_value=False
    
  63. )
    
  64. @mock.patch("django.core.management.commands.migrate.Command.sync_apps")
    
  65. class TestDbCreationTests(SimpleTestCase):
    
  66.     available_apps = ["backends.base.app_unmigrated"]
    
  67. 
    
  68.     @mock.patch("django.db.migrations.executor.MigrationExecutor.migrate")
    
  69.     def test_migrate_test_setting_false(
    
  70.         self, mocked_migrate, mocked_sync_apps, *mocked_objects
    
  71.     ):
    
  72.         test_connection = get_connection_copy()
    
  73.         test_connection.settings_dict["TEST"]["MIGRATE"] = False
    
  74.         creation = test_connection.creation_class(test_connection)
    
  75.         if connection.vendor == "oracle":
    
  76.             # Don't close connection on Oracle.
    
  77.             creation.connection.close = mock.Mock()
    
  78.         old_database_name = test_connection.settings_dict["NAME"]
    
  79.         try:
    
  80.             with mock.patch.object(creation, "_create_test_db"):
    
  81.                 creation.create_test_db(verbosity=0, autoclobber=True, serialize=False)
    
  82.             # Migrations don't run.
    
  83.             mocked_migrate.assert_called()
    
  84.             args, kwargs = mocked_migrate.call_args
    
  85.             self.assertEqual(args, ([],))
    
  86.             self.assertEqual(kwargs["plan"], [])
    
  87.             # App is synced.
    
  88.             mocked_sync_apps.assert_called()
    
  89.             mocked_args, _ = mocked_sync_apps.call_args
    
  90.             self.assertEqual(mocked_args[1], {"app_unmigrated"})
    
  91.         finally:
    
  92.             with mock.patch.object(creation, "_destroy_test_db"):
    
  93.                 creation.destroy_test_db(old_database_name, verbosity=0)
    
  94. 
    
  95.     @mock.patch("django.db.migrations.executor.MigrationRecorder.ensure_schema")
    
  96.     def test_migrate_test_setting_false_ensure_schema(
    
  97.         self,
    
  98.         mocked_ensure_schema,
    
  99.         mocked_sync_apps,
    
  100.         *mocked_objects,
    
  101.     ):
    
  102.         test_connection = get_connection_copy()
    
  103.         test_connection.settings_dict["TEST"]["MIGRATE"] = False
    
  104.         creation = test_connection.creation_class(test_connection)
    
  105.         if connection.vendor == "oracle":
    
  106.             # Don't close connection on Oracle.
    
  107.             creation.connection.close = mock.Mock()
    
  108.         old_database_name = test_connection.settings_dict["NAME"]
    
  109.         try:
    
  110.             with mock.patch.object(creation, "_create_test_db"):
    
  111.                 creation.create_test_db(verbosity=0, autoclobber=True, serialize=False)
    
  112.             # The django_migrations table is not created.
    
  113.             mocked_ensure_schema.assert_not_called()
    
  114.             # App is synced.
    
  115.             mocked_sync_apps.assert_called()
    
  116.             mocked_args, _ = mocked_sync_apps.call_args
    
  117.             self.assertEqual(mocked_args[1], {"app_unmigrated"})
    
  118.         finally:
    
  119.             with mock.patch.object(creation, "_destroy_test_db"):
    
  120.                 creation.destroy_test_db(old_database_name, verbosity=0)
    
  121. 
    
  122.     @mock.patch("django.db.migrations.executor.MigrationExecutor.migrate")
    
  123.     def test_migrate_test_setting_true(
    
  124.         self, mocked_migrate, mocked_sync_apps, *mocked_objects
    
  125.     ):
    
  126.         test_connection = get_connection_copy()
    
  127.         test_connection.settings_dict["TEST"]["MIGRATE"] = True
    
  128.         creation = test_connection.creation_class(test_connection)
    
  129.         if connection.vendor == "oracle":
    
  130.             # Don't close connection on Oracle.
    
  131.             creation.connection.close = mock.Mock()
    
  132.         old_database_name = test_connection.settings_dict["NAME"]
    
  133.         try:
    
  134.             with mock.patch.object(creation, "_create_test_db"):
    
  135.                 creation.create_test_db(verbosity=0, autoclobber=True, serialize=False)
    
  136.             # Migrations run.
    
  137.             mocked_migrate.assert_called()
    
  138.             args, kwargs = mocked_migrate.call_args
    
  139.             self.assertEqual(args, ([("app_unmigrated", "0001_initial")],))
    
  140.             self.assertEqual(len(kwargs["plan"]), 1)
    
  141.             # App is not synced.
    
  142.             mocked_sync_apps.assert_not_called()
    
  143.         finally:
    
  144.             with mock.patch.object(creation, "_destroy_test_db"):
    
  145.                 creation.destroy_test_db(old_database_name, verbosity=0)
    
  146. 
    
  147.     @mock.patch.dict(os.environ, {"RUNNING_DJANGOS_TEST_SUITE": ""})
    
  148.     @mock.patch("django.db.migrations.executor.MigrationExecutor.migrate")
    
  149.     @mock.patch.object(BaseDatabaseCreation, "mark_expected_failures_and_skips")
    
  150.     def test_mark_expected_failures_and_skips_call(
    
  151.         self, mark_expected_failures_and_skips, *mocked_objects
    
  152.     ):
    
  153.         """
    
  154.         mark_expected_failures_and_skips() isn't called unless
    
  155.         RUNNING_DJANGOS_TEST_SUITE is 'true'.
    
  156.         """
    
  157.         test_connection = get_connection_copy()
    
  158.         creation = test_connection.creation_class(test_connection)
    
  159.         if connection.vendor == "oracle":
    
  160.             # Don't close connection on Oracle.
    
  161.             creation.connection.close = mock.Mock()
    
  162.         old_database_name = test_connection.settings_dict["NAME"]
    
  163.         try:
    
  164.             with mock.patch.object(creation, "_create_test_db"):
    
  165.                 creation.create_test_db(verbosity=0, autoclobber=True, serialize=False)
    
  166.             self.assertIs(mark_expected_failures_and_skips.called, False)
    
  167.         finally:
    
  168.             with mock.patch.object(creation, "_destroy_test_db"):
    
  169.                 creation.destroy_test_db(old_database_name, verbosity=0)
    
  170. 
    
  171. 
    
  172. class TestDeserializeDbFromString(TransactionTestCase):
    
  173.     available_apps = ["backends"]
    
  174. 
    
  175.     def test_circular_reference(self):
    
  176.         # deserialize_db_from_string() handles circular references.
    
  177.         data = """
    
  178.         [
    
  179.             {
    
  180.                 "model": "backends.object",
    
  181.                 "pk": 1,
    
  182.                 "fields": {"obj_ref": 1, "related_objects": []}
    
  183.             },
    
  184.             {
    
  185.                 "model": "backends.objectreference",
    
  186.                 "pk": 1,
    
  187.                 "fields": {"obj": 1}
    
  188.             }
    
  189.         ]
    
  190.         """
    
  191.         connection.creation.deserialize_db_from_string(data)
    
  192.         obj = Object.objects.get()
    
  193.         obj_ref = ObjectReference.objects.get()
    
  194.         self.assertEqual(obj.obj_ref, obj_ref)
    
  195.         self.assertEqual(obj_ref.obj, obj)
    
  196. 
    
  197.     def test_self_reference(self):
    
  198.         # serialize_db_to_string() and deserialize_db_from_string() handles
    
  199.         # self references.
    
  200.         obj_1 = ObjectSelfReference.objects.create(key="X")
    
  201.         obj_2 = ObjectSelfReference.objects.create(key="Y", obj=obj_1)
    
  202.         obj_1.obj = obj_2
    
  203.         obj_1.save()
    
  204.         # Serialize objects.
    
  205.         with mock.patch("django.db.migrations.loader.MigrationLoader") as loader:
    
  206.             # serialize_db_to_string() serializes only migrated apps, so mark
    
  207.             # the backends app as migrated.
    
  208.             loader_instance = loader.return_value
    
  209.             loader_instance.migrated_apps = {"backends"}
    
  210.             data = connection.creation.serialize_db_to_string()
    
  211.         ObjectSelfReference.objects.all().delete()
    
  212.         # Deserialize objects.
    
  213.         connection.creation.deserialize_db_from_string(data)
    
  214.         obj_1 = ObjectSelfReference.objects.get(key="X")
    
  215.         obj_2 = ObjectSelfReference.objects.get(key="Y")
    
  216.         self.assertEqual(obj_1.obj, obj_2)
    
  217.         self.assertEqual(obj_2.obj, obj_1)
    
  218. 
    
  219.     def test_circular_reference_with_natural_key(self):
    
  220.         # serialize_db_to_string() and deserialize_db_from_string() handles
    
  221.         # circular references for models with natural keys.
    
  222.         obj_a = CircularA.objects.create(key="A")
    
  223.         obj_b = CircularB.objects.create(key="B", obj=obj_a)
    
  224.         obj_a.obj = obj_b
    
  225.         obj_a.save()
    
  226.         # Serialize objects.
    
  227.         with mock.patch("django.db.migrations.loader.MigrationLoader") as loader:
    
  228.             # serialize_db_to_string() serializes only migrated apps, so mark
    
  229.             # the backends app as migrated.
    
  230.             loader_instance = loader.return_value
    
  231.             loader_instance.migrated_apps = {"backends"}
    
  232.             data = connection.creation.serialize_db_to_string()
    
  233.         CircularA.objects.all().delete()
    
  234.         CircularB.objects.all().delete()
    
  235.         # Deserialize objects.
    
  236.         connection.creation.deserialize_db_from_string(data)
    
  237.         obj_a = CircularA.objects.get()
    
  238.         obj_b = CircularB.objects.get()
    
  239.         self.assertEqual(obj_a.obj, obj_b)
    
  240.         self.assertEqual(obj_b.obj, obj_a)
    
  241. 
    
  242.     def test_serialize_db_to_string_base_manager(self):
    
  243.         SchoolClass.objects.create(year=1000, last_updated=datetime.datetime.now())
    
  244.         with mock.patch("django.db.migrations.loader.MigrationLoader") as loader:
    
  245.             # serialize_db_to_string() serializes only migrated apps, so mark
    
  246.             # the backends app as migrated.
    
  247.             loader_instance = loader.return_value
    
  248.             loader_instance.migrated_apps = {"backends"}
    
  249.             data = connection.creation.serialize_db_to_string()
    
  250.         self.assertIn('"model": "backends.schoolclass"', data)
    
  251.         self.assertIn('"year": 1000', data)
    
  252. 
    
  253. 
    
  254. class SkipTestClass:
    
  255.     def skip_function(self):
    
  256.         pass
    
  257. 
    
  258. 
    
  259. def skip_test_function():
    
  260.     pass
    
  261. 
    
  262. 
    
  263. def expected_failure_test_function():
    
  264.     pass
    
  265. 
    
  266. 
    
  267. class TestMarkTests(SimpleTestCase):
    
  268.     def test_mark_expected_failures_and_skips(self):
    
  269.         test_connection = get_connection_copy()
    
  270.         creation = BaseDatabaseCreation(test_connection)
    
  271.         creation.connection.features.django_test_expected_failures = {
    
  272.             "backends.base.test_creation.expected_failure_test_function",
    
  273.         }
    
  274.         creation.connection.features.django_test_skips = {
    
  275.             "skip test class": {
    
  276.                 "backends.base.test_creation.SkipTestClass",
    
  277.             },
    
  278.             "skip test function": {
    
  279.                 "backends.base.test_creation.skip_test_function",
    
  280.             },
    
  281.         }
    
  282.         creation.mark_expected_failures_and_skips()
    
  283.         self.assertIs(
    
  284.             expected_failure_test_function.__unittest_expecting_failure__,
    
  285.             True,
    
  286.         )
    
  287.         self.assertIs(SkipTestClass.__unittest_skip__, True)
    
  288.         self.assertEqual(
    
  289.             SkipTestClass.__unittest_skip_why__,
    
  290.             "skip test class",
    
  291.         )
    
  292.         self.assertIs(skip_test_function.__unittest_skip__, True)
    
  293.         self.assertEqual(
    
  294.             skip_test_function.__unittest_skip_why__,
    
  295.             "skip test function",
    
  296.         )