1. import os
    
  2. import re
    
  3. from datetime import datetime
    
  4. from pathlib import Path
    
  5. 
    
  6. from django.contrib.gis.gdal import DataSource, Envelope, GDALException, OGRGeometry
    
  7. from django.contrib.gis.gdal.field import OFTDateTime, OFTInteger, OFTReal, OFTString
    
  8. from django.contrib.gis.geos import GEOSGeometry
    
  9. from django.test import SimpleTestCase
    
  10. 
    
  11. from ..test_data import TEST_DATA, TestDS, get_ds_file
    
  12. 
    
  13. wgs_84_wkt = (
    
  14.     'GEOGCS["GCS_WGS_1984",DATUM["WGS_1984",SPHEROID["WGS_1984",'
    
  15.     '6378137,298.257223563]],PRIMEM["Greenwich",0],UNIT["Degree",'
    
  16.     "0.017453292519943295]]"
    
  17. )
    
  18. # Using a regex because of small differences depending on GDAL versions.
    
  19. wgs_84_wkt_regex = r'^GEOGCS\["(GCS_)?WGS[ _](19)?84".*$'
    
  20. 
    
  21. datetime_format = "%Y-%m-%dT%H:%M:%S"
    
  22. 
    
  23. # List of acceptable data sources.
    
  24. ds_list = (
    
  25.     TestDS(
    
  26.         "test_point",
    
  27.         nfeat=5,
    
  28.         nfld=3,
    
  29.         geom="POINT",
    
  30.         gtype=1,
    
  31.         driver="ESRI Shapefile",
    
  32.         fields={"dbl": OFTReal, "int": OFTInteger, "str": OFTString},
    
  33.         extent=(-1.35011, 0.166623, -0.524093, 0.824508),  # Got extent from QGIS
    
  34.         srs_wkt=wgs_84_wkt,
    
  35.         field_values={
    
  36.             "dbl": [float(i) for i in range(1, 6)],
    
  37.             "int": list(range(1, 6)),
    
  38.             "str": [str(i) for i in range(1, 6)],
    
  39.         },
    
  40.         fids=range(5),
    
  41.     ),
    
  42.     TestDS(
    
  43.         "test_vrt",
    
  44.         ext="vrt",
    
  45.         nfeat=3,
    
  46.         nfld=3,
    
  47.         geom="POINT",
    
  48.         gtype="Point25D",
    
  49.         driver="OGR_VRT",
    
  50.         fields={
    
  51.             "POINT_X": OFTString,
    
  52.             "POINT_Y": OFTString,
    
  53.             "NUM": OFTString,
    
  54.         },  # VRT uses CSV, which all types are OFTString.
    
  55.         extent=(1.0, 2.0, 100.0, 523.5),  # Min/Max from CSV
    
  56.         field_values={
    
  57.             "POINT_X": ["1.0", "5.0", "100.0"],
    
  58.             "POINT_Y": ["2.0", "23.0", "523.5"],
    
  59.             "NUM": ["5", "17", "23"],
    
  60.         },
    
  61.         fids=range(1, 4),
    
  62.     ),
    
  63.     TestDS(
    
  64.         "test_poly",
    
  65.         nfeat=3,
    
  66.         nfld=3,
    
  67.         geom="POLYGON",
    
  68.         gtype=3,
    
  69.         driver="ESRI Shapefile",
    
  70.         fields={"float": OFTReal, "int": OFTInteger, "str": OFTString},
    
  71.         extent=(-1.01513, -0.558245, 0.161876, 0.839637),  # Got extent from QGIS
    
  72.         srs_wkt=wgs_84_wkt,
    
  73.     ),
    
  74.     TestDS(
    
  75.         "has_nulls",
    
  76.         nfeat=3,
    
  77.         nfld=6,
    
  78.         geom="POLYGON",
    
  79.         gtype=3,
    
  80.         driver="GeoJSON",
    
  81.         ext="geojson",
    
  82.         fields={
    
  83.             "uuid": OFTString,
    
  84.             "name": OFTString,
    
  85.             "num": OFTReal,
    
  86.             "integer": OFTInteger,
    
  87.             "datetime": OFTDateTime,
    
  88.             "boolean": OFTInteger,
    
  89.         },
    
  90.         extent=(-75.274200, 39.846504, -74.959717, 40.119040),  # Got extent from QGIS
    
  91.         field_values={
    
  92.             "uuid": [
    
  93.                 "1378c26f-cbe6-44b0-929f-eb330d4991f5",
    
  94.                 "fa2ba67c-a135-4338-b924-a9622b5d869f",
    
  95.                 "4494c1f3-55ab-4256-b365-12115cb388d5",
    
  96.             ],
    
  97.             "name": ["Philadelphia", None, "north"],
    
  98.             "num": [1.001, None, 0.0],
    
  99.             "integer": [5, None, 8],
    
  100.             "boolean": [True, None, False],
    
  101.             "datetime": [
    
  102.                 datetime.strptime("1994-08-14T11:32:14", datetime_format),
    
  103.                 None,
    
  104.                 datetime.strptime("2018-11-29T03:02:52", datetime_format),
    
  105.             ],
    
  106.         },
    
  107.         fids=range(3),
    
  108.     ),
    
  109. )
    
  110. 
    
  111. bad_ds = (TestDS("foo"),)
    
  112. 
    
  113. 
    
  114. class DataSourceTest(SimpleTestCase):
    
  115.     def test01_valid_shp(self):
    
  116.         "Testing valid SHP Data Source files."
    
  117. 
    
  118.         for source in ds_list:
    
  119.             # Loading up the data source
    
  120.             ds = DataSource(source.ds)
    
  121. 
    
  122.             # The layer count is what's expected (only 1 layer in a SHP file).
    
  123.             self.assertEqual(1, len(ds))
    
  124. 
    
  125.             # Making sure GetName works
    
  126.             self.assertEqual(source.ds, ds.name)
    
  127. 
    
  128.             # Making sure the driver name matches up
    
  129.             self.assertEqual(source.driver, str(ds.driver))
    
  130. 
    
  131.             # Making sure indexing works
    
  132.             msg = "Index out of range when accessing layers in a datasource: %s."
    
  133.             with self.assertRaisesMessage(IndexError, msg % len(ds)):
    
  134.                 ds.__getitem__(len(ds))
    
  135. 
    
  136.             with self.assertRaisesMessage(
    
  137.                 IndexError, "Invalid OGR layer name given: invalid."
    
  138.             ):
    
  139.                 ds.__getitem__("invalid")
    
  140. 
    
  141.     def test_ds_input_pathlib(self):
    
  142.         test_shp = Path(get_ds_file("test_point", "shp"))
    
  143.         ds = DataSource(test_shp)
    
  144.         self.assertEqual(len(ds), 1)
    
  145. 
    
  146.     def test02_invalid_shp(self):
    
  147.         "Testing invalid SHP files for the Data Source."
    
  148.         for source in bad_ds:
    
  149.             with self.assertRaises(GDALException):
    
  150.                 DataSource(source.ds)
    
  151. 
    
  152.     def test03a_layers(self):
    
  153.         "Testing Data Source Layers."
    
  154.         for source in ds_list:
    
  155.             ds = DataSource(source.ds)
    
  156. 
    
  157.             # Incrementing through each layer, this tests DataSource.__iter__
    
  158.             for layer in ds:
    
  159.                 self.assertEqual(layer.name, source.name)
    
  160.                 self.assertEqual(str(layer), source.name)
    
  161.                 # Making sure we get the number of features we expect
    
  162.                 self.assertEqual(len(layer), source.nfeat)
    
  163. 
    
  164.                 # Making sure we get the number of fields we expect
    
  165.                 self.assertEqual(source.nfld, layer.num_fields)
    
  166.                 self.assertEqual(source.nfld, len(layer.fields))
    
  167. 
    
  168.                 # Testing the layer's extent (an Envelope), and its properties
    
  169.                 self.assertIsInstance(layer.extent, Envelope)
    
  170.                 self.assertAlmostEqual(source.extent[0], layer.extent.min_x, 5)
    
  171.                 self.assertAlmostEqual(source.extent[1], layer.extent.min_y, 5)
    
  172.                 self.assertAlmostEqual(source.extent[2], layer.extent.max_x, 5)
    
  173.                 self.assertAlmostEqual(source.extent[3], layer.extent.max_y, 5)
    
  174. 
    
  175.                 # Now checking the field names.
    
  176.                 flds = layer.fields
    
  177.                 for f in flds:
    
  178.                     self.assertIn(f, source.fields)
    
  179. 
    
  180.                 # Negative FIDs are not allowed.
    
  181.                 with self.assertRaisesMessage(
    
  182.                     IndexError, "Negative indices are not allowed on OGR Layers."
    
  183.                 ):
    
  184.                     layer.__getitem__(-1)
    
  185.                 with self.assertRaisesMessage(IndexError, "Invalid feature id: 50000."):
    
  186.                     layer.__getitem__(50000)
    
  187. 
    
  188.                 if hasattr(source, "field_values"):
    
  189.                     # Testing `Layer.get_fields` (which uses Layer.__iter__)
    
  190.                     for fld_name, fld_value in source.field_values.items():
    
  191.                         self.assertEqual(fld_value, layer.get_fields(fld_name))
    
  192. 
    
  193.                     # Testing `Layer.__getitem__`.
    
  194.                     for i, fid in enumerate(source.fids):
    
  195.                         feat = layer[fid]
    
  196.                         self.assertEqual(fid, feat.fid)
    
  197.                         # Maybe this should be in the test below, but we might
    
  198.                         # as well test the feature values here while in this
    
  199.                         # loop.
    
  200.                         for fld_name, fld_value in source.field_values.items():
    
  201.                             self.assertEqual(fld_value[i], feat.get(fld_name))
    
  202. 
    
  203.                         msg = (
    
  204.                             "Index out of range when accessing field in a feature: %s."
    
  205.                         )
    
  206.                         with self.assertRaisesMessage(IndexError, msg % len(feat)):
    
  207.                             feat.__getitem__(len(feat))
    
  208. 
    
  209.                         with self.assertRaisesMessage(
    
  210.                             IndexError, "Invalid OFT field name given: invalid."
    
  211.                         ):
    
  212.                             feat.__getitem__("invalid")
    
  213. 
    
  214.     def test03b_layer_slice(self):
    
  215.         "Test indexing and slicing on Layers."
    
  216.         # Using the first data-source because the same slice
    
  217.         # can be used for both the layer and the control values.
    
  218.         source = ds_list[0]
    
  219.         ds = DataSource(source.ds)
    
  220. 
    
  221.         sl = slice(1, 3)
    
  222.         feats = ds[0][sl]
    
  223. 
    
  224.         for fld_name in ds[0].fields:
    
  225.             test_vals = [feat.get(fld_name) for feat in feats]
    
  226.             control_vals = source.field_values[fld_name][sl]
    
  227.             self.assertEqual(control_vals, test_vals)
    
  228. 
    
  229.     def test03c_layer_references(self):
    
  230.         """
    
  231.         Ensure OGR objects keep references to the objects they belong to.
    
  232.         """
    
  233.         source = ds_list[0]
    
  234. 
    
  235.         # See ticket #9448.
    
  236.         def get_layer():
    
  237.             # This DataSource object is not accessible outside this
    
  238.             # scope.  However, a reference should still be kept alive
    
  239.             # on the `Layer` returned.
    
  240.             ds = DataSource(source.ds)
    
  241.             return ds[0]
    
  242. 
    
  243.         # Making sure we can call OGR routines on the Layer returned.
    
  244.         lyr = get_layer()
    
  245.         self.assertEqual(source.nfeat, len(lyr))
    
  246.         self.assertEqual(source.gtype, lyr.geom_type.num)
    
  247. 
    
  248.         # Same issue for Feature/Field objects, see #18640
    
  249.         self.assertEqual(str(lyr[0]["str"]), "1")
    
  250. 
    
  251.     def test04_features(self):
    
  252.         "Testing Data Source Features."
    
  253.         for source in ds_list:
    
  254.             ds = DataSource(source.ds)
    
  255. 
    
  256.             # Incrementing through each layer
    
  257.             for layer in ds:
    
  258.                 # Incrementing through each feature in the layer
    
  259.                 for feat in layer:
    
  260.                     # Making sure the number of fields, and the geometry type
    
  261.                     # are what's expected.
    
  262.                     self.assertEqual(source.nfld, len(list(feat)))
    
  263.                     self.assertEqual(source.gtype, feat.geom_type)
    
  264. 
    
  265.                     # Making sure the fields match to an appropriate OFT type.
    
  266.                     for k, v in source.fields.items():
    
  267.                         # Making sure we get the proper OGR Field instance, using
    
  268.                         # a string value index for the feature.
    
  269.                         self.assertIsInstance(feat[k], v)
    
  270.                     self.assertIsInstance(feat.fields[0], str)
    
  271. 
    
  272.                     # Testing Feature.__iter__
    
  273.                     for fld in feat:
    
  274.                         self.assertIn(fld.name, source.fields)
    
  275. 
    
  276.     def test05_geometries(self):
    
  277.         "Testing Geometries from Data Source Features."
    
  278.         for source in ds_list:
    
  279.             ds = DataSource(source.ds)
    
  280. 
    
  281.             # Incrementing through each layer and feature.
    
  282.             for layer in ds:
    
  283.                 geoms = layer.get_geoms()
    
  284.                 geos_geoms = layer.get_geoms(geos=True)
    
  285.                 self.assertEqual(len(geoms), len(geos_geoms))
    
  286.                 self.assertEqual(len(geoms), len(layer))
    
  287.                 for feat, geom, geos_geom in zip(layer, geoms, geos_geoms):
    
  288.                     g = feat.geom
    
  289.                     self.assertEqual(geom, g)
    
  290.                     self.assertIsInstance(geos_geom, GEOSGeometry)
    
  291.                     self.assertEqual(g, geos_geom.ogr)
    
  292.                     # Making sure we get the right Geometry name & type
    
  293.                     self.assertEqual(source.geom, g.geom_name)
    
  294.                     self.assertEqual(source.gtype, g.geom_type)
    
  295. 
    
  296.                     # Making sure the SpatialReference is as expected.
    
  297.                     if hasattr(source, "srs_wkt"):
    
  298.                         self.assertIsNotNone(re.match(wgs_84_wkt_regex, g.srs.wkt))
    
  299. 
    
  300.     def test06_spatial_filter(self):
    
  301.         "Testing the Layer.spatial_filter property."
    
  302.         ds = DataSource(get_ds_file("cities", "shp"))
    
  303.         lyr = ds[0]
    
  304. 
    
  305.         # When not set, it should be None.
    
  306.         self.assertIsNone(lyr.spatial_filter)
    
  307. 
    
  308.         # Must be set a/an OGRGeometry or 4-tuple.
    
  309.         with self.assertRaises(TypeError):
    
  310.             lyr._set_spatial_filter("foo")
    
  311. 
    
  312.         # Setting the spatial filter with a tuple/list with the extent of
    
  313.         # a buffer centering around Pueblo.
    
  314.         with self.assertRaises(ValueError):
    
  315.             lyr._set_spatial_filter(list(range(5)))
    
  316.         filter_extent = (-105.609252, 37.255001, -103.609252, 39.255001)
    
  317.         lyr.spatial_filter = (-105.609252, 37.255001, -103.609252, 39.255001)
    
  318.         self.assertEqual(OGRGeometry.from_bbox(filter_extent), lyr.spatial_filter)
    
  319.         feats = [feat for feat in lyr]
    
  320.         self.assertEqual(1, len(feats))
    
  321.         self.assertEqual("Pueblo", feats[0].get("Name"))
    
  322. 
    
  323.         # Setting the spatial filter with an OGRGeometry for buffer centering
    
  324.         # around Houston.
    
  325.         filter_geom = OGRGeometry(
    
  326.             "POLYGON((-96.363151 28.763374,-94.363151 28.763374,"
    
  327.             "-94.363151 30.763374,-96.363151 30.763374,-96.363151 28.763374))"
    
  328.         )
    
  329.         lyr.spatial_filter = filter_geom
    
  330.         self.assertEqual(filter_geom, lyr.spatial_filter)
    
  331.         feats = [feat for feat in lyr]
    
  332.         self.assertEqual(1, len(feats))
    
  333.         self.assertEqual("Houston", feats[0].get("Name"))
    
  334. 
    
  335.         # Clearing the spatial filter by setting it to None.  Now
    
  336.         # should indicate that there are 3 features in the Layer.
    
  337.         lyr.spatial_filter = None
    
  338.         self.assertEqual(3, len(lyr))
    
  339. 
    
  340.     def test07_integer_overflow(self):
    
  341.         "Testing that OFTReal fields, treated as OFTInteger, do not overflow."
    
  342.         # Using *.dbf from Census 2010 TIGER Shapefile for Texas,
    
  343.         # which has land area ('ALAND10') stored in a Real field
    
  344.         # with no precision.
    
  345.         ds = DataSource(os.path.join(TEST_DATA, "texas.dbf"))
    
  346.         feat = ds[0][0]
    
  347.         # Reference value obtained using `ogrinfo`.
    
  348.         self.assertEqual(676586997978, feat.get("ALAND10"))
    
  349. 
    
  350.     def test_nonexistent_field(self):
    
  351.         source = ds_list[0]
    
  352.         ds = DataSource(source.ds)
    
  353.         msg = "invalid field name: nonexistent"
    
  354.         with self.assertRaisesMessage(GDALException, msg):
    
  355.             ds[0].get_fields("nonexistent")