caching.py 33 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004
  1. from __future__ import annotations
  2. import collections
  3. import functools
  4. import logging
  5. import math
  6. import os
  7. import threading
  8. from collections import OrderedDict
  9. from collections.abc import Callable
  10. from concurrent.futures import Future, ThreadPoolExecutor
  11. from itertools import groupby
  12. from operator import itemgetter
  13. from typing import TYPE_CHECKING, Any, ClassVar, Generic, NamedTuple, TypeVar
  14. if TYPE_CHECKING:
  15. import mmap
  16. from typing_extensions import ParamSpec
  17. P = ParamSpec("P")
  18. else:
  19. P = TypeVar("P")
  20. T = TypeVar("T")
  21. logger = logging.getLogger("fsspec.caching")
  22. Fetcher = Callable[[int, int], bytes] # Maps (start, end) to bytes
  23. MultiFetcher = Callable[[list[int, int]], bytes] # Maps [(start, end)] to bytes
  24. class BaseCache:
  25. """Pass-though cache: doesn't keep anything, calls every time
  26. Acts as base class for other cachers
  27. Parameters
  28. ----------
  29. blocksize: int
  30. How far to read ahead in numbers of bytes
  31. fetcher: func
  32. Function of the form f(start, end) which gets bytes from remote as
  33. specified
  34. size: int
  35. How big this file is
  36. """
  37. name: ClassVar[str] = "none"
  38. def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None:
  39. self.blocksize = blocksize
  40. self.nblocks = 0
  41. self.fetcher = fetcher
  42. self.size = size
  43. self.hit_count = 0
  44. self.miss_count = 0
  45. # the bytes that we actually requested
  46. self.total_requested_bytes = 0
  47. def _fetch(self, start: int | None, stop: int | None) -> bytes:
  48. if start is None:
  49. start = 0
  50. if stop is None:
  51. stop = self.size
  52. if start >= self.size or start >= stop:
  53. return b""
  54. return self.fetcher(start, stop)
  55. def _reset_stats(self) -> None:
  56. """Reset hit and miss counts for a more ganular report e.g. by file."""
  57. self.hit_count = 0
  58. self.miss_count = 0
  59. self.total_requested_bytes = 0
  60. def _log_stats(self) -> str:
  61. """Return a formatted string of the cache statistics."""
  62. if self.hit_count == 0 and self.miss_count == 0:
  63. # a cache that does nothing, this is for logs only
  64. return ""
  65. return f" , {self.name}: {self.hit_count} hits, {self.miss_count} misses, {self.total_requested_bytes} total requested bytes"
  66. def __repr__(self) -> str:
  67. # TODO: use rich for better formatting
  68. return f"""
  69. <{self.__class__.__name__}:
  70. block size : {self.blocksize}
  71. block count : {self.nblocks}
  72. file size : {self.size}
  73. cache hits : {self.hit_count}
  74. cache misses: {self.miss_count}
  75. total requested bytes: {self.total_requested_bytes}>
  76. """
  77. class MMapCache(BaseCache):
  78. """memory-mapped sparse file cache
  79. Opens temporary file, which is filled blocks-wise when data is requested.
  80. Ensure there is enough disc space in the temporary location.
  81. This cache method might only work on posix
  82. Parameters
  83. ----------
  84. blocksize: int
  85. How far to read ahead in numbers of bytes
  86. fetcher: Fetcher
  87. Function of the form f(start, end) which gets bytes from remote as
  88. specified
  89. size: int
  90. How big this file is
  91. location: str
  92. Where to create the temporary file. If None, a temporary file is
  93. created using tempfile.TemporaryFile().
  94. blocks: set[int]
  95. Set of block numbers that have already been fetched. If None, an empty
  96. set is created.
  97. multi_fetcher: MultiFetcher
  98. Function of the form f([(start, end)]) which gets bytes from remote
  99. as specified. This function is used to fetch multiple blocks at once.
  100. If not specified, the fetcher function is used instead.
  101. """
  102. name = "mmap"
  103. def __init__(
  104. self,
  105. blocksize: int,
  106. fetcher: Fetcher,
  107. size: int,
  108. location: str | None = None,
  109. blocks: set[int] | None = None,
  110. multi_fetcher: MultiFetcher | None = None,
  111. ) -> None:
  112. super().__init__(blocksize, fetcher, size)
  113. self.blocks = set() if blocks is None else blocks
  114. self.location = location
  115. self.multi_fetcher = multi_fetcher
  116. self.cache = self._makefile()
  117. def _makefile(self) -> mmap.mmap | bytearray:
  118. import mmap
  119. import tempfile
  120. if self.size == 0:
  121. return bytearray()
  122. # posix version
  123. if self.location is None or not os.path.exists(self.location):
  124. if self.location is None:
  125. fd = tempfile.TemporaryFile()
  126. self.blocks = set()
  127. else:
  128. fd = open(self.location, "wb+")
  129. fd.seek(self.size - 1)
  130. fd.write(b"1")
  131. fd.flush()
  132. else:
  133. fd = open(self.location, "r+b")
  134. return mmap.mmap(fd.fileno(), self.size)
  135. def _fetch(self, start: int | None, end: int | None) -> bytes:
  136. logger.debug(f"MMap cache fetching {start}-{end}")
  137. if start is None:
  138. start = 0
  139. if end is None:
  140. end = self.size
  141. if start >= self.size or start >= end:
  142. return b""
  143. start_block = start // self.blocksize
  144. end_block = end // self.blocksize
  145. block_range = range(start_block, end_block + 1)
  146. # Determine which blocks need to be fetched. This sequence is sorted by construction.
  147. need = (i for i in block_range if i not in self.blocks)
  148. # Count the number of blocks already cached
  149. self.hit_count += sum(1 for i in block_range if i in self.blocks)
  150. ranges = []
  151. # Consolidate needed blocks.
  152. # Algorithm adapted from Python 2.x itertools documentation.
  153. # We are grouping an enumerated sequence of blocks. By comparing when the difference
  154. # between an ascending range (provided by enumerate) and the needed block numbers
  155. # we can detect when the block number skips values. The key computes this difference.
  156. # Whenever the difference changes, we know that we have previously cached block(s),
  157. # and a new group is started. In other words, this algorithm neatly groups
  158. # runs of consecutive block numbers so they can be fetched together.
  159. for _, _blocks in groupby(enumerate(need), key=lambda x: x[0] - x[1]):
  160. # Extract the blocks from the enumerated sequence
  161. _blocks = tuple(map(itemgetter(1), _blocks))
  162. # Compute start of first block
  163. sstart = _blocks[0] * self.blocksize
  164. # Compute the end of the last block. Last block may not be full size.
  165. send = min(_blocks[-1] * self.blocksize + self.blocksize, self.size)
  166. # Fetch bytes (could be multiple consecutive blocks)
  167. self.total_requested_bytes += send - sstart
  168. logger.debug(
  169. f"MMap get blocks {_blocks[0]}-{_blocks[-1]} ({sstart}-{send})"
  170. )
  171. ranges.append((sstart, send))
  172. # Update set of cached blocks
  173. self.blocks.update(_blocks)
  174. # Update cache statistics with number of blocks we had to cache
  175. self.miss_count += len(_blocks)
  176. if not ranges:
  177. return self.cache[start:end]
  178. if self.multi_fetcher:
  179. logger.debug(f"MMap get blocks {ranges}")
  180. for idx, r in enumerate(self.multi_fetcher(ranges)):
  181. sstart, send = ranges[idx]
  182. logger.debug(f"MMap copy block ({sstart}-{send}")
  183. self.cache[sstart:send] = r
  184. else:
  185. for sstart, send in ranges:
  186. logger.debug(f"MMap get block ({sstart}-{send}")
  187. self.cache[sstart:send] = self.fetcher(sstart, send)
  188. return self.cache[start:end]
  189. def __getstate__(self) -> dict[str, Any]:
  190. state = self.__dict__.copy()
  191. # Remove the unpicklable entries.
  192. del state["cache"]
  193. return state
  194. def __setstate__(self, state: dict[str, Any]) -> None:
  195. # Restore instance attributes
  196. self.__dict__.update(state)
  197. self.cache = self._makefile()
  198. class ReadAheadCache(BaseCache):
  199. """Cache which reads only when we get beyond a block of data
  200. This is a much simpler version of BytesCache, and does not attempt to
  201. fill holes in the cache or keep fragments alive. It is best suited to
  202. many small reads in a sequential order (e.g., reading lines from a file).
  203. """
  204. name = "readahead"
  205. def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None:
  206. super().__init__(blocksize, fetcher, size)
  207. self.cache = b""
  208. self.start = 0
  209. self.end = 0
  210. def _fetch(self, start: int | None, end: int | None) -> bytes:
  211. if start is None:
  212. start = 0
  213. if end is None or end > self.size:
  214. end = self.size
  215. if start >= self.size or start >= end:
  216. return b""
  217. l = end - start
  218. if start >= self.start and end <= self.end:
  219. # cache hit
  220. self.hit_count += 1
  221. return self.cache[start - self.start : end - self.start]
  222. elif self.start <= start < self.end:
  223. # partial hit
  224. self.miss_count += 1
  225. part = self.cache[start - self.start :]
  226. l -= len(part)
  227. start = self.end
  228. else:
  229. # miss
  230. self.miss_count += 1
  231. part = b""
  232. end = min(self.size, end + self.blocksize)
  233. self.total_requested_bytes += end - start
  234. self.cache = self.fetcher(start, end) # new block replaces old
  235. self.start = start
  236. self.end = self.start + len(self.cache)
  237. return part + self.cache[:l]
  238. class FirstChunkCache(BaseCache):
  239. """Caches the first block of a file only
  240. This may be useful for file types where the metadata is stored in the header,
  241. but is randomly accessed.
  242. """
  243. name = "first"
  244. def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None:
  245. if blocksize > size:
  246. # this will buffer the whole thing
  247. blocksize = size
  248. super().__init__(blocksize, fetcher, size)
  249. self.cache: bytes | None = None
  250. def _fetch(self, start: int | None, end: int | None) -> bytes:
  251. start = start or 0
  252. if start > self.size:
  253. logger.debug("FirstChunkCache: requested start > file size")
  254. return b""
  255. end = min(end, self.size)
  256. if start < self.blocksize:
  257. if self.cache is None:
  258. self.miss_count += 1
  259. if end > self.blocksize:
  260. self.total_requested_bytes += end
  261. data = self.fetcher(0, end)
  262. self.cache = data[: self.blocksize]
  263. return data[start:]
  264. self.cache = self.fetcher(0, self.blocksize)
  265. self.total_requested_bytes += self.blocksize
  266. part = self.cache[start:end]
  267. if end > self.blocksize:
  268. self.total_requested_bytes += end - self.blocksize
  269. part += self.fetcher(self.blocksize, end)
  270. self.hit_count += 1
  271. return part
  272. else:
  273. self.miss_count += 1
  274. self.total_requested_bytes += end - start
  275. return self.fetcher(start, end)
  276. class BlockCache(BaseCache):
  277. """
  278. Cache holding memory as a set of blocks.
  279. Requests are only ever made ``blocksize`` at a time, and are
  280. stored in an LRU cache. The least recently accessed block is
  281. discarded when more than ``maxblocks`` are stored.
  282. Parameters
  283. ----------
  284. blocksize : int
  285. The number of bytes to store in each block.
  286. Requests are only ever made for ``blocksize``, so this
  287. should balance the overhead of making a request against
  288. the granularity of the blocks.
  289. fetcher : Callable
  290. size : int
  291. The total size of the file being cached.
  292. maxblocks : int
  293. The maximum number of blocks to cache for. The maximum memory
  294. use for this cache is then ``blocksize * maxblocks``.
  295. """
  296. name = "blockcache"
  297. def __init__(
  298. self, blocksize: int, fetcher: Fetcher, size: int, maxblocks: int = 32
  299. ) -> None:
  300. super().__init__(blocksize, fetcher, size)
  301. self.nblocks = math.ceil(size / blocksize)
  302. self.maxblocks = maxblocks
  303. self._fetch_block_cached = functools.lru_cache(maxblocks)(self._fetch_block)
  304. def cache_info(self):
  305. """
  306. The statistics on the block cache.
  307. Returns
  308. -------
  309. NamedTuple
  310. Returned directly from the LRU Cache used internally.
  311. """
  312. return self._fetch_block_cached.cache_info()
  313. def __getstate__(self) -> dict[str, Any]:
  314. state = self.__dict__
  315. del state["_fetch_block_cached"]
  316. return state
  317. def __setstate__(self, state: dict[str, Any]) -> None:
  318. self.__dict__.update(state)
  319. self._fetch_block_cached = functools.lru_cache(state["maxblocks"])(
  320. self._fetch_block
  321. )
  322. def _fetch(self, start: int | None, end: int | None) -> bytes:
  323. if start is None:
  324. start = 0
  325. if end is None:
  326. end = self.size
  327. if start >= self.size or start >= end:
  328. return b""
  329. return self._read_cache(
  330. start, end, start // self.blocksize, (end - 1) // self.blocksize
  331. )
  332. def _fetch_block(self, block_number: int) -> bytes:
  333. """
  334. Fetch the block of data for `block_number`.
  335. """
  336. if block_number > self.nblocks:
  337. raise ValueError(
  338. f"'block_number={block_number}' is greater than "
  339. f"the number of blocks ({self.nblocks})"
  340. )
  341. start = block_number * self.blocksize
  342. end = start + self.blocksize
  343. self.total_requested_bytes += end - start
  344. self.miss_count += 1
  345. logger.info("BlockCache fetching block %d", block_number)
  346. block_contents = super()._fetch(start, end)
  347. return block_contents
  348. def _read_cache(
  349. self, start: int, end: int, start_block_number: int, end_block_number: int
  350. ) -> bytes:
  351. """
  352. Read from our block cache.
  353. Parameters
  354. ----------
  355. start, end : int
  356. The start and end byte positions.
  357. start_block_number, end_block_number : int
  358. The start and end block numbers.
  359. """
  360. start_pos = start % self.blocksize
  361. end_pos = end % self.blocksize
  362. if end_pos == 0:
  363. end_pos = self.blocksize
  364. self.hit_count += 1
  365. if start_block_number == end_block_number:
  366. block: bytes = self._fetch_block_cached(start_block_number)
  367. return block[start_pos:end_pos]
  368. else:
  369. # read from the initial
  370. out = [self._fetch_block_cached(start_block_number)[start_pos:]]
  371. # intermediate blocks
  372. # Note: it'd be nice to combine these into one big request. However
  373. # that doesn't play nicely with our LRU cache.
  374. out.extend(
  375. map(
  376. self._fetch_block_cached,
  377. range(start_block_number + 1, end_block_number),
  378. )
  379. )
  380. # final block
  381. out.append(self._fetch_block_cached(end_block_number)[:end_pos])
  382. return b"".join(out)
  383. class BytesCache(BaseCache):
  384. """Cache which holds data in a in-memory bytes object
  385. Implements read-ahead by the block size, for semi-random reads progressing
  386. through the file.
  387. Parameters
  388. ----------
  389. trim: bool
  390. As we read more data, whether to discard the start of the buffer when
  391. we are more than a blocksize ahead of it.
  392. """
  393. name: ClassVar[str] = "bytes"
  394. def __init__(
  395. self, blocksize: int, fetcher: Fetcher, size: int, trim: bool = True
  396. ) -> None:
  397. super().__init__(blocksize, fetcher, size)
  398. self.cache = b""
  399. self.start: int | None = None
  400. self.end: int | None = None
  401. self.trim = trim
  402. def _fetch(self, start: int | None, end: int | None) -> bytes:
  403. # TODO: only set start/end after fetch, in case it fails?
  404. # is this where retry logic might go?
  405. if start is None:
  406. start = 0
  407. if end is None:
  408. end = self.size
  409. if start >= self.size or start >= end:
  410. return b""
  411. if (
  412. self.start is not None
  413. and start >= self.start
  414. and self.end is not None
  415. and end < self.end
  416. ):
  417. # cache hit: we have all the required data
  418. offset = start - self.start
  419. self.hit_count += 1
  420. return self.cache[offset : offset + end - start]
  421. if self.blocksize:
  422. bend = min(self.size, end + self.blocksize)
  423. else:
  424. bend = end
  425. if bend == start or start > self.size:
  426. return b""
  427. if (self.start is None or start < self.start) and (
  428. self.end is None or end > self.end
  429. ):
  430. # First read, or extending both before and after
  431. self.total_requested_bytes += bend - start
  432. self.miss_count += 1
  433. self.cache = self.fetcher(start, bend)
  434. self.start = start
  435. else:
  436. assert self.start is not None
  437. assert self.end is not None
  438. self.miss_count += 1
  439. if start < self.start:
  440. if self.end is None or self.end - end > self.blocksize:
  441. self.total_requested_bytes += bend - start
  442. self.cache = self.fetcher(start, bend)
  443. self.start = start
  444. else:
  445. self.total_requested_bytes += self.start - start
  446. new = self.fetcher(start, self.start)
  447. self.start = start
  448. self.cache = new + self.cache
  449. elif self.end is not None and bend > self.end:
  450. if self.end > self.size:
  451. pass
  452. elif end - self.end > self.blocksize:
  453. self.total_requested_bytes += bend - start
  454. self.cache = self.fetcher(start, bend)
  455. self.start = start
  456. else:
  457. self.total_requested_bytes += bend - self.end
  458. new = self.fetcher(self.end, bend)
  459. self.cache = self.cache + new
  460. self.end = self.start + len(self.cache)
  461. offset = start - self.start
  462. out = self.cache[offset : offset + end - start]
  463. if self.trim:
  464. num = (self.end - self.start) // (self.blocksize + 1)
  465. if num > 1:
  466. self.start += self.blocksize * num
  467. self.cache = self.cache[self.blocksize * num :]
  468. return out
  469. def __len__(self) -> int:
  470. return len(self.cache)
  471. class AllBytes(BaseCache):
  472. """Cache entire contents of the file"""
  473. name: ClassVar[str] = "all"
  474. def __init__(
  475. self,
  476. blocksize: int | None = None,
  477. fetcher: Fetcher | None = None,
  478. size: int | None = None,
  479. data: bytes | None = None,
  480. ) -> None:
  481. super().__init__(blocksize, fetcher, size) # type: ignore[arg-type]
  482. if data is None:
  483. self.miss_count += 1
  484. self.total_requested_bytes += self.size
  485. data = self.fetcher(0, self.size)
  486. self.data = data
  487. def _fetch(self, start: int | None, stop: int | None) -> bytes:
  488. self.hit_count += 1
  489. return self.data[start:stop]
  490. class KnownPartsOfAFile(BaseCache):
  491. """
  492. Cache holding known file parts.
  493. Parameters
  494. ----------
  495. blocksize: int
  496. How far to read ahead in numbers of bytes
  497. fetcher: func
  498. Function of the form f(start, end) which gets bytes from remote as
  499. specified
  500. size: int
  501. How big this file is
  502. data: dict
  503. A dictionary mapping explicit `(start, stop)` file-offset tuples
  504. with known bytes.
  505. strict: bool, default True
  506. Whether to fetch reads that go beyond a known byte-range boundary.
  507. If `False`, any read that ends outside a known part will be zero
  508. padded. Note that zero padding will not be used for reads that
  509. begin outside a known byte-range.
  510. """
  511. name: ClassVar[str] = "parts"
  512. def __init__(
  513. self,
  514. blocksize: int,
  515. fetcher: Fetcher,
  516. size: int,
  517. data: dict[tuple[int, int], bytes] | None = None,
  518. strict: bool = False,
  519. **_: Any,
  520. ):
  521. super().__init__(blocksize, fetcher, size)
  522. self.strict = strict
  523. # simple consolidation of contiguous blocks
  524. if data:
  525. old_offsets = sorted(data.keys())
  526. offsets = [old_offsets[0]]
  527. blocks = [data.pop(old_offsets[0])]
  528. for start, stop in old_offsets[1:]:
  529. start0, stop0 = offsets[-1]
  530. if start == stop0:
  531. offsets[-1] = (start0, stop)
  532. blocks[-1] += data.pop((start, stop))
  533. else:
  534. offsets.append((start, stop))
  535. blocks.append(data.pop((start, stop)))
  536. self.data = dict(zip(offsets, blocks))
  537. else:
  538. self.data = {}
  539. @property
  540. def size(self):
  541. return sum(_[1] - _[0] for _ in self.data)
  542. @size.setter
  543. def size(self, value):
  544. pass
  545. @property
  546. def nblocks(self):
  547. return len(self.data)
  548. @nblocks.setter
  549. def nblocks(self, value):
  550. pass
  551. def _fetch(self, start: int | None, stop: int | None) -> bytes:
  552. logger.debug("Known parts request %s %s", start, stop)
  553. if start is None:
  554. start = 0
  555. if stop is None:
  556. stop = self.size
  557. self.total_requested_bytes += stop - start
  558. out = b""
  559. started = False
  560. loc_old = 0
  561. for loc0, loc1 in sorted(self.data):
  562. if (loc0 <= start < loc1) and (loc0 <= stop <= loc1):
  563. # entirely within the block
  564. off = start - loc0
  565. self.hit_count += 1
  566. return self.data[(loc0, loc1)][off : off + stop - start]
  567. if stop <= loc0:
  568. break
  569. if started and loc0 > loc_old:
  570. # a gap where we need data
  571. self.miss_count += 1
  572. if self.strict:
  573. raise ValueError
  574. out += b"\x00" * (loc0 - loc_old)
  575. if loc0 <= start < loc1:
  576. # found the start
  577. self.hit_count += 1
  578. off = start - loc0
  579. out = self.data[(loc0, loc1)][off : off + stop - start]
  580. started = True
  581. elif start < loc0 and stop > loc1:
  582. # the whole block
  583. self.hit_count += 1
  584. out += self.data[(loc0, loc1)]
  585. elif loc0 <= stop <= loc1:
  586. # end block
  587. self.hit_count += 1
  588. out = out + self.data[(loc0, loc1)][: stop - loc0]
  589. return out
  590. loc_old = loc1
  591. self.miss_count += 1
  592. if started and not self.strict:
  593. out = out + b"\x00" * (stop - loc_old)
  594. return out
  595. raise ValueError
  596. class UpdatableLRU(Generic[P, T]):
  597. """
  598. Custom implementation of LRU cache that allows updating keys
  599. Used by BackgroundBlockCache
  600. """
  601. class CacheInfo(NamedTuple):
  602. hits: int
  603. misses: int
  604. maxsize: int
  605. currsize: int
  606. def __init__(self, func: Callable[P, T], max_size: int = 128) -> None:
  607. self._cache: OrderedDict[Any, T] = collections.OrderedDict()
  608. self._func = func
  609. self._max_size = max_size
  610. self._hits = 0
  611. self._misses = 0
  612. self._lock = threading.Lock()
  613. def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
  614. if kwargs:
  615. raise TypeError(f"Got unexpected keyword argument {kwargs.keys()}")
  616. with self._lock:
  617. if args in self._cache:
  618. self._cache.move_to_end(args)
  619. self._hits += 1
  620. return self._cache[args]
  621. result = self._func(*args, **kwargs)
  622. with self._lock:
  623. self._cache[args] = result
  624. self._misses += 1
  625. if len(self._cache) > self._max_size:
  626. self._cache.popitem(last=False)
  627. return result
  628. def is_key_cached(self, *args: Any) -> bool:
  629. with self._lock:
  630. return args in self._cache
  631. def add_key(self, result: T, *args: Any) -> None:
  632. with self._lock:
  633. self._cache[args] = result
  634. if len(self._cache) > self._max_size:
  635. self._cache.popitem(last=False)
  636. def cache_info(self) -> UpdatableLRU.CacheInfo:
  637. with self._lock:
  638. return self.CacheInfo(
  639. maxsize=self._max_size,
  640. currsize=len(self._cache),
  641. hits=self._hits,
  642. misses=self._misses,
  643. )
  644. class BackgroundBlockCache(BaseCache):
  645. """
  646. Cache holding memory as a set of blocks with pre-loading of
  647. the next block in the background.
  648. Requests are only ever made ``blocksize`` at a time, and are
  649. stored in an LRU cache. The least recently accessed block is
  650. discarded when more than ``maxblocks`` are stored. If the
  651. next block is not in cache, it is loaded in a separate thread
  652. in non-blocking way.
  653. Parameters
  654. ----------
  655. blocksize : int
  656. The number of bytes to store in each block.
  657. Requests are only ever made for ``blocksize``, so this
  658. should balance the overhead of making a request against
  659. the granularity of the blocks.
  660. fetcher : Callable
  661. size : int
  662. The total size of the file being cached.
  663. maxblocks : int
  664. The maximum number of blocks to cache for. The maximum memory
  665. use for this cache is then ``blocksize * maxblocks``.
  666. """
  667. name: ClassVar[str] = "background"
  668. def __init__(
  669. self, blocksize: int, fetcher: Fetcher, size: int, maxblocks: int = 32
  670. ) -> None:
  671. super().__init__(blocksize, fetcher, size)
  672. self.nblocks = math.ceil(size / blocksize)
  673. self.maxblocks = maxblocks
  674. self._fetch_block_cached = UpdatableLRU(self._fetch_block, maxblocks)
  675. self._thread_executor = ThreadPoolExecutor(max_workers=1)
  676. self._fetch_future_block_number: int | None = None
  677. self._fetch_future: Future[bytes] | None = None
  678. self._fetch_future_lock = threading.Lock()
  679. def cache_info(self) -> UpdatableLRU.CacheInfo:
  680. """
  681. The statistics on the block cache.
  682. Returns
  683. -------
  684. NamedTuple
  685. Returned directly from the LRU Cache used internally.
  686. """
  687. return self._fetch_block_cached.cache_info()
  688. def __getstate__(self) -> dict[str, Any]:
  689. state = self.__dict__
  690. del state["_fetch_block_cached"]
  691. del state["_thread_executor"]
  692. del state["_fetch_future_block_number"]
  693. del state["_fetch_future"]
  694. del state["_fetch_future_lock"]
  695. return state
  696. def __setstate__(self, state) -> None:
  697. self.__dict__.update(state)
  698. self._fetch_block_cached = UpdatableLRU(self._fetch_block, state["maxblocks"])
  699. self._thread_executor = ThreadPoolExecutor(max_workers=1)
  700. self._fetch_future_block_number = None
  701. self._fetch_future = None
  702. self._fetch_future_lock = threading.Lock()
  703. def _fetch(self, start: int | None, end: int | None) -> bytes:
  704. if start is None:
  705. start = 0
  706. if end is None:
  707. end = self.size
  708. if start >= self.size or start >= end:
  709. return b""
  710. # byte position -> block numbers
  711. start_block_number = start // self.blocksize
  712. end_block_number = end // self.blocksize
  713. fetch_future_block_number = None
  714. fetch_future = None
  715. with self._fetch_future_lock:
  716. # Background thread is running. Check we we can or must join it.
  717. if self._fetch_future is not None:
  718. assert self._fetch_future_block_number is not None
  719. if self._fetch_future.done():
  720. logger.info("BlockCache joined background fetch without waiting.")
  721. self._fetch_block_cached.add_key(
  722. self._fetch_future.result(), self._fetch_future_block_number
  723. )
  724. # Cleanup the fetch variables. Done with fetching the block.
  725. self._fetch_future_block_number = None
  726. self._fetch_future = None
  727. else:
  728. # Must join if we need the block for the current fetch
  729. must_join = bool(
  730. start_block_number
  731. <= self._fetch_future_block_number
  732. <= end_block_number
  733. )
  734. if must_join:
  735. # Copy to the local variables to release lock
  736. # before waiting for result
  737. fetch_future_block_number = self._fetch_future_block_number
  738. fetch_future = self._fetch_future
  739. # Cleanup the fetch variables. Have a local copy.
  740. self._fetch_future_block_number = None
  741. self._fetch_future = None
  742. # Need to wait for the future for the current read
  743. if fetch_future is not None:
  744. logger.info("BlockCache waiting for background fetch.")
  745. # Wait until result and put it in cache
  746. self._fetch_block_cached.add_key(
  747. fetch_future.result(), fetch_future_block_number
  748. )
  749. # these are cached, so safe to do multiple calls for the same start and end.
  750. for block_number in range(start_block_number, end_block_number + 1):
  751. self._fetch_block_cached(block_number)
  752. # fetch next block in the background if nothing is running in the background,
  753. # the block is within file and it is not already cached
  754. end_block_plus_1 = end_block_number + 1
  755. with self._fetch_future_lock:
  756. if (
  757. self._fetch_future is None
  758. and end_block_plus_1 <= self.nblocks
  759. and not self._fetch_block_cached.is_key_cached(end_block_plus_1)
  760. ):
  761. self._fetch_future_block_number = end_block_plus_1
  762. self._fetch_future = self._thread_executor.submit(
  763. self._fetch_block, end_block_plus_1, "async"
  764. )
  765. return self._read_cache(
  766. start,
  767. end,
  768. start_block_number=start_block_number,
  769. end_block_number=end_block_number,
  770. )
  771. def _fetch_block(self, block_number: int, log_info: str = "sync") -> bytes:
  772. """
  773. Fetch the block of data for `block_number`.
  774. """
  775. if block_number > self.nblocks:
  776. raise ValueError(
  777. f"'block_number={block_number}' is greater than "
  778. f"the number of blocks ({self.nblocks})"
  779. )
  780. start = block_number * self.blocksize
  781. end = start + self.blocksize
  782. logger.info("BlockCache fetching block (%s) %d", log_info, block_number)
  783. self.total_requested_bytes += end - start
  784. self.miss_count += 1
  785. block_contents = super()._fetch(start, end)
  786. return block_contents
  787. def _read_cache(
  788. self, start: int, end: int, start_block_number: int, end_block_number: int
  789. ) -> bytes:
  790. """
  791. Read from our block cache.
  792. Parameters
  793. ----------
  794. start, end : int
  795. The start and end byte positions.
  796. start_block_number, end_block_number : int
  797. The start and end block numbers.
  798. """
  799. start_pos = start % self.blocksize
  800. end_pos = end % self.blocksize
  801. # kind of pointless to count this as a hit, but it is
  802. self.hit_count += 1
  803. if start_block_number == end_block_number:
  804. block = self._fetch_block_cached(start_block_number)
  805. return block[start_pos:end_pos]
  806. else:
  807. # read from the initial
  808. out = [self._fetch_block_cached(start_block_number)[start_pos:]]
  809. # intermediate blocks
  810. # Note: it'd be nice to combine these into one big request. However
  811. # that doesn't play nicely with our LRU cache.
  812. out.extend(
  813. map(
  814. self._fetch_block_cached,
  815. range(start_block_number + 1, end_block_number),
  816. )
  817. )
  818. # final block
  819. out.append(self._fetch_block_cached(end_block_number)[:end_pos])
  820. return b"".join(out)
  821. caches: dict[str | None, type[BaseCache]] = {
  822. # one custom case
  823. None: BaseCache,
  824. }
  825. def register_cache(cls: type[BaseCache], clobber: bool = False) -> None:
  826. """'Register' cache implementation.
  827. Parameters
  828. ----------
  829. clobber: bool, optional
  830. If set to True (default is False) - allow to overwrite existing
  831. entry.
  832. Raises
  833. ------
  834. ValueError
  835. """
  836. name = cls.name
  837. if not clobber and name in caches:
  838. raise ValueError(f"Cache with name {name!r} is already known: {caches[name]}")
  839. caches[name] = cls
  840. for c in (
  841. BaseCache,
  842. MMapCache,
  843. BytesCache,
  844. ReadAheadCache,
  845. BlockCache,
  846. FirstChunkCache,
  847. AllBytes,
  848. KnownPartsOfAFile,
  849. BackgroundBlockCache,
  850. ):
  851. register_cache(c)