1. from django.db import connection, models
    
  2. from django.test import SimpleTestCase
    
  3. 
    
  4. from .utils import FuncTestMixin
    
  5. 
    
  6. 
    
  7. def test_mutation(raises=True):
    
  8.     def wrapper(mutation_func):
    
  9.         def test(test_case_instance, *args, **kwargs):
    
  10.             class TestFunc(models.Func):
    
  11.                 output_field = models.IntegerField()
    
  12. 
    
  13.                 def __init__(self):
    
  14.                     self.attribute = "initial"
    
  15.                     super().__init__("initial", ["initial"])
    
  16. 
    
  17.                 def as_sql(self, *args, **kwargs):
    
  18.                     mutation_func(self)
    
  19.                     return "", ()
    
  20. 
    
  21.             if raises:
    
  22.                 msg = "TestFunc Func was mutated during compilation."
    
  23.                 with test_case_instance.assertRaisesMessage(AssertionError, msg):
    
  24.                     getattr(TestFunc(), "as_" + connection.vendor)(None, None)
    
  25.             else:
    
  26.                 getattr(TestFunc(), "as_" + connection.vendor)(None, None)
    
  27. 
    
  28.         return test
    
  29. 
    
  30.     return wrapper
    
  31. 
    
  32. 
    
  33. class FuncTestMixinTests(FuncTestMixin, SimpleTestCase):
    
  34.     @test_mutation()
    
  35.     def test_mutated_attribute(func):
    
  36.         func.attribute = "mutated"
    
  37. 
    
  38.     @test_mutation()
    
  39.     def test_mutated_expressions(func):
    
  40.         func.source_expressions.clear()
    
  41. 
    
  42.     @test_mutation()
    
  43.     def test_mutated_expression(func):
    
  44.         func.source_expressions[0].name = "mutated"
    
  45. 
    
  46.     @test_mutation()
    
  47.     def test_mutated_expression_deep(func):
    
  48.         func.source_expressions[1].value[0] = "mutated"
    
  49. 
    
  50.     @test_mutation(raises=False)
    
  51.     def test_not_mutated(func):
    
  52.         pass