1. import os
    
  2. import re
    
  3. from io import StringIO
    
  4. 
    
  5. from django.contrib.gis.gdal import GDAL_VERSION, Driver, GDALException
    
  6. from django.contrib.gis.utils.ogrinspect import ogrinspect
    
  7. from django.core.management import call_command
    
  8. from django.db import connection, connections
    
  9. from django.test import SimpleTestCase, TestCase, skipUnlessDBFeature
    
  10. from django.test.utils import modify_settings
    
  11. 
    
  12. from ..test_data import TEST_DATA
    
  13. from .models import AllOGRFields
    
  14. 
    
  15. 
    
  16. class InspectDbTests(TestCase):
    
  17.     def test_geom_columns(self):
    
  18.         """
    
  19.         Test the geo-enabled inspectdb command.
    
  20.         """
    
  21.         out = StringIO()
    
  22.         call_command(
    
  23.             "inspectdb",
    
  24.             table_name_filter=lambda tn: tn == "inspectapp_allogrfields",
    
  25.             stdout=out,
    
  26.         )
    
  27.         output = out.getvalue()
    
  28.         if connection.features.supports_geometry_field_introspection:
    
  29.             self.assertIn("geom = models.PolygonField()", output)
    
  30.             self.assertIn("point = models.PointField()", output)
    
  31.         else:
    
  32.             self.assertIn("geom = models.GeometryField(", output)
    
  33.             self.assertIn("point = models.GeometryField(", output)
    
  34. 
    
  35.     @skipUnlessDBFeature("supports_3d_storage")
    
  36.     def test_3d_columns(self):
    
  37.         out = StringIO()
    
  38.         call_command(
    
  39.             "inspectdb",
    
  40.             table_name_filter=lambda tn: tn == "inspectapp_fields3d",
    
  41.             stdout=out,
    
  42.         )
    
  43.         output = out.getvalue()
    
  44.         if connection.features.supports_geometry_field_introspection:
    
  45.             self.assertIn("point = models.PointField(dim=3)", output)
    
  46.             if connection.features.supports_geography:
    
  47.                 self.assertIn(
    
  48.                     "pointg = models.PointField(geography=True, dim=3)", output
    
  49.                 )
    
  50.             else:
    
  51.                 self.assertIn("pointg = models.PointField(dim=3)", output)
    
  52.             self.assertIn("line = models.LineStringField(dim=3)", output)
    
  53.             self.assertIn("poly = models.PolygonField(dim=3)", output)
    
  54.         else:
    
  55.             self.assertIn("point = models.GeometryField(", output)
    
  56.             self.assertIn("pointg = models.GeometryField(", output)
    
  57.             self.assertIn("line = models.GeometryField(", output)
    
  58.             self.assertIn("poly = models.GeometryField(", output)
    
  59. 
    
  60. 
    
  61. @modify_settings(
    
  62.     INSTALLED_APPS={"append": "django.contrib.gis"},
    
  63. )
    
  64. class OGRInspectTest(SimpleTestCase):
    
  65.     maxDiff = 1024
    
  66. 
    
  67.     def test_poly(self):
    
  68.         shp_file = os.path.join(TEST_DATA, "test_poly", "test_poly.shp")
    
  69.         model_def = ogrinspect(shp_file, "MyModel")
    
  70. 
    
  71.         expected = [
    
  72.             "# This is an auto-generated Django model module created by ogrinspect.",
    
  73.             "from django.contrib.gis.db import models",
    
  74.             "",
    
  75.             "",
    
  76.             "class MyModel(models.Model):",
    
  77.             "    float = models.FloatField()",
    
  78.             "    int = models.BigIntegerField()",
    
  79.             "    str = models.CharField(max_length=80)",
    
  80.             "    geom = models.PolygonField()",
    
  81.         ]
    
  82. 
    
  83.         self.assertEqual(model_def, "\n".join(expected))
    
  84. 
    
  85.     def test_poly_multi(self):
    
  86.         shp_file = os.path.join(TEST_DATA, "test_poly", "test_poly.shp")
    
  87.         model_def = ogrinspect(shp_file, "MyModel", multi_geom=True)
    
  88.         self.assertIn("geom = models.MultiPolygonField()", model_def)
    
  89.         # Same test with a 25D-type geometry field
    
  90.         shp_file = os.path.join(TEST_DATA, "gas_lines", "gas_leitung.shp")
    
  91.         model_def = ogrinspect(shp_file, "MyModel", multi_geom=True)
    
  92.         srid = "-1" if GDAL_VERSION < (2, 3) else "31253"
    
  93.         self.assertIn("geom = models.MultiLineStringField(srid=%s)" % srid, model_def)
    
  94. 
    
  95.     def test_date_field(self):
    
  96.         shp_file = os.path.join(TEST_DATA, "cities", "cities.shp")
    
  97.         model_def = ogrinspect(shp_file, "City")
    
  98. 
    
  99.         expected = [
    
  100.             "# This is an auto-generated Django model module created by ogrinspect.",
    
  101.             "from django.contrib.gis.db import models",
    
  102.             "",
    
  103.             "",
    
  104.             "class City(models.Model):",
    
  105.             "    name = models.CharField(max_length=80)",
    
  106.             "    population = models.BigIntegerField()",
    
  107.             "    density = models.FloatField()",
    
  108.             "    created = models.DateField()",
    
  109.             "    geom = models.PointField()",
    
  110.         ]
    
  111. 
    
  112.         self.assertEqual(model_def, "\n".join(expected))
    
  113. 
    
  114.     def test_time_field(self):
    
  115.         # Getting the database identifier used by OGR, if None returned
    
  116.         # GDAL does not have the support compiled in.
    
  117.         ogr_db = get_ogr_db_string()
    
  118.         if not ogr_db:
    
  119.             self.skipTest("Unable to setup an OGR connection to your database")
    
  120. 
    
  121.         try:
    
  122.             # Writing shapefiles via GDAL currently does not support writing OGRTime
    
  123.             # fields, so we need to actually use a database
    
  124.             model_def = ogrinspect(
    
  125.                 ogr_db,
    
  126.                 "Measurement",
    
  127.                 layer_key=AllOGRFields._meta.db_table,
    
  128.                 decimal=["f_decimal"],
    
  129.             )
    
  130.         except GDALException:
    
  131.             self.skipTest("Unable to setup an OGR connection to your database")
    
  132. 
    
  133.         self.assertTrue(
    
  134.             model_def.startswith(
    
  135.                 "# This is an auto-generated Django model module created by "
    
  136.                 "ogrinspect.\n"
    
  137.                 "from django.contrib.gis.db import models\n"
    
  138.                 "\n"
    
  139.                 "\n"
    
  140.                 "class Measurement(models.Model):\n"
    
  141.             )
    
  142.         )
    
  143. 
    
  144.         # The ordering of model fields might vary depending on several factors
    
  145.         # (version of GDAL, etc.).
    
  146.         if connection.vendor == "sqlite" and GDAL_VERSION < (3, 4):
    
  147.             # SpatiaLite introspection is somewhat lacking on GDAL < 3.4 (#29461).
    
  148.             self.assertIn("    f_decimal = models.CharField(max_length=0)", model_def)
    
  149.         else:
    
  150.             self.assertIn(
    
  151.                 "    f_decimal = models.DecimalField(max_digits=0, decimal_places=0)",
    
  152.                 model_def,
    
  153.             )
    
  154.         self.assertIn("    f_int = models.IntegerField()", model_def)
    
  155.         if not connection.ops.mariadb:
    
  156.             # Probably a bug between GDAL and MariaDB on time fields.
    
  157.             self.assertIn("    f_datetime = models.DateTimeField()", model_def)
    
  158.             self.assertIn("    f_time = models.TimeField()", model_def)
    
  159.         if connection.vendor == "sqlite" and GDAL_VERSION < (3, 4):
    
  160.             self.assertIn("    f_float = models.CharField(max_length=0)", model_def)
    
  161.         else:
    
  162.             self.assertIn("    f_float = models.FloatField()", model_def)
    
  163.         max_length = 0 if connection.vendor == "sqlite" else 10
    
  164.         self.assertIn(
    
  165.             "    f_char = models.CharField(max_length=%s)" % max_length, model_def
    
  166.         )
    
  167.         self.assertIn("    f_date = models.DateField()", model_def)
    
  168. 
    
  169.         # Some backends may have srid=-1
    
  170.         self.assertIsNotNone(
    
  171.             re.search(r"    geom = models.PolygonField\(([^\)])*\)", model_def)
    
  172.         )
    
  173. 
    
  174.     def test_management_command(self):
    
  175.         shp_file = os.path.join(TEST_DATA, "cities", "cities.shp")
    
  176.         out = StringIO()
    
  177.         call_command("ogrinspect", shp_file, "City", stdout=out)
    
  178.         output = out.getvalue()
    
  179.         self.assertIn("class City(models.Model):", output)
    
  180. 
    
  181.     def test_mapping_option(self):
    
  182.         expected = (
    
  183.             "    geom = models.PointField()\n"
    
  184.             "\n"
    
  185.             "\n"
    
  186.             "# Auto-generated `LayerMapping` dictionary for City model\n"
    
  187.             "city_mapping = {\n"
    
  188.             "    'name': 'Name',\n"
    
  189.             "    'population': 'Population',\n"
    
  190.             "    'density': 'Density',\n"
    
  191.             "    'created': 'Created',\n"
    
  192.             "    'geom': 'POINT',\n"
    
  193.             "}\n"
    
  194.         )
    
  195.         shp_file = os.path.join(TEST_DATA, "cities", "cities.shp")
    
  196.         out = StringIO()
    
  197.         call_command("ogrinspect", shp_file, "--mapping", "City", stdout=out)
    
  198.         self.assertIn(expected, out.getvalue())
    
  199. 
    
  200. 
    
  201. def get_ogr_db_string():
    
  202.     """
    
  203.     Construct the DB string that GDAL will use to inspect the database.
    
  204.     GDAL will create its own connection to the database, so we re-use the
    
  205.     connection settings from the Django test.
    
  206.     """
    
  207.     db = connections.settings["default"]
    
  208. 
    
  209.     # Map from the django backend into the OGR driver name and database identifier
    
  210.     # https://gdal.org/drivers/vector/
    
  211.     #
    
  212.     # TODO: Support Oracle (OCI).
    
  213.     drivers = {
    
  214.         "django.contrib.gis.db.backends.postgis": (
    
  215.             "PostgreSQL",
    
  216.             "PG:dbname='%(db_name)s'",
    
  217.             " ",
    
  218.         ),
    
  219.         "django.contrib.gis.db.backends.mysql": ("MySQL", 'MYSQL:"%(db_name)s"', ","),
    
  220.         "django.contrib.gis.db.backends.spatialite": ("SQLite", "%(db_name)s", ""),
    
  221.     }
    
  222. 
    
  223.     db_engine = db["ENGINE"]
    
  224.     if db_engine not in drivers:
    
  225.         return None
    
  226. 
    
  227.     drv_name, db_str, param_sep = drivers[db_engine]
    
  228. 
    
  229.     # Ensure that GDAL library has driver support for the database.
    
  230.     try:
    
  231.         Driver(drv_name)
    
  232.     except GDALException:
    
  233.         return None
    
  234. 
    
  235.     # SQLite/SpatiaLite in-memory databases
    
  236.     if db["NAME"] == ":memory:":
    
  237.         return None
    
  238. 
    
  239.     # Build the params of the OGR database connection string
    
  240.     params = [db_str % {"db_name": db["NAME"]}]
    
  241. 
    
  242.     def add(key, template):
    
  243.         value = db.get(key, None)
    
  244.         # Don't add the parameter if it is not in django's settings
    
  245.         if value:
    
  246.             params.append(template % value)
    
  247. 
    
  248.     add("HOST", "host='%s'")
    
  249.     add("PORT", "port='%s'")
    
  250.     add("USER", "user='%s'")
    
  251.     add("PASSWORD", "password='%s'")
    
  252. 
    
  253.     return param_sep.join(params)