1. import datetime
    
  2. from unittest import mock
    
  3. 
    
  4. from django.db import connections
    
  5. from django.db.models.sql.compiler import cursor_iter
    
  6. from django.test import TestCase
    
  7. 
    
  8. from .models import Article
    
  9. 
    
  10. 
    
  11. class QuerySetIteratorTests(TestCase):
    
  12.     itersize_index_in_mock_args = 3
    
  13. 
    
  14.     @classmethod
    
  15.     def setUpTestData(cls):
    
  16.         Article.objects.create(name="Article 1", created=datetime.datetime.now())
    
  17.         Article.objects.create(name="Article 2", created=datetime.datetime.now())
    
  18. 
    
  19.     def test_iterator_invalid_chunk_size(self):
    
  20.         for size in (0, -1):
    
  21.             with self.subTest(size=size):
    
  22.                 with self.assertRaisesMessage(
    
  23.                     ValueError, "Chunk size must be strictly positive."
    
  24.                 ):
    
  25.                     Article.objects.iterator(chunk_size=size)
    
  26. 
    
  27.     def test_default_iterator_chunk_size(self):
    
  28.         qs = Article.objects.iterator()
    
  29.         with mock.patch(
    
  30.             "django.db.models.sql.compiler.cursor_iter", side_effect=cursor_iter
    
  31.         ) as cursor_iter_mock:
    
  32.             next(qs)
    
  33.         self.assertEqual(cursor_iter_mock.call_count, 1)
    
  34.         mock_args, _mock_kwargs = cursor_iter_mock.call_args
    
  35.         self.assertEqual(mock_args[self.itersize_index_in_mock_args], 2000)
    
  36. 
    
  37.     def test_iterator_chunk_size(self):
    
  38.         batch_size = 3
    
  39.         qs = Article.objects.iterator(chunk_size=batch_size)
    
  40.         with mock.patch(
    
  41.             "django.db.models.sql.compiler.cursor_iter", side_effect=cursor_iter
    
  42.         ) as cursor_iter_mock:
    
  43.             next(qs)
    
  44.         self.assertEqual(cursor_iter_mock.call_count, 1)
    
  45.         mock_args, _mock_kwargs = cursor_iter_mock.call_args
    
  46.         self.assertEqual(mock_args[self.itersize_index_in_mock_args], batch_size)
    
  47. 
    
  48.     def test_no_chunked_reads(self):
    
  49.         """
    
  50.         If the database backend doesn't support chunked reads, then the
    
  51.         result of SQLCompiler.execute_sql() is a list.
    
  52.         """
    
  53.         qs = Article.objects.all()
    
  54.         compiler = qs.query.get_compiler(using=qs.db)
    
  55.         features = connections[qs.db].features
    
  56.         with mock.patch.object(features, "can_use_chunked_reads", False):
    
  57.             result = compiler.execute_sql(chunked_fetch=True)
    
  58.         self.assertIsInstance(result, list)