1. from django.db import connection
    
  2. from django.db.backends.ddl_references import (
    
  3.     Columns,
    
  4.     Expressions,
    
  5.     ForeignKeyName,
    
  6.     IndexName,
    
  7.     Statement,
    
  8.     Table,
    
  9. )
    
  10. from django.db.models import ExpressionList, F
    
  11. from django.db.models.functions import Upper
    
  12. from django.db.models.indexes import IndexExpression
    
  13. from django.db.models.sql import Query
    
  14. from django.test import SimpleTestCase, TransactionTestCase
    
  15. 
    
  16. from .models import Person
    
  17. 
    
  18. 
    
  19. class TableTests(SimpleTestCase):
    
  20.     def setUp(self):
    
  21.         self.reference = Table("table", lambda table: table.upper())
    
  22. 
    
  23.     def test_references_table(self):
    
  24.         self.assertIs(self.reference.references_table("table"), True)
    
  25.         self.assertIs(self.reference.references_table("other"), False)
    
  26. 
    
  27.     def test_rename_table_references(self):
    
  28.         self.reference.rename_table_references("other", "table")
    
  29.         self.assertIs(self.reference.references_table("table"), True)
    
  30.         self.assertIs(self.reference.references_table("other"), False)
    
  31.         self.reference.rename_table_references("table", "other")
    
  32.         self.assertIs(self.reference.references_table("table"), False)
    
  33.         self.assertIs(self.reference.references_table("other"), True)
    
  34. 
    
  35.     def test_repr(self):
    
  36.         self.assertEqual(repr(self.reference), "<Table 'TABLE'>")
    
  37. 
    
  38.     def test_str(self):
    
  39.         self.assertEqual(str(self.reference), "TABLE")
    
  40. 
    
  41. 
    
  42. class ColumnsTests(TableTests):
    
  43.     def setUp(self):
    
  44.         self.reference = Columns(
    
  45.             "table", ["first_column", "second_column"], lambda column: column.upper()
    
  46.         )
    
  47. 
    
  48.     def test_references_column(self):
    
  49.         self.assertIs(self.reference.references_column("other", "first_column"), False)
    
  50.         self.assertIs(self.reference.references_column("table", "third_column"), False)
    
  51.         self.assertIs(self.reference.references_column("table", "first_column"), True)
    
  52. 
    
  53.     def test_rename_column_references(self):
    
  54.         self.reference.rename_column_references("other", "first_column", "third_column")
    
  55.         self.assertIs(self.reference.references_column("table", "first_column"), True)
    
  56.         self.assertIs(self.reference.references_column("table", "third_column"), False)
    
  57.         self.assertIs(self.reference.references_column("other", "third_column"), False)
    
  58.         self.reference.rename_column_references("table", "third_column", "first_column")
    
  59.         self.assertIs(self.reference.references_column("table", "first_column"), True)
    
  60.         self.assertIs(self.reference.references_column("table", "third_column"), False)
    
  61.         self.reference.rename_column_references("table", "first_column", "third_column")
    
  62.         self.assertIs(self.reference.references_column("table", "first_column"), False)
    
  63.         self.assertIs(self.reference.references_column("table", "third_column"), True)
    
  64. 
    
  65.     def test_repr(self):
    
  66.         self.assertEqual(
    
  67.             repr(self.reference), "<Columns 'FIRST_COLUMN, SECOND_COLUMN'>"
    
  68.         )
    
  69. 
    
  70.     def test_str(self):
    
  71.         self.assertEqual(str(self.reference), "FIRST_COLUMN, SECOND_COLUMN")
    
  72. 
    
  73. 
    
  74. class IndexNameTests(ColumnsTests):
    
  75.     def setUp(self):
    
  76.         def create_index_name(table_name, column_names, suffix):
    
  77.             return ", ".join(
    
  78.                 "%s_%s_%s" % (table_name, column_name, suffix)
    
  79.                 for column_name in column_names
    
  80.             )
    
  81. 
    
  82.         self.reference = IndexName(
    
  83.             "table", ["first_column", "second_column"], "suffix", create_index_name
    
  84.         )
    
  85. 
    
  86.     def test_repr(self):
    
  87.         self.assertEqual(
    
  88.             repr(self.reference),
    
  89.             "<IndexName 'table_first_column_suffix, table_second_column_suffix'>",
    
  90.         )
    
  91. 
    
  92.     def test_str(self):
    
  93.         self.assertEqual(
    
  94.             str(self.reference), "table_first_column_suffix, table_second_column_suffix"
    
  95.         )
    
  96. 
    
  97. 
    
  98. class ForeignKeyNameTests(IndexNameTests):
    
  99.     def setUp(self):
    
  100.         def create_foreign_key_name(table_name, column_names, suffix):
    
  101.             return ", ".join(
    
  102.                 "%s_%s_%s" % (table_name, column_name, suffix)
    
  103.                 for column_name in column_names
    
  104.             )
    
  105. 
    
  106.         self.reference = ForeignKeyName(
    
  107.             "table",
    
  108.             ["first_column", "second_column"],
    
  109.             "to_table",
    
  110.             ["to_first_column", "to_second_column"],
    
  111.             "%(to_table)s_%(to_column)s_fk",
    
  112.             create_foreign_key_name,
    
  113.         )
    
  114. 
    
  115.     def test_references_table(self):
    
  116.         super().test_references_table()
    
  117.         self.assertIs(self.reference.references_table("to_table"), True)
    
  118. 
    
  119.     def test_references_column(self):
    
  120.         super().test_references_column()
    
  121.         self.assertIs(
    
  122.             self.reference.references_column("to_table", "second_column"), False
    
  123.         )
    
  124.         self.assertIs(
    
  125.             self.reference.references_column("to_table", "to_second_column"), True
    
  126.         )
    
  127. 
    
  128.     def test_rename_table_references(self):
    
  129.         super().test_rename_table_references()
    
  130.         self.reference.rename_table_references("to_table", "other_to_table")
    
  131.         self.assertIs(self.reference.references_table("other_to_table"), True)
    
  132.         self.assertIs(self.reference.references_table("to_table"), False)
    
  133. 
    
  134.     def test_rename_column_references(self):
    
  135.         super().test_rename_column_references()
    
  136.         self.reference.rename_column_references(
    
  137.             "to_table", "second_column", "third_column"
    
  138.         )
    
  139.         self.assertIs(self.reference.references_column("table", "second_column"), True)
    
  140.         self.assertIs(
    
  141.             self.reference.references_column("to_table", "to_second_column"), True
    
  142.         )
    
  143.         self.reference.rename_column_references(
    
  144.             "to_table", "to_first_column", "to_third_column"
    
  145.         )
    
  146.         self.assertIs(
    
  147.             self.reference.references_column("to_table", "to_first_column"), False
    
  148.         )
    
  149.         self.assertIs(
    
  150.             self.reference.references_column("to_table", "to_third_column"), True
    
  151.         )
    
  152. 
    
  153.     def test_repr(self):
    
  154.         self.assertEqual(
    
  155.             repr(self.reference),
    
  156.             "<ForeignKeyName 'table_first_column_to_table_to_first_column_fk, "
    
  157.             "table_second_column_to_table_to_first_column_fk'>",
    
  158.         )
    
  159. 
    
  160.     def test_str(self):
    
  161.         self.assertEqual(
    
  162.             str(self.reference),
    
  163.             "table_first_column_to_table_to_first_column_fk, "
    
  164.             "table_second_column_to_table_to_first_column_fk",
    
  165.         )
    
  166. 
    
  167. 
    
  168. class MockReference:
    
  169.     def __init__(self, representation, referenced_tables, referenced_columns):
    
  170.         self.representation = representation
    
  171.         self.referenced_tables = referenced_tables
    
  172.         self.referenced_columns = referenced_columns
    
  173. 
    
  174.     def references_table(self, table):
    
  175.         return table in self.referenced_tables
    
  176. 
    
  177.     def references_column(self, table, column):
    
  178.         return (table, column) in self.referenced_columns
    
  179. 
    
  180.     def rename_table_references(self, old_table, new_table):
    
  181.         if old_table in self.referenced_tables:
    
  182.             self.referenced_tables.remove(old_table)
    
  183.             self.referenced_tables.add(new_table)
    
  184. 
    
  185.     def rename_column_references(self, table, old_column, new_column):
    
  186.         column = (table, old_column)
    
  187.         if column in self.referenced_columns:
    
  188.             self.referenced_columns.remove(column)
    
  189.             self.referenced_columns.add((table, new_column))
    
  190. 
    
  191.     def __str__(self):
    
  192.         return self.representation
    
  193. 
    
  194. 
    
  195. class StatementTests(SimpleTestCase):
    
  196.     def test_references_table(self):
    
  197.         statement = Statement(
    
  198.             "", reference=MockReference("", {"table"}, {}), non_reference=""
    
  199.         )
    
  200.         self.assertIs(statement.references_table("table"), True)
    
  201.         self.assertIs(statement.references_table("other"), False)
    
  202. 
    
  203.     def test_references_column(self):
    
  204.         statement = Statement(
    
  205.             "", reference=MockReference("", {}, {("table", "column")}), non_reference=""
    
  206.         )
    
  207.         self.assertIs(statement.references_column("table", "column"), True)
    
  208.         self.assertIs(statement.references_column("other", "column"), False)
    
  209. 
    
  210.     def test_rename_table_references(self):
    
  211.         reference = MockReference("", {"table"}, {})
    
  212.         statement = Statement("", reference=reference, non_reference="")
    
  213.         statement.rename_table_references("table", "other")
    
  214.         self.assertEqual(reference.referenced_tables, {"other"})
    
  215. 
    
  216.     def test_rename_column_references(self):
    
  217.         reference = MockReference("", {}, {("table", "column")})
    
  218.         statement = Statement("", reference=reference, non_reference="")
    
  219.         statement.rename_column_references("table", "column", "other")
    
  220.         self.assertEqual(reference.referenced_columns, {("table", "other")})
    
  221. 
    
  222.     def test_repr(self):
    
  223.         reference = MockReference("reference", {}, {})
    
  224.         statement = Statement(
    
  225.             "%(reference)s - %(non_reference)s",
    
  226.             reference=reference,
    
  227.             non_reference="non_reference",
    
  228.         )
    
  229.         self.assertEqual(repr(statement), "<Statement 'reference - non_reference'>")
    
  230. 
    
  231.     def test_str(self):
    
  232.         reference = MockReference("reference", {}, {})
    
  233.         statement = Statement(
    
  234.             "%(reference)s - %(non_reference)s",
    
  235.             reference=reference,
    
  236.             non_reference="non_reference",
    
  237.         )
    
  238.         self.assertEqual(str(statement), "reference - non_reference")
    
  239. 
    
  240. 
    
  241. class ExpressionsTests(TransactionTestCase):
    
  242.     available_apps = []
    
  243. 
    
  244.     def setUp(self):
    
  245.         compiler = Person.objects.all().query.get_compiler(connection.alias)
    
  246.         self.editor = connection.schema_editor()
    
  247.         self.expressions = Expressions(
    
  248.             table=Person._meta.db_table,
    
  249.             expressions=ExpressionList(
    
  250.                 IndexExpression(F("first_name")),
    
  251.                 IndexExpression(F("last_name").desc()),
    
  252.                 IndexExpression(Upper("last_name")),
    
  253.             ).resolve_expression(compiler.query),
    
  254.             compiler=compiler,
    
  255.             quote_value=self.editor.quote_value,
    
  256.         )
    
  257. 
    
  258.     def test_references_table(self):
    
  259.         self.assertIs(self.expressions.references_table(Person._meta.db_table), True)
    
  260.         self.assertIs(self.expressions.references_table("other"), False)
    
  261. 
    
  262.     def test_references_column(self):
    
  263.         table = Person._meta.db_table
    
  264.         self.assertIs(self.expressions.references_column(table, "first_name"), True)
    
  265.         self.assertIs(self.expressions.references_column(table, "last_name"), True)
    
  266.         self.assertIs(self.expressions.references_column(table, "other"), False)
    
  267. 
    
  268.     def test_rename_table_references(self):
    
  269.         table = Person._meta.db_table
    
  270.         self.expressions.rename_table_references(table, "other")
    
  271.         self.assertIs(self.expressions.references_table(table), False)
    
  272.         self.assertIs(self.expressions.references_table("other"), True)
    
  273.         self.assertIn(
    
  274.             "%s.%s"
    
  275.             % (
    
  276.                 self.editor.quote_name("other"),
    
  277.                 self.editor.quote_name("first_name"),
    
  278.             ),
    
  279.             str(self.expressions),
    
  280.         )
    
  281. 
    
  282.     def test_rename_table_references_without_alias(self):
    
  283.         compiler = Query(Person, alias_cols=False).get_compiler(connection=connection)
    
  284.         table = Person._meta.db_table
    
  285.         expressions = Expressions(
    
  286.             table=table,
    
  287.             expressions=ExpressionList(
    
  288.                 IndexExpression(Upper("last_name")),
    
  289.                 IndexExpression(F("first_name")),
    
  290.             ).resolve_expression(compiler.query),
    
  291.             compiler=compiler,
    
  292.             quote_value=self.editor.quote_value,
    
  293.         )
    
  294.         expressions.rename_table_references(table, "other")
    
  295.         self.assertIs(expressions.references_table(table), False)
    
  296.         self.assertIs(expressions.references_table("other"), True)
    
  297.         expected_str = "(UPPER(%s)), %s" % (
    
  298.             self.editor.quote_name("last_name"),
    
  299.             self.editor.quote_name("first_name"),
    
  300.         )
    
  301.         self.assertEqual(str(expressions), expected_str)
    
  302. 
    
  303.     def test_rename_column_references(self):
    
  304.         table = Person._meta.db_table
    
  305.         self.expressions.rename_column_references(table, "first_name", "other")
    
  306.         self.assertIs(self.expressions.references_column(table, "other"), True)
    
  307.         self.assertIs(self.expressions.references_column(table, "first_name"), False)
    
  308.         self.assertIn(
    
  309.             "%s.%s" % (self.editor.quote_name(table), self.editor.quote_name("other")),
    
  310.             str(self.expressions),
    
  311.         )
    
  312. 
    
  313.     def test_str(self):
    
  314.         table_name = self.editor.quote_name(Person._meta.db_table)
    
  315.         expected_str = "%s.%s, %s.%s DESC, (UPPER(%s.%s))" % (
    
  316.             table_name,
    
  317.             self.editor.quote_name("first_name"),
    
  318.             table_name,
    
  319.             self.editor.quote_name("last_name"),
    
  320.             table_name,
    
  321.             self.editor.quote_name("last_name"),
    
  322.         )
    
  323.         self.assertEqual(str(self.expressions), expected_str)