1. import os
    
  2. import pathlib
    
  3. from unittest import mock, skipUnless
    
  4. 
    
  5. from django.conf import settings
    
  6. from django.contrib.gis.geoip2 import HAS_GEOIP2
    
  7. from django.contrib.gis.geos import GEOSGeometry
    
  8. from django.test import SimpleTestCase
    
  9. 
    
  10. if HAS_GEOIP2:
    
  11.     from django.contrib.gis.geoip2 import GeoIP2, GeoIP2Exception
    
  12. 
    
  13. 
    
  14. # Note: Requires both the GeoIP country and city datasets.
    
  15. # The GEOIP_DATA path should be the only setting set (the directory
    
  16. # should contain links or the actual database files 'GeoLite2-City.mmdb' and
    
  17. # 'GeoLite2-City.mmdb'.
    
  18. @skipUnless(
    
  19.     HAS_GEOIP2 and getattr(settings, "GEOIP_PATH", None),
    
  20.     "GeoIP is required along with the GEOIP_PATH setting.",
    
  21. )
    
  22. class GeoIPTest(SimpleTestCase):
    
  23.     addr = "129.237.192.1"
    
  24.     fqdn = "ku.edu"
    
  25. 
    
  26.     def test01_init(self):
    
  27.         "GeoIP initialization."
    
  28.         g1 = GeoIP2()  # Everything inferred from GeoIP path
    
  29.         path = settings.GEOIP_PATH
    
  30.         g2 = GeoIP2(path, 0)  # Passing in data path explicitly.
    
  31.         g3 = GeoIP2.open(path, 0)  # MaxMind Python API syntax.
    
  32.         # path accepts str and pathlib.Path.
    
  33.         if isinstance(path, str):
    
  34.             g4 = GeoIP2(pathlib.Path(path))
    
  35.         else:
    
  36.             g4 = GeoIP2(str(path))
    
  37. 
    
  38.         for g in (g1, g2, g3, g4):
    
  39.             self.assertTrue(g._country)
    
  40.             self.assertTrue(g._city)
    
  41. 
    
  42.         # Only passing in the location of one database.
    
  43.         city = os.path.join(path, "GeoLite2-City.mmdb")
    
  44.         cntry = os.path.join(path, "GeoLite2-Country.mmdb")
    
  45.         g4 = GeoIP2(city, country="")
    
  46.         self.assertIsNone(g4._country)
    
  47.         g5 = GeoIP2(cntry, city="")
    
  48.         self.assertIsNone(g5._city)
    
  49. 
    
  50.         # Improper parameters.
    
  51.         bad_params = (23, "foo", 15.23)
    
  52.         for bad in bad_params:
    
  53.             with self.assertRaises(GeoIP2Exception):
    
  54.                 GeoIP2(cache=bad)
    
  55.             if isinstance(bad, str):
    
  56.                 e = GeoIP2Exception
    
  57.             else:
    
  58.                 e = TypeError
    
  59.             with self.assertRaises(e):
    
  60.                 GeoIP2(bad, 0)
    
  61. 
    
  62.     def test_no_database_file(self):
    
  63.         invalid_path = os.path.join(os.path.dirname(__file__), "data")
    
  64.         msg = "Could not load a database from %s." % invalid_path
    
  65.         with self.assertRaisesMessage(GeoIP2Exception, msg):
    
  66.             GeoIP2(invalid_path)
    
  67. 
    
  68.     def test02_bad_query(self):
    
  69.         "GeoIP query parameter checking."
    
  70.         cntry_g = GeoIP2(city="<foo>")
    
  71.         # No city database available, these calls should fail.
    
  72.         with self.assertRaises(GeoIP2Exception):
    
  73.             cntry_g.city("tmc.edu")
    
  74.         with self.assertRaises(GeoIP2Exception):
    
  75.             cntry_g.coords("tmc.edu")
    
  76. 
    
  77.         # Non-string query should raise TypeError
    
  78.         with self.assertRaises(TypeError):
    
  79.             cntry_g.country_code(17)
    
  80.         with self.assertRaises(TypeError):
    
  81.             cntry_g.country_name(GeoIP2)
    
  82. 
    
  83.     @mock.patch("socket.gethostbyname")
    
  84.     def test03_country(self, gethostbyname):
    
  85.         "GeoIP country querying methods."
    
  86.         gethostbyname.return_value = "128.249.1.1"
    
  87.         g = GeoIP2(city="<foo>")
    
  88. 
    
  89.         for query in (self.fqdn, self.addr):
    
  90.             self.assertEqual(
    
  91.                 "US",
    
  92.                 g.country_code(query),
    
  93.                 "Failed for func country_code and query %s" % query,
    
  94.             )
    
  95.             self.assertEqual(
    
  96.                 "United States",
    
  97.                 g.country_name(query),
    
  98.                 "Failed for func country_name and query %s" % query,
    
  99.             )
    
  100.             self.assertEqual(
    
  101.                 {"country_code": "US", "country_name": "United States"},
    
  102.                 g.country(query),
    
  103.             )
    
  104. 
    
  105.     @mock.patch("socket.gethostbyname")
    
  106.     def test04_city(self, gethostbyname):
    
  107.         "GeoIP city querying methods."
    
  108.         gethostbyname.return_value = "129.237.192.1"
    
  109.         g = GeoIP2(country="<foo>")
    
  110. 
    
  111.         for query in (self.fqdn, self.addr):
    
  112.             # Country queries should still work.
    
  113.             self.assertEqual(
    
  114.                 "US",
    
  115.                 g.country_code(query),
    
  116.                 "Failed for func country_code and query %s" % query,
    
  117.             )
    
  118.             self.assertEqual(
    
  119.                 "United States",
    
  120.                 g.country_name(query),
    
  121.                 "Failed for func country_name and query %s" % query,
    
  122.             )
    
  123.             self.assertEqual(
    
  124.                 {"country_code": "US", "country_name": "United States"},
    
  125.                 g.country(query),
    
  126.             )
    
  127. 
    
  128.             # City information dictionary.
    
  129.             d = g.city(query)
    
  130.             self.assertEqual("NA", d["continent_code"])
    
  131.             self.assertEqual("North America", d["continent_name"])
    
  132.             self.assertEqual("US", d["country_code"])
    
  133.             self.assertEqual("Lawrence", d["city"])
    
  134.             self.assertEqual("KS", d["region"])
    
  135.             self.assertEqual("America/Chicago", d["time_zone"])
    
  136.             self.assertFalse(d["is_in_european_union"])
    
  137.             geom = g.geos(query)
    
  138.             self.assertIsInstance(geom, GEOSGeometry)
    
  139. 
    
  140.             for e1, e2 in (
    
  141.                 geom.tuple,
    
  142.                 g.coords(query),
    
  143.                 g.lon_lat(query),
    
  144.                 g.lat_lon(query),
    
  145.             ):
    
  146.                 self.assertIsInstance(e1, float)
    
  147.                 self.assertIsInstance(e2, float)
    
  148. 
    
  149.     def test06_ipv6_query(self):
    
  150.         "GeoIP can lookup IPv6 addresses."
    
  151.         g = GeoIP2()
    
  152.         d = g.city("2002:81ed:c9a5::81ed:c9a5")  # IPv6 address for www.nhm.ku.edu
    
  153.         self.assertEqual("US", d["country_code"])
    
  154.         self.assertEqual("Lawrence", d["city"])
    
  155.         self.assertEqual("KS", d["region"])
    
  156. 
    
  157.     def test_repr(self):
    
  158.         path = settings.GEOIP_PATH
    
  159.         g = GeoIP2(path=path)
    
  160.         meta = g._reader.metadata()
    
  161.         version = "%s.%s" % (
    
  162.             meta.binary_format_major_version,
    
  163.             meta.binary_format_minor_version,
    
  164.         )
    
  165.         country_path = g._country_file
    
  166.         city_path = g._city_file
    
  167.         expected = (
    
  168.             '<GeoIP2 [v%(version)s] _country_file="%(country)s", _city_file="%(city)s">'
    
  169.             % {
    
  170.                 "version": version,
    
  171.                 "country": country_path,
    
  172.                 "city": city_path,
    
  173.             }
    
  174.         )
    
  175.         self.assertEqual(repr(g), expected)
    
  176. 
    
  177.     @mock.patch("socket.gethostbyname", return_value="expected")
    
  178.     def test_check_query(self, gethostbyname):
    
  179.         g = GeoIP2()
    
  180.         self.assertEqual(g._check_query("127.0.0.1"), "127.0.0.1")
    
  181.         self.assertEqual(
    
  182.             g._check_query("2002:81ed:c9a5::81ed:c9a5"), "2002:81ed:c9a5::81ed:c9a5"
    
  183.         )
    
  184.         self.assertEqual(g._check_query("invalid-ip-address"), "expected")