1. import os
    
  2. import signal
    
  3. import subprocess
    
  4. import sys
    
  5. from pathlib import Path
    
  6. from unittest import mock, skipUnless
    
  7. 
    
  8. from django.db import connection
    
  9. from django.db.backends.postgresql.client import DatabaseClient
    
  10. from django.test import SimpleTestCase
    
  11. 
    
  12. 
    
  13. class PostgreSqlDbshellCommandTestCase(SimpleTestCase):
    
  14.     def settings_to_cmd_args_env(self, settings_dict, parameters=None):
    
  15.         if parameters is None:
    
  16.             parameters = []
    
  17.         return DatabaseClient.settings_to_cmd_args_env(settings_dict, parameters)
    
  18. 
    
  19.     def test_basic(self):
    
  20.         self.assertEqual(
    
  21.             self.settings_to_cmd_args_env(
    
  22.                 {
    
  23.                     "NAME": "dbname",
    
  24.                     "USER": "someuser",
    
  25.                     "PASSWORD": "somepassword",
    
  26.                     "HOST": "somehost",
    
  27.                     "PORT": "444",
    
  28.                 }
    
  29.             ),
    
  30.             (
    
  31.                 ["psql", "-U", "someuser", "-h", "somehost", "-p", "444", "dbname"],
    
  32.                 {"PGPASSWORD": "somepassword"},
    
  33.             ),
    
  34.         )
    
  35. 
    
  36.     def test_nopass(self):
    
  37.         self.assertEqual(
    
  38.             self.settings_to_cmd_args_env(
    
  39.                 {
    
  40.                     "NAME": "dbname",
    
  41.                     "USER": "someuser",
    
  42.                     "HOST": "somehost",
    
  43.                     "PORT": "444",
    
  44.                 }
    
  45.             ),
    
  46.             (
    
  47.                 ["psql", "-U", "someuser", "-h", "somehost", "-p", "444", "dbname"],
    
  48.                 None,
    
  49.             ),
    
  50.         )
    
  51. 
    
  52.     def test_ssl_certificate(self):
    
  53.         self.assertEqual(
    
  54.             self.settings_to_cmd_args_env(
    
  55.                 {
    
  56.                     "NAME": "dbname",
    
  57.                     "USER": "someuser",
    
  58.                     "HOST": "somehost",
    
  59.                     "PORT": "444",
    
  60.                     "OPTIONS": {
    
  61.                         "sslmode": "verify-ca",
    
  62.                         "sslrootcert": "root.crt",
    
  63.                         "sslcert": "client.crt",
    
  64.                         "sslkey": "client.key",
    
  65.                     },
    
  66.                 }
    
  67.             ),
    
  68.             (
    
  69.                 ["psql", "-U", "someuser", "-h", "somehost", "-p", "444", "dbname"],
    
  70.                 {
    
  71.                     "PGSSLCERT": "client.crt",
    
  72.                     "PGSSLKEY": "client.key",
    
  73.                     "PGSSLMODE": "verify-ca",
    
  74.                     "PGSSLROOTCERT": "root.crt",
    
  75.                 },
    
  76.             ),
    
  77.         )
    
  78. 
    
  79.     def test_service(self):
    
  80.         self.assertEqual(
    
  81.             self.settings_to_cmd_args_env({"OPTIONS": {"service": "django_test"}}),
    
  82.             (["psql"], {"PGSERVICE": "django_test"}),
    
  83.         )
    
  84. 
    
  85.     def test_passfile(self):
    
  86.         self.assertEqual(
    
  87.             self.settings_to_cmd_args_env(
    
  88.                 {
    
  89.                     "NAME": "dbname",
    
  90.                     "USER": "someuser",
    
  91.                     "HOST": "somehost",
    
  92.                     "PORT": "444",
    
  93.                     "OPTIONS": {
    
  94.                         "passfile": "~/.custompgpass",
    
  95.                     },
    
  96.                 }
    
  97.             ),
    
  98.             (
    
  99.                 ["psql", "-U", "someuser", "-h", "somehost", "-p", "444", "dbname"],
    
  100.                 {"PGPASSFILE": "~/.custompgpass"},
    
  101.             ),
    
  102.         )
    
  103.         self.assertEqual(
    
  104.             self.settings_to_cmd_args_env(
    
  105.                 {
    
  106.                     "OPTIONS": {
    
  107.                         "service": "django_test",
    
  108.                         "passfile": "~/.custompgpass",
    
  109.                     },
    
  110.                 }
    
  111.             ),
    
  112.             (
    
  113.                 ["psql"],
    
  114.                 {"PGSERVICE": "django_test", "PGPASSFILE": "~/.custompgpass"},
    
  115.             ),
    
  116.         )
    
  117. 
    
  118.     def test_column(self):
    
  119.         self.assertEqual(
    
  120.             self.settings_to_cmd_args_env(
    
  121.                 {
    
  122.                     "NAME": "dbname",
    
  123.                     "USER": "some:user",
    
  124.                     "PASSWORD": "some:password",
    
  125.                     "HOST": "::1",
    
  126.                     "PORT": "444",
    
  127.                 }
    
  128.             ),
    
  129.             (
    
  130.                 ["psql", "-U", "some:user", "-h", "::1", "-p", "444", "dbname"],
    
  131.                 {"PGPASSWORD": "some:password"},
    
  132.             ),
    
  133.         )
    
  134. 
    
  135.     def test_accent(self):
    
  136.         username = "rôle"
    
  137.         password = "sésame"
    
  138.         self.assertEqual(
    
  139.             self.settings_to_cmd_args_env(
    
  140.                 {
    
  141.                     "NAME": "dbname",
    
  142.                     "USER": username,
    
  143.                     "PASSWORD": password,
    
  144.                     "HOST": "somehost",
    
  145.                     "PORT": "444",
    
  146.                 }
    
  147.             ),
    
  148.             (
    
  149.                 ["psql", "-U", username, "-h", "somehost", "-p", "444", "dbname"],
    
  150.                 {"PGPASSWORD": password},
    
  151.             ),
    
  152.         )
    
  153. 
    
  154.     def test_parameters(self):
    
  155.         self.assertEqual(
    
  156.             self.settings_to_cmd_args_env({"NAME": "dbname"}, ["--help"]),
    
  157.             (["psql", "dbname", "--help"], None),
    
  158.         )
    
  159. 
    
  160.     @skipUnless(connection.vendor == "postgresql", "Requires a PostgreSQL connection")
    
  161.     def test_sigint_handler(self):
    
  162.         """SIGINT is ignored in Python and passed to psql to abort queries."""
    
  163. 
    
  164.         def _mock_subprocess_run(*args, **kwargs):
    
  165.             handler = signal.getsignal(signal.SIGINT)
    
  166.             self.assertEqual(handler, signal.SIG_IGN)
    
  167. 
    
  168.         sigint_handler = signal.getsignal(signal.SIGINT)
    
  169.         # The default handler isn't SIG_IGN.
    
  170.         self.assertNotEqual(sigint_handler, signal.SIG_IGN)
    
  171.         with mock.patch("subprocess.run", new=_mock_subprocess_run):
    
  172.             connection.client.runshell([])
    
  173.         # dbshell restores the original handler.
    
  174.         self.assertEqual(sigint_handler, signal.getsignal(signal.SIGINT))
    
  175. 
    
  176.     def test_crash_password_does_not_leak(self):
    
  177.         # The password doesn't leak in an exception that results from a client
    
  178.         # crash.
    
  179.         args, env = self.settings_to_cmd_args_env({"PASSWORD": "somepassword"}, [])
    
  180.         if env:
    
  181.             env = {**os.environ, **env}
    
  182.         fake_client = Path(__file__).with_name("fake_client.py")
    
  183.         args[0:1] = [sys.executable, str(fake_client)]
    
  184.         with self.assertRaises(subprocess.CalledProcessError) as ctx:
    
  185.             subprocess.run(args, check=True, env=env)
    
  186.         self.assertNotIn("somepassword", str(ctx.exception))