certs.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. """0MQ authentication related functions and classes."""
  2. # Copyright (C) PyZMQ Developers
  3. # Distributed under the terms of the Modified BSD License.
  4. import datetime
  5. import glob
  6. import os
  7. from typing import Dict, Optional, Tuple, Union
  8. import zmq
  9. _cert_secret_banner = """# **** Generated on {0} by pyzmq ****
  10. # ZeroMQ CURVE **Secret** Certificate
  11. # DO NOT PROVIDE THIS FILE TO OTHER USERS nor change its permissions.
  12. """
  13. _cert_public_banner = """# **** Generated on {0} by pyzmq ****
  14. # ZeroMQ CURVE Public Certificate
  15. # Exchange securely, or use a secure mechanism to verify the contents
  16. # of this file after exchange. Store public certificates in your home
  17. # directory, in the .curve subdirectory.
  18. """
  19. def _write_key_file(
  20. key_filename: Union[str, os.PathLike],
  21. banner: str,
  22. public_key: Union[str, bytes],
  23. secret_key: Optional[Union[str, bytes]] = None,
  24. metadata: Optional[Dict[str, str]] = None,
  25. encoding: str = 'utf-8',
  26. ) -> None:
  27. """Create a certificate file"""
  28. if isinstance(public_key, bytes):
  29. public_key = public_key.decode(encoding)
  30. if isinstance(secret_key, bytes):
  31. secret_key = secret_key.decode(encoding)
  32. with open(key_filename, 'w', encoding='utf8') as f:
  33. f.write(banner.format(datetime.datetime.now()))
  34. f.write('metadata\n')
  35. if metadata:
  36. for k, v in metadata.items():
  37. if isinstance(k, bytes):
  38. k = k.decode(encoding)
  39. if isinstance(v, bytes):
  40. v = v.decode(encoding)
  41. f.write(f" {k} = {v}\n")
  42. f.write('curve\n')
  43. f.write(f" public-key = \"{public_key}\"\n")
  44. if secret_key:
  45. f.write(f" secret-key = \"{secret_key}\"\n")
  46. def create_certificates(
  47. key_dir: Union[str, os.PathLike],
  48. name: str,
  49. metadata: Optional[Dict[str, str]] = None,
  50. ) -> Tuple[str, str]:
  51. """Create zmq certificates.
  52. Returns the file paths to the public and secret certificate files.
  53. """
  54. public_key, secret_key = zmq.curve_keypair()
  55. base_filename = os.path.join(key_dir, name)
  56. secret_key_file = f"{base_filename}.key_secret"
  57. public_key_file = f"{base_filename}.key"
  58. now = datetime.datetime.now()
  59. _write_key_file(public_key_file, _cert_public_banner.format(now), public_key)
  60. _write_key_file(
  61. secret_key_file,
  62. _cert_secret_banner.format(now),
  63. public_key,
  64. secret_key=secret_key,
  65. metadata=metadata,
  66. )
  67. return public_key_file, secret_key_file
  68. def load_certificate(
  69. filename: Union[str, os.PathLike],
  70. ) -> Tuple[bytes, Optional[bytes]]:
  71. """Load public and secret key from a zmq certificate.
  72. Returns (public_key, secret_key)
  73. If the certificate file only contains the public key,
  74. secret_key will be None.
  75. If there is no public key found in the file, ValueError will be raised.
  76. """
  77. public_key = None
  78. secret_key = None
  79. if not os.path.exists(filename):
  80. raise OSError(f"Invalid certificate file: {filename}")
  81. with open(filename, 'rb') as f:
  82. for line in f:
  83. line = line.strip()
  84. if line.startswith(b'#'):
  85. continue
  86. if line.startswith(b'public-key'):
  87. public_key = line.split(b"=", 1)[1].strip(b' \t\'"')
  88. if line.startswith(b'secret-key'):
  89. secret_key = line.split(b"=", 1)[1].strip(b' \t\'"')
  90. if public_key and secret_key:
  91. break
  92. if public_key is None:
  93. raise ValueError(f"No public key found in {filename}")
  94. return public_key, secret_key
  95. def load_certificates(directory: Union[str, os.PathLike] = '.') -> Dict[bytes, bool]:
  96. """Load public keys from all certificates in a directory"""
  97. certs = {}
  98. if not os.path.isdir(directory):
  99. raise OSError(f"Invalid certificate directory: {directory}")
  100. # Follow czmq pattern of public keys stored in *.key files.
  101. glob_string = os.path.join(directory, "*.key")
  102. cert_files = glob.glob(glob_string)
  103. for cert_file in cert_files:
  104. public_key, _ = load_certificate(cert_file)
  105. if public_key:
  106. certs[public_key] = True
  107. return certs
  108. __all__ = ['create_certificates', 'load_certificate', 'load_certificates']