1. import copy
    
  2. import unittest
    
  3. from functools import wraps
    
  4. from unittest import mock
    
  5. 
    
  6. from django.conf import settings
    
  7. from django.db import DEFAULT_DB_ALIAS, connection
    
  8. from django.db.models import Func
    
  9. 
    
  10. 
    
  11. def skipUnlessGISLookup(*gis_lookups):
    
  12.     """
    
  13.     Skip a test unless a database supports all of gis_lookups.
    
  14.     """
    
  15. 
    
  16.     def decorator(test_func):
    
  17.         @wraps(test_func)
    
  18.         def skip_wrapper(*args, **kwargs):
    
  19.             if any(key not in connection.ops.gis_operators for key in gis_lookups):
    
  20.                 raise unittest.SkipTest(
    
  21.                     "Database doesn't support all the lookups: %s"
    
  22.                     % ", ".join(gis_lookups)
    
  23.                 )
    
  24.             return test_func(*args, **kwargs)
    
  25. 
    
  26.         return skip_wrapper
    
  27. 
    
  28.     return decorator
    
  29. 
    
  30. 
    
  31. _default_db = settings.DATABASES[DEFAULT_DB_ALIAS]["ENGINE"].rsplit(".")[-1]
    
  32. # MySQL spatial indices can't handle NULL geometries.
    
  33. gisfield_may_be_null = _default_db != "mysql"
    
  34. 
    
  35. 
    
  36. class FuncTestMixin:
    
  37.     """Assert that Func expressions aren't mutated during their as_sql()."""
    
  38. 
    
  39.     def setUp(self):
    
  40.         def as_sql_wrapper(original_as_sql):
    
  41.             def inner(*args, **kwargs):
    
  42.                 func = original_as_sql.__self__
    
  43.                 # Resolve output_field before as_sql() so touching it in
    
  44.                 # as_sql() won't change __dict__.
    
  45.                 func.output_field
    
  46.                 __dict__original = copy.deepcopy(func.__dict__)
    
  47.                 result = original_as_sql(*args, **kwargs)
    
  48.                 msg = (
    
  49.                     "%s Func was mutated during compilation." % func.__class__.__name__
    
  50.                 )
    
  51.                 self.assertEqual(func.__dict__, __dict__original, msg)
    
  52.                 return result
    
  53. 
    
  54.             return inner
    
  55. 
    
  56.         def __getattribute__(self, name):
    
  57.             if name != vendor_impl:
    
  58.                 return __getattribute__original(self, name)
    
  59.             try:
    
  60.                 as_sql = __getattribute__original(self, vendor_impl)
    
  61.             except AttributeError:
    
  62.                 as_sql = __getattribute__original(self, "as_sql")
    
  63.             return as_sql_wrapper(as_sql)
    
  64. 
    
  65.         vendor_impl = "as_" + connection.vendor
    
  66.         __getattribute__original = Func.__getattribute__
    
  67.         self.func_patcher = mock.patch.object(
    
  68.             Func, "__getattribute__", __getattribute__
    
  69.         )
    
  70.         self.func_patcher.start()
    
  71.         super().setUp()
    
  72. 
    
  73.     def tearDown(self):
    
  74.         super().tearDown()
    
  75.         self.func_patcher.stop()