1. import os
    
  2. import subprocess
    
  3. import sys
    
  4. from pathlib import Path
    
  5. 
    
  6. from django.db.backends.mysql.client import DatabaseClient
    
  7. from django.test import SimpleTestCase
    
  8. 
    
  9. 
    
  10. class MySqlDbshellCommandTestCase(SimpleTestCase):
    
  11.     def settings_to_cmd_args_env(self, settings_dict, parameters=None):
    
  12.         if parameters is None:
    
  13.             parameters = []
    
  14.         return DatabaseClient.settings_to_cmd_args_env(settings_dict, parameters)
    
  15. 
    
  16.     def test_fails_with_keyerror_on_incomplete_config(self):
    
  17.         with self.assertRaises(KeyError):
    
  18.             self.settings_to_cmd_args_env({})
    
  19. 
    
  20.     def test_basic_params_specified_in_settings(self):
    
  21.         expected_args = [
    
  22.             "mysql",
    
  23.             "--user=someuser",
    
  24.             "--host=somehost",
    
  25.             "--port=444",
    
  26.             "somedbname",
    
  27.         ]
    
  28.         expected_env = {"MYSQL_PWD": "somepassword"}
    
  29.         self.assertEqual(
    
  30.             self.settings_to_cmd_args_env(
    
  31.                 {
    
  32.                     "NAME": "somedbname",
    
  33.                     "USER": "someuser",
    
  34.                     "PASSWORD": "somepassword",
    
  35.                     "HOST": "somehost",
    
  36.                     "PORT": 444,
    
  37.                     "OPTIONS": {},
    
  38.                 }
    
  39.             ),
    
  40.             (expected_args, expected_env),
    
  41.         )
    
  42. 
    
  43.     def test_options_override_settings_proper_values(self):
    
  44.         settings_port = 444
    
  45.         options_port = 555
    
  46.         self.assertNotEqual(settings_port, options_port, "test pre-req")
    
  47.         expected_args = [
    
  48.             "mysql",
    
  49.             "--user=optionuser",
    
  50.             "--host=optionhost",
    
  51.             "--port=%s" % options_port,
    
  52.             "optiondbname",
    
  53.         ]
    
  54.         expected_env = {"MYSQL_PWD": "optionpassword"}
    
  55.         for keys in [("database", "password"), ("db", "passwd")]:
    
  56.             with self.subTest(keys=keys):
    
  57.                 database, password = keys
    
  58.                 self.assertEqual(
    
  59.                     self.settings_to_cmd_args_env(
    
  60.                         {
    
  61.                             "NAME": "settingdbname",
    
  62.                             "USER": "settinguser",
    
  63.                             "PASSWORD": "settingpassword",
    
  64.                             "HOST": "settinghost",
    
  65.                             "PORT": settings_port,
    
  66.                             "OPTIONS": {
    
  67.                                 database: "optiondbname",
    
  68.                                 "user": "optionuser",
    
  69.                                 password: "optionpassword",
    
  70.                                 "host": "optionhost",
    
  71.                                 "port": options_port,
    
  72.                             },
    
  73.                         }
    
  74.                     ),
    
  75.                     (expected_args, expected_env),
    
  76.                 )
    
  77. 
    
  78.     def test_options_non_deprecated_keys_preferred(self):
    
  79.         expected_args = [
    
  80.             "mysql",
    
  81.             "--user=someuser",
    
  82.             "--host=somehost",
    
  83.             "--port=444",
    
  84.             "optiondbname",
    
  85.         ]
    
  86.         expected_env = {"MYSQL_PWD": "optionpassword"}
    
  87.         self.assertEqual(
    
  88.             self.settings_to_cmd_args_env(
    
  89.                 {
    
  90.                     "NAME": "settingdbname",
    
  91.                     "USER": "someuser",
    
  92.                     "PASSWORD": "settingpassword",
    
  93.                     "HOST": "somehost",
    
  94.                     "PORT": 444,
    
  95.                     "OPTIONS": {
    
  96.                         "database": "optiondbname",
    
  97.                         "db": "deprecatedoptiondbname",
    
  98.                         "password": "optionpassword",
    
  99.                         "passwd": "deprecatedoptionpassword",
    
  100.                     },
    
  101.                 }
    
  102.             ),
    
  103.             (expected_args, expected_env),
    
  104.         )
    
  105. 
    
  106.     def test_options_charset(self):
    
  107.         expected_args = [
    
  108.             "mysql",
    
  109.             "--user=someuser",
    
  110.             "--host=somehost",
    
  111.             "--port=444",
    
  112.             "--default-character-set=utf8",
    
  113.             "somedbname",
    
  114.         ]
    
  115.         expected_env = {"MYSQL_PWD": "somepassword"}
    
  116.         self.assertEqual(
    
  117.             self.settings_to_cmd_args_env(
    
  118.                 {
    
  119.                     "NAME": "somedbname",
    
  120.                     "USER": "someuser",
    
  121.                     "PASSWORD": "somepassword",
    
  122.                     "HOST": "somehost",
    
  123.                     "PORT": 444,
    
  124.                     "OPTIONS": {"charset": "utf8"},
    
  125.                 }
    
  126.             ),
    
  127.             (expected_args, expected_env),
    
  128.         )
    
  129. 
    
  130.     def test_can_connect_using_sockets(self):
    
  131.         expected_args = [
    
  132.             "mysql",
    
  133.             "--user=someuser",
    
  134.             "--socket=/path/to/mysql.socket.file",
    
  135.             "somedbname",
    
  136.         ]
    
  137.         expected_env = {"MYSQL_PWD": "somepassword"}
    
  138.         self.assertEqual(
    
  139.             self.settings_to_cmd_args_env(
    
  140.                 {
    
  141.                     "NAME": "somedbname",
    
  142.                     "USER": "someuser",
    
  143.                     "PASSWORD": "somepassword",
    
  144.                     "HOST": "/path/to/mysql.socket.file",
    
  145.                     "PORT": None,
    
  146.                     "OPTIONS": {},
    
  147.                 }
    
  148.             ),
    
  149.             (expected_args, expected_env),
    
  150.         )
    
  151. 
    
  152.     def test_ssl_certificate_is_added(self):
    
  153.         expected_args = [
    
  154.             "mysql",
    
  155.             "--user=someuser",
    
  156.             "--host=somehost",
    
  157.             "--port=444",
    
  158.             "--ssl-ca=sslca",
    
  159.             "--ssl-cert=sslcert",
    
  160.             "--ssl-key=sslkey",
    
  161.             "somedbname",
    
  162.         ]
    
  163.         expected_env = {"MYSQL_PWD": "somepassword"}
    
  164.         self.assertEqual(
    
  165.             self.settings_to_cmd_args_env(
    
  166.                 {
    
  167.                     "NAME": "somedbname",
    
  168.                     "USER": "someuser",
    
  169.                     "PASSWORD": "somepassword",
    
  170.                     "HOST": "somehost",
    
  171.                     "PORT": 444,
    
  172.                     "OPTIONS": {
    
  173.                         "ssl": {
    
  174.                             "ca": "sslca",
    
  175.                             "cert": "sslcert",
    
  176.                             "key": "sslkey",
    
  177.                         },
    
  178.                     },
    
  179.                 }
    
  180.             ),
    
  181.             (expected_args, expected_env),
    
  182.         )
    
  183. 
    
  184.     def test_parameters(self):
    
  185.         self.assertEqual(
    
  186.             self.settings_to_cmd_args_env(
    
  187.                 {
    
  188.                     "NAME": "somedbname",
    
  189.                     "USER": None,
    
  190.                     "PASSWORD": None,
    
  191.                     "HOST": None,
    
  192.                     "PORT": None,
    
  193.                     "OPTIONS": {},
    
  194.                 },
    
  195.                 ["--help"],
    
  196.             ),
    
  197.             (["mysql", "somedbname", "--help"], None),
    
  198.         )
    
  199. 
    
  200.     def test_crash_password_does_not_leak(self):
    
  201.         # The password doesn't leak in an exception that results from a client
    
  202.         # crash.
    
  203.         args, env = DatabaseClient.settings_to_cmd_args_env(
    
  204.             {
    
  205.                 "NAME": "somedbname",
    
  206.                 "USER": "someuser",
    
  207.                 "PASSWORD": "somepassword",
    
  208.                 "HOST": "somehost",
    
  209.                 "PORT": 444,
    
  210.                 "OPTIONS": {},
    
  211.             },
    
  212.             [],
    
  213.         )
    
  214.         if env:
    
  215.             env = {**os.environ, **env}
    
  216.         fake_client = Path(__file__).with_name("fake_client.py")
    
  217.         args[0:1] = [sys.executable, str(fake_client)]
    
  218.         with self.assertRaises(subprocess.CalledProcessError) as ctx:
    
  219.             subprocess.run(args, check=True, env=env)
    
  220.         self.assertNotIn("somepassword", str(ctx.exception))