sessionmanager.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. """A base class session manager."""
  2. # Copyright (c) Jupyter Development Team.
  3. # Distributed under the terms of the Modified BSD License.
  4. import os
  5. import pathlib
  6. import uuid
  7. from typing import Any, NewType, Optional, Union, cast
  8. KernelName = NewType("KernelName", str)
  9. ModelName = NewType("ModelName", str)
  10. try:
  11. import sqlite3
  12. except ImportError:
  13. # fallback on pysqlite2 if Python was build without sqlite
  14. from pysqlite2 import dbapi2 as sqlite3 # type:ignore[no-redef]
  15. from dataclasses import dataclass, fields
  16. from jupyter_core.utils import ensure_async
  17. from tornado import web
  18. from traitlets import Instance, TraitError, Unicode, validate
  19. from traitlets.config.configurable import LoggingConfigurable
  20. from jupyter_server.traittypes import InstanceFromClasses
  21. class KernelSessionRecordConflict(Exception):
  22. """Exception class to use when two KernelSessionRecords cannot
  23. merge because of conflicting data.
  24. """
  25. @dataclass
  26. class KernelSessionRecord:
  27. """A record object for tracking a Jupyter Server Kernel Session.
  28. Two records that share a session_id must also share a kernel_id, while
  29. kernels can have multiple session (and thereby) session_ids
  30. associated with them.
  31. """
  32. session_id: Optional[str] = None
  33. kernel_id: Optional[str] = None
  34. def __eq__(self, other: object) -> bool:
  35. """Whether a record equals another."""
  36. if isinstance(other, KernelSessionRecord):
  37. condition1 = self.kernel_id and self.kernel_id == other.kernel_id
  38. condition2 = all(
  39. [
  40. self.session_id == other.session_id,
  41. self.kernel_id is None or other.kernel_id is None,
  42. ]
  43. )
  44. if any([condition1, condition2]):
  45. return True
  46. # If two records share session_id but have different kernels, this is
  47. # and ill-posed expression. This should never be true. Raise an exception
  48. # to inform the user.
  49. if all(
  50. [
  51. self.session_id,
  52. self.session_id == other.session_id,
  53. self.kernel_id != other.kernel_id,
  54. ]
  55. ):
  56. msg = (
  57. "A single session_id can only have one kernel_id "
  58. "associated with. These two KernelSessionRecords share the same "
  59. "session_id but have different kernel_ids. This should "
  60. "not be possible and is likely an issue with the session "
  61. "records."
  62. )
  63. raise KernelSessionRecordConflict(msg)
  64. return False
  65. def update(self, other: "KernelSessionRecord") -> None:
  66. """Updates in-place a kernel from other (only accepts positive updates"""
  67. if not isinstance(other, KernelSessionRecord):
  68. msg = "'other' must be an instance of KernelSessionRecord." # type:ignore[unreachable]
  69. raise TypeError(msg)
  70. if other.kernel_id and self.kernel_id and other.kernel_id != self.kernel_id:
  71. msg = "Could not update the record from 'other' because the two records conflict."
  72. raise KernelSessionRecordConflict(msg)
  73. for field in fields(self):
  74. if hasattr(other, field.name) and getattr(other, field.name):
  75. setattr(self, field.name, getattr(other, field.name))
  76. class KernelSessionRecordList:
  77. """An object for storing and managing a list of KernelSessionRecords.
  78. When adding a record to the list, the KernelSessionRecordList
  79. first checks if the record already exists in the list. If it does,
  80. the record will be updated with the new information; otherwise,
  81. it will be appended.
  82. """
  83. _records: list[KernelSessionRecord]
  84. def __init__(self, *records: KernelSessionRecord):
  85. """Initialize a record list."""
  86. self._records = []
  87. for record in records:
  88. self.update(record)
  89. def __str__(self):
  90. """The string representation of a record list."""
  91. return str(self._records)
  92. def __contains__(self, record: Union[KernelSessionRecord, str]) -> bool:
  93. """Search for records by kernel_id and session_id"""
  94. if isinstance(record, KernelSessionRecord) and record in self._records:
  95. return True
  96. if isinstance(record, str):
  97. for r in self._records:
  98. if record in [r.session_id, r.kernel_id]:
  99. return True
  100. return False
  101. def __len__(self):
  102. """The length of the record list."""
  103. return len(self._records)
  104. def get(self, record: Union[KernelSessionRecord, str]) -> KernelSessionRecord:
  105. """Return a full KernelSessionRecord from a session_id, kernel_id, or
  106. incomplete KernelSessionRecord.
  107. """
  108. if isinstance(record, str):
  109. for r in self._records:
  110. if record in (r.kernel_id, r.session_id):
  111. return r
  112. elif isinstance(record, KernelSessionRecord):
  113. for r in self._records:
  114. if record == r:
  115. return record
  116. msg = f"{record} not found in KernelSessionRecordList."
  117. raise ValueError(msg)
  118. def update(self, record: KernelSessionRecord) -> None:
  119. """Update a record in-place or append it if not in the list."""
  120. try:
  121. idx = self._records.index(record)
  122. self._records[idx].update(record)
  123. except ValueError:
  124. self._records.append(record)
  125. def remove(self, record: KernelSessionRecord) -> None:
  126. """Remove a record if its found in the list. If it's not found,
  127. do nothing.
  128. """
  129. if record in self._records:
  130. self._records.remove(record)
  131. class SessionManager(LoggingConfigurable):
  132. """A session manager."""
  133. database_filepath = Unicode(
  134. default_value=":memory:",
  135. help=(
  136. "The filesystem path to SQLite Database file "
  137. "(e.g. /path/to/session_database.db). By default, the session "
  138. "database is stored in-memory (i.e. `:memory:` setting from sqlite3) "
  139. "and does not persist when the current Jupyter Server shuts down."
  140. ),
  141. ).tag(config=True)
  142. @validate("database_filepath")
  143. def _validate_database_filepath(self, proposal):
  144. """Validate a database file path."""
  145. value = proposal["value"]
  146. if value == ":memory:":
  147. return value
  148. path = pathlib.Path(value)
  149. if path.exists():
  150. # Verify that the database path is not a directory.
  151. if path.is_dir():
  152. msg = "`database_filepath` expected a file path, but the given path is a directory."
  153. raise TraitError(msg)
  154. # Verify that database path is an SQLite 3 Database by checking its header.
  155. with open(value, "rb") as f:
  156. header = f.read(100)
  157. if not header.startswith(b"SQLite format 3") and header != b"":
  158. msg = "The given file is not an SQLite database file."
  159. raise TraitError(msg)
  160. return value
  161. kernel_manager = Instance("jupyter_server.services.kernels.kernelmanager.MappingKernelManager")
  162. contents_manager = InstanceFromClasses(
  163. [
  164. "jupyter_server.services.contents.manager.ContentsManager",
  165. "notebook.services.contents.manager.ContentsManager",
  166. ]
  167. )
  168. def __init__(self, *args, **kwargs):
  169. """Initialize a record list."""
  170. super().__init__(*args, **kwargs)
  171. self._pending_sessions = KernelSessionRecordList()
  172. # Session database initialized below
  173. _cursor = None
  174. _connection = None
  175. _columns = {"session_id", "path", "name", "type", "kernel_id"}
  176. @property
  177. def cursor(self):
  178. """Start a cursor and create a database called 'session'"""
  179. if self._cursor is None:
  180. self._cursor = self.connection.cursor()
  181. self._cursor.execute(
  182. """CREATE TABLE IF NOT EXISTS session
  183. (session_id, path, name, type, kernel_id)"""
  184. )
  185. return self._cursor
  186. @property
  187. def connection(self):
  188. """Start a database connection"""
  189. if self._connection is None:
  190. # Set isolation level to None to autocommit all changes to the database.
  191. self._connection = sqlite3.connect(self.database_filepath, isolation_level=None)
  192. self._connection.row_factory = sqlite3.Row
  193. return self._connection
  194. def close(self):
  195. """Close the sqlite connection"""
  196. if self._cursor is not None:
  197. self._cursor.close()
  198. self._cursor = None
  199. def __del__(self):
  200. """Close connection once SessionManager closes"""
  201. self.close()
  202. async def session_exists(self, path):
  203. """Check to see if the session of a given name exists"""
  204. exists = False
  205. self.cursor.execute("SELECT * FROM session WHERE path=?", (path,))
  206. row = self.cursor.fetchone()
  207. if row is not None:
  208. # Note, although we found a row for the session, the associated kernel may have
  209. # been culled or died unexpectedly. If that's the case, we should delete the
  210. # row, thereby terminating the session. This can be done via a call to
  211. # row_to_model that tolerates that condition. If row_to_model returns None,
  212. # we'll return false, since, at that point, the session doesn't exist anyway.
  213. model = await self.row_to_model(row, tolerate_culled=True)
  214. if model is not None:
  215. exists = True
  216. return exists
  217. def new_session_id(self) -> str:
  218. """Create a uuid for a new session"""
  219. return str(uuid.uuid4())
  220. async def create_session(
  221. self,
  222. path: Optional[str] = None,
  223. name: Optional[ModelName] = None,
  224. type: Optional[str] = None,
  225. kernel_name: Optional[KernelName] = None,
  226. kernel_id: Optional[str] = None,
  227. ) -> dict[str, Any]:
  228. """Creates a session and returns its model
  229. Parameters
  230. ----------
  231. name: ModelName(str)
  232. Usually the model name, like the filename associated with current
  233. kernel.
  234. """
  235. session_id = self.new_session_id()
  236. record = KernelSessionRecord(session_id=session_id)
  237. self._pending_sessions.update(record)
  238. if kernel_id is not None and kernel_id in self.kernel_manager:
  239. pass
  240. else:
  241. kernel_id = await self.start_kernel_for_session(
  242. session_id, path, name, type, kernel_name
  243. )
  244. record.kernel_id = kernel_id
  245. self._pending_sessions.update(record)
  246. result = await self.save_session(
  247. session_id, path=path, name=name, type=type, kernel_id=kernel_id
  248. )
  249. self._pending_sessions.remove(record)
  250. return cast(dict[str, Any], result)
  251. def get_kernel_env(
  252. self, path: Optional[str], name: Optional[ModelName] = None
  253. ) -> dict[str, str]:
  254. """Return the environment variables that need to be set in the kernel
  255. Parameters
  256. ----------
  257. path : str
  258. the url path for the given session.
  259. name: ModelName(str), optional
  260. Here the name is likely to be the name of the associated file
  261. with the current kernel at startup time.
  262. """
  263. if name is not None:
  264. cwd = self.kernel_manager.cwd_for_path(path)
  265. path = os.path.join(cwd, name)
  266. assert isinstance(path, str)
  267. return {**os.environ, "JPY_SESSION_NAME": path}
  268. async def start_kernel_for_session(
  269. self,
  270. session_id: str,
  271. path: Optional[str],
  272. name: Optional[ModelName],
  273. type: Optional[str],
  274. kernel_name: Optional[KernelName],
  275. ) -> str:
  276. """Start a new kernel for a given session.
  277. Parameters
  278. ----------
  279. session_id : str
  280. uuid for the session; this method must be given a session_id
  281. path : str
  282. the path for the given session - seem to be a session id sometime.
  283. name : str
  284. Usually the model name, like the filename associated with current
  285. kernel.
  286. type : str
  287. the type of the session
  288. kernel_name : str
  289. the name of the kernel specification to use. The default kernel name will be used if not provided.
  290. """
  291. # allow contents manager to specify kernels cwd
  292. kernel_path = await ensure_async(self.contents_manager.get_kernel_path(path=path))
  293. kernel_env = self.get_kernel_env(path, name)
  294. kernel_id = await self.kernel_manager.start_kernel(
  295. path=kernel_path,
  296. kernel_name=kernel_name,
  297. env=kernel_env,
  298. )
  299. return cast(str, kernel_id)
  300. async def save_session(self, session_id, path=None, name=None, type=None, kernel_id=None):
  301. """Saves the items for the session with the given session_id
  302. Given a session_id (and any other of the arguments), this method
  303. creates a row in the sqlite session database that holds the information
  304. for a session.
  305. Parameters
  306. ----------
  307. session_id : str
  308. uuid for the session; this method must be given a session_id
  309. path : str
  310. the path for the given session
  311. name : str
  312. the name of the session
  313. type : str
  314. the type of the session
  315. kernel_id : str
  316. a uuid for the kernel associated with this session
  317. Returns
  318. -------
  319. model : dict
  320. a dictionary of the session model
  321. """
  322. self.cursor.execute(
  323. "INSERT INTO session VALUES (?,?,?,?,?)",
  324. (session_id, path, name, type, kernel_id),
  325. )
  326. result = await self.get_session(session_id=session_id)
  327. return result
  328. async def get_session(self, **kwargs):
  329. """Returns the model for a particular session.
  330. Takes a keyword argument and searches for the value in the session
  331. database, then returns the rest of the session's info.
  332. Parameters
  333. ----------
  334. **kwargs : dict
  335. must be given one of the keywords and values from the session database
  336. (i.e. session_id, path, name, type, kernel_id)
  337. Returns
  338. -------
  339. model : dict
  340. returns a dictionary that includes all the information from the
  341. session described by the kwarg.
  342. """
  343. if not kwargs:
  344. msg = "must specify a column to query"
  345. raise TypeError(msg)
  346. conditions = []
  347. for column in kwargs:
  348. if column not in self._columns:
  349. msg = f"No such column: {column}"
  350. raise TypeError(msg)
  351. conditions.append("%s=?" % column)
  352. query = "SELECT * FROM session WHERE %s" % (" AND ".join(conditions)) # noqa: S608
  353. self.cursor.execute(query, list(kwargs.values()))
  354. try:
  355. row = self.cursor.fetchone()
  356. except KeyError:
  357. # The kernel is missing, so the session just got deleted.
  358. row = None
  359. if row is None:
  360. q = []
  361. for key, value in kwargs.items():
  362. q.append(f"{key}={value!r}")
  363. raise web.HTTPError(404, "Session not found: %s" % (", ".join(q)))
  364. try:
  365. model = await self.row_to_model(row)
  366. except KeyError as e:
  367. raise web.HTTPError(404, "Session not found: %s" % str(e)) from e
  368. return model
  369. async def update_session(self, session_id, **kwargs):
  370. """Updates the values in the session database.
  371. Changes the values of the session with the given session_id
  372. with the values from the keyword arguments.
  373. Parameters
  374. ----------
  375. session_id : str
  376. a uuid that identifies a session in the sqlite3 database
  377. **kwargs : str
  378. the key must correspond to a column title in session database,
  379. and the value replaces the current value in the session
  380. with session_id.
  381. """
  382. await self.get_session(session_id=session_id)
  383. if not kwargs:
  384. # no changes
  385. return
  386. sets = []
  387. for column in kwargs:
  388. if column not in self._columns:
  389. raise TypeError("No such column: %r" % column)
  390. sets.append("%s=?" % column)
  391. query = "UPDATE session SET %s WHERE session_id=?" % (", ".join(sets)) # noqa: S608
  392. self.cursor.execute(query, [*list(kwargs.values()), session_id])
  393. if hasattr(self.kernel_manager, "update_env"):
  394. self.cursor.execute(
  395. "SELECT path, name, kernel_id FROM session WHERE session_id=?", [session_id]
  396. )
  397. path, name, kernel_id = self.cursor.fetchone()
  398. self.kernel_manager.update_env(kernel_id=kernel_id, env=self.get_kernel_env(path, name))
  399. async def kernel_culled(self, kernel_id: str) -> bool:
  400. """Checks if the kernel is still considered alive and returns true if its not found."""
  401. return kernel_id not in self.kernel_manager
  402. async def row_to_model(self, row, tolerate_culled=False):
  403. """Takes sqlite database session row and turns it into a dictionary"""
  404. kernel_culled: bool = await ensure_async(self.kernel_culled(row["kernel_id"]))
  405. if kernel_culled:
  406. # The kernel was culled or died without deleting the session.
  407. # We can't use delete_session here because that tries to find
  408. # and shut down the kernel - so we'll delete the row directly.
  409. #
  410. # If caller wishes to tolerate culled kernels, log a warning
  411. # and return None. Otherwise, raise KeyError with a similar
  412. # message.
  413. self.cursor.execute("DELETE FROM session WHERE session_id=?", (row["session_id"],))
  414. msg = (
  415. "Kernel '{kernel_id}' appears to have been culled or died unexpectedly, "
  416. "invalidating session '{session_id}'. The session has been removed.".format(
  417. kernel_id=row["kernel_id"], session_id=row["session_id"]
  418. )
  419. )
  420. if tolerate_culled:
  421. self.log.warning(f"{msg} Continuing...")
  422. return None
  423. raise KeyError(msg)
  424. kernel_model = await ensure_async(self.kernel_manager.kernel_model(row["kernel_id"]))
  425. model = {
  426. "id": row["session_id"],
  427. "path": row["path"],
  428. "name": row["name"],
  429. "type": row["type"],
  430. "kernel": kernel_model,
  431. }
  432. if row["type"] == "notebook":
  433. # Provide the deprecated API.
  434. model["notebook"] = {"path": row["path"], "name": row["name"]}
  435. return model
  436. async def list_sessions(self):
  437. """Returns a list of dictionaries containing all the information from
  438. the session database"""
  439. c = self.cursor.execute("SELECT * FROM session")
  440. result = []
  441. # We need to use fetchall() here, because row_to_model can delete rows,
  442. # which messes up the cursor if we're iterating over rows.
  443. for row in c.fetchall():
  444. try:
  445. model = await self.row_to_model(row)
  446. result.append(model)
  447. except KeyError:
  448. pass
  449. return result
  450. async def delete_session(self, session_id):
  451. """Deletes the row in the session database with given session_id"""
  452. record = KernelSessionRecord(session_id=session_id)
  453. self._pending_sessions.update(record)
  454. session = await self.get_session(session_id=session_id)
  455. await ensure_async(self.kernel_manager.shutdown_kernel(session["kernel"]["id"]))
  456. self.cursor.execute("DELETE FROM session WHERE session_id=?", (session_id,))
  457. self._pending_sessions.remove(record)