base.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. """Base implementation of 0MQ authentication."""
  2. # Copyright (C) PyZMQ Developers
  3. # Distributed under the terms of the Modified BSD License.
  4. import logging
  5. import os
  6. from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union
  7. import zmq
  8. from zmq.error import _check_version
  9. from zmq.utils import z85
  10. from .certs import load_certificates
  11. CURVE_ALLOW_ANY = '*'
  12. VERSION = b'1.0'
  13. class Authenticator:
  14. """Implementation of ZAP authentication for zmq connections.
  15. This authenticator class does not register with an event loop. As a result,
  16. you will need to manually call `handle_zap_message`::
  17. auth = zmq.Authenticator()
  18. auth.allow("127.0.0.1")
  19. auth.start()
  20. while True:
  21. await auth.handle_zap_msg(auth.zap_socket.recv_multipart())
  22. Alternatively, you can register `auth.zap_socket` with a poller.
  23. Since many users will want to run ZAP in a way that does not block the
  24. main thread, other authentication classes (such as :mod:`zmq.auth.thread`)
  25. are provided.
  26. Note:
  27. - libzmq provides four levels of security: default NULL (which the Authenticator does
  28. not see), and authenticated NULL, PLAIN, CURVE, and GSSAPI, which the Authenticator can see.
  29. - until you add policies, all incoming NULL connections are allowed.
  30. (classic ZeroMQ behavior), and all PLAIN and CURVE connections are denied.
  31. - GSSAPI requires no configuration.
  32. """
  33. context: "zmq.Context"
  34. encoding: str
  35. allow_any: bool
  36. credentials_providers: Dict[str, Any]
  37. zap_socket: "zmq.Socket"
  38. _allowed: Set[str]
  39. _denied: Set[str]
  40. passwords: Dict[str, Dict[str, str]]
  41. certs: Dict[str, Dict[bytes, Any]]
  42. log: Any
  43. def __init__(
  44. self,
  45. context: Optional["zmq.Context"] = None,
  46. encoding: str = 'utf-8',
  47. log: Any = None,
  48. ):
  49. _check_version((4, 0), "security")
  50. self.context = context or zmq.Context.instance()
  51. self.encoding = encoding
  52. self.allow_any = False
  53. self.credentials_providers = {}
  54. self.zap_socket = None # type: ignore
  55. self._allowed = set()
  56. self._denied = set()
  57. # passwords is a dict keyed by domain and contains values
  58. # of dicts with username:password pairs.
  59. self.passwords = {}
  60. # certs is dict keyed by domain and contains values
  61. # of dicts keyed by the public keys from the specified location.
  62. self.certs = {}
  63. self.log = log or logging.getLogger('zmq.auth')
  64. def start(self) -> None:
  65. """Create and bind the ZAP socket"""
  66. self.zap_socket = self.context.socket(zmq.REP, socket_class=zmq.Socket)
  67. self.zap_socket.linger = 1
  68. self.zap_socket.bind("inproc://zeromq.zap.01")
  69. self.log.debug("Starting")
  70. def stop(self) -> None:
  71. """Close the ZAP socket"""
  72. if self.zap_socket:
  73. self.zap_socket.close()
  74. self.zap_socket = None # type: ignore
  75. def allow(self, *addresses: str) -> None:
  76. """Allow IP address(es).
  77. Connections from addresses not explicitly allowed will be rejected.
  78. - For NULL, all clients from this address will be accepted.
  79. - For real auth setups, they will be allowed to continue with authentication.
  80. allow is mutually exclusive with deny.
  81. """
  82. if self._denied:
  83. raise ValueError("Only use allow or deny, not both")
  84. self.log.debug("Allowing %s", ','.join(addresses))
  85. self._allowed.update(addresses)
  86. def deny(self, *addresses: str) -> None:
  87. """Deny IP address(es).
  88. Addresses not explicitly denied will be allowed to continue with authentication.
  89. deny is mutually exclusive with allow.
  90. """
  91. if self._allowed:
  92. raise ValueError("Only use a allow or deny, not both")
  93. self.log.debug("Denying %s", ','.join(addresses))
  94. self._denied.update(addresses)
  95. def configure_plain(
  96. self, domain: str = '*', passwords: Optional[Dict[str, str]] = None
  97. ) -> None:
  98. """Configure PLAIN authentication for a given domain.
  99. PLAIN authentication uses a plain-text password file.
  100. To cover all domains, use "*".
  101. You can modify the password file at any time; it is reloaded automatically.
  102. """
  103. if passwords:
  104. self.passwords[domain] = passwords
  105. self.log.debug("Configure plain: %s", domain)
  106. def configure_curve(
  107. self, domain: str = '*', location: Union[str, os.PathLike] = "."
  108. ) -> None:
  109. """Configure CURVE authentication for a given domain.
  110. CURVE authentication uses a directory that holds all public client certificates,
  111. i.e. their public keys.
  112. To cover all domains, use "*".
  113. You can add and remove certificates in that directory at any time. configure_curve must be called
  114. every time certificates are added or removed, in order to update the Authenticator's state
  115. To allow all client keys without checking, specify CURVE_ALLOW_ANY for the location.
  116. """
  117. # If location is CURVE_ALLOW_ANY then allow all clients. Otherwise
  118. # treat location as a directory that holds the certificates.
  119. self.log.debug("Configure curve: %s[%s]", domain, location)
  120. if location == CURVE_ALLOW_ANY:
  121. self.allow_any = True
  122. else:
  123. self.allow_any = False
  124. try:
  125. self.certs[domain] = load_certificates(location)
  126. except Exception as e:
  127. self.log.error("Failed to load CURVE certs from %s: %s", location, e)
  128. def configure_curve_callback(
  129. self, domain: str = '*', credentials_provider: Any = None
  130. ) -> None:
  131. """Configure CURVE authentication for a given domain.
  132. CURVE authentication using a callback function validating
  133. the client public key according to a custom mechanism, e.g. checking the
  134. key against records in a db. credentials_provider is an object of a class which
  135. implements a callback method accepting two parameters (domain and key), e.g.::
  136. class CredentialsProvider(object):
  137. def __init__(self):
  138. ...e.g. db connection
  139. def callback(self, domain, key):
  140. valid = ...lookup key and/or domain in db
  141. if valid:
  142. logging.info('Authorizing: {0}, {1}'.format(domain, key))
  143. return True
  144. else:
  145. logging.warning('NOT Authorizing: {0}, {1}'.format(domain, key))
  146. return False
  147. To cover all domains, use "*".
  148. """
  149. self.allow_any = False
  150. if credentials_provider is not None:
  151. self.credentials_providers[domain] = credentials_provider
  152. else:
  153. self.log.error("None credentials_provider provided for domain:%s", domain)
  154. def curve_user_id(self, client_public_key: bytes) -> str:
  155. """Return the User-Id corresponding to a CURVE client's public key
  156. Default implementation uses the z85-encoding of the public key.
  157. Override to define a custom mapping of public key : user-id
  158. This is only called on successful authentication.
  159. Parameters
  160. ----------
  161. client_public_key: bytes
  162. The client public key used for the given message
  163. Returns
  164. -------
  165. user_id: unicode
  166. The user ID as text
  167. """
  168. return z85.encode(client_public_key).decode('ascii')
  169. def configure_gssapi(
  170. self, domain: str = '*', location: Optional[str] = None
  171. ) -> None:
  172. """Configure GSSAPI authentication
  173. Currently this is a no-op because there is nothing to configure with GSSAPI.
  174. """
  175. async def handle_zap_message(self, msg: List[bytes]):
  176. """Perform ZAP authentication"""
  177. if len(msg) < 6:
  178. self.log.error("Invalid ZAP message, not enough frames: %r", msg)
  179. if len(msg) < 2:
  180. self.log.error("Not enough information to reply")
  181. else:
  182. self._send_zap_reply(msg[1], b"400", b"Not enough frames")
  183. return
  184. version, request_id, domain, address, identity, mechanism = msg[:6]
  185. credentials = msg[6:]
  186. domain = domain.decode(self.encoding, 'replace')
  187. address = address.decode(self.encoding, 'replace')
  188. if version != VERSION:
  189. self.log.error("Invalid ZAP version: %r", msg)
  190. self._send_zap_reply(request_id, b"400", b"Invalid version")
  191. return
  192. self.log.debug(
  193. "version: %r, request_id: %r, domain: %r,"
  194. " address: %r, identity: %r, mechanism: %r",
  195. version,
  196. request_id,
  197. domain,
  198. address,
  199. identity,
  200. mechanism,
  201. )
  202. # Is address is explicitly allowed or _denied?
  203. allowed = False
  204. denied = False
  205. reason = b"NO ACCESS"
  206. if self._allowed:
  207. if address in self._allowed:
  208. allowed = True
  209. self.log.debug("PASSED (allowed) address=%s", address)
  210. else:
  211. denied = True
  212. reason = b"Address not allowed"
  213. self.log.debug("DENIED (not allowed) address=%s", address)
  214. elif self._denied:
  215. if address in self._denied:
  216. denied = True
  217. reason = b"Address denied"
  218. self.log.debug("DENIED (denied) address=%s", address)
  219. else:
  220. allowed = True
  221. self.log.debug("PASSED (not denied) address=%s", address)
  222. # Perform authentication mechanism-specific checks if necessary
  223. username = "anonymous"
  224. if not denied:
  225. if mechanism == b'NULL' and not allowed:
  226. # For NULL, we allow if the address wasn't denied
  227. self.log.debug("ALLOWED (NULL)")
  228. allowed = True
  229. elif mechanism == b'PLAIN':
  230. # For PLAIN, even a _alloweded address must authenticate
  231. if len(credentials) != 2:
  232. self.log.error("Invalid PLAIN credentials: %r", credentials)
  233. self._send_zap_reply(request_id, b"400", b"Invalid credentials")
  234. return
  235. username, password = (
  236. c.decode(self.encoding, 'replace') for c in credentials
  237. )
  238. allowed, reason = self._authenticate_plain(domain, username, password)
  239. elif mechanism == b'CURVE':
  240. # For CURVE, even a _alloweded address must authenticate
  241. if len(credentials) != 1:
  242. self.log.error("Invalid CURVE credentials: %r", credentials)
  243. self._send_zap_reply(request_id, b"400", b"Invalid credentials")
  244. return
  245. key = credentials[0]
  246. allowed, reason = await self._authenticate_curve(domain, key)
  247. if allowed:
  248. username = self.curve_user_id(key)
  249. elif mechanism == b'GSSAPI':
  250. if len(credentials) != 1:
  251. self.log.error("Invalid GSSAPI credentials: %r", credentials)
  252. self._send_zap_reply(request_id, b"400", b"Invalid credentials")
  253. return
  254. # use principal as user-id for now
  255. principal = credentials[0]
  256. username = principal.decode("utf8")
  257. allowed, reason = self._authenticate_gssapi(domain, principal)
  258. if allowed:
  259. self._send_zap_reply(request_id, b"200", b"OK", username)
  260. else:
  261. self._send_zap_reply(request_id, b"400", reason)
  262. def _authenticate_plain(
  263. self, domain: str, username: str, password: str
  264. ) -> Tuple[bool, bytes]:
  265. """PLAIN ZAP authentication"""
  266. allowed = False
  267. reason = b""
  268. if self.passwords:
  269. # If no domain is not specified then use the default domain
  270. if not domain:
  271. domain = '*'
  272. if domain in self.passwords:
  273. if username in self.passwords[domain]:
  274. if password == self.passwords[domain][username]:
  275. allowed = True
  276. else:
  277. reason = b"Invalid password"
  278. else:
  279. reason = b"Invalid username"
  280. else:
  281. reason = b"Invalid domain"
  282. if allowed:
  283. self.log.debug(
  284. "ALLOWED (PLAIN) domain=%s username=%s password=%s",
  285. domain,
  286. username,
  287. password,
  288. )
  289. else:
  290. self.log.debug("DENIED %s", reason)
  291. else:
  292. reason = b"No passwords defined"
  293. self.log.debug("DENIED (PLAIN) %s", reason)
  294. return allowed, reason
  295. async def _authenticate_curve(
  296. self, domain: str, client_key: bytes
  297. ) -> Tuple[bool, bytes]:
  298. """CURVE ZAP authentication"""
  299. allowed = False
  300. reason = b""
  301. if self.allow_any:
  302. allowed = True
  303. reason = b"OK"
  304. self.log.debug("ALLOWED (CURVE allow any client)")
  305. elif self.credentials_providers != {}:
  306. # If no explicit domain is specified then use the default domain
  307. if not domain:
  308. domain = '*'
  309. if domain in self.credentials_providers:
  310. z85_client_key = z85.encode(client_key)
  311. # Callback to check if key is Allowed
  312. r = self.credentials_providers[domain].callback(domain, z85_client_key)
  313. if isinstance(r, Awaitable):
  314. r = await r
  315. if r:
  316. allowed = True
  317. reason = b"OK"
  318. else:
  319. reason = b"Unknown key"
  320. status = "ALLOWED" if allowed else "DENIED"
  321. self.log.debug(
  322. "%s (CURVE auth_callback) domain=%s client_key=%s",
  323. status,
  324. domain,
  325. z85_client_key,
  326. )
  327. else:
  328. reason = b"Unknown domain"
  329. else:
  330. # If no explicit domain is specified then use the default domain
  331. if not domain:
  332. domain = '*'
  333. if domain in self.certs:
  334. # The certs dict stores keys in z85 format, convert binary key to z85 bytes
  335. z85_client_key = z85.encode(client_key)
  336. if self.certs[domain].get(z85_client_key):
  337. allowed = True
  338. reason = b"OK"
  339. else:
  340. reason = b"Unknown key"
  341. status = "ALLOWED" if allowed else "DENIED"
  342. self.log.debug(
  343. "%s (CURVE) domain=%s client_key=%s",
  344. status,
  345. domain,
  346. z85_client_key,
  347. )
  348. else:
  349. reason = b"Unknown domain"
  350. return allowed, reason
  351. def _authenticate_gssapi(self, domain: str, principal: bytes) -> Tuple[bool, bytes]:
  352. """Nothing to do for GSSAPI, which has already been handled by an external service."""
  353. self.log.debug("ALLOWED (GSSAPI) domain=%s principal=%s", domain, principal)
  354. return True, b'OK'
  355. def _send_zap_reply(
  356. self,
  357. request_id: bytes,
  358. status_code: bytes,
  359. status_text: bytes,
  360. user_id: str = 'anonymous',
  361. ) -> None:
  362. """Send a ZAP reply to finish the authentication."""
  363. user_id = user_id if status_code == b'200' else b''
  364. if isinstance(user_id, str):
  365. user_id = user_id.encode(self.encoding, 'replace')
  366. metadata = b'' # not currently used
  367. self.log.debug("ZAP reply code=%s text=%s", status_code, status_text)
  368. reply = [VERSION, request_id, status_code, status_text, user_id, metadata]
  369. self.zap_socket.send_multipart(reply)
  370. __all__ = ['Authenticator', 'CURVE_ALLOW_ANY']