1. import time
    
  2. import unittest
    
  3. from datetime import date, datetime
    
  4. 
    
  5. from django.core.exceptions import FieldError
    
  6. from django.db import connection, models
    
  7. from django.test import SimpleTestCase, TestCase, override_settings
    
  8. from django.test.utils import register_lookup
    
  9. from django.utils import timezone
    
  10. 
    
  11. from .models import Article, Author, MySQLUnixTimestamp
    
  12. 
    
  13. 
    
  14. class Div3Lookup(models.Lookup):
    
  15.     lookup_name = "div3"
    
  16. 
    
  17.     def as_sql(self, compiler, connection):
    
  18.         lhs, params = self.process_lhs(compiler, connection)
    
  19.         rhs, rhs_params = self.process_rhs(compiler, connection)
    
  20.         params.extend(rhs_params)
    
  21.         return "(%s) %%%% 3 = %s" % (lhs, rhs), params
    
  22. 
    
  23.     def as_oracle(self, compiler, connection):
    
  24.         lhs, params = self.process_lhs(compiler, connection)
    
  25.         rhs, rhs_params = self.process_rhs(compiler, connection)
    
  26.         params.extend(rhs_params)
    
  27.         return "mod(%s, 3) = %s" % (lhs, rhs), params
    
  28. 
    
  29. 
    
  30. class Div3Transform(models.Transform):
    
  31.     lookup_name = "div3"
    
  32. 
    
  33.     def as_sql(self, compiler, connection):
    
  34.         lhs, lhs_params = compiler.compile(self.lhs)
    
  35.         return "(%s) %%%% 3" % lhs, lhs_params
    
  36. 
    
  37.     def as_oracle(self, compiler, connection, **extra_context):
    
  38.         lhs, lhs_params = compiler.compile(self.lhs)
    
  39.         return "mod(%s, 3)" % lhs, lhs_params
    
  40. 
    
  41. 
    
  42. class Div3BilateralTransform(Div3Transform):
    
  43.     bilateral = True
    
  44. 
    
  45. 
    
  46. class Mult3BilateralTransform(models.Transform):
    
  47.     bilateral = True
    
  48.     lookup_name = "mult3"
    
  49. 
    
  50.     def as_sql(self, compiler, connection):
    
  51.         lhs, lhs_params = compiler.compile(self.lhs)
    
  52.         return "3 * (%s)" % lhs, lhs_params
    
  53. 
    
  54. 
    
  55. class LastDigitTransform(models.Transform):
    
  56.     lookup_name = "lastdigit"
    
  57. 
    
  58.     def as_sql(self, compiler, connection):
    
  59.         lhs, lhs_params = compiler.compile(self.lhs)
    
  60.         return "SUBSTR(CAST(%s AS CHAR(2)), 2, 1)" % lhs, lhs_params
    
  61. 
    
  62. 
    
  63. class UpperBilateralTransform(models.Transform):
    
  64.     bilateral = True
    
  65.     lookup_name = "upper"
    
  66. 
    
  67.     def as_sql(self, compiler, connection):
    
  68.         lhs, lhs_params = compiler.compile(self.lhs)
    
  69.         return "UPPER(%s)" % lhs, lhs_params
    
  70. 
    
  71. 
    
  72. class YearTransform(models.Transform):
    
  73.     # Use a name that avoids collision with the built-in year lookup.
    
  74.     lookup_name = "testyear"
    
  75. 
    
  76.     def as_sql(self, compiler, connection):
    
  77.         lhs_sql, params = compiler.compile(self.lhs)
    
  78.         return connection.ops.date_extract_sql("year", lhs_sql, params)
    
  79. 
    
  80.     @property
    
  81.     def output_field(self):
    
  82.         return models.IntegerField()
    
  83. 
    
  84. 
    
  85. @YearTransform.register_lookup
    
  86. class YearExact(models.lookups.Lookup):
    
  87.     lookup_name = "exact"
    
  88. 
    
  89.     def as_sql(self, compiler, connection):
    
  90.         # We will need to skip the extract part, and instead go
    
  91.         # directly with the originating field, that is self.lhs.lhs
    
  92.         lhs_sql, lhs_params = self.process_lhs(compiler, connection, self.lhs.lhs)
    
  93.         rhs_sql, rhs_params = self.process_rhs(compiler, connection)
    
  94.         # Note that we must be careful so that we have params in the
    
  95.         # same order as we have the parts in the SQL.
    
  96.         params = lhs_params + rhs_params + lhs_params + rhs_params
    
  97.         # We use PostgreSQL specific SQL here. Note that we must do the
    
  98.         # conversions in SQL instead of in Python to support F() references.
    
  99.         return (
    
  100.             "%(lhs)s >= (%(rhs)s || '-01-01')::date "
    
  101.             "AND %(lhs)s <= (%(rhs)s || '-12-31')::date"
    
  102.             % {"lhs": lhs_sql, "rhs": rhs_sql},
    
  103.             params,
    
  104.         )
    
  105. 
    
  106. 
    
  107. @YearTransform.register_lookup
    
  108. class YearLte(models.lookups.LessThanOrEqual):
    
  109.     """
    
  110.     The purpose of this lookup is to efficiently compare the year of the field.
    
  111.     """
    
  112. 
    
  113.     def as_sql(self, compiler, connection):
    
  114.         # Skip the YearTransform above us (no possibility for efficient
    
  115.         # lookup otherwise).
    
  116.         real_lhs = self.lhs.lhs
    
  117.         lhs_sql, params = self.process_lhs(compiler, connection, real_lhs)
    
  118.         rhs_sql, rhs_params = self.process_rhs(compiler, connection)
    
  119.         params.extend(rhs_params)
    
  120.         # Build SQL where the integer year is concatenated with last month
    
  121.         # and day, then convert that to date. (We try to have SQL like:
    
  122.         #     WHERE somecol <= '2013-12-31')
    
  123.         # but also make it work if the rhs_sql is field reference.
    
  124.         return "%s <= (%s || '-12-31')::date" % (lhs_sql, rhs_sql), params
    
  125. 
    
  126. 
    
  127. class Exactly(models.lookups.Exact):
    
  128.     """
    
  129.     This lookup is used to test lookup registration.
    
  130.     """
    
  131. 
    
  132.     lookup_name = "exactly"
    
  133. 
    
  134.     def get_rhs_op(self, connection, rhs):
    
  135.         return connection.operators["exact"] % rhs
    
  136. 
    
  137. 
    
  138. class SQLFuncMixin:
    
  139.     def as_sql(self, compiler, connection):
    
  140.         return "%s()" % self.name, []
    
  141. 
    
  142.     @property
    
  143.     def output_field(self):
    
  144.         return CustomField()
    
  145. 
    
  146. 
    
  147. class SQLFuncLookup(SQLFuncMixin, models.Lookup):
    
  148.     def __init__(self, name, *args, **kwargs):
    
  149.         super().__init__(*args, **kwargs)
    
  150.         self.name = name
    
  151. 
    
  152. 
    
  153. class SQLFuncTransform(SQLFuncMixin, models.Transform):
    
  154.     def __init__(self, name, *args, **kwargs):
    
  155.         super().__init__(*args, **kwargs)
    
  156.         self.name = name
    
  157. 
    
  158. 
    
  159. class SQLFuncFactory:
    
  160.     def __init__(self, key, name):
    
  161.         self.key = key
    
  162.         self.name = name
    
  163. 
    
  164.     def __call__(self, *args, **kwargs):
    
  165.         if self.key == "lookupfunc":
    
  166.             return SQLFuncLookup(self.name, *args, **kwargs)
    
  167.         return SQLFuncTransform(self.name, *args, **kwargs)
    
  168. 
    
  169. 
    
  170. class CustomField(models.TextField):
    
  171.     def get_lookup(self, lookup_name):
    
  172.         if lookup_name.startswith("lookupfunc_"):
    
  173.             key, name = lookup_name.split("_", 1)
    
  174.             return SQLFuncFactory(key, name)
    
  175.         return super().get_lookup(lookup_name)
    
  176. 
    
  177.     def get_transform(self, lookup_name):
    
  178.         if lookup_name.startswith("transformfunc_"):
    
  179.             key, name = lookup_name.split("_", 1)
    
  180.             return SQLFuncFactory(key, name)
    
  181.         return super().get_transform(lookup_name)
    
  182. 
    
  183. 
    
  184. class CustomModel(models.Model):
    
  185.     field = CustomField()
    
  186. 
    
  187. 
    
  188. # We will register this class temporarily in the test method.
    
  189. 
    
  190. 
    
  191. class InMonth(models.lookups.Lookup):
    
  192.     """
    
  193.     InMonth matches if the column's month is the same as value's month.
    
  194.     """
    
  195. 
    
  196.     lookup_name = "inmonth"
    
  197. 
    
  198.     def as_sql(self, compiler, connection):
    
  199.         lhs, lhs_params = self.process_lhs(compiler, connection)
    
  200.         rhs, rhs_params = self.process_rhs(compiler, connection)
    
  201.         # We need to be careful so that we get the params in right
    
  202.         # places.
    
  203.         params = lhs_params + rhs_params + lhs_params + rhs_params
    
  204.         return (
    
  205.             "%s >= date_trunc('month', %s) and "
    
  206.             "%s < date_trunc('month', %s) + interval '1 months'" % (lhs, rhs, lhs, rhs),
    
  207.             params,
    
  208.         )
    
  209. 
    
  210. 
    
  211. class DateTimeTransform(models.Transform):
    
  212.     lookup_name = "as_datetime"
    
  213. 
    
  214.     @property
    
  215.     def output_field(self):
    
  216.         return models.DateTimeField()
    
  217. 
    
  218.     def as_sql(self, compiler, connection):
    
  219.         lhs, params = compiler.compile(self.lhs)
    
  220.         return "from_unixtime({})".format(lhs), params
    
  221. 
    
  222. 
    
  223. class LookupTests(TestCase):
    
  224.     def test_custom_name_lookup(self):
    
  225.         a1 = Author.objects.create(name="a1", birthdate=date(1981, 2, 16))
    
  226.         Author.objects.create(name="a2", birthdate=date(2012, 2, 29))
    
  227.         with register_lookup(models.DateField, YearTransform), register_lookup(
    
  228.             models.DateField, YearTransform, lookup_name="justtheyear"
    
  229.         ), register_lookup(YearTransform, Exactly), register_lookup(
    
  230.             YearTransform, Exactly, lookup_name="isactually"
    
  231.         ):
    
  232.             qs1 = Author.objects.filter(birthdate__testyear__exactly=1981)
    
  233.             qs2 = Author.objects.filter(birthdate__justtheyear__isactually=1981)
    
  234.             self.assertSequenceEqual(qs1, [a1])
    
  235.             self.assertSequenceEqual(qs2, [a1])
    
  236. 
    
  237.     def test_custom_exact_lookup_none_rhs(self):
    
  238.         """
    
  239.         __exact=None is transformed to __isnull=True if a custom lookup class
    
  240.         with lookup_name != 'exact' is registered as the `exact` lookup.
    
  241.         """
    
  242.         field = Author._meta.get_field("birthdate")
    
  243.         OldExactLookup = field.get_lookup("exact")
    
  244.         author = Author.objects.create(name="author", birthdate=None)
    
  245.         try:
    
  246.             field.register_lookup(Exactly, "exact")
    
  247.             self.assertEqual(Author.objects.get(birthdate__exact=None), author)
    
  248.         finally:
    
  249.             field.register_lookup(OldExactLookup, "exact")
    
  250. 
    
  251.     def test_basic_lookup(self):
    
  252.         a1 = Author.objects.create(name="a1", age=1)
    
  253.         a2 = Author.objects.create(name="a2", age=2)
    
  254.         a3 = Author.objects.create(name="a3", age=3)
    
  255.         a4 = Author.objects.create(name="a4", age=4)
    
  256.         with register_lookup(models.IntegerField, Div3Lookup):
    
  257.             self.assertSequenceEqual(Author.objects.filter(age__div3=0), [a3])
    
  258.             self.assertSequenceEqual(
    
  259.                 Author.objects.filter(age__div3=1).order_by("age"), [a1, a4]
    
  260.             )
    
  261.             self.assertSequenceEqual(Author.objects.filter(age__div3=2), [a2])
    
  262.             self.assertSequenceEqual(Author.objects.filter(age__div3=3), [])
    
  263. 
    
  264.     @unittest.skipUnless(
    
  265.         connection.vendor == "postgresql", "PostgreSQL specific SQL used"
    
  266.     )
    
  267.     def test_birthdate_month(self):
    
  268.         a1 = Author.objects.create(name="a1", birthdate=date(1981, 2, 16))
    
  269.         a2 = Author.objects.create(name="a2", birthdate=date(2012, 2, 29))
    
  270.         a3 = Author.objects.create(name="a3", birthdate=date(2012, 1, 31))
    
  271.         a4 = Author.objects.create(name="a4", birthdate=date(2012, 3, 1))
    
  272.         with register_lookup(models.DateField, InMonth):
    
  273.             self.assertSequenceEqual(
    
  274.                 Author.objects.filter(birthdate__inmonth=date(2012, 1, 15)), [a3]
    
  275.             )
    
  276.             self.assertSequenceEqual(
    
  277.                 Author.objects.filter(birthdate__inmonth=date(2012, 2, 1)), [a2]
    
  278.             )
    
  279.             self.assertSequenceEqual(
    
  280.                 Author.objects.filter(birthdate__inmonth=date(1981, 2, 28)), [a1]
    
  281.             )
    
  282.             self.assertSequenceEqual(
    
  283.                 Author.objects.filter(birthdate__inmonth=date(2012, 3, 12)), [a4]
    
  284.             )
    
  285.             self.assertSequenceEqual(
    
  286.                 Author.objects.filter(birthdate__inmonth=date(2012, 4, 1)), []
    
  287.             )
    
  288. 
    
  289.     def test_div3_extract(self):
    
  290.         with register_lookup(models.IntegerField, Div3Transform):
    
  291.             a1 = Author.objects.create(name="a1", age=1)
    
  292.             a2 = Author.objects.create(name="a2", age=2)
    
  293.             a3 = Author.objects.create(name="a3", age=3)
    
  294.             a4 = Author.objects.create(name="a4", age=4)
    
  295.             baseqs = Author.objects.order_by("name")
    
  296.             self.assertSequenceEqual(baseqs.filter(age__div3=2), [a2])
    
  297.             self.assertSequenceEqual(baseqs.filter(age__div3__lte=3), [a1, a2, a3, a4])
    
  298.             self.assertSequenceEqual(baseqs.filter(age__div3__in=[0, 2]), [a2, a3])
    
  299.             self.assertSequenceEqual(baseqs.filter(age__div3__in=[2, 4]), [a2])
    
  300.             self.assertSequenceEqual(baseqs.filter(age__div3__gte=3), [])
    
  301.             self.assertSequenceEqual(
    
  302.                 baseqs.filter(age__div3__range=(1, 2)), [a1, a2, a4]
    
  303.             )
    
  304. 
    
  305.     def test_foreignobject_lookup_registration(self):
    
  306.         field = Article._meta.get_field("author")
    
  307. 
    
  308.         with register_lookup(models.ForeignObject, Exactly):
    
  309.             self.assertIs(field.get_lookup("exactly"), Exactly)
    
  310. 
    
  311.         # ForeignObject should ignore regular Field lookups
    
  312.         with register_lookup(models.Field, Exactly):
    
  313.             self.assertIsNone(field.get_lookup("exactly"))
    
  314. 
    
  315.     def test_lookups_caching(self):
    
  316.         field = Article._meta.get_field("author")
    
  317. 
    
  318.         # clear and re-cache
    
  319.         field.get_lookups.cache_clear()
    
  320.         self.assertNotIn("exactly", field.get_lookups())
    
  321. 
    
  322.         # registration should bust the cache
    
  323.         with register_lookup(models.ForeignObject, Exactly):
    
  324.             # getting the lookups again should re-cache
    
  325.             self.assertIn("exactly", field.get_lookups())
    
  326.         # Unregistration should bust the cache.
    
  327.         self.assertNotIn("exactly", field.get_lookups())
    
  328. 
    
  329. 
    
  330. class BilateralTransformTests(TestCase):
    
  331.     def test_bilateral_upper(self):
    
  332.         with register_lookup(models.CharField, UpperBilateralTransform):
    
  333.             author1 = Author.objects.create(name="Doe")
    
  334.             author2 = Author.objects.create(name="doe")
    
  335.             author3 = Author.objects.create(name="Foo")
    
  336.             self.assertCountEqual(
    
  337.                 Author.objects.filter(name__upper="doe"),
    
  338.                 [author1, author2],
    
  339.             )
    
  340.             self.assertSequenceEqual(
    
  341.                 Author.objects.filter(name__upper__contains="f"),
    
  342.                 [author3],
    
  343.             )
    
  344. 
    
  345.     def test_bilateral_inner_qs(self):
    
  346.         with register_lookup(models.CharField, UpperBilateralTransform):
    
  347.             msg = "Bilateral transformations on nested querysets are not implemented."
    
  348.             with self.assertRaisesMessage(NotImplementedError, msg):
    
  349.                 Author.objects.filter(
    
  350.                     name__upper__in=Author.objects.values_list("name")
    
  351.                 )
    
  352. 
    
  353.     def test_bilateral_multi_value(self):
    
  354.         with register_lookup(models.CharField, UpperBilateralTransform):
    
  355.             Author.objects.bulk_create(
    
  356.                 [
    
  357.                     Author(name="Foo"),
    
  358.                     Author(name="Bar"),
    
  359.                     Author(name="Ray"),
    
  360.                 ]
    
  361.             )
    
  362.             self.assertQuerysetEqual(
    
  363.                 Author.objects.filter(name__upper__in=["foo", "bar", "doe"]).order_by(
    
  364.                     "name"
    
  365.                 ),
    
  366.                 ["Bar", "Foo"],
    
  367.                 lambda a: a.name,
    
  368.             )
    
  369. 
    
  370.     def test_div3_bilateral_extract(self):
    
  371.         with register_lookup(models.IntegerField, Div3BilateralTransform):
    
  372.             a1 = Author.objects.create(name="a1", age=1)
    
  373.             a2 = Author.objects.create(name="a2", age=2)
    
  374.             a3 = Author.objects.create(name="a3", age=3)
    
  375.             a4 = Author.objects.create(name="a4", age=4)
    
  376.             baseqs = Author.objects.order_by("name")
    
  377.             self.assertSequenceEqual(baseqs.filter(age__div3=2), [a2])
    
  378.             self.assertSequenceEqual(baseqs.filter(age__div3__lte=3), [a3])
    
  379.             self.assertSequenceEqual(baseqs.filter(age__div3__in=[0, 2]), [a2, a3])
    
  380.             self.assertSequenceEqual(baseqs.filter(age__div3__in=[2, 4]), [a1, a2, a4])
    
  381.             self.assertSequenceEqual(baseqs.filter(age__div3__gte=3), [a1, a2, a3, a4])
    
  382.             self.assertSequenceEqual(
    
  383.                 baseqs.filter(age__div3__range=(1, 2)), [a1, a2, a4]
    
  384.             )
    
  385. 
    
  386.     def test_bilateral_order(self):
    
  387.         with register_lookup(
    
  388.             models.IntegerField, Mult3BilateralTransform, Div3BilateralTransform
    
  389.         ):
    
  390.             a1 = Author.objects.create(name="a1", age=1)
    
  391.             a2 = Author.objects.create(name="a2", age=2)
    
  392.             a3 = Author.objects.create(name="a3", age=3)
    
  393.             a4 = Author.objects.create(name="a4", age=4)
    
  394.             baseqs = Author.objects.order_by("name")
    
  395. 
    
  396.             # mult3__div3 always leads to 0
    
  397.             self.assertSequenceEqual(
    
  398.                 baseqs.filter(age__mult3__div3=42), [a1, a2, a3, a4]
    
  399.             )
    
  400.             self.assertSequenceEqual(baseqs.filter(age__div3__mult3=42), [a3])
    
  401. 
    
  402.     def test_transform_order_by(self):
    
  403.         with register_lookup(models.IntegerField, LastDigitTransform):
    
  404.             a1 = Author.objects.create(name="a1", age=11)
    
  405.             a2 = Author.objects.create(name="a2", age=23)
    
  406.             a3 = Author.objects.create(name="a3", age=32)
    
  407.             a4 = Author.objects.create(name="a4", age=40)
    
  408.             qs = Author.objects.order_by("age__lastdigit")
    
  409.             self.assertSequenceEqual(qs, [a4, a1, a3, a2])
    
  410. 
    
  411.     def test_bilateral_fexpr(self):
    
  412.         with register_lookup(models.IntegerField, Mult3BilateralTransform):
    
  413.             a1 = Author.objects.create(name="a1", age=1, average_rating=3.2)
    
  414.             a2 = Author.objects.create(name="a2", age=2, average_rating=0.5)
    
  415.             a3 = Author.objects.create(name="a3", age=3, average_rating=1.5)
    
  416.             a4 = Author.objects.create(name="a4", age=4)
    
  417.             baseqs = Author.objects.order_by("name")
    
  418.             self.assertSequenceEqual(
    
  419.                 baseqs.filter(age__mult3=models.F("age")), [a1, a2, a3, a4]
    
  420.             )
    
  421.             # Same as age >= average_rating
    
  422.             self.assertSequenceEqual(
    
  423.                 baseqs.filter(age__mult3__gte=models.F("average_rating")), [a2, a3]
    
  424.             )
    
  425. 
    
  426. 
    
  427. @override_settings(USE_TZ=True)
    
  428. class DateTimeLookupTests(TestCase):
    
  429.     @unittest.skipUnless(connection.vendor == "mysql", "MySQL specific SQL used")
    
  430.     def test_datetime_output_field(self):
    
  431.         with register_lookup(models.PositiveIntegerField, DateTimeTransform):
    
  432.             ut = MySQLUnixTimestamp.objects.create(timestamp=time.time())
    
  433.             y2k = timezone.make_aware(datetime(2000, 1, 1))
    
  434.             self.assertSequenceEqual(
    
  435.                 MySQLUnixTimestamp.objects.filter(timestamp__as_datetime__gt=y2k), [ut]
    
  436.             )
    
  437. 
    
  438. 
    
  439. class YearLteTests(TestCase):
    
  440.     @classmethod
    
  441.     def setUpTestData(cls):
    
  442.         cls.a1 = Author.objects.create(name="a1", birthdate=date(1981, 2, 16))
    
  443.         cls.a2 = Author.objects.create(name="a2", birthdate=date(2012, 2, 29))
    
  444.         cls.a3 = Author.objects.create(name="a3", birthdate=date(2012, 1, 31))
    
  445.         cls.a4 = Author.objects.create(name="a4", birthdate=date(2012, 3, 1))
    
  446. 
    
  447.     def setUp(self):
    
  448.         models.DateField.register_lookup(YearTransform)
    
  449. 
    
  450.     def tearDown(self):
    
  451.         models.DateField._unregister_lookup(YearTransform)
    
  452. 
    
  453.     @unittest.skipUnless(
    
  454.         connection.vendor == "postgresql", "PostgreSQL specific SQL used"
    
  455.     )
    
  456.     def test_year_lte(self):
    
  457.         baseqs = Author.objects.order_by("name")
    
  458.         self.assertSequenceEqual(
    
  459.             baseqs.filter(birthdate__testyear__lte=2012),
    
  460.             [self.a1, self.a2, self.a3, self.a4],
    
  461.         )
    
  462.         self.assertSequenceEqual(
    
  463.             baseqs.filter(birthdate__testyear=2012), [self.a2, self.a3, self.a4]
    
  464.         )
    
  465. 
    
  466.         self.assertNotIn("BETWEEN", str(baseqs.filter(birthdate__testyear=2012).query))
    
  467.         self.assertSequenceEqual(
    
  468.             baseqs.filter(birthdate__testyear__lte=2011), [self.a1]
    
  469.         )
    
  470.         # The non-optimized version works, too.
    
  471.         self.assertSequenceEqual(baseqs.filter(birthdate__testyear__lt=2012), [self.a1])
    
  472. 
    
  473.     @unittest.skipUnless(
    
  474.         connection.vendor == "postgresql", "PostgreSQL specific SQL used"
    
  475.     )
    
  476.     def test_year_lte_fexpr(self):
    
  477.         self.a2.age = 2011
    
  478.         self.a2.save()
    
  479.         self.a3.age = 2012
    
  480.         self.a3.save()
    
  481.         self.a4.age = 2013
    
  482.         self.a4.save()
    
  483.         baseqs = Author.objects.order_by("name")
    
  484.         self.assertSequenceEqual(
    
  485.             baseqs.filter(birthdate__testyear__lte=models.F("age")), [self.a3, self.a4]
    
  486.         )
    
  487.         self.assertSequenceEqual(
    
  488.             baseqs.filter(birthdate__testyear__lt=models.F("age")), [self.a4]
    
  489.         )
    
  490. 
    
  491.     def test_year_lte_sql(self):
    
  492.         # This test will just check the generated SQL for __lte. This
    
  493.         # doesn't require running on PostgreSQL and spots the most likely
    
  494.         # error - not running YearLte SQL at all.
    
  495.         baseqs = Author.objects.order_by("name")
    
  496.         self.assertIn(
    
  497.             "<= (2011 || ", str(baseqs.filter(birthdate__testyear__lte=2011).query)
    
  498.         )
    
  499.         self.assertIn("-12-31", str(baseqs.filter(birthdate__testyear__lte=2011).query))
    
  500. 
    
  501.     def test_postgres_year_exact(self):
    
  502.         baseqs = Author.objects.order_by("name")
    
  503.         self.assertIn("= (2011 || ", str(baseqs.filter(birthdate__testyear=2011).query))
    
  504.         self.assertIn("-12-31", str(baseqs.filter(birthdate__testyear=2011).query))
    
  505. 
    
  506.     def test_custom_implementation_year_exact(self):
    
  507.         try:
    
  508.             # Two ways to add a customized implementation for different backends:
    
  509.             # First is MonkeyPatch of the class.
    
  510.             def as_custom_sql(self, compiler, connection):
    
  511.                 lhs_sql, lhs_params = self.process_lhs(
    
  512.                     compiler, connection, self.lhs.lhs
    
  513.                 )
    
  514.                 rhs_sql, rhs_params = self.process_rhs(compiler, connection)
    
  515.                 params = lhs_params + rhs_params + lhs_params + rhs_params
    
  516.                 return (
    
  517.                     "%(lhs)s >= "
    
  518.                     "str_to_date(concat(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
    
  519.                     "AND %(lhs)s <= "
    
  520.                     "str_to_date(concat(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')"
    
  521.                     % {"lhs": lhs_sql, "rhs": rhs_sql},
    
  522.                     params,
    
  523.                 )
    
  524. 
    
  525.             setattr(YearExact, "as_" + connection.vendor, as_custom_sql)
    
  526.             self.assertIn(
    
  527.                 "concat(", str(Author.objects.filter(birthdate__testyear=2012).query)
    
  528.             )
    
  529.         finally:
    
  530.             delattr(YearExact, "as_" + connection.vendor)
    
  531.         try:
    
  532.             # The other way is to subclass the original lookup and register the
    
  533.             # subclassed lookup instead of the original.
    
  534.             class CustomYearExact(YearExact):
    
  535.                 # This method should be named "as_mysql" for MySQL,
    
  536.                 # "as_postgresql" for postgres and so on, but as we don't know
    
  537.                 # which DB we are running on, we need to use setattr.
    
  538.                 def as_custom_sql(self, compiler, connection):
    
  539.                     lhs_sql, lhs_params = self.process_lhs(
    
  540.                         compiler, connection, self.lhs.lhs
    
  541.                     )
    
  542.                     rhs_sql, rhs_params = self.process_rhs(compiler, connection)
    
  543.                     params = lhs_params + rhs_params + lhs_params + rhs_params
    
  544.                     return (
    
  545.                         "%(lhs)s >= "
    
  546.                         "str_to_date(CONCAT(%(rhs)s, '-01-01'), '%%%%Y-%%%%m-%%%%d') "
    
  547.                         "AND %(lhs)s <= "
    
  548.                         "str_to_date(CONCAT(%(rhs)s, '-12-31'), '%%%%Y-%%%%m-%%%%d')"
    
  549.                         % {"lhs": lhs_sql, "rhs": rhs_sql},
    
  550.                         params,
    
  551.                     )
    
  552. 
    
  553.             setattr(
    
  554.                 CustomYearExact,
    
  555.                 "as_" + connection.vendor,
    
  556.                 CustomYearExact.as_custom_sql,
    
  557.             )
    
  558.             YearTransform.register_lookup(CustomYearExact)
    
  559.             self.assertIn(
    
  560.                 "CONCAT(", str(Author.objects.filter(birthdate__testyear=2012).query)
    
  561.             )
    
  562.         finally:
    
  563.             YearTransform._unregister_lookup(CustomYearExact)
    
  564.             YearTransform.register_lookup(YearExact)
    
  565. 
    
  566. 
    
  567. class TrackCallsYearTransform(YearTransform):
    
  568.     # Use a name that avoids collision with the built-in year lookup.
    
  569.     lookup_name = "testyear"
    
  570.     call_order = []
    
  571. 
    
  572.     def as_sql(self, compiler, connection):
    
  573.         lhs_sql, params = compiler.compile(self.lhs)
    
  574.         return connection.ops.date_extract_sql("year", lhs_sql), params
    
  575. 
    
  576.     @property
    
  577.     def output_field(self):
    
  578.         return models.IntegerField()
    
  579. 
    
  580.     def get_lookup(self, lookup_name):
    
  581.         self.call_order.append("lookup")
    
  582.         return super().get_lookup(lookup_name)
    
  583. 
    
  584.     def get_transform(self, lookup_name):
    
  585.         self.call_order.append("transform")
    
  586.         return super().get_transform(lookup_name)
    
  587. 
    
  588. 
    
  589. class LookupTransformCallOrderTests(SimpleTestCase):
    
  590.     def test_call_order(self):
    
  591.         with register_lookup(models.DateField, TrackCallsYearTransform):
    
  592.             # junk lookup - tries lookup, then transform, then fails
    
  593.             msg = (
    
  594.                 "Unsupported lookup 'junk' for IntegerField or join on the field not "
    
  595.                 "permitted."
    
  596.             )
    
  597.             with self.assertRaisesMessage(FieldError, msg):
    
  598.                 Author.objects.filter(birthdate__testyear__junk=2012)
    
  599.             self.assertEqual(
    
  600.                 TrackCallsYearTransform.call_order, ["lookup", "transform"]
    
  601.             )
    
  602.             TrackCallsYearTransform.call_order = []
    
  603.             # junk transform - tries transform only, then fails
    
  604.             with self.assertRaisesMessage(FieldError, msg):
    
  605.                 Author.objects.filter(birthdate__testyear__junk__more_junk=2012)
    
  606.             self.assertEqual(TrackCallsYearTransform.call_order, ["transform"])
    
  607.             TrackCallsYearTransform.call_order = []
    
  608.             # Just getting the year (implied __exact) - lookup only
    
  609.             Author.objects.filter(birthdate__testyear=2012)
    
  610.             self.assertEqual(TrackCallsYearTransform.call_order, ["lookup"])
    
  611.             TrackCallsYearTransform.call_order = []
    
  612.             # Just getting the year (explicit __exact) - lookup only
    
  613.             Author.objects.filter(birthdate__testyear__exact=2012)
    
  614.             self.assertEqual(TrackCallsYearTransform.call_order, ["lookup"])
    
  615. 
    
  616. 
    
  617. class CustomisedMethodsTests(SimpleTestCase):
    
  618.     def test_overridden_get_lookup(self):
    
  619.         q = CustomModel.objects.filter(field__lookupfunc_monkeys=3)
    
  620.         self.assertIn("monkeys()", str(q.query))
    
  621. 
    
  622.     def test_overridden_get_transform(self):
    
  623.         q = CustomModel.objects.filter(field__transformfunc_banana=3)
    
  624.         self.assertIn("banana()", str(q.query))
    
  625. 
    
  626.     def test_overridden_get_lookup_chain(self):
    
  627.         q = CustomModel.objects.filter(
    
  628.             field__transformfunc_banana__lookupfunc_elephants=3
    
  629.         )
    
  630.         self.assertIn("elephants()", str(q.query))
    
  631. 
    
  632.     def test_overridden_get_transform_chain(self):
    
  633.         q = CustomModel.objects.filter(
    
  634.             field__transformfunc_banana__transformfunc_pear=3
    
  635.         )
    
  636.         self.assertIn("pear()", str(q.query))
    
  637. 
    
  638. 
    
  639. class SubqueryTransformTests(TestCase):
    
  640.     def test_subquery_usage(self):
    
  641.         with register_lookup(models.IntegerField, Div3Transform):
    
  642.             Author.objects.create(name="a1", age=1)
    
  643.             a2 = Author.objects.create(name="a2", age=2)
    
  644.             Author.objects.create(name="a3", age=3)
    
  645.             Author.objects.create(name="a4", age=4)
    
  646.             qs = Author.objects.order_by("name").filter(
    
  647.                 id__in=Author.objects.filter(age__div3=2)
    
  648.             )
    
  649.             self.assertSequenceEqual(qs, [a2])