1. import datetime
    
  2. 
    
  3. from django.core.exceptions import FieldDoesNotExist
    
  4. from django.db.models import F
    
  5. from django.db.models.functions import Lower
    
  6. from django.db.utils import IntegrityError
    
  7. from django.test import TestCase, override_settings, skipUnlessDBFeature
    
  8. 
    
  9. from .models import (
    
  10.     Article,
    
  11.     CustomDbColumn,
    
  12.     CustomPk,
    
  13.     Detail,
    
  14.     Food,
    
  15.     Individual,
    
  16.     JSONFieldNullable,
    
  17.     Member,
    
  18.     Note,
    
  19.     Number,
    
  20.     Order,
    
  21.     Paragraph,
    
  22.     RelatedObject,
    
  23.     SingleObject,
    
  24.     SpecialCategory,
    
  25.     Tag,
    
  26.     Valid,
    
  27. )
    
  28. 
    
  29. 
    
  30. class WriteToOtherRouter:
    
  31.     def db_for_write(self, model, **hints):
    
  32.         return "other"
    
  33. 
    
  34. 
    
  35. class BulkUpdateNoteTests(TestCase):
    
  36.     @classmethod
    
  37.     def setUpTestData(cls):
    
  38.         cls.notes = [Note.objects.create(note=str(i), misc=str(i)) for i in range(10)]
    
  39. 
    
  40.     def create_tags(self):
    
  41.         self.tags = [Tag.objects.create(name=str(i)) for i in range(10)]
    
  42. 
    
  43.     def test_simple(self):
    
  44.         for note in self.notes:
    
  45.             note.note = "test-%s" % note.id
    
  46.         with self.assertNumQueries(1):
    
  47.             Note.objects.bulk_update(self.notes, ["note"])
    
  48.         self.assertCountEqual(
    
  49.             Note.objects.values_list("note", flat=True),
    
  50.             [cat.note for cat in self.notes],
    
  51.         )
    
  52. 
    
  53.     def test_multiple_fields(self):
    
  54.         for note in self.notes:
    
  55.             note.note = "test-%s" % note.id
    
  56.             note.misc = "misc-%s" % note.id
    
  57.         with self.assertNumQueries(1):
    
  58.             Note.objects.bulk_update(self.notes, ["note", "misc"])
    
  59.         self.assertCountEqual(
    
  60.             Note.objects.values_list("note", flat=True),
    
  61.             [cat.note for cat in self.notes],
    
  62.         )
    
  63.         self.assertCountEqual(
    
  64.             Note.objects.values_list("misc", flat=True),
    
  65.             [cat.misc for cat in self.notes],
    
  66.         )
    
  67. 
    
  68.     def test_batch_size(self):
    
  69.         with self.assertNumQueries(len(self.notes)):
    
  70.             Note.objects.bulk_update(self.notes, fields=["note"], batch_size=1)
    
  71. 
    
  72.     def test_unsaved_models(self):
    
  73.         objs = self.notes + [Note(note="test", misc="test")]
    
  74.         msg = "All bulk_update() objects must have a primary key set."
    
  75.         with self.assertRaisesMessage(ValueError, msg):
    
  76.             Note.objects.bulk_update(objs, fields=["note"])
    
  77. 
    
  78.     def test_foreign_keys_do_not_lookup(self):
    
  79.         self.create_tags()
    
  80.         for note, tag in zip(self.notes, self.tags):
    
  81.             note.tag = tag
    
  82.         with self.assertNumQueries(1):
    
  83.             Note.objects.bulk_update(self.notes, ["tag"])
    
  84.         self.assertSequenceEqual(Note.objects.filter(tag__isnull=False), self.notes)
    
  85. 
    
  86.     def test_set_field_to_null(self):
    
  87.         self.create_tags()
    
  88.         Note.objects.update(tag=self.tags[0])
    
  89.         for note in self.notes:
    
  90.             note.tag = None
    
  91.         Note.objects.bulk_update(self.notes, ["tag"])
    
  92.         self.assertCountEqual(Note.objects.filter(tag__isnull=True), self.notes)
    
  93. 
    
  94.     def test_set_mixed_fields_to_null(self):
    
  95.         self.create_tags()
    
  96.         midpoint = len(self.notes) // 2
    
  97.         top, bottom = self.notes[:midpoint], self.notes[midpoint:]
    
  98.         for note in top:
    
  99.             note.tag = None
    
  100.         for note in bottom:
    
  101.             note.tag = self.tags[0]
    
  102.         Note.objects.bulk_update(self.notes, ["tag"])
    
  103.         self.assertCountEqual(Note.objects.filter(tag__isnull=True), top)
    
  104.         self.assertCountEqual(Note.objects.filter(tag__isnull=False), bottom)
    
  105. 
    
  106.     def test_functions(self):
    
  107.         Note.objects.update(note="TEST")
    
  108.         for note in self.notes:
    
  109.             note.note = Lower("note")
    
  110.         Note.objects.bulk_update(self.notes, ["note"])
    
  111.         self.assertEqual(set(Note.objects.values_list("note", flat=True)), {"test"})
    
  112. 
    
  113.     # Tests that use self.notes go here, otherwise put them in another class.
    
  114. 
    
  115. 
    
  116. class BulkUpdateTests(TestCase):
    
  117.     databases = {"default", "other"}
    
  118. 
    
  119.     def test_no_fields(self):
    
  120.         msg = "Field names must be given to bulk_update()."
    
  121.         with self.assertRaisesMessage(ValueError, msg):
    
  122.             Note.objects.bulk_update([], fields=[])
    
  123. 
    
  124.     def test_invalid_batch_size(self):
    
  125.         msg = "Batch size must be a positive integer."
    
  126.         with self.assertRaisesMessage(ValueError, msg):
    
  127.             Note.objects.bulk_update([], fields=["note"], batch_size=-1)
    
  128. 
    
  129.     def test_nonexistent_field(self):
    
  130.         with self.assertRaisesMessage(
    
  131.             FieldDoesNotExist, "Note has no field named 'nonexistent'"
    
  132.         ):
    
  133.             Note.objects.bulk_update([], ["nonexistent"])
    
  134. 
    
  135.     pk_fields_error = "bulk_update() cannot be used with primary key fields."
    
  136. 
    
  137.     def test_update_primary_key(self):
    
  138.         with self.assertRaisesMessage(ValueError, self.pk_fields_error):
    
  139.             Note.objects.bulk_update([], ["id"])
    
  140. 
    
  141.     def test_update_custom_primary_key(self):
    
  142.         with self.assertRaisesMessage(ValueError, self.pk_fields_error):
    
  143.             CustomPk.objects.bulk_update([], ["name"])
    
  144. 
    
  145.     def test_empty_objects(self):
    
  146.         with self.assertNumQueries(0):
    
  147.             rows_updated = Note.objects.bulk_update([], ["note"])
    
  148.         self.assertEqual(rows_updated, 0)
    
  149. 
    
  150.     def test_large_batch(self):
    
  151.         Note.objects.bulk_create(
    
  152.             [Note(note=str(i), misc=str(i)) for i in range(0, 2000)]
    
  153.         )
    
  154.         notes = list(Note.objects.all())
    
  155.         rows_updated = Note.objects.bulk_update(notes, ["note"])
    
  156.         self.assertEqual(rows_updated, 2000)
    
  157. 
    
  158.     def test_updated_rows_when_passing_duplicates(self):
    
  159.         note = Note.objects.create(note="test-note", misc="test")
    
  160.         rows_updated = Note.objects.bulk_update([note, note], ["note"])
    
  161.         self.assertEqual(rows_updated, 1)
    
  162.         # Duplicates in different batches.
    
  163.         rows_updated = Note.objects.bulk_update([note, note], ["note"], batch_size=1)
    
  164.         self.assertEqual(rows_updated, 2)
    
  165. 
    
  166.     def test_only_concrete_fields_allowed(self):
    
  167.         obj = Valid.objects.create(valid="test")
    
  168.         detail = Detail.objects.create(data="test")
    
  169.         paragraph = Paragraph.objects.create(text="test")
    
  170.         Member.objects.create(name="test", details=detail)
    
  171.         msg = "bulk_update() can only be used with concrete fields."
    
  172.         with self.assertRaisesMessage(ValueError, msg):
    
  173.             Detail.objects.bulk_update([detail], fields=["member"])
    
  174.         with self.assertRaisesMessage(ValueError, msg):
    
  175.             Paragraph.objects.bulk_update([paragraph], fields=["page"])
    
  176.         with self.assertRaisesMessage(ValueError, msg):
    
  177.             Valid.objects.bulk_update([obj], fields=["parent"])
    
  178. 
    
  179.     def test_custom_db_columns(self):
    
  180.         model = CustomDbColumn.objects.create(custom_column=1)
    
  181.         model.custom_column = 2
    
  182.         CustomDbColumn.objects.bulk_update([model], fields=["custom_column"])
    
  183.         model.refresh_from_db()
    
  184.         self.assertEqual(model.custom_column, 2)
    
  185. 
    
  186.     def test_custom_pk(self):
    
  187.         custom_pks = [
    
  188.             CustomPk.objects.create(name="pk-%s" % i, extra="") for i in range(10)
    
  189.         ]
    
  190.         for model in custom_pks:
    
  191.             model.extra = "extra-%s" % model.pk
    
  192.         CustomPk.objects.bulk_update(custom_pks, ["extra"])
    
  193.         self.assertCountEqual(
    
  194.             CustomPk.objects.values_list("extra", flat=True),
    
  195.             [cat.extra for cat in custom_pks],
    
  196.         )
    
  197. 
    
  198.     def test_falsey_pk_value(self):
    
  199.         order = Order.objects.create(pk=0, name="test")
    
  200.         order.name = "updated"
    
  201.         Order.objects.bulk_update([order], ["name"])
    
  202.         order.refresh_from_db()
    
  203.         self.assertEqual(order.name, "updated")
    
  204. 
    
  205.     def test_inherited_fields(self):
    
  206.         special_categories = [
    
  207.             SpecialCategory.objects.create(name=str(i), special_name=str(i))
    
  208.             for i in range(10)
    
  209.         ]
    
  210.         for category in special_categories:
    
  211.             category.name = "test-%s" % category.id
    
  212.             category.special_name = "special-test-%s" % category.special_name
    
  213.         SpecialCategory.objects.bulk_update(
    
  214.             special_categories, ["name", "special_name"]
    
  215.         )
    
  216.         self.assertCountEqual(
    
  217.             SpecialCategory.objects.values_list("name", flat=True),
    
  218.             [cat.name for cat in special_categories],
    
  219.         )
    
  220.         self.assertCountEqual(
    
  221.             SpecialCategory.objects.values_list("special_name", flat=True),
    
  222.             [cat.special_name for cat in special_categories],
    
  223.         )
    
  224. 
    
  225.     def test_field_references(self):
    
  226.         numbers = [Number.objects.create(num=0) for _ in range(10)]
    
  227.         for number in numbers:
    
  228.             number.num = F("num") + 1
    
  229.         Number.objects.bulk_update(numbers, ["num"])
    
  230.         self.assertCountEqual(Number.objects.filter(num=1), numbers)
    
  231. 
    
  232.     def test_f_expression(self):
    
  233.         notes = [
    
  234.             Note.objects.create(note="test_note", misc="test_misc") for _ in range(10)
    
  235.         ]
    
  236.         for note in notes:
    
  237.             note.misc = F("note")
    
  238.         Note.objects.bulk_update(notes, ["misc"])
    
  239.         self.assertCountEqual(Note.objects.filter(misc="test_note"), notes)
    
  240. 
    
  241.     def test_booleanfield(self):
    
  242.         individuals = [Individual.objects.create(alive=False) for _ in range(10)]
    
  243.         for individual in individuals:
    
  244.             individual.alive = True
    
  245.         Individual.objects.bulk_update(individuals, ["alive"])
    
  246.         self.assertCountEqual(Individual.objects.filter(alive=True), individuals)
    
  247. 
    
  248.     def test_ipaddressfield(self):
    
  249.         for ip in ("2001::1", "1.2.3.4"):
    
  250.             with self.subTest(ip=ip):
    
  251.                 models = [
    
  252.                     CustomDbColumn.objects.create(ip_address="0.0.0.0")
    
  253.                     for _ in range(10)
    
  254.                 ]
    
  255.                 for model in models:
    
  256.                     model.ip_address = ip
    
  257.                 CustomDbColumn.objects.bulk_update(models, ["ip_address"])
    
  258.                 self.assertCountEqual(
    
  259.                     CustomDbColumn.objects.filter(ip_address=ip), models
    
  260.                 )
    
  261. 
    
  262.     def test_datetime_field(self):
    
  263.         articles = [
    
  264.             Article.objects.create(name=str(i), created=datetime.datetime.today())
    
  265.             for i in range(10)
    
  266.         ]
    
  267.         point_in_time = datetime.datetime(1991, 10, 31)
    
  268.         for article in articles:
    
  269.             article.created = point_in_time
    
  270.         Article.objects.bulk_update(articles, ["created"])
    
  271.         self.assertCountEqual(Article.objects.filter(created=point_in_time), articles)
    
  272. 
    
  273.     @skipUnlessDBFeature("supports_json_field")
    
  274.     def test_json_field(self):
    
  275.         JSONFieldNullable.objects.bulk_create(
    
  276.             [JSONFieldNullable(json_field={"a": i}) for i in range(10)]
    
  277.         )
    
  278.         objs = JSONFieldNullable.objects.all()
    
  279.         for obj in objs:
    
  280.             obj.json_field = {"c": obj.json_field["a"] + 1}
    
  281.         JSONFieldNullable.objects.bulk_update(objs, ["json_field"])
    
  282.         self.assertCountEqual(
    
  283.             JSONFieldNullable.objects.filter(json_field__has_key="c"), objs
    
  284.         )
    
  285. 
    
  286.     def test_nullable_fk_after_related_save(self):
    
  287.         parent = RelatedObject.objects.create()
    
  288.         child = SingleObject()
    
  289.         parent.single = child
    
  290.         parent.single.save()
    
  291.         RelatedObject.objects.bulk_update([parent], fields=["single"])
    
  292.         self.assertEqual(parent.single_id, parent.single.pk)
    
  293.         parent.refresh_from_db()
    
  294.         self.assertEqual(parent.single, child)
    
  295. 
    
  296.     def test_unsaved_parent(self):
    
  297.         parent = RelatedObject.objects.create()
    
  298.         parent.single = SingleObject()
    
  299.         msg = (
    
  300.             "bulk_update() prohibited to prevent data loss due to unsaved "
    
  301.             "related object 'single'."
    
  302.         )
    
  303.         with self.assertRaisesMessage(ValueError, msg):
    
  304.             RelatedObject.objects.bulk_update([parent], fields=["single"])
    
  305. 
    
  306.     def test_unspecified_unsaved_parent(self):
    
  307.         parent = RelatedObject.objects.create()
    
  308.         parent.single = SingleObject()
    
  309.         parent.f = 42
    
  310.         RelatedObject.objects.bulk_update([parent], fields=["f"])
    
  311.         parent.refresh_from_db()
    
  312.         self.assertEqual(parent.f, 42)
    
  313.         self.assertIsNone(parent.single)
    
  314. 
    
  315.     @override_settings(DATABASE_ROUTERS=[WriteToOtherRouter()])
    
  316.     def test_database_routing(self):
    
  317.         note = Note.objects.create(note="create")
    
  318.         note.note = "bulk_update"
    
  319.         with self.assertNumQueries(1, using="other"):
    
  320.             Note.objects.bulk_update([note], fields=["note"])
    
  321. 
    
  322.     @override_settings(DATABASE_ROUTERS=[WriteToOtherRouter()])
    
  323.     def test_database_routing_batch_atomicity(self):
    
  324.         f1 = Food.objects.create(name="Banana")
    
  325.         f2 = Food.objects.create(name="Apple")
    
  326.         f1.name = "Kiwi"
    
  327.         f2.name = "Kiwi"
    
  328.         with self.assertRaises(IntegrityError):
    
  329.             Food.objects.bulk_update([f1, f2], fields=["name"], batch_size=1)
    
  330.         self.assertIs(Food.objects.filter(name="Kiwi").exists(), False)