1. import json
    
  2. 
    
  3. from django.contrib.gis.db.models.fields import BaseSpatialField
    
  4. from django.contrib.gis.db.models.functions import Distance
    
  5. from django.contrib.gis.db.models.lookups import DistanceLookupBase, GISLookup
    
  6. from django.contrib.gis.gdal import GDALRaster
    
  7. from django.contrib.gis.geos import GEOSGeometry
    
  8. from django.contrib.gis.measure import D
    
  9. from django.contrib.gis.shortcuts import numpy
    
  10. from django.db import connection
    
  11. from django.db.models import F, Func, Q
    
  12. from django.test import TransactionTestCase, skipUnlessDBFeature
    
  13. from django.test.utils import CaptureQueriesContext
    
  14. 
    
  15. from ..data.rasters.textrasters import JSON_RASTER
    
  16. from .models import RasterModel, RasterRelatedModel
    
  17. 
    
  18. 
    
  19. @skipUnlessDBFeature("supports_raster")
    
  20. class RasterFieldTest(TransactionTestCase):
    
  21.     available_apps = ["gis_tests.rasterapp"]
    
  22. 
    
  23.     def setUp(self):
    
  24.         rast = GDALRaster(
    
  25.             {
    
  26.                 "srid": 4326,
    
  27.                 "origin": [0, 0],
    
  28.                 "scale": [-1, 1],
    
  29.                 "skew": [0, 0],
    
  30.                 "width": 5,
    
  31.                 "height": 5,
    
  32.                 "nr_of_bands": 2,
    
  33.                 "bands": [{"data": range(25)}, {"data": range(25, 50)}],
    
  34.             }
    
  35.         )
    
  36.         model_instance = RasterModel.objects.create(
    
  37.             rast=rast,
    
  38.             rastprojected=rast,
    
  39.             geom="POINT (-95.37040 29.70486)",
    
  40.         )
    
  41.         RasterRelatedModel.objects.create(rastermodel=model_instance)
    
  42. 
    
  43.     def test_field_null_value(self):
    
  44.         """
    
  45.         Test creating a model where the RasterField has a null value.
    
  46.         """
    
  47.         r = RasterModel.objects.create(rast=None)
    
  48.         r.refresh_from_db()
    
  49.         self.assertIsNone(r.rast)
    
  50. 
    
  51.     def test_access_band_data_directly_from_queryset(self):
    
  52.         RasterModel.objects.create(rast=JSON_RASTER)
    
  53.         qs = RasterModel.objects.all()
    
  54.         qs[0].rast.bands[0].data()
    
  55. 
    
  56.     def test_deserialize_with_pixeltype_flags(self):
    
  57.         no_data = 3
    
  58.         rast = GDALRaster(
    
  59.             {
    
  60.                 "srid": 4326,
    
  61.                 "origin": [0, 0],
    
  62.                 "scale": [-1, 1],
    
  63.                 "skew": [0, 0],
    
  64.                 "width": 1,
    
  65.                 "height": 1,
    
  66.                 "nr_of_bands": 1,
    
  67.                 "bands": [{"data": [no_data], "nodata_value": no_data}],
    
  68.             }
    
  69.         )
    
  70.         r = RasterModel.objects.create(rast=rast)
    
  71.         RasterModel.objects.filter(pk=r.pk).update(
    
  72.             rast=Func(F("rast"), function="ST_SetBandIsNoData"),
    
  73.         )
    
  74.         r.refresh_from_db()
    
  75.         band = r.rast.bands[0].data()
    
  76.         if numpy:
    
  77.             band = band.flatten().tolist()
    
  78.         self.assertEqual(band, [no_data])
    
  79.         self.assertEqual(r.rast.bands[0].nodata_value, no_data)
    
  80. 
    
  81.     def test_model_creation(self):
    
  82.         """
    
  83.         Test RasterField through a test model.
    
  84.         """
    
  85.         # Create model instance from JSON raster
    
  86.         r = RasterModel.objects.create(rast=JSON_RASTER)
    
  87.         r.refresh_from_db()
    
  88.         # Test raster metadata properties
    
  89.         self.assertEqual((5, 5), (r.rast.width, r.rast.height))
    
  90.         self.assertEqual([0.0, -1.0, 0.0, 0.0, 0.0, 1.0], r.rast.geotransform)
    
  91.         self.assertIsNone(r.rast.bands[0].nodata_value)
    
  92.         # Compare srs
    
  93.         self.assertEqual(r.rast.srs.srid, 4326)
    
  94.         # Compare pixel values
    
  95.         band = r.rast.bands[0].data()
    
  96.         # If numpy, convert result to list
    
  97.         if numpy:
    
  98.             band = band.flatten().tolist()
    
  99.         # Loop through rows in band data and assert single
    
  100.         # value is as expected.
    
  101.         self.assertEqual(
    
  102.             [
    
  103.                 0.0,
    
  104.                 1.0,
    
  105.                 2.0,
    
  106.                 3.0,
    
  107.                 4.0,
    
  108.                 5.0,
    
  109.                 6.0,
    
  110.                 7.0,
    
  111.                 8.0,
    
  112.                 9.0,
    
  113.                 10.0,
    
  114.                 11.0,
    
  115.                 12.0,
    
  116.                 13.0,
    
  117.                 14.0,
    
  118.                 15.0,
    
  119.                 16.0,
    
  120.                 17.0,
    
  121.                 18.0,
    
  122.                 19.0,
    
  123.                 20.0,
    
  124.                 21.0,
    
  125.                 22.0,
    
  126.                 23.0,
    
  127.                 24.0,
    
  128.             ],
    
  129.             band,
    
  130.         )
    
  131. 
    
  132.     def test_implicit_raster_transformation(self):
    
  133.         """
    
  134.         Test automatic transformation of rasters with srid different from the
    
  135.         field srid.
    
  136.         """
    
  137.         # Parse json raster
    
  138.         rast = json.loads(JSON_RASTER)
    
  139.         # Update srid to another value
    
  140.         rast["srid"] = 3086
    
  141.         # Save model and get it from db
    
  142.         r = RasterModel.objects.create(rast=rast)
    
  143.         r.refresh_from_db()
    
  144.         # Confirm raster has been transformed to the default srid
    
  145.         self.assertEqual(r.rast.srs.srid, 4326)
    
  146.         # Confirm geotransform is in lat/lon
    
  147.         expected = [
    
  148.             -87.9298551266551,
    
  149.             9.459646421449934e-06,
    
  150.             0.0,
    
  151.             23.94249275457565,
    
  152.             0.0,
    
  153.             -9.459646421449934e-06,
    
  154.         ]
    
  155.         for val, exp in zip(r.rast.geotransform, expected):
    
  156.             self.assertAlmostEqual(exp, val)
    
  157. 
    
  158.     def test_verbose_name_arg(self):
    
  159.         """
    
  160.         RasterField should accept a positional verbose name argument.
    
  161.         """
    
  162.         self.assertEqual(
    
  163.             RasterModel._meta.get_field("rast").verbose_name, "A Verbose Raster Name"
    
  164.         )
    
  165. 
    
  166.     def test_all_gis_lookups_with_rasters(self):
    
  167.         """
    
  168.         Evaluate all possible lookups for all input combinations (i.e.
    
  169.         raster-raster, raster-geom, geom-raster) and for projected and
    
  170.         unprojected coordinate systems. This test just checks that the lookup
    
  171.         can be called, but doesn't check if the result makes logical sense.
    
  172.         """
    
  173.         from django.contrib.gis.db.backends.postgis.operations import PostGISOperations
    
  174. 
    
  175.         # Create test raster and geom.
    
  176.         rast = GDALRaster(json.loads(JSON_RASTER))
    
  177.         stx_pnt = GEOSGeometry("POINT (-95.370401017314293 29.704867409475465)", 4326)
    
  178.         stx_pnt.transform(3086)
    
  179. 
    
  180.         lookups = [
    
  181.             (name, lookup)
    
  182.             for name, lookup in BaseSpatialField.get_lookups().items()
    
  183.             if issubclass(lookup, GISLookup)
    
  184.         ]
    
  185.         self.assertNotEqual(lookups, [], "No lookups found")
    
  186.         # Loop through all the GIS lookups.
    
  187.         for name, lookup in lookups:
    
  188.             # Construct lookup filter strings.
    
  189.             combo_keys = [
    
  190.                 field + name
    
  191.                 for field in [
    
  192.                     "rast__",
    
  193.                     "rast__",
    
  194.                     "rastprojected__0__",
    
  195.                     "rast__",
    
  196.                     "rastprojected__",
    
  197.                     "geom__",
    
  198.                     "rast__",
    
  199.                 ]
    
  200.             ]
    
  201.             if issubclass(lookup, DistanceLookupBase):
    
  202.                 # Set lookup values for distance lookups.
    
  203.                 combo_values = [
    
  204.                     (rast, 50, "spheroid"),
    
  205.                     (rast, 0, 50, "spheroid"),
    
  206.                     (rast, 0, D(km=1)),
    
  207.                     (stx_pnt, 0, 500),
    
  208.                     (stx_pnt, D(km=1000)),
    
  209.                     (rast, 500),
    
  210.                     (json.loads(JSON_RASTER), 500),
    
  211.                 ]
    
  212.             elif name == "relate":
    
  213.                 # Set lookup values for the relate lookup.
    
  214.                 combo_values = [
    
  215.                     (rast, "T*T***FF*"),
    
  216.                     (rast, 0, "T*T***FF*"),
    
  217.                     (rast, 0, "T*T***FF*"),
    
  218.                     (stx_pnt, 0, "T*T***FF*"),
    
  219.                     (stx_pnt, "T*T***FF*"),
    
  220.                     (rast, "T*T***FF*"),
    
  221.                     (json.loads(JSON_RASTER), "T*T***FF*"),
    
  222.                 ]
    
  223.             elif name == "isvalid":
    
  224.                 # The isvalid lookup doesn't make sense for rasters.
    
  225.                 continue
    
  226.             elif PostGISOperations.gis_operators[name].func:
    
  227.                 # Set lookup values for all function based operators.
    
  228.                 combo_values = [
    
  229.                     rast,
    
  230.                     (rast, 0),
    
  231.                     (rast, 0),
    
  232.                     (stx_pnt, 0),
    
  233.                     stx_pnt,
    
  234.                     rast,
    
  235.                     json.loads(JSON_RASTER),
    
  236.                 ]
    
  237.             else:
    
  238.                 # Override band lookup for these, as it's not supported.
    
  239.                 combo_keys[2] = "rastprojected__" + name
    
  240.                 # Set lookup values for all other operators.
    
  241.                 combo_values = [
    
  242.                     rast,
    
  243.                     None,
    
  244.                     rast,
    
  245.                     stx_pnt,
    
  246.                     stx_pnt,
    
  247.                     rast,
    
  248.                     json.loads(JSON_RASTER),
    
  249.                 ]
    
  250. 
    
  251.             # Create query filter combinations.
    
  252.             self.assertEqual(
    
  253.                 len(combo_keys),
    
  254.                 len(combo_values),
    
  255.                 "Number of lookup names and values should be the same",
    
  256.             )
    
  257.             combos = [x for x in zip(combo_keys, combo_values) if x[1]]
    
  258.             self.assertEqual(
    
  259.                 [(n, x) for n, x in enumerate(combos) if x in combos[:n]],
    
  260.                 [],
    
  261.                 "There are repeated test lookups",
    
  262.             )
    
  263.             combos = [{k: v} for k, v in combos]
    
  264. 
    
  265.             for combo in combos:
    
  266.                 # Apply this query filter.
    
  267.                 qs = RasterModel.objects.filter(**combo)
    
  268. 
    
  269.                 # Evaluate normal filter qs.
    
  270.                 self.assertIn(qs.count(), [0, 1])
    
  271. 
    
  272.             # Evaluate on conditional Q expressions.
    
  273.             qs = RasterModel.objects.filter(Q(**combos[0]) & Q(**combos[1]))
    
  274.             self.assertIn(qs.count(), [0, 1])
    
  275. 
    
  276.     def test_dwithin_gis_lookup_output_with_rasters(self):
    
  277.         """
    
  278.         Check the logical functionality of the dwithin lookup for different
    
  279.         input parameters.
    
  280.         """
    
  281.         # Create test raster and geom.
    
  282.         rast = GDALRaster(json.loads(JSON_RASTER))
    
  283.         stx_pnt = GEOSGeometry("POINT (-95.370401017314293 29.704867409475465)", 4326)
    
  284.         stx_pnt.transform(3086)
    
  285. 
    
  286.         # Filter raster with different lookup raster formats.
    
  287.         qs = RasterModel.objects.filter(rastprojected__dwithin=(rast, D(km=1)))
    
  288.         self.assertEqual(qs.count(), 1)
    
  289. 
    
  290.         qs = RasterModel.objects.filter(
    
  291.             rastprojected__dwithin=(json.loads(JSON_RASTER), D(km=1))
    
  292.         )
    
  293.         self.assertEqual(qs.count(), 1)
    
  294. 
    
  295.         qs = RasterModel.objects.filter(rastprojected__dwithin=(JSON_RASTER, D(km=1)))
    
  296.         self.assertEqual(qs.count(), 1)
    
  297. 
    
  298.         # Filter in an unprojected coordinate system.
    
  299.         qs = RasterModel.objects.filter(rast__dwithin=(rast, 40))
    
  300.         self.assertEqual(qs.count(), 1)
    
  301. 
    
  302.         # Filter with band index transform.
    
  303.         qs = RasterModel.objects.filter(rast__1__dwithin=(rast, 1, 40))
    
  304.         self.assertEqual(qs.count(), 1)
    
  305.         qs = RasterModel.objects.filter(rast__1__dwithin=(rast, 40))
    
  306.         self.assertEqual(qs.count(), 1)
    
  307.         qs = RasterModel.objects.filter(rast__dwithin=(rast, 1, 40))
    
  308.         self.assertEqual(qs.count(), 1)
    
  309. 
    
  310.         # Filter raster by geom.
    
  311.         qs = RasterModel.objects.filter(rast__dwithin=(stx_pnt, 500))
    
  312.         self.assertEqual(qs.count(), 1)
    
  313. 
    
  314.         qs = RasterModel.objects.filter(rastprojected__dwithin=(stx_pnt, D(km=10000)))
    
  315.         self.assertEqual(qs.count(), 1)
    
  316. 
    
  317.         qs = RasterModel.objects.filter(rast__dwithin=(stx_pnt, 5))
    
  318.         self.assertEqual(qs.count(), 0)
    
  319. 
    
  320.         qs = RasterModel.objects.filter(rastprojected__dwithin=(stx_pnt, D(km=100)))
    
  321.         self.assertEqual(qs.count(), 0)
    
  322. 
    
  323.         # Filter geom by raster.
    
  324.         qs = RasterModel.objects.filter(geom__dwithin=(rast, 500))
    
  325.         self.assertEqual(qs.count(), 1)
    
  326. 
    
  327.         # Filter through related model.
    
  328.         qs = RasterRelatedModel.objects.filter(rastermodel__rast__dwithin=(rast, 40))
    
  329.         self.assertEqual(qs.count(), 1)
    
  330. 
    
  331.         # Filter through related model with band index transform
    
  332.         qs = RasterRelatedModel.objects.filter(rastermodel__rast__1__dwithin=(rast, 40))
    
  333.         self.assertEqual(qs.count(), 1)
    
  334. 
    
  335.         # Filter through conditional statements.
    
  336.         qs = RasterModel.objects.filter(
    
  337.             Q(rast__dwithin=(rast, 40))
    
  338.             & Q(rastprojected__dwithin=(stx_pnt, D(km=10000)))
    
  339.         )
    
  340.         self.assertEqual(qs.count(), 1)
    
  341. 
    
  342.         # Filter through different lookup.
    
  343.         qs = RasterModel.objects.filter(rastprojected__bbcontains=rast)
    
  344.         self.assertEqual(qs.count(), 1)
    
  345. 
    
  346.     def test_lookup_input_tuple_too_long(self):
    
  347.         rast = GDALRaster(json.loads(JSON_RASTER))
    
  348.         msg = "Tuple too long for lookup bbcontains."
    
  349.         with self.assertRaisesMessage(ValueError, msg):
    
  350.             RasterModel.objects.filter(rast__bbcontains=(rast, 1, 2))
    
  351. 
    
  352.     def test_lookup_input_band_not_allowed(self):
    
  353.         rast = GDALRaster(json.loads(JSON_RASTER))
    
  354.         qs = RasterModel.objects.filter(rast__bbcontains=(rast, 1))
    
  355.         msg = "Band indices are not allowed for this operator, it works on bbox only."
    
  356.         with self.assertRaisesMessage(ValueError, msg):
    
  357.             qs.count()
    
  358. 
    
  359.     def test_isvalid_lookup_with_raster_error(self):
    
  360.         qs = RasterModel.objects.filter(rast__isvalid=True)
    
  361.         msg = (
    
  362.             "IsValid function requires a GeometryField in position 1, got RasterField."
    
  363.         )
    
  364.         with self.assertRaisesMessage(TypeError, msg):
    
  365.             qs.count()
    
  366. 
    
  367.     def test_result_of_gis_lookup_with_rasters(self):
    
  368.         # Point is in the interior
    
  369.         qs = RasterModel.objects.filter(
    
  370.             rast__contains=GEOSGeometry("POINT (-0.5 0.5)", 4326)
    
  371.         )
    
  372.         self.assertEqual(qs.count(), 1)
    
  373.         # Point is in the exterior
    
  374.         qs = RasterModel.objects.filter(
    
  375.             rast__contains=GEOSGeometry("POINT (0.5 0.5)", 4326)
    
  376.         )
    
  377.         self.assertEqual(qs.count(), 0)
    
  378.         # A point on the boundary is not contained properly
    
  379.         qs = RasterModel.objects.filter(
    
  380.             rast__contains_properly=GEOSGeometry("POINT (0 0)", 4326)
    
  381.         )
    
  382.         self.assertEqual(qs.count(), 0)
    
  383.         # Raster is located left of the point
    
  384.         qs = RasterModel.objects.filter(rast__left=GEOSGeometry("POINT (1 0)", 4326))
    
  385.         self.assertEqual(qs.count(), 1)
    
  386. 
    
  387.     def test_lookup_with_raster_bbox(self):
    
  388.         rast = GDALRaster(json.loads(JSON_RASTER))
    
  389.         # Shift raster upward
    
  390.         rast.origin.y = 2
    
  391.         # The raster in the model is not strictly below
    
  392.         qs = RasterModel.objects.filter(rast__strictly_below=rast)
    
  393.         self.assertEqual(qs.count(), 0)
    
  394.         # Shift raster further upward
    
  395.         rast.origin.y = 6
    
  396.         # The raster in the model is strictly below
    
  397.         qs = RasterModel.objects.filter(rast__strictly_below=rast)
    
  398.         self.assertEqual(qs.count(), 1)
    
  399. 
    
  400.     def test_lookup_with_polygonized_raster(self):
    
  401.         rast = GDALRaster(json.loads(JSON_RASTER))
    
  402.         # Move raster to overlap with the model point on the left side
    
  403.         rast.origin.x = -95.37040 + 1
    
  404.         rast.origin.y = 29.70486
    
  405.         # Raster overlaps with point in model
    
  406.         qs = RasterModel.objects.filter(geom__intersects=rast)
    
  407.         self.assertEqual(qs.count(), 1)
    
  408.         # Change left side of raster to be nodata values
    
  409.         rast.bands[0].data(data=[0, 0, 0, 1, 1], shape=(5, 1))
    
  410.         rast.bands[0].nodata_value = 0
    
  411.         qs = RasterModel.objects.filter(geom__intersects=rast)
    
  412.         # Raster does not overlap anymore after polygonization
    
  413.         # where the nodata zone is not included.
    
  414.         self.assertEqual(qs.count(), 0)
    
  415. 
    
  416.     def test_lookup_value_error(self):
    
  417.         # Test with invalid dict lookup parameter
    
  418.         obj = {}
    
  419.         msg = "Couldn't create spatial object from lookup value '%s'." % obj
    
  420.         with self.assertRaisesMessage(ValueError, msg):
    
  421.             RasterModel.objects.filter(geom__intersects=obj)
    
  422.         # Test with invalid string lookup parameter
    
  423.         obj = "00000"
    
  424.         msg = "Couldn't create spatial object from lookup value '%s'." % obj
    
  425.         with self.assertRaisesMessage(ValueError, msg):
    
  426.             RasterModel.objects.filter(geom__intersects=obj)
    
  427. 
    
  428.     def test_db_function_errors(self):
    
  429.         """
    
  430.         Errors are raised when using DB functions with raster content.
    
  431.         """
    
  432.         point = GEOSGeometry("SRID=3086;POINT (-697024.9213808845 683729.1705516104)")
    
  433.         rast = GDALRaster(json.loads(JSON_RASTER))
    
  434.         msg = "Distance function requires a geometric argument in position 2."
    
  435.         with self.assertRaisesMessage(TypeError, msg):
    
  436.             RasterModel.objects.annotate(distance_from_point=Distance("geom", rast))
    
  437.         with self.assertRaisesMessage(TypeError, msg):
    
  438.             RasterModel.objects.annotate(
    
  439.                 distance_from_point=Distance("rastprojected", rast)
    
  440.             )
    
  441.         msg = (
    
  442.             "Distance function requires a GeometryField in position 1, got RasterField."
    
  443.         )
    
  444.         with self.assertRaisesMessage(TypeError, msg):
    
  445.             RasterModel.objects.annotate(
    
  446.                 distance_from_point=Distance("rastprojected", point)
    
  447.             ).count()
    
  448. 
    
  449.     def test_lhs_with_index_rhs_without_index(self):
    
  450.         with CaptureQueriesContext(connection) as queries:
    
  451.             RasterModel.objects.filter(
    
  452.                 rast__0__contains=json.loads(JSON_RASTER)
    
  453.             ).exists()
    
  454.         # It's easier to check the indexes in the generated SQL than to write
    
  455.         # tests that cover all index combinations.
    
  456.         self.assertRegex(queries[-1]["sql"], r"WHERE ST_Contains\([^)]*, 1, [^)]*, 1\)")