1. from django.db import connection
    
  2. from django.db.models import (
    
  3.     CharField,
    
  4.     F,
    
  5.     Func,
    
  6.     IntegerField,
    
  7.     OuterRef,
    
  8.     Q,
    
  9.     Subquery,
    
  10.     Value,
    
  11.     Window,
    
  12. )
    
  13. from django.db.models.fields.json import KeyTextTransform, KeyTransform
    
  14. from django.db.models.functions import Cast, Concat, Substr
    
  15. from django.test import skipUnlessDBFeature
    
  16. from django.test.utils import Approximate, ignore_warnings
    
  17. from django.utils import timezone
    
  18. from django.utils.deprecation import RemovedInDjango50Warning
    
  19. 
    
  20. from . import PostgreSQLTestCase
    
  21. from .models import AggregateTestModel, HotelReservation, Room, StatTestModel
    
  22. 
    
  23. try:
    
  24.     from django.contrib.postgres.aggregates import (
    
  25.         ArrayAgg,
    
  26.         BitAnd,
    
  27.         BitOr,
    
  28.         BitXor,
    
  29.         BoolAnd,
    
  30.         BoolOr,
    
  31.         Corr,
    
  32.         CovarPop,
    
  33.         JSONBAgg,
    
  34.         RegrAvgX,
    
  35.         RegrAvgY,
    
  36.         RegrCount,
    
  37.         RegrIntercept,
    
  38.         RegrR2,
    
  39.         RegrSlope,
    
  40.         RegrSXX,
    
  41.         RegrSXY,
    
  42.         RegrSYY,
    
  43.         StatAggregate,
    
  44.         StringAgg,
    
  45.     )
    
  46.     from django.contrib.postgres.fields import ArrayField
    
  47. except ImportError:
    
  48.     pass  # psycopg2 is not installed
    
  49. 
    
  50. 
    
  51. class TestGeneralAggregate(PostgreSQLTestCase):
    
  52.     @classmethod
    
  53.     def setUpTestData(cls):
    
  54.         cls.aggs = AggregateTestModel.objects.bulk_create(
    
  55.             [
    
  56.                 AggregateTestModel(
    
  57.                     boolean_field=True,
    
  58.                     char_field="Foo1",
    
  59.                     text_field="Text1",
    
  60.                     integer_field=0,
    
  61.                 ),
    
  62.                 AggregateTestModel(
    
  63.                     boolean_field=False,
    
  64.                     char_field="Foo2",
    
  65.                     text_field="Text2",
    
  66.                     integer_field=1,
    
  67.                     json_field={"lang": "pl"},
    
  68.                 ),
    
  69.                 AggregateTestModel(
    
  70.                     boolean_field=False,
    
  71.                     char_field="Foo4",
    
  72.                     text_field="Text4",
    
  73.                     integer_field=2,
    
  74.                     json_field={"lang": "en"},
    
  75.                 ),
    
  76.                 AggregateTestModel(
    
  77.                     boolean_field=True,
    
  78.                     char_field="Foo3",
    
  79.                     text_field="Text3",
    
  80.                     integer_field=0,
    
  81.                     json_field={"breed": "collie"},
    
  82.                 ),
    
  83.             ]
    
  84.         )
    
  85. 
    
  86.     @ignore_warnings(category=RemovedInDjango50Warning)
    
  87.     def test_empty_result_set(self):
    
  88.         AggregateTestModel.objects.all().delete()
    
  89.         tests = [
    
  90.             (ArrayAgg("char_field"), []),
    
  91.             (ArrayAgg("integer_field"), []),
    
  92.             (ArrayAgg("boolean_field"), []),
    
  93.             (BitAnd("integer_field"), None),
    
  94.             (BitOr("integer_field"), None),
    
  95.             (BoolAnd("boolean_field"), None),
    
  96.             (BoolOr("boolean_field"), None),
    
  97.             (JSONBAgg("integer_field"), []),
    
  98.             (StringAgg("char_field", delimiter=";"), ""),
    
  99.         ]
    
  100.         if connection.features.has_bit_xor:
    
  101.             tests.append((BitXor("integer_field"), None))
    
  102.         for aggregation, expected_result in tests:
    
  103.             with self.subTest(aggregation=aggregation):
    
  104.                 # Empty result with non-execution optimization.
    
  105.                 with self.assertNumQueries(0):
    
  106.                     values = AggregateTestModel.objects.none().aggregate(
    
  107.                         aggregation=aggregation,
    
  108.                     )
    
  109.                     self.assertEqual(values, {"aggregation": expected_result})
    
  110.                 # Empty result when query must be executed.
    
  111.                 with self.assertNumQueries(1):
    
  112.                     values = AggregateTestModel.objects.aggregate(
    
  113.                         aggregation=aggregation,
    
  114.                     )
    
  115.                     self.assertEqual(values, {"aggregation": expected_result})
    
  116. 
    
  117.     def test_default_argument(self):
    
  118.         AggregateTestModel.objects.all().delete()
    
  119.         tests = [
    
  120.             (ArrayAgg("char_field", default=["<empty>"]), ["<empty>"]),
    
  121.             (ArrayAgg("integer_field", default=[0]), [0]),
    
  122.             (ArrayAgg("boolean_field", default=[False]), [False]),
    
  123.             (BitAnd("integer_field", default=0), 0),
    
  124.             (BitOr("integer_field", default=0), 0),
    
  125.             (BoolAnd("boolean_field", default=False), False),
    
  126.             (BoolOr("boolean_field", default=False), False),
    
  127.             (JSONBAgg("integer_field", default=Value('["<empty>"]')), ["<empty>"]),
    
  128.             (
    
  129.                 StringAgg("char_field", delimiter=";", default=Value("<empty>")),
    
  130.                 "<empty>",
    
  131.             ),
    
  132.         ]
    
  133.         if connection.features.has_bit_xor:
    
  134.             tests.append((BitXor("integer_field", default=0), 0))
    
  135.         for aggregation, expected_result in tests:
    
  136.             with self.subTest(aggregation=aggregation):
    
  137.                 # Empty result with non-execution optimization.
    
  138.                 with self.assertNumQueries(0):
    
  139.                     values = AggregateTestModel.objects.none().aggregate(
    
  140.                         aggregation=aggregation,
    
  141.                     )
    
  142.                     self.assertEqual(values, {"aggregation": expected_result})
    
  143.                 # Empty result when query must be executed.
    
  144.                 with self.assertNumQueries(1):
    
  145.                     values = AggregateTestModel.objects.aggregate(
    
  146.                         aggregation=aggregation,
    
  147.                     )
    
  148.                     self.assertEqual(values, {"aggregation": expected_result})
    
  149. 
    
  150.     def test_convert_value_deprecation(self):
    
  151.         AggregateTestModel.objects.all().delete()
    
  152.         queryset = AggregateTestModel.objects.all()
    
  153. 
    
  154.         with self.assertWarnsMessage(
    
  155.             RemovedInDjango50Warning, ArrayAgg.deprecation_msg
    
  156.         ):
    
  157.             queryset.aggregate(aggregation=ArrayAgg("boolean_field"))
    
  158. 
    
  159.         with self.assertWarnsMessage(
    
  160.             RemovedInDjango50Warning, JSONBAgg.deprecation_msg
    
  161.         ):
    
  162.             queryset.aggregate(aggregation=JSONBAgg("integer_field"))
    
  163. 
    
  164.         with self.assertWarnsMessage(
    
  165.             RemovedInDjango50Warning, StringAgg.deprecation_msg
    
  166.         ):
    
  167.             queryset.aggregate(aggregation=StringAgg("char_field", delimiter=";"))
    
  168. 
    
  169.         # No warnings raised if default argument provided.
    
  170.         self.assertEqual(
    
  171.             queryset.aggregate(aggregation=ArrayAgg("boolean_field", default=None)),
    
  172.             {"aggregation": None},
    
  173.         )
    
  174.         self.assertEqual(
    
  175.             queryset.aggregate(aggregation=JSONBAgg("integer_field", default=None)),
    
  176.             {"aggregation": None},
    
  177.         )
    
  178.         self.assertEqual(
    
  179.             queryset.aggregate(
    
  180.                 aggregation=StringAgg("char_field", delimiter=";", default=None),
    
  181.             ),
    
  182.             {"aggregation": None},
    
  183.         )
    
  184.         self.assertEqual(
    
  185.             queryset.aggregate(
    
  186.                 aggregation=ArrayAgg("boolean_field", default=Value([]))
    
  187.             ),
    
  188.             {"aggregation": []},
    
  189.         )
    
  190.         self.assertEqual(
    
  191.             queryset.aggregate(
    
  192.                 aggregation=JSONBAgg("integer_field", default=Value("[]"))
    
  193.             ),
    
  194.             {"aggregation": []},
    
  195.         )
    
  196.         self.assertEqual(
    
  197.             queryset.aggregate(
    
  198.                 aggregation=StringAgg("char_field", delimiter=";", default=Value("")),
    
  199.             ),
    
  200.             {"aggregation": ""},
    
  201.         )
    
  202. 
    
  203.     def test_array_agg_charfield(self):
    
  204.         values = AggregateTestModel.objects.aggregate(arrayagg=ArrayAgg("char_field"))
    
  205.         self.assertEqual(values, {"arrayagg": ["Foo1", "Foo2", "Foo4", "Foo3"]})
    
  206. 
    
  207.     def test_array_agg_charfield_ordering(self):
    
  208.         ordering_test_cases = (
    
  209.             (F("char_field").desc(), ["Foo4", "Foo3", "Foo2", "Foo1"]),
    
  210.             (F("char_field").asc(), ["Foo1", "Foo2", "Foo3", "Foo4"]),
    
  211.             (F("char_field"), ["Foo1", "Foo2", "Foo3", "Foo4"]),
    
  212.             (
    
  213.                 [F("boolean_field"), F("char_field").desc()],
    
  214.                 ["Foo4", "Foo2", "Foo3", "Foo1"],
    
  215.             ),
    
  216.             (
    
  217.                 (F("boolean_field"), F("char_field").desc()),
    
  218.                 ["Foo4", "Foo2", "Foo3", "Foo1"],
    
  219.             ),
    
  220.             ("char_field", ["Foo1", "Foo2", "Foo3", "Foo4"]),
    
  221.             ("-char_field", ["Foo4", "Foo3", "Foo2", "Foo1"]),
    
  222.             (Concat("char_field", Value("@")), ["Foo1", "Foo2", "Foo3", "Foo4"]),
    
  223.             (Concat("char_field", Value("@")).desc(), ["Foo4", "Foo3", "Foo2", "Foo1"]),
    
  224.             (
    
  225.                 (
    
  226.                     Substr("char_field", 1, 1),
    
  227.                     F("integer_field"),
    
  228.                     Substr("char_field", 4, 1).desc(),
    
  229.                 ),
    
  230.                 ["Foo3", "Foo1", "Foo2", "Foo4"],
    
  231.             ),
    
  232.         )
    
  233.         for ordering, expected_output in ordering_test_cases:
    
  234.             with self.subTest(ordering=ordering, expected_output=expected_output):
    
  235.                 values = AggregateTestModel.objects.aggregate(
    
  236.                     arrayagg=ArrayAgg("char_field", ordering=ordering)
    
  237.                 )
    
  238.                 self.assertEqual(values, {"arrayagg": expected_output})
    
  239. 
    
  240.     def test_array_agg_integerfield(self):
    
  241.         values = AggregateTestModel.objects.aggregate(
    
  242.             arrayagg=ArrayAgg("integer_field")
    
  243.         )
    
  244.         self.assertEqual(values, {"arrayagg": [0, 1, 2, 0]})
    
  245. 
    
  246.     def test_array_agg_integerfield_ordering(self):
    
  247.         values = AggregateTestModel.objects.aggregate(
    
  248.             arrayagg=ArrayAgg("integer_field", ordering=F("integer_field").desc())
    
  249.         )
    
  250.         self.assertEqual(values, {"arrayagg": [2, 1, 0, 0]})
    
  251. 
    
  252.     def test_array_agg_booleanfield(self):
    
  253.         values = AggregateTestModel.objects.aggregate(
    
  254.             arrayagg=ArrayAgg("boolean_field")
    
  255.         )
    
  256.         self.assertEqual(values, {"arrayagg": [True, False, False, True]})
    
  257. 
    
  258.     def test_array_agg_booleanfield_ordering(self):
    
  259.         ordering_test_cases = (
    
  260.             (F("boolean_field").asc(), [False, False, True, True]),
    
  261.             (F("boolean_field").desc(), [True, True, False, False]),
    
  262.             (F("boolean_field"), [False, False, True, True]),
    
  263.         )
    
  264.         for ordering, expected_output in ordering_test_cases:
    
  265.             with self.subTest(ordering=ordering, expected_output=expected_output):
    
  266.                 values = AggregateTestModel.objects.aggregate(
    
  267.                     arrayagg=ArrayAgg("boolean_field", ordering=ordering)
    
  268.                 )
    
  269.                 self.assertEqual(values, {"arrayagg": expected_output})
    
  270. 
    
  271.     def test_array_agg_jsonfield(self):
    
  272.         values = AggregateTestModel.objects.aggregate(
    
  273.             arrayagg=ArrayAgg(
    
  274.                 KeyTransform("lang", "json_field"),
    
  275.                 filter=Q(json_field__lang__isnull=False),
    
  276.             ),
    
  277.         )
    
  278.         self.assertEqual(values, {"arrayagg": ["pl", "en"]})
    
  279. 
    
  280.     def test_array_agg_jsonfield_ordering(self):
    
  281.         values = AggregateTestModel.objects.aggregate(
    
  282.             arrayagg=ArrayAgg(
    
  283.                 KeyTransform("lang", "json_field"),
    
  284.                 filter=Q(json_field__lang__isnull=False),
    
  285.                 ordering=KeyTransform("lang", "json_field"),
    
  286.             ),
    
  287.         )
    
  288.         self.assertEqual(values, {"arrayagg": ["en", "pl"]})
    
  289. 
    
  290.     def test_array_agg_filter(self):
    
  291.         values = AggregateTestModel.objects.aggregate(
    
  292.             arrayagg=ArrayAgg("integer_field", filter=Q(integer_field__gt=0)),
    
  293.         )
    
  294.         self.assertEqual(values, {"arrayagg": [1, 2]})
    
  295. 
    
  296.     def test_array_agg_lookups(self):
    
  297.         aggr1 = AggregateTestModel.objects.create()
    
  298.         aggr2 = AggregateTestModel.objects.create()
    
  299.         StatTestModel.objects.create(related_field=aggr1, int1=1, int2=0)
    
  300.         StatTestModel.objects.create(related_field=aggr1, int1=2, int2=0)
    
  301.         StatTestModel.objects.create(related_field=aggr2, int1=3, int2=0)
    
  302.         StatTestModel.objects.create(related_field=aggr2, int1=4, int2=0)
    
  303.         qs = (
    
  304.             StatTestModel.objects.values("related_field")
    
  305.             .annotate(array=ArrayAgg("int1"))
    
  306.             .filter(array__overlap=[2])
    
  307.             .values_list("array", flat=True)
    
  308.         )
    
  309.         self.assertCountEqual(qs.get(), [1, 2])
    
  310. 
    
  311.     def test_bit_and_general(self):
    
  312.         values = AggregateTestModel.objects.filter(integer_field__in=[0, 1]).aggregate(
    
  313.             bitand=BitAnd("integer_field")
    
  314.         )
    
  315.         self.assertEqual(values, {"bitand": 0})
    
  316. 
    
  317.     def test_bit_and_on_only_true_values(self):
    
  318.         values = AggregateTestModel.objects.filter(integer_field=1).aggregate(
    
  319.             bitand=BitAnd("integer_field")
    
  320.         )
    
  321.         self.assertEqual(values, {"bitand": 1})
    
  322. 
    
  323.     def test_bit_and_on_only_false_values(self):
    
  324.         values = AggregateTestModel.objects.filter(integer_field=0).aggregate(
    
  325.             bitand=BitAnd("integer_field")
    
  326.         )
    
  327.         self.assertEqual(values, {"bitand": 0})
    
  328. 
    
  329.     def test_bit_or_general(self):
    
  330.         values = AggregateTestModel.objects.filter(integer_field__in=[0, 1]).aggregate(
    
  331.             bitor=BitOr("integer_field")
    
  332.         )
    
  333.         self.assertEqual(values, {"bitor": 1})
    
  334. 
    
  335.     def test_bit_or_on_only_true_values(self):
    
  336.         values = AggregateTestModel.objects.filter(integer_field=1).aggregate(
    
  337.             bitor=BitOr("integer_field")
    
  338.         )
    
  339.         self.assertEqual(values, {"bitor": 1})
    
  340. 
    
  341.     def test_bit_or_on_only_false_values(self):
    
  342.         values = AggregateTestModel.objects.filter(integer_field=0).aggregate(
    
  343.             bitor=BitOr("integer_field")
    
  344.         )
    
  345.         self.assertEqual(values, {"bitor": 0})
    
  346. 
    
  347.     @skipUnlessDBFeature("has_bit_xor")
    
  348.     def test_bit_xor_general(self):
    
  349.         AggregateTestModel.objects.create(integer_field=3)
    
  350.         values = AggregateTestModel.objects.filter(
    
  351.             integer_field__in=[1, 3],
    
  352.         ).aggregate(bitxor=BitXor("integer_field"))
    
  353.         self.assertEqual(values, {"bitxor": 2})
    
  354. 
    
  355.     @skipUnlessDBFeature("has_bit_xor")
    
  356.     def test_bit_xor_on_only_true_values(self):
    
  357.         values = AggregateTestModel.objects.filter(
    
  358.             integer_field=1,
    
  359.         ).aggregate(bitxor=BitXor("integer_field"))
    
  360.         self.assertEqual(values, {"bitxor": 1})
    
  361. 
    
  362.     @skipUnlessDBFeature("has_bit_xor")
    
  363.     def test_bit_xor_on_only_false_values(self):
    
  364.         values = AggregateTestModel.objects.filter(
    
  365.             integer_field=0,
    
  366.         ).aggregate(bitxor=BitXor("integer_field"))
    
  367.         self.assertEqual(values, {"bitxor": 0})
    
  368. 
    
  369.     def test_bool_and_general(self):
    
  370.         values = AggregateTestModel.objects.aggregate(booland=BoolAnd("boolean_field"))
    
  371.         self.assertEqual(values, {"booland": False})
    
  372. 
    
  373.     def test_bool_and_q_object(self):
    
  374.         values = AggregateTestModel.objects.aggregate(
    
  375.             booland=BoolAnd(Q(integer_field__gt=2)),
    
  376.         )
    
  377.         self.assertEqual(values, {"booland": False})
    
  378. 
    
  379.     def test_bool_or_general(self):
    
  380.         values = AggregateTestModel.objects.aggregate(boolor=BoolOr("boolean_field"))
    
  381.         self.assertEqual(values, {"boolor": True})
    
  382. 
    
  383.     def test_bool_or_q_object(self):
    
  384.         values = AggregateTestModel.objects.aggregate(
    
  385.             boolor=BoolOr(Q(integer_field__gt=2)),
    
  386.         )
    
  387.         self.assertEqual(values, {"boolor": False})
    
  388. 
    
  389.     def test_string_agg_requires_delimiter(self):
    
  390.         with self.assertRaises(TypeError):
    
  391.             AggregateTestModel.objects.aggregate(stringagg=StringAgg("char_field"))
    
  392. 
    
  393.     def test_string_agg_delimiter_escaping(self):
    
  394.         values = AggregateTestModel.objects.aggregate(
    
  395.             stringagg=StringAgg("char_field", delimiter="'")
    
  396.         )
    
  397.         self.assertEqual(values, {"stringagg": "Foo1'Foo2'Foo4'Foo3"})
    
  398. 
    
  399.     def test_string_agg_charfield(self):
    
  400.         values = AggregateTestModel.objects.aggregate(
    
  401.             stringagg=StringAgg("char_field", delimiter=";")
    
  402.         )
    
  403.         self.assertEqual(values, {"stringagg": "Foo1;Foo2;Foo4;Foo3"})
    
  404. 
    
  405.     def test_string_agg_default_output_field(self):
    
  406.         values = AggregateTestModel.objects.aggregate(
    
  407.             stringagg=StringAgg("text_field", delimiter=";"),
    
  408.         )
    
  409.         self.assertEqual(values, {"stringagg": "Text1;Text2;Text4;Text3"})
    
  410. 
    
  411.     def test_string_agg_charfield_ordering(self):
    
  412.         ordering_test_cases = (
    
  413.             (F("char_field").desc(), "Foo4;Foo3;Foo2;Foo1"),
    
  414.             (F("char_field").asc(), "Foo1;Foo2;Foo3;Foo4"),
    
  415.             (F("char_field"), "Foo1;Foo2;Foo3;Foo4"),
    
  416.             ("char_field", "Foo1;Foo2;Foo3;Foo4"),
    
  417.             ("-char_field", "Foo4;Foo3;Foo2;Foo1"),
    
  418.             (Concat("char_field", Value("@")), "Foo1;Foo2;Foo3;Foo4"),
    
  419.             (Concat("char_field", Value("@")).desc(), "Foo4;Foo3;Foo2;Foo1"),
    
  420.         )
    
  421.         for ordering, expected_output in ordering_test_cases:
    
  422.             with self.subTest(ordering=ordering, expected_output=expected_output):
    
  423.                 values = AggregateTestModel.objects.aggregate(
    
  424.                     stringagg=StringAgg("char_field", delimiter=";", ordering=ordering)
    
  425.                 )
    
  426.                 self.assertEqual(values, {"stringagg": expected_output})
    
  427. 
    
  428.     def test_string_agg_jsonfield_ordering(self):
    
  429.         values = AggregateTestModel.objects.aggregate(
    
  430.             stringagg=StringAgg(
    
  431.                 KeyTextTransform("lang", "json_field"),
    
  432.                 delimiter=";",
    
  433.                 ordering=KeyTextTransform("lang", "json_field"),
    
  434.                 output_field=CharField(),
    
  435.             ),
    
  436.         )
    
  437.         self.assertEqual(values, {"stringagg": "en;pl"})
    
  438. 
    
  439.     def test_string_agg_filter(self):
    
  440.         values = AggregateTestModel.objects.aggregate(
    
  441.             stringagg=StringAgg(
    
  442.                 "char_field",
    
  443.                 delimiter=";",
    
  444.                 filter=Q(char_field__endswith="3") | Q(char_field__endswith="1"),
    
  445.             )
    
  446.         )
    
  447.         self.assertEqual(values, {"stringagg": "Foo1;Foo3"})
    
  448. 
    
  449.     def test_orderable_agg_alternative_fields(self):
    
  450.         values = AggregateTestModel.objects.aggregate(
    
  451.             arrayagg=ArrayAgg("integer_field", ordering=F("char_field").asc())
    
  452.         )
    
  453.         self.assertEqual(values, {"arrayagg": [0, 1, 0, 2]})
    
  454. 
    
  455.     def test_jsonb_agg(self):
    
  456.         values = AggregateTestModel.objects.aggregate(jsonbagg=JSONBAgg("char_field"))
    
  457.         self.assertEqual(values, {"jsonbagg": ["Foo1", "Foo2", "Foo4", "Foo3"]})
    
  458. 
    
  459.     def test_jsonb_agg_charfield_ordering(self):
    
  460.         ordering_test_cases = (
    
  461.             (F("char_field").desc(), ["Foo4", "Foo3", "Foo2", "Foo1"]),
    
  462.             (F("char_field").asc(), ["Foo1", "Foo2", "Foo3", "Foo4"]),
    
  463.             (F("char_field"), ["Foo1", "Foo2", "Foo3", "Foo4"]),
    
  464.             ("char_field", ["Foo1", "Foo2", "Foo3", "Foo4"]),
    
  465.             ("-char_field", ["Foo4", "Foo3", "Foo2", "Foo1"]),
    
  466.             (Concat("char_field", Value("@")), ["Foo1", "Foo2", "Foo3", "Foo4"]),
    
  467.             (Concat("char_field", Value("@")).desc(), ["Foo4", "Foo3", "Foo2", "Foo1"]),
    
  468.         )
    
  469.         for ordering, expected_output in ordering_test_cases:
    
  470.             with self.subTest(ordering=ordering, expected_output=expected_output):
    
  471.                 values = AggregateTestModel.objects.aggregate(
    
  472.                     jsonbagg=JSONBAgg("char_field", ordering=ordering),
    
  473.                 )
    
  474.                 self.assertEqual(values, {"jsonbagg": expected_output})
    
  475. 
    
  476.     def test_jsonb_agg_integerfield_ordering(self):
    
  477.         values = AggregateTestModel.objects.aggregate(
    
  478.             jsonbagg=JSONBAgg("integer_field", ordering=F("integer_field").desc()),
    
  479.         )
    
  480.         self.assertEqual(values, {"jsonbagg": [2, 1, 0, 0]})
    
  481. 
    
  482.     def test_jsonb_agg_booleanfield_ordering(self):
    
  483.         ordering_test_cases = (
    
  484.             (F("boolean_field").asc(), [False, False, True, True]),
    
  485.             (F("boolean_field").desc(), [True, True, False, False]),
    
  486.             (F("boolean_field"), [False, False, True, True]),
    
  487.         )
    
  488.         for ordering, expected_output in ordering_test_cases:
    
  489.             with self.subTest(ordering=ordering, expected_output=expected_output):
    
  490.                 values = AggregateTestModel.objects.aggregate(
    
  491.                     jsonbagg=JSONBAgg("boolean_field", ordering=ordering),
    
  492.                 )
    
  493.                 self.assertEqual(values, {"jsonbagg": expected_output})
    
  494. 
    
  495.     def test_jsonb_agg_jsonfield_ordering(self):
    
  496.         values = AggregateTestModel.objects.aggregate(
    
  497.             jsonbagg=JSONBAgg(
    
  498.                 KeyTransform("lang", "json_field"),
    
  499.                 filter=Q(json_field__lang__isnull=False),
    
  500.                 ordering=KeyTransform("lang", "json_field"),
    
  501.             ),
    
  502.         )
    
  503.         self.assertEqual(values, {"jsonbagg": ["en", "pl"]})
    
  504. 
    
  505.     def test_jsonb_agg_key_index_transforms(self):
    
  506.         room101 = Room.objects.create(number=101)
    
  507.         room102 = Room.objects.create(number=102)
    
  508.         datetimes = [
    
  509.             timezone.datetime(2018, 6, 20),
    
  510.             timezone.datetime(2018, 6, 24),
    
  511.             timezone.datetime(2018, 6, 28),
    
  512.         ]
    
  513.         HotelReservation.objects.create(
    
  514.             datespan=(datetimes[0].date(), datetimes[1].date()),
    
  515.             start=datetimes[0],
    
  516.             end=datetimes[1],
    
  517.             room=room102,
    
  518.             requirements={"double_bed": True, "parking": True},
    
  519.         )
    
  520.         HotelReservation.objects.create(
    
  521.             datespan=(datetimes[1].date(), datetimes[2].date()),
    
  522.             start=datetimes[1],
    
  523.             end=datetimes[2],
    
  524.             room=room102,
    
  525.             requirements={"double_bed": False, "sea_view": True, "parking": False},
    
  526.         )
    
  527.         HotelReservation.objects.create(
    
  528.             datespan=(datetimes[0].date(), datetimes[2].date()),
    
  529.             start=datetimes[0],
    
  530.             end=datetimes[2],
    
  531.             room=room101,
    
  532.             requirements={"sea_view": False},
    
  533.         )
    
  534.         values = (
    
  535.             Room.objects.annotate(
    
  536.                 requirements=JSONBAgg(
    
  537.                     "hotelreservation__requirements",
    
  538.                     ordering="-hotelreservation__start",
    
  539.                 )
    
  540.             )
    
  541.             .filter(requirements__0__sea_view=True)
    
  542.             .values("number", "requirements")
    
  543.         )
    
  544.         self.assertSequenceEqual(
    
  545.             values,
    
  546.             [
    
  547.                 {
    
  548.                     "number": 102,
    
  549.                     "requirements": [
    
  550.                         {"double_bed": False, "sea_view": True, "parking": False},
    
  551.                         {"double_bed": True, "parking": True},
    
  552.                     ],
    
  553.                 },
    
  554.             ],
    
  555.         )
    
  556. 
    
  557.     def test_string_agg_array_agg_ordering_in_subquery(self):
    
  558.         stats = []
    
  559.         for i, agg in enumerate(AggregateTestModel.objects.order_by("char_field")):
    
  560.             stats.append(StatTestModel(related_field=agg, int1=i, int2=i + 1))
    
  561.             stats.append(StatTestModel(related_field=agg, int1=i + 1, int2=i))
    
  562.         StatTestModel.objects.bulk_create(stats)
    
  563. 
    
  564.         for aggregate, expected_result in (
    
  565.             (
    
  566.                 ArrayAgg("stattestmodel__int1", ordering="-stattestmodel__int2"),
    
  567.                 [
    
  568.                     ("Foo1", [0, 1]),
    
  569.                     ("Foo2", [1, 2]),
    
  570.                     ("Foo3", [2, 3]),
    
  571.                     ("Foo4", [3, 4]),
    
  572.                 ],
    
  573.             ),
    
  574.             (
    
  575.                 StringAgg(
    
  576.                     Cast("stattestmodel__int1", CharField()),
    
  577.                     delimiter=";",
    
  578.                     ordering="-stattestmodel__int2",
    
  579.                 ),
    
  580.                 [("Foo1", "0;1"), ("Foo2", "1;2"), ("Foo3", "2;3"), ("Foo4", "3;4")],
    
  581.             ),
    
  582.         ):
    
  583.             with self.subTest(aggregate=aggregate.__class__.__name__):
    
  584.                 subquery = (
    
  585.                     AggregateTestModel.objects.filter(
    
  586.                         pk=OuterRef("pk"),
    
  587.                     )
    
  588.                     .annotate(agg=aggregate)
    
  589.                     .values("agg")
    
  590.                 )
    
  591.                 values = (
    
  592.                     AggregateTestModel.objects.annotate(
    
  593.                         agg=Subquery(subquery),
    
  594.                     )
    
  595.                     .order_by("char_field")
    
  596.                     .values_list("char_field", "agg")
    
  597.                 )
    
  598.                 self.assertEqual(list(values), expected_result)
    
  599. 
    
  600.     def test_string_agg_array_agg_filter_in_subquery(self):
    
  601.         StatTestModel.objects.bulk_create(
    
  602.             [
    
  603.                 StatTestModel(related_field=self.aggs[0], int1=0, int2=5),
    
  604.                 StatTestModel(related_field=self.aggs[0], int1=1, int2=4),
    
  605.                 StatTestModel(related_field=self.aggs[0], int1=2, int2=3),
    
  606.             ]
    
  607.         )
    
  608.         for aggregate, expected_result in (
    
  609.             (
    
  610.                 ArrayAgg("stattestmodel__int1", filter=Q(stattestmodel__int2__gt=3)),
    
  611.                 [("Foo1", [0, 1]), ("Foo2", None)],
    
  612.             ),
    
  613.             (
    
  614.                 StringAgg(
    
  615.                     Cast("stattestmodel__int2", CharField()),
    
  616.                     delimiter=";",
    
  617.                     filter=Q(stattestmodel__int1__lt=2),
    
  618.                 ),
    
  619.                 [("Foo1", "5;4"), ("Foo2", None)],
    
  620.             ),
    
  621.         ):
    
  622.             with self.subTest(aggregate=aggregate.__class__.__name__):
    
  623.                 subquery = (
    
  624.                     AggregateTestModel.objects.filter(
    
  625.                         pk=OuterRef("pk"),
    
  626.                     )
    
  627.                     .annotate(agg=aggregate)
    
  628.                     .values("agg")
    
  629.                 )
    
  630.                 values = (
    
  631.                     AggregateTestModel.objects.annotate(
    
  632.                         agg=Subquery(subquery),
    
  633.                     )
    
  634.                     .filter(
    
  635.                         char_field__in=["Foo1", "Foo2"],
    
  636.                     )
    
  637.                     .order_by("char_field")
    
  638.                     .values_list("char_field", "agg")
    
  639.                 )
    
  640.                 self.assertEqual(list(values), expected_result)
    
  641. 
    
  642.     def test_string_agg_filter_in_subquery_with_exclude(self):
    
  643.         subquery = (
    
  644.             AggregateTestModel.objects.annotate(
    
  645.                 stringagg=StringAgg(
    
  646.                     "char_field",
    
  647.                     delimiter=";",
    
  648.                     filter=Q(char_field__endswith="1"),
    
  649.                 )
    
  650.             )
    
  651.             .exclude(stringagg="")
    
  652.             .values("id")
    
  653.         )
    
  654.         self.assertSequenceEqual(
    
  655.             AggregateTestModel.objects.filter(id__in=Subquery(subquery)),
    
  656.             [self.aggs[0]],
    
  657.         )
    
  658. 
    
  659.     def test_ordering_isnt_cleared_for_array_subquery(self):
    
  660.         inner_qs = AggregateTestModel.objects.order_by("-integer_field")
    
  661.         qs = AggregateTestModel.objects.annotate(
    
  662.             integers=Func(
    
  663.                 Subquery(inner_qs.values("integer_field")),
    
  664.                 function="ARRAY",
    
  665.                 output_field=ArrayField(base_field=IntegerField()),
    
  666.             ),
    
  667.         )
    
  668.         self.assertSequenceEqual(
    
  669.             qs.first().integers,
    
  670.             inner_qs.values_list("integer_field", flat=True),
    
  671.         )
    
  672. 
    
  673.     def test_window(self):
    
  674.         self.assertCountEqual(
    
  675.             AggregateTestModel.objects.annotate(
    
  676.                 integers=Window(
    
  677.                     expression=ArrayAgg("char_field"),
    
  678.                     partition_by=F("integer_field"),
    
  679.                 )
    
  680.             ).values("integers", "char_field"),
    
  681.             [
    
  682.                 {"integers": ["Foo1", "Foo3"], "char_field": "Foo1"},
    
  683.                 {"integers": ["Foo1", "Foo3"], "char_field": "Foo3"},
    
  684.                 {"integers": ["Foo2"], "char_field": "Foo2"},
    
  685.                 {"integers": ["Foo4"], "char_field": "Foo4"},
    
  686.             ],
    
  687.         )
    
  688. 
    
  689.     def test_values_list(self):
    
  690.         tests = [ArrayAgg("integer_field"), JSONBAgg("integer_field")]
    
  691.         for aggregation in tests:
    
  692.             with self.subTest(aggregation=aggregation):
    
  693.                 self.assertCountEqual(
    
  694.                     AggregateTestModel.objects.values_list(aggregation),
    
  695.                     [([0],), ([1],), ([2],), ([0],)],
    
  696.                 )
    
  697. 
    
  698. 
    
  699. class TestAggregateDistinct(PostgreSQLTestCase):
    
  700.     @classmethod
    
  701.     def setUpTestData(cls):
    
  702.         AggregateTestModel.objects.create(char_field="Foo")
    
  703.         AggregateTestModel.objects.create(char_field="Foo")
    
  704.         AggregateTestModel.objects.create(char_field="Bar")
    
  705. 
    
  706.     def test_string_agg_distinct_false(self):
    
  707.         values = AggregateTestModel.objects.aggregate(
    
  708.             stringagg=StringAgg("char_field", delimiter=" ", distinct=False)
    
  709.         )
    
  710.         self.assertEqual(values["stringagg"].count("Foo"), 2)
    
  711.         self.assertEqual(values["stringagg"].count("Bar"), 1)
    
  712. 
    
  713.     def test_string_agg_distinct_true(self):
    
  714.         values = AggregateTestModel.objects.aggregate(
    
  715.             stringagg=StringAgg("char_field", delimiter=" ", distinct=True)
    
  716.         )
    
  717.         self.assertEqual(values["stringagg"].count("Foo"), 1)
    
  718.         self.assertEqual(values["stringagg"].count("Bar"), 1)
    
  719. 
    
  720.     def test_array_agg_distinct_false(self):
    
  721.         values = AggregateTestModel.objects.aggregate(
    
  722.             arrayagg=ArrayAgg("char_field", distinct=False)
    
  723.         )
    
  724.         self.assertEqual(sorted(values["arrayagg"]), ["Bar", "Foo", "Foo"])
    
  725. 
    
  726.     def test_array_agg_distinct_true(self):
    
  727.         values = AggregateTestModel.objects.aggregate(
    
  728.             arrayagg=ArrayAgg("char_field", distinct=True)
    
  729.         )
    
  730.         self.assertEqual(sorted(values["arrayagg"]), ["Bar", "Foo"])
    
  731. 
    
  732.     def test_jsonb_agg_distinct_false(self):
    
  733.         values = AggregateTestModel.objects.aggregate(
    
  734.             jsonbagg=JSONBAgg("char_field", distinct=False),
    
  735.         )
    
  736.         self.assertEqual(sorted(values["jsonbagg"]), ["Bar", "Foo", "Foo"])
    
  737. 
    
  738.     def test_jsonb_agg_distinct_true(self):
    
  739.         values = AggregateTestModel.objects.aggregate(
    
  740.             jsonbagg=JSONBAgg("char_field", distinct=True),
    
  741.         )
    
  742.         self.assertEqual(sorted(values["jsonbagg"]), ["Bar", "Foo"])
    
  743. 
    
  744. 
    
  745. class TestStatisticsAggregate(PostgreSQLTestCase):
    
  746.     @classmethod
    
  747.     def setUpTestData(cls):
    
  748.         StatTestModel.objects.create(
    
  749.             int1=1,
    
  750.             int2=3,
    
  751.             related_field=AggregateTestModel.objects.create(integer_field=0),
    
  752.         )
    
  753.         StatTestModel.objects.create(
    
  754.             int1=2,
    
  755.             int2=2,
    
  756.             related_field=AggregateTestModel.objects.create(integer_field=1),
    
  757.         )
    
  758.         StatTestModel.objects.create(
    
  759.             int1=3,
    
  760.             int2=1,
    
  761.             related_field=AggregateTestModel.objects.create(integer_field=2),
    
  762.         )
    
  763. 
    
  764.     # Tests for base class (StatAggregate)
    
  765. 
    
  766.     def test_missing_arguments_raises_exception(self):
    
  767.         with self.assertRaisesMessage(ValueError, "Both y and x must be provided."):
    
  768.             StatAggregate(x=None, y=None)
    
  769. 
    
  770.     def test_correct_source_expressions(self):
    
  771.         func = StatAggregate(x="test", y=13)
    
  772.         self.assertIsInstance(func.source_expressions[0], Value)
    
  773.         self.assertIsInstance(func.source_expressions[1], F)
    
  774. 
    
  775.     def test_alias_is_required(self):
    
  776.         class SomeFunc(StatAggregate):
    
  777.             function = "TEST"
    
  778. 
    
  779.         with self.assertRaisesMessage(TypeError, "Complex aggregates require an alias"):
    
  780.             StatTestModel.objects.aggregate(SomeFunc(y="int2", x="int1"))
    
  781. 
    
  782.     # Test aggregates
    
  783. 
    
  784.     def test_empty_result_set(self):
    
  785.         StatTestModel.objects.all().delete()
    
  786.         tests = [
    
  787.             (Corr(y="int2", x="int1"), None),
    
  788.             (CovarPop(y="int2", x="int1"), None),
    
  789.             (CovarPop(y="int2", x="int1", sample=True), None),
    
  790.             (RegrAvgX(y="int2", x="int1"), None),
    
  791.             (RegrAvgY(y="int2", x="int1"), None),
    
  792.             (RegrCount(y="int2", x="int1"), 0),
    
  793.             (RegrIntercept(y="int2", x="int1"), None),
    
  794.             (RegrR2(y="int2", x="int1"), None),
    
  795.             (RegrSlope(y="int2", x="int1"), None),
    
  796.             (RegrSXX(y="int2", x="int1"), None),
    
  797.             (RegrSXY(y="int2", x="int1"), None),
    
  798.             (RegrSYY(y="int2", x="int1"), None),
    
  799.         ]
    
  800.         for aggregation, expected_result in tests:
    
  801.             with self.subTest(aggregation=aggregation):
    
  802.                 # Empty result with non-execution optimization.
    
  803.                 with self.assertNumQueries(0):
    
  804.                     values = StatTestModel.objects.none().aggregate(
    
  805.                         aggregation=aggregation,
    
  806.                     )
    
  807.                     self.assertEqual(values, {"aggregation": expected_result})
    
  808.                 # Empty result when query must be executed.
    
  809.                 with self.assertNumQueries(1):
    
  810.                     values = StatTestModel.objects.aggregate(
    
  811.                         aggregation=aggregation,
    
  812.                     )
    
  813.                     self.assertEqual(values, {"aggregation": expected_result})
    
  814. 
    
  815.     def test_default_argument(self):
    
  816.         StatTestModel.objects.all().delete()
    
  817.         tests = [
    
  818.             (Corr(y="int2", x="int1", default=0), 0),
    
  819.             (CovarPop(y="int2", x="int1", default=0), 0),
    
  820.             (CovarPop(y="int2", x="int1", sample=True, default=0), 0),
    
  821.             (RegrAvgX(y="int2", x="int1", default=0), 0),
    
  822.             (RegrAvgY(y="int2", x="int1", default=0), 0),
    
  823.             # RegrCount() doesn't support the default argument.
    
  824.             (RegrIntercept(y="int2", x="int1", default=0), 0),
    
  825.             (RegrR2(y="int2", x="int1", default=0), 0),
    
  826.             (RegrSlope(y="int2", x="int1", default=0), 0),
    
  827.             (RegrSXX(y="int2", x="int1", default=0), 0),
    
  828.             (RegrSXY(y="int2", x="int1", default=0), 0),
    
  829.             (RegrSYY(y="int2", x="int1", default=0), 0),
    
  830.         ]
    
  831.         for aggregation, expected_result in tests:
    
  832.             with self.subTest(aggregation=aggregation):
    
  833.                 # Empty result with non-execution optimization.
    
  834.                 with self.assertNumQueries(0):
    
  835.                     values = StatTestModel.objects.none().aggregate(
    
  836.                         aggregation=aggregation,
    
  837.                     )
    
  838.                     self.assertEqual(values, {"aggregation": expected_result})
    
  839.                 # Empty result when query must be executed.
    
  840.                 with self.assertNumQueries(1):
    
  841.                     values = StatTestModel.objects.aggregate(
    
  842.                         aggregation=aggregation,
    
  843.                     )
    
  844.                     self.assertEqual(values, {"aggregation": expected_result})
    
  845. 
    
  846.     def test_corr_general(self):
    
  847.         values = StatTestModel.objects.aggregate(corr=Corr(y="int2", x="int1"))
    
  848.         self.assertEqual(values, {"corr": -1.0})
    
  849. 
    
  850.     def test_covar_pop_general(self):
    
  851.         values = StatTestModel.objects.aggregate(covarpop=CovarPop(y="int2", x="int1"))
    
  852.         self.assertEqual(values, {"covarpop": Approximate(-0.66, places=1)})
    
  853. 
    
  854.     def test_covar_pop_sample(self):
    
  855.         values = StatTestModel.objects.aggregate(
    
  856.             covarpop=CovarPop(y="int2", x="int1", sample=True)
    
  857.         )
    
  858.         self.assertEqual(values, {"covarpop": -1.0})
    
  859. 
    
  860.     def test_regr_avgx_general(self):
    
  861.         values = StatTestModel.objects.aggregate(regravgx=RegrAvgX(y="int2", x="int1"))
    
  862.         self.assertEqual(values, {"regravgx": 2.0})
    
  863. 
    
  864.     def test_regr_avgy_general(self):
    
  865.         values = StatTestModel.objects.aggregate(regravgy=RegrAvgY(y="int2", x="int1"))
    
  866.         self.assertEqual(values, {"regravgy": 2.0})
    
  867. 
    
  868.     def test_regr_count_general(self):
    
  869.         values = StatTestModel.objects.aggregate(
    
  870.             regrcount=RegrCount(y="int2", x="int1")
    
  871.         )
    
  872.         self.assertEqual(values, {"regrcount": 3})
    
  873. 
    
  874.     def test_regr_count_default(self):
    
  875.         msg = "RegrCount does not allow default."
    
  876.         with self.assertRaisesMessage(TypeError, msg):
    
  877.             RegrCount(y="int2", x="int1", default=0)
    
  878. 
    
  879.     def test_regr_intercept_general(self):
    
  880.         values = StatTestModel.objects.aggregate(
    
  881.             regrintercept=RegrIntercept(y="int2", x="int1")
    
  882.         )
    
  883.         self.assertEqual(values, {"regrintercept": 4})
    
  884. 
    
  885.     def test_regr_r2_general(self):
    
  886.         values = StatTestModel.objects.aggregate(regrr2=RegrR2(y="int2", x="int1"))
    
  887.         self.assertEqual(values, {"regrr2": 1})
    
  888. 
    
  889.     def test_regr_slope_general(self):
    
  890.         values = StatTestModel.objects.aggregate(
    
  891.             regrslope=RegrSlope(y="int2", x="int1")
    
  892.         )
    
  893.         self.assertEqual(values, {"regrslope": -1})
    
  894. 
    
  895.     def test_regr_sxx_general(self):
    
  896.         values = StatTestModel.objects.aggregate(regrsxx=RegrSXX(y="int2", x="int1"))
    
  897.         self.assertEqual(values, {"regrsxx": 2.0})
    
  898. 
    
  899.     def test_regr_sxy_general(self):
    
  900.         values = StatTestModel.objects.aggregate(regrsxy=RegrSXY(y="int2", x="int1"))
    
  901.         self.assertEqual(values, {"regrsxy": -2.0})
    
  902. 
    
  903.     def test_regr_syy_general(self):
    
  904.         values = StatTestModel.objects.aggregate(regrsyy=RegrSYY(y="int2", x="int1"))
    
  905.         self.assertEqual(values, {"regrsyy": 2.0})
    
  906. 
    
  907.     def test_regr_avgx_with_related_obj_and_number_as_argument(self):
    
  908.         """
    
  909.         This is more complex test to check if JOIN on field and
    
  910.         number as argument works as expected.
    
  911.         """
    
  912.         values = StatTestModel.objects.aggregate(
    
  913.             complex_regravgx=RegrAvgX(y=5, x="related_field__integer_field")
    
  914.         )
    
  915.         self.assertEqual(values, {"complex_regravgx": 1.0})