1. from django.db import connection, migrations, models
    
  2. from django.db.migrations.state import ProjectState
    
  3. from django.test import override_settings
    
  4. 
    
  5. from .test_base import OperationTestBase
    
  6. 
    
  7. 
    
  8. class AgnosticRouter:
    
  9.     """
    
  10.     A router that doesn't have an opinion regarding migrating.
    
  11.     """
    
  12. 
    
  13.     def allow_migrate(self, db, app_label, **hints):
    
  14.         return None
    
  15. 
    
  16. 
    
  17. class MigrateNothingRouter:
    
  18.     """
    
  19.     A router that doesn't allow migrating.
    
  20.     """
    
  21. 
    
  22.     def allow_migrate(self, db, app_label, **hints):
    
  23.         return False
    
  24. 
    
  25. 
    
  26. class MigrateEverythingRouter:
    
  27.     """
    
  28.     A router that always allows migrating.
    
  29.     """
    
  30. 
    
  31.     def allow_migrate(self, db, app_label, **hints):
    
  32.         return True
    
  33. 
    
  34. 
    
  35. class MigrateWhenFooRouter:
    
  36.     """
    
  37.     A router that allows migrating depending on a hint.
    
  38.     """
    
  39. 
    
  40.     def allow_migrate(self, db, app_label, **hints):
    
  41.         return hints.get("foo", False)
    
  42. 
    
  43. 
    
  44. class MultiDBOperationTests(OperationTestBase):
    
  45.     databases = {"default", "other"}
    
  46. 
    
  47.     def _test_create_model(self, app_label, should_run):
    
  48.         """
    
  49.         CreateModel honors multi-db settings.
    
  50.         """
    
  51.         operation = migrations.CreateModel(
    
  52.             "Pony",
    
  53.             [("id", models.AutoField(primary_key=True))],
    
  54.         )
    
  55.         # Test the state alteration
    
  56.         project_state = ProjectState()
    
  57.         new_state = project_state.clone()
    
  58.         operation.state_forwards(app_label, new_state)
    
  59.         # Test the database alteration
    
  60.         self.assertTableNotExists("%s_pony" % app_label)
    
  61.         with connection.schema_editor() as editor:
    
  62.             operation.database_forwards(app_label, editor, project_state, new_state)
    
  63.         if should_run:
    
  64.             self.assertTableExists("%s_pony" % app_label)
    
  65.         else:
    
  66.             self.assertTableNotExists("%s_pony" % app_label)
    
  67.         # And test reversal
    
  68.         with connection.schema_editor() as editor:
    
  69.             operation.database_backwards(app_label, editor, new_state, project_state)
    
  70.         self.assertTableNotExists("%s_pony" % app_label)
    
  71. 
    
  72.     @override_settings(DATABASE_ROUTERS=[AgnosticRouter()])
    
  73.     def test_create_model(self):
    
  74.         """
    
  75.         Test when router doesn't have an opinion (i.e. CreateModel should run).
    
  76.         """
    
  77.         self._test_create_model("test_mltdb_crmo", should_run=True)
    
  78. 
    
  79.     @override_settings(DATABASE_ROUTERS=[MigrateNothingRouter()])
    
  80.     def test_create_model2(self):
    
  81.         """
    
  82.         Test when router returns False (i.e. CreateModel shouldn't run).
    
  83.         """
    
  84.         self._test_create_model("test_mltdb_crmo2", should_run=False)
    
  85. 
    
  86.     @override_settings(DATABASE_ROUTERS=[MigrateEverythingRouter()])
    
  87.     def test_create_model3(self):
    
  88.         """
    
  89.         Test when router returns True (i.e. CreateModel should run).
    
  90.         """
    
  91.         self._test_create_model("test_mltdb_crmo3", should_run=True)
    
  92. 
    
  93.     def test_create_model4(self):
    
  94.         """
    
  95.         Test multiple routers.
    
  96.         """
    
  97.         with override_settings(DATABASE_ROUTERS=[AgnosticRouter(), AgnosticRouter()]):
    
  98.             self._test_create_model("test_mltdb_crmo4", should_run=True)
    
  99.         with override_settings(
    
  100.             DATABASE_ROUTERS=[MigrateNothingRouter(), MigrateEverythingRouter()]
    
  101.         ):
    
  102.             self._test_create_model("test_mltdb_crmo4", should_run=False)
    
  103.         with override_settings(
    
  104.             DATABASE_ROUTERS=[MigrateEverythingRouter(), MigrateNothingRouter()]
    
  105.         ):
    
  106.             self._test_create_model("test_mltdb_crmo4", should_run=True)
    
  107. 
    
  108.     def _test_run_sql(self, app_label, should_run, hints=None):
    
  109.         with override_settings(DATABASE_ROUTERS=[MigrateEverythingRouter()]):
    
  110.             project_state = self.set_up_test_model(app_label)
    
  111. 
    
  112.         sql = """
    
  113.         INSERT INTO {0}_pony (pink, weight) VALUES (1, 3.55);
    
  114.         INSERT INTO {0}_pony (pink, weight) VALUES (3, 5.0);
    
  115.         """.format(
    
  116.             app_label
    
  117.         )
    
  118. 
    
  119.         operation = migrations.RunSQL(sql, hints=hints or {})
    
  120.         # Test the state alteration does nothing
    
  121.         new_state = project_state.clone()
    
  122.         operation.state_forwards(app_label, new_state)
    
  123.         self.assertEqual(new_state, project_state)
    
  124.         # Test the database alteration
    
  125.         self.assertEqual(
    
  126.             project_state.apps.get_model(app_label, "Pony").objects.count(), 0
    
  127.         )
    
  128.         with connection.schema_editor() as editor:
    
  129.             operation.database_forwards(app_label, editor, project_state, new_state)
    
  130.         Pony = project_state.apps.get_model(app_label, "Pony")
    
  131.         if should_run:
    
  132.             self.assertEqual(Pony.objects.count(), 2)
    
  133.         else:
    
  134.             self.assertEqual(Pony.objects.count(), 0)
    
  135. 
    
  136.     @override_settings(DATABASE_ROUTERS=[MigrateNothingRouter()])
    
  137.     def test_run_sql_migrate_nothing_router(self):
    
  138.         self._test_run_sql("test_mltdb_runsql", should_run=False)
    
  139. 
    
  140.     @override_settings(DATABASE_ROUTERS=[MigrateWhenFooRouter()])
    
  141.     def test_run_sql_migrate_foo_router_without_hints(self):
    
  142.         self._test_run_sql("test_mltdb_runsql2", should_run=False)
    
  143. 
    
  144.     @override_settings(DATABASE_ROUTERS=[MigrateWhenFooRouter()])
    
  145.     def test_run_sql_migrate_foo_router_with_hints(self):
    
  146.         self._test_run_sql("test_mltdb_runsql3", should_run=True, hints={"foo": True})
    
  147. 
    
  148.     def _test_run_python(self, app_label, should_run, hints=None):
    
  149.         with override_settings(DATABASE_ROUTERS=[MigrateEverythingRouter()]):
    
  150.             project_state = self.set_up_test_model(app_label)
    
  151. 
    
  152.         # Create the operation
    
  153.         def inner_method(models, schema_editor):
    
  154.             Pony = models.get_model(app_label, "Pony")
    
  155.             Pony.objects.create(pink=1, weight=3.55)
    
  156.             Pony.objects.create(weight=5)
    
  157. 
    
  158.         operation = migrations.RunPython(inner_method, hints=hints or {})
    
  159.         # Test the state alteration does nothing
    
  160.         new_state = project_state.clone()
    
  161.         operation.state_forwards(app_label, new_state)
    
  162.         self.assertEqual(new_state, project_state)
    
  163.         # Test the database alteration
    
  164.         self.assertEqual(
    
  165.             project_state.apps.get_model(app_label, "Pony").objects.count(), 0
    
  166.         )
    
  167.         with connection.schema_editor() as editor:
    
  168.             operation.database_forwards(app_label, editor, project_state, new_state)
    
  169.         Pony = project_state.apps.get_model(app_label, "Pony")
    
  170.         if should_run:
    
  171.             self.assertEqual(Pony.objects.count(), 2)
    
  172.         else:
    
  173.             self.assertEqual(Pony.objects.count(), 0)
    
  174. 
    
  175.     @override_settings(DATABASE_ROUTERS=[MigrateNothingRouter()])
    
  176.     def test_run_python_migrate_nothing_router(self):
    
  177.         self._test_run_python("test_mltdb_runpython", should_run=False)
    
  178. 
    
  179.     @override_settings(DATABASE_ROUTERS=[MigrateWhenFooRouter()])
    
  180.     def test_run_python_migrate_foo_router_without_hints(self):
    
  181.         self._test_run_python("test_mltdb_runpython2", should_run=False)
    
  182. 
    
  183.     @override_settings(DATABASE_ROUTERS=[MigrateWhenFooRouter()])
    
  184.     def test_run_python_migrate_foo_router_with_hints(self):
    
  185.         self._test_run_python(
    
  186.             "test_mltdb_runpython3", should_run=True, hints={"foo": True}
    
  187.         )