1. import operator
    
  2. import unittest
    
  3. from collections import namedtuple
    
  4. from contextlib import contextmanager
    
  5. 
    
  6. from django.db import connection, models
    
  7. from django.test import TestCase
    
  8. 
    
  9. from ..models import Person
    
  10. 
    
  11. 
    
  12. @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests")
    
  13. class ServerSideCursorsPostgres(TestCase):
    
  14.     cursor_fields = (
    
  15.         "name, statement, is_holdable, is_binary, is_scrollable, creation_time"
    
  16.     )
    
  17.     PostgresCursor = namedtuple("PostgresCursor", cursor_fields)
    
  18. 
    
  19.     @classmethod
    
  20.     def setUpTestData(cls):
    
  21.         Person.objects.create(first_name="a", last_name="a")
    
  22.         Person.objects.create(first_name="b", last_name="b")
    
  23. 
    
  24.     def inspect_cursors(self):
    
  25.         with connection.cursor() as cursor:
    
  26.             cursor.execute(
    
  27.                 "SELECT {fields} FROM pg_cursors;".format(fields=self.cursor_fields)
    
  28.             )
    
  29.             cursors = cursor.fetchall()
    
  30.         return [self.PostgresCursor._make(cursor) for cursor in cursors]
    
  31. 
    
  32.     @contextmanager
    
  33.     def override_db_setting(self, **kwargs):
    
  34.         for setting in kwargs:
    
  35.             original_value = connection.settings_dict.get(setting)
    
  36.             if setting in connection.settings_dict:
    
  37.                 self.addCleanup(
    
  38.                     operator.setitem, connection.settings_dict, setting, original_value
    
  39.                 )
    
  40.             else:
    
  41.                 self.addCleanup(operator.delitem, connection.settings_dict, setting)
    
  42. 
    
  43.             connection.settings_dict[setting] = kwargs[setting]
    
  44.             yield
    
  45. 
    
  46.     def assertUsesCursor(self, queryset, num_expected=1):
    
  47.         next(queryset)  # Open a server-side cursor
    
  48.         cursors = self.inspect_cursors()
    
  49.         self.assertEqual(len(cursors), num_expected)
    
  50.         for cursor in cursors:
    
  51.             self.assertIn("_django_curs_", cursor.name)
    
  52.             self.assertFalse(cursor.is_scrollable)
    
  53.             self.assertFalse(cursor.is_holdable)
    
  54.             self.assertFalse(cursor.is_binary)
    
  55. 
    
  56.     def asserNotUsesCursor(self, queryset):
    
  57.         self.assertUsesCursor(queryset, num_expected=0)
    
  58. 
    
  59.     def test_server_side_cursor(self):
    
  60.         self.assertUsesCursor(Person.objects.iterator())
    
  61. 
    
  62.     def test_values(self):
    
  63.         self.assertUsesCursor(Person.objects.values("first_name").iterator())
    
  64. 
    
  65.     def test_values_list(self):
    
  66.         self.assertUsesCursor(Person.objects.values_list("first_name").iterator())
    
  67. 
    
  68.     def test_values_list_flat(self):
    
  69.         self.assertUsesCursor(
    
  70.             Person.objects.values_list("first_name", flat=True).iterator()
    
  71.         )
    
  72. 
    
  73.     def test_values_list_fields_not_equal_to_names(self):
    
  74.         expr = models.Count("id")
    
  75.         self.assertUsesCursor(
    
  76.             Person.objects.annotate(id__count=expr)
    
  77.             .values_list(expr, "id__count")
    
  78.             .iterator()
    
  79.         )
    
  80. 
    
  81.     def test_server_side_cursor_many_cursors(self):
    
  82.         persons = Person.objects.iterator()
    
  83.         persons2 = Person.objects.iterator()
    
  84.         next(persons)  # Open a server-side cursor
    
  85.         self.assertUsesCursor(persons2, num_expected=2)
    
  86. 
    
  87.     def test_closed_server_side_cursor(self):
    
  88.         persons = Person.objects.iterator()
    
  89.         next(persons)  # Open a server-side cursor
    
  90.         del persons
    
  91.         cursors = self.inspect_cursors()
    
  92.         self.assertEqual(len(cursors), 0)
    
  93. 
    
  94.     def test_server_side_cursors_setting(self):
    
  95.         with self.override_db_setting(DISABLE_SERVER_SIDE_CURSORS=False):
    
  96.             persons = Person.objects.iterator()
    
  97.             self.assertUsesCursor(persons)
    
  98.             del persons  # Close server-side cursor
    
  99. 
    
  100.         with self.override_db_setting(DISABLE_SERVER_SIDE_CURSORS=True):
    
  101.             self.asserNotUsesCursor(Person.objects.iterator())