datastore.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. """leveldb log datastore.
  2. Format is described at:
  3. https://github.com/google/leveldb/blob/master/doc/log_format.md
  4. block := record* trailer?
  5. record :=
  6. checksum: uint32 // crc32c of type and data[] ; little-endian
  7. length: uint16 // little-endian
  8. type: uint8 // One of FULL, FIRST, MIDDLE, LAST
  9. data: uint8[length]
  10. header :=
  11. ident: char[4]
  12. magic: uint16
  13. version: uint8
  14. """
  15. from __future__ import annotations
  16. import logging
  17. import os
  18. import struct
  19. import zlib
  20. from typing import TYPE_CHECKING
  21. if TYPE_CHECKING:
  22. from typing import IO, Any
  23. from wandb.proto.wandb_internal_pb2 import Record
  24. logger = logging.getLogger(__name__)
  25. LEVELDBLOG_HEADER_LEN = 7
  26. LEVELDBLOG_BLOCK_LEN = 32768
  27. LEVELDBLOG_DATA_LEN = LEVELDBLOG_BLOCK_LEN - LEVELDBLOG_HEADER_LEN
  28. LEVELDBLOG_FULL = 1
  29. LEVELDBLOG_FIRST = 2
  30. LEVELDBLOG_MIDDLE = 3
  31. LEVELDBLOG_LAST = 4
  32. LEVELDBLOG_HEADER_IDENT = ":W&B"
  33. LEVELDBLOG_HEADER_MAGIC = (
  34. 0xBEE1 # zlib.crc32(bytes("Weights & Biases", 'iso8859-1')) & 0xffff
  35. )
  36. LEVELDBLOG_HEADER_VERSION = 0
  37. try:
  38. bytes("", "ascii")
  39. def strtobytes(x):
  40. """Strtobytes."""
  41. return bytes(x, "iso8859-1")
  42. except Exception:
  43. strtobytes = str
  44. class DataStore:
  45. _index: int
  46. _flush_offset: int
  47. def __init__(self) -> None:
  48. self._opened_for_scan = False
  49. self._fp: IO[Any] | None = None
  50. self._index = 0
  51. self._flush_offset = 0
  52. self._size_bytes = 0
  53. self._crc = [0] * (LEVELDBLOG_LAST + 1)
  54. for x in range(1, LEVELDBLOG_LAST + 1):
  55. self._crc[x] = zlib.crc32(strtobytes(chr(x))) & 0xFFFFFFFF
  56. def open_for_write(self, fname: str) -> None:
  57. self._fname = fname
  58. logger.info("open: %s", fname)
  59. open_flags = "xb"
  60. self._fp = open(fname, open_flags)
  61. self._write_header()
  62. def open_for_append(self, fname):
  63. # TODO: implement
  64. self._fname = fname
  65. logger.info("open: %s", fname)
  66. self._fp = open(fname, "wb")
  67. # do something with _index
  68. def open_for_scan(self, fname):
  69. self._fname = fname
  70. logger.info("open for scan: %s", fname)
  71. self._fp = open(fname, "r+b")
  72. self._index = 0
  73. self._size_bytes = os.stat(fname).st_size
  74. self._opened_for_scan = True
  75. self._read_header()
  76. def seek(self, offset: int) -> None:
  77. self._fp.seek(offset) # type: ignore
  78. self._index = offset
  79. def get_offset(self) -> int:
  80. offset = self._fp.tell() # type: ignore
  81. return offset
  82. def in_last_block(self):
  83. """Determine if we're in the last block to handle in-progress writes."""
  84. return self._index > self._size_bytes - LEVELDBLOG_DATA_LEN
  85. def scan_record(self):
  86. assert self._opened_for_scan, "file not open for scanning"
  87. # TODO(jhr): handle some assertions as file corruption issues
  88. # assume we have enough room to read header, checked by caller?
  89. header = self._fp.read(LEVELDBLOG_HEADER_LEN)
  90. if len(header) == 0:
  91. return None
  92. assert len(header) == LEVELDBLOG_HEADER_LEN, (
  93. f"record header is {len(header)} bytes instead of the expected {LEVELDBLOG_HEADER_LEN}"
  94. )
  95. fields = struct.unpack("<IHB", header)
  96. checksum, dlength, dtype = fields
  97. # check len, better fit in the block
  98. self._index += LEVELDBLOG_HEADER_LEN
  99. data = self._fp.read(dlength)
  100. checksum_computed = zlib.crc32(data, self._crc[dtype]) & 0xFFFFFFFF
  101. assert checksum == checksum_computed, (
  102. "record checksum is invalid, data may be corrupt"
  103. )
  104. self._index += dlength
  105. return dtype, data
  106. def scan_data(self):
  107. # TODO(jhr): handle some assertions as file corruption issues
  108. # how much left in the block. if less than header len, read as pad,
  109. offset = self._index % LEVELDBLOG_BLOCK_LEN
  110. space_left = LEVELDBLOG_BLOCK_LEN - offset
  111. if space_left < LEVELDBLOG_HEADER_LEN:
  112. pad_check = strtobytes("\x00" * space_left)
  113. pad = self._fp.read(space_left)
  114. # verify they are zero
  115. assert pad == pad_check, "invalid padding"
  116. self._index += space_left
  117. record = self.scan_record()
  118. if record is None: # eof
  119. return None
  120. dtype, data = record
  121. if dtype == LEVELDBLOG_FULL:
  122. return data
  123. assert dtype == LEVELDBLOG_FIRST, (
  124. f"expected record to be type {LEVELDBLOG_FIRST} but found {dtype}"
  125. )
  126. while True:
  127. offset = self._index % LEVELDBLOG_BLOCK_LEN
  128. record = self.scan_record()
  129. if record is None: # eof
  130. return None
  131. dtype, new_data = record
  132. if dtype == LEVELDBLOG_LAST:
  133. data += new_data
  134. break
  135. assert dtype == LEVELDBLOG_MIDDLE, (
  136. f"expected record to be type {LEVELDBLOG_MIDDLE} but found {dtype}"
  137. )
  138. data += new_data
  139. return data
  140. def _write_header(self):
  141. data = struct.pack(
  142. "<4sHB",
  143. strtobytes(LEVELDBLOG_HEADER_IDENT),
  144. LEVELDBLOG_HEADER_MAGIC,
  145. LEVELDBLOG_HEADER_VERSION,
  146. )
  147. assert len(data) == LEVELDBLOG_HEADER_LEN, (
  148. f"header size is {len(data)} bytes, expected {LEVELDBLOG_HEADER_LEN}"
  149. )
  150. self._fp.write(data)
  151. self._index += len(data)
  152. def _read_header(self):
  153. header = self._fp.read(LEVELDBLOG_HEADER_LEN)
  154. assert len(header) == LEVELDBLOG_HEADER_LEN, (
  155. f"header is {len(header)} bytes instead of the expected {LEVELDBLOG_HEADER_LEN}"
  156. )
  157. ident, magic, version = struct.unpack("<4sHB", header)
  158. if ident != strtobytes(LEVELDBLOG_HEADER_IDENT):
  159. raise Exception("Invalid header")
  160. if magic != LEVELDBLOG_HEADER_MAGIC:
  161. raise Exception("Invalid header")
  162. if version != LEVELDBLOG_HEADER_VERSION:
  163. raise Exception("Invalid header")
  164. self._index += len(header)
  165. def _write_record(self, s, dtype=None):
  166. """Write record that must fit into a block."""
  167. # double check that there is enough space
  168. # (this is a precondition to calling this method)
  169. assert len(s) + LEVELDBLOG_HEADER_LEN <= (
  170. LEVELDBLOG_BLOCK_LEN - self._index % LEVELDBLOG_BLOCK_LEN
  171. ), "not enough space to write new records"
  172. dlength = len(s)
  173. dtype = dtype or LEVELDBLOG_FULL
  174. # print("record: length={} type={}".format(dlength, dtype))
  175. checksum = zlib.crc32(s, self._crc[dtype]) & 0xFFFFFFFF
  176. # logger.info("write_record: index=%d len=%d dtype=%d",
  177. # self._index, dlength, dtype)
  178. self._fp.write(struct.pack("<IHB", checksum, dlength, dtype))
  179. if dlength:
  180. self._fp.write(s)
  181. self._index += LEVELDBLOG_HEADER_LEN + len(s)
  182. def _write_data(self, s):
  183. start_offset = self._index
  184. offset = self._index % LEVELDBLOG_BLOCK_LEN
  185. space_left = LEVELDBLOG_BLOCK_LEN - offset
  186. data_used = 0
  187. data_left = len(s)
  188. # logger.info("write_data: index=%d offset=%d len=%d",
  189. # self._index, offset, data_left)
  190. if space_left < LEVELDBLOG_HEADER_LEN:
  191. pad = "\x00" * space_left
  192. self._fp.write(strtobytes(pad))
  193. self._index += space_left
  194. offset = 0
  195. space_left = LEVELDBLOG_BLOCK_LEN
  196. # does it fit in first (possibly partial) block?
  197. if data_left + LEVELDBLOG_HEADER_LEN <= space_left:
  198. self._write_record(s)
  199. else:
  200. # write first record (we could still be in the middle of a block,
  201. # but this write will end on a block boundary)
  202. data_room = space_left - LEVELDBLOG_HEADER_LEN
  203. self._write_record(s[:data_room], LEVELDBLOG_FIRST)
  204. data_used += data_room
  205. data_left -= data_room
  206. assert data_left, "data_left should be non-zero"
  207. # write middles (if any)
  208. while data_left > LEVELDBLOG_DATA_LEN:
  209. self._write_record(
  210. s[data_used : data_used + LEVELDBLOG_DATA_LEN],
  211. LEVELDBLOG_MIDDLE,
  212. )
  213. data_used += LEVELDBLOG_DATA_LEN
  214. data_left -= LEVELDBLOG_DATA_LEN
  215. # write last and flush the entire block to disk
  216. self._write_record(s[data_used:], LEVELDBLOG_LAST)
  217. self._fp.flush()
  218. os.fsync(self._fp.fileno())
  219. self._flush_offset = self._index
  220. return start_offset, self._index, self._flush_offset
  221. def ensure_flushed(self, off: int) -> None:
  222. self._fp.flush() # type: ignore
  223. def write(self, obj: Record) -> tuple[int, int, int]:
  224. """Write a protocol buffer.
  225. Args:
  226. obj: Protocol buffer to write.
  227. Returns:
  228. (start_offset, end_offset, flush_offset) if successful,
  229. None otherwise
  230. """
  231. raw_size = obj.ByteSize()
  232. s = obj.SerializeToString()
  233. assert len(s) == raw_size, "invalid serialization"
  234. ret = self._write_data(s)
  235. return ret
  236. def close(self) -> None:
  237. if self._fp is not None:
  238. logger.info("close: %s", self._fname)
  239. self._fp.close()