1. import unittest
    
  2. 
    
  3. import sqlparse
    
  4. 
    
  5. from django.db import connection
    
  6. from django.test import TestCase
    
  7. 
    
  8. 
    
  9. @unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests")
    
  10. class IntrospectionTests(TestCase):
    
  11.     def test_get_primary_key_column(self):
    
  12.         """
    
  13.         Get the primary key column regardless of whether or not it has
    
  14.         quotation.
    
  15.         """
    
  16.         testable_column_strings = (
    
  17.             ("id", "id"),
    
  18.             ("[id]", "id"),
    
  19.             ("`id`", "id"),
    
  20.             ('"id"', "id"),
    
  21.             ("[id col]", "id col"),
    
  22.             ("`id col`", "id col"),
    
  23.             ('"id col"', "id col"),
    
  24.         )
    
  25.         with connection.cursor() as cursor:
    
  26.             for column, expected_string in testable_column_strings:
    
  27.                 sql = "CREATE TABLE test_primary (%s int PRIMARY KEY NOT NULL)" % column
    
  28.                 with self.subTest(column=column):
    
  29.                     try:
    
  30.                         cursor.execute(sql)
    
  31.                         field = connection.introspection.get_primary_key_column(
    
  32.                             cursor, "test_primary"
    
  33.                         )
    
  34.                         self.assertEqual(field, expected_string)
    
  35.                     finally:
    
  36.                         cursor.execute("DROP TABLE test_primary")
    
  37. 
    
  38.     def test_get_primary_key_column_pk_constraint(self):
    
  39.         sql = """
    
  40.             CREATE TABLE test_primary(
    
  41.                 id INTEGER NOT NULL,
    
  42.                 created DATE,
    
  43.                 PRIMARY KEY(id)
    
  44.             )
    
  45.         """
    
  46.         with connection.cursor() as cursor:
    
  47.             try:
    
  48.                 cursor.execute(sql)
    
  49.                 field = connection.introspection.get_primary_key_column(
    
  50.                     cursor,
    
  51.                     "test_primary",
    
  52.                 )
    
  53.                 self.assertEqual(field, "id")
    
  54.             finally:
    
  55.                 cursor.execute("DROP TABLE test_primary")
    
  56. 
    
  57. 
    
  58. @unittest.skipUnless(connection.vendor == "sqlite", "SQLite tests")
    
  59. class ParsingTests(TestCase):
    
  60.     def parse_definition(self, sql, columns):
    
  61.         """Parse a column or constraint definition."""
    
  62.         statement = sqlparse.parse(sql)[0]
    
  63.         tokens = (token for token in statement.flatten() if not token.is_whitespace)
    
  64.         with connection.cursor():
    
  65.             return connection.introspection._parse_column_or_constraint_definition(
    
  66.                 tokens, set(columns)
    
  67.             )
    
  68. 
    
  69.     def assertConstraint(self, constraint_details, cols, unique=False, check=False):
    
  70.         self.assertEqual(
    
  71.             constraint_details,
    
  72.             {
    
  73.                 "unique": unique,
    
  74.                 "columns": cols,
    
  75.                 "primary_key": False,
    
  76.                 "foreign_key": None,
    
  77.                 "check": check,
    
  78.                 "index": False,
    
  79.             },
    
  80.         )
    
  81. 
    
  82.     def test_unique_column(self):
    
  83.         tests = (
    
  84.             ('"ref" integer UNIQUE,', ["ref"]),
    
  85.             ("ref integer UNIQUE,", ["ref"]),
    
  86.             ('"customname" integer UNIQUE,', ["customname"]),
    
  87.             ("customname integer UNIQUE,", ["customname"]),
    
  88.         )
    
  89.         for sql, columns in tests:
    
  90.             with self.subTest(sql=sql):
    
  91.                 constraint, details, check, _ = self.parse_definition(sql, columns)
    
  92.                 self.assertIsNone(constraint)
    
  93.                 self.assertConstraint(details, columns, unique=True)
    
  94.                 self.assertIsNone(check)
    
  95. 
    
  96.     def test_unique_constraint(self):
    
  97.         tests = (
    
  98.             ('CONSTRAINT "ref" UNIQUE ("ref"),', "ref", ["ref"]),
    
  99.             ("CONSTRAINT ref UNIQUE (ref),", "ref", ["ref"]),
    
  100.             (
    
  101.                 'CONSTRAINT "customname1" UNIQUE ("customname2"),',
    
  102.                 "customname1",
    
  103.                 ["customname2"],
    
  104.             ),
    
  105.             (
    
  106.                 "CONSTRAINT customname1 UNIQUE (customname2),",
    
  107.                 "customname1",
    
  108.                 ["customname2"],
    
  109.             ),
    
  110.         )
    
  111.         for sql, constraint_name, columns in tests:
    
  112.             with self.subTest(sql=sql):
    
  113.                 constraint, details, check, _ = self.parse_definition(sql, columns)
    
  114.                 self.assertEqual(constraint, constraint_name)
    
  115.                 self.assertConstraint(details, columns, unique=True)
    
  116.                 self.assertIsNone(check)
    
  117. 
    
  118.     def test_unique_constraint_multicolumn(self):
    
  119.         tests = (
    
  120.             (
    
  121.                 'CONSTRAINT "ref" UNIQUE ("ref", "customname"),',
    
  122.                 "ref",
    
  123.                 ["ref", "customname"],
    
  124.             ),
    
  125.             ("CONSTRAINT ref UNIQUE (ref, customname),", "ref", ["ref", "customname"]),
    
  126.         )
    
  127.         for sql, constraint_name, columns in tests:
    
  128.             with self.subTest(sql=sql):
    
  129.                 constraint, details, check, _ = self.parse_definition(sql, columns)
    
  130.                 self.assertEqual(constraint, constraint_name)
    
  131.                 self.assertConstraint(details, columns, unique=True)
    
  132.                 self.assertIsNone(check)
    
  133. 
    
  134.     def test_check_column(self):
    
  135.         tests = (
    
  136.             ('"ref" varchar(255) CHECK ("ref" != \'test\'),', ["ref"]),
    
  137.             ("ref varchar(255) CHECK (ref != 'test'),", ["ref"]),
    
  138.             (
    
  139.                 '"customname1" varchar(255) CHECK ("customname2" != \'test\'),',
    
  140.                 ["customname2"],
    
  141.             ),
    
  142.             (
    
  143.                 "customname1 varchar(255) CHECK (customname2 != 'test'),",
    
  144.                 ["customname2"],
    
  145.             ),
    
  146.         )
    
  147.         for sql, columns in tests:
    
  148.             with self.subTest(sql=sql):
    
  149.                 constraint, details, check, _ = self.parse_definition(sql, columns)
    
  150.                 self.assertIsNone(constraint)
    
  151.                 self.assertIsNone(details)
    
  152.                 self.assertConstraint(check, columns, check=True)
    
  153. 
    
  154.     def test_check_constraint(self):
    
  155.         tests = (
    
  156.             ('CONSTRAINT "ref" CHECK ("ref" != \'test\'),', "ref", ["ref"]),
    
  157.             ("CONSTRAINT ref CHECK (ref != 'test'),", "ref", ["ref"]),
    
  158.             (
    
  159.                 'CONSTRAINT "customname1" CHECK ("customname2" != \'test\'),',
    
  160.                 "customname1",
    
  161.                 ["customname2"],
    
  162.             ),
    
  163.             (
    
  164.                 "CONSTRAINT customname1 CHECK (customname2 != 'test'),",
    
  165.                 "customname1",
    
  166.                 ["customname2"],
    
  167.             ),
    
  168.         )
    
  169.         for sql, constraint_name, columns in tests:
    
  170.             with self.subTest(sql=sql):
    
  171.                 constraint, details, check, _ = self.parse_definition(sql, columns)
    
  172.                 self.assertEqual(constraint, constraint_name)
    
  173.                 self.assertIsNone(details)
    
  174.                 self.assertConstraint(check, columns, check=True)
    
  175. 
    
  176.     def test_check_column_with_operators_and_functions(self):
    
  177.         tests = (
    
  178.             ('"ref" integer CHECK ("ref" BETWEEN 1 AND 10),', ["ref"]),
    
  179.             ('"ref" varchar(255) CHECK ("ref" LIKE \'test%\'),', ["ref"]),
    
  180.             (
    
  181.                 '"ref" varchar(255) CHECK (LENGTH(ref) > "max_length"),',
    
  182.                 ["ref", "max_length"],
    
  183.             ),
    
  184.         )
    
  185.         for sql, columns in tests:
    
  186.             with self.subTest(sql=sql):
    
  187.                 constraint, details, check, _ = self.parse_definition(sql, columns)
    
  188.                 self.assertIsNone(constraint)
    
  189.                 self.assertIsNone(details)
    
  190.                 self.assertConstraint(check, columns, check=True)
    
  191. 
    
  192.     def test_check_and_unique_column(self):
    
  193.         tests = (
    
  194.             ('"ref" varchar(255) CHECK ("ref" != \'test\') UNIQUE,', ["ref"]),
    
  195.             ("ref varchar(255) UNIQUE CHECK (ref != 'test'),", ["ref"]),
    
  196.         )
    
  197.         for sql, columns in tests:
    
  198.             with self.subTest(sql=sql):
    
  199.                 constraint, details, check, _ = self.parse_definition(sql, columns)
    
  200.                 self.assertIsNone(constraint)
    
  201.                 self.assertConstraint(details, columns, unique=True)
    
  202.                 self.assertConstraint(check, columns, check=True)