| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283 |
- """leveldb log datastore.
- Format is described at:
- https://github.com/google/leveldb/blob/master/doc/log_format.md
- block := record* trailer?
- record :=
- checksum: uint32 // crc32c of type and data[] ; little-endian
- length: uint16 // little-endian
- type: uint8 // One of FULL, FIRST, MIDDLE, LAST
- data: uint8[length]
- header :=
- ident: char[4]
- magic: uint16
- version: uint8
- """
- from __future__ import annotations
- import logging
- import os
- import struct
- import zlib
- from typing import TYPE_CHECKING
- if TYPE_CHECKING:
- from typing import IO, Any
- from wandb.proto.wandb_internal_pb2 import Record
- logger = logging.getLogger(__name__)
- LEVELDBLOG_HEADER_LEN = 7
- LEVELDBLOG_BLOCK_LEN = 32768
- LEVELDBLOG_DATA_LEN = LEVELDBLOG_BLOCK_LEN - LEVELDBLOG_HEADER_LEN
- LEVELDBLOG_FULL = 1
- LEVELDBLOG_FIRST = 2
- LEVELDBLOG_MIDDLE = 3
- LEVELDBLOG_LAST = 4
- LEVELDBLOG_HEADER_IDENT = ":W&B"
- LEVELDBLOG_HEADER_MAGIC = (
- 0xBEE1 # zlib.crc32(bytes("Weights & Biases", 'iso8859-1')) & 0xffff
- )
- LEVELDBLOG_HEADER_VERSION = 0
- try:
- bytes("", "ascii")
- def strtobytes(x):
- """Strtobytes."""
- return bytes(x, "iso8859-1")
- except Exception:
- strtobytes = str
- class DataStore:
- _index: int
- _flush_offset: int
- def __init__(self) -> None:
- self._opened_for_scan = False
- self._fp: IO[Any] | None = None
- self._index = 0
- self._flush_offset = 0
- self._size_bytes = 0
- self._crc = [0] * (LEVELDBLOG_LAST + 1)
- for x in range(1, LEVELDBLOG_LAST + 1):
- self._crc[x] = zlib.crc32(strtobytes(chr(x))) & 0xFFFFFFFF
- def open_for_write(self, fname: str) -> None:
- self._fname = fname
- logger.info("open: %s", fname)
- open_flags = "xb"
- self._fp = open(fname, open_flags)
- self._write_header()
- def open_for_append(self, fname):
- # TODO: implement
- self._fname = fname
- logger.info("open: %s", fname)
- self._fp = open(fname, "wb")
- # do something with _index
- def open_for_scan(self, fname):
- self._fname = fname
- logger.info("open for scan: %s", fname)
- self._fp = open(fname, "r+b")
- self._index = 0
- self._size_bytes = os.stat(fname).st_size
- self._opened_for_scan = True
- self._read_header()
- def seek(self, offset: int) -> None:
- self._fp.seek(offset) # type: ignore
- self._index = offset
- def get_offset(self) -> int:
- offset = self._fp.tell() # type: ignore
- return offset
- def in_last_block(self):
- """Determine if we're in the last block to handle in-progress writes."""
- return self._index > self._size_bytes - LEVELDBLOG_DATA_LEN
- def scan_record(self):
- assert self._opened_for_scan, "file not open for scanning"
- # TODO(jhr): handle some assertions as file corruption issues
- # assume we have enough room to read header, checked by caller?
- header = self._fp.read(LEVELDBLOG_HEADER_LEN)
- if len(header) == 0:
- return None
- assert len(header) == LEVELDBLOG_HEADER_LEN, (
- f"record header is {len(header)} bytes instead of the expected {LEVELDBLOG_HEADER_LEN}"
- )
- fields = struct.unpack("<IHB", header)
- checksum, dlength, dtype = fields
- # check len, better fit in the block
- self._index += LEVELDBLOG_HEADER_LEN
- data = self._fp.read(dlength)
- checksum_computed = zlib.crc32(data, self._crc[dtype]) & 0xFFFFFFFF
- assert checksum == checksum_computed, (
- "record checksum is invalid, data may be corrupt"
- )
- self._index += dlength
- return dtype, data
- def scan_data(self):
- # TODO(jhr): handle some assertions as file corruption issues
- # how much left in the block. if less than header len, read as pad,
- offset = self._index % LEVELDBLOG_BLOCK_LEN
- space_left = LEVELDBLOG_BLOCK_LEN - offset
- if space_left < LEVELDBLOG_HEADER_LEN:
- pad_check = strtobytes("\x00" * space_left)
- pad = self._fp.read(space_left)
- # verify they are zero
- assert pad == pad_check, "invalid padding"
- self._index += space_left
- record = self.scan_record()
- if record is None: # eof
- return None
- dtype, data = record
- if dtype == LEVELDBLOG_FULL:
- return data
- assert dtype == LEVELDBLOG_FIRST, (
- f"expected record to be type {LEVELDBLOG_FIRST} but found {dtype}"
- )
- while True:
- offset = self._index % LEVELDBLOG_BLOCK_LEN
- record = self.scan_record()
- if record is None: # eof
- return None
- dtype, new_data = record
- if dtype == LEVELDBLOG_LAST:
- data += new_data
- break
- assert dtype == LEVELDBLOG_MIDDLE, (
- f"expected record to be type {LEVELDBLOG_MIDDLE} but found {dtype}"
- )
- data += new_data
- return data
- def _write_header(self):
- data = struct.pack(
- "<4sHB",
- strtobytes(LEVELDBLOG_HEADER_IDENT),
- LEVELDBLOG_HEADER_MAGIC,
- LEVELDBLOG_HEADER_VERSION,
- )
- assert len(data) == LEVELDBLOG_HEADER_LEN, (
- f"header size is {len(data)} bytes, expected {LEVELDBLOG_HEADER_LEN}"
- )
- self._fp.write(data)
- self._index += len(data)
- def _read_header(self):
- header = self._fp.read(LEVELDBLOG_HEADER_LEN)
- assert len(header) == LEVELDBLOG_HEADER_LEN, (
- f"header is {len(header)} bytes instead of the expected {LEVELDBLOG_HEADER_LEN}"
- )
- ident, magic, version = struct.unpack("<4sHB", header)
- if ident != strtobytes(LEVELDBLOG_HEADER_IDENT):
- raise Exception("Invalid header")
- if magic != LEVELDBLOG_HEADER_MAGIC:
- raise Exception("Invalid header")
- if version != LEVELDBLOG_HEADER_VERSION:
- raise Exception("Invalid header")
- self._index += len(header)
- def _write_record(self, s, dtype=None):
- """Write record that must fit into a block."""
- # double check that there is enough space
- # (this is a precondition to calling this method)
- assert len(s) + LEVELDBLOG_HEADER_LEN <= (
- LEVELDBLOG_BLOCK_LEN - self._index % LEVELDBLOG_BLOCK_LEN
- ), "not enough space to write new records"
- dlength = len(s)
- dtype = dtype or LEVELDBLOG_FULL
- # print("record: length={} type={}".format(dlength, dtype))
- checksum = zlib.crc32(s, self._crc[dtype]) & 0xFFFFFFFF
- # logger.info("write_record: index=%d len=%d dtype=%d",
- # self._index, dlength, dtype)
- self._fp.write(struct.pack("<IHB", checksum, dlength, dtype))
- if dlength:
- self._fp.write(s)
- self._index += LEVELDBLOG_HEADER_LEN + len(s)
- def _write_data(self, s):
- start_offset = self._index
- offset = self._index % LEVELDBLOG_BLOCK_LEN
- space_left = LEVELDBLOG_BLOCK_LEN - offset
- data_used = 0
- data_left = len(s)
- # logger.info("write_data: index=%d offset=%d len=%d",
- # self._index, offset, data_left)
- if space_left < LEVELDBLOG_HEADER_LEN:
- pad = "\x00" * space_left
- self._fp.write(strtobytes(pad))
- self._index += space_left
- offset = 0
- space_left = LEVELDBLOG_BLOCK_LEN
- # does it fit in first (possibly partial) block?
- if data_left + LEVELDBLOG_HEADER_LEN <= space_left:
- self._write_record(s)
- else:
- # write first record (we could still be in the middle of a block,
- # but this write will end on a block boundary)
- data_room = space_left - LEVELDBLOG_HEADER_LEN
- self._write_record(s[:data_room], LEVELDBLOG_FIRST)
- data_used += data_room
- data_left -= data_room
- assert data_left, "data_left should be non-zero"
- # write middles (if any)
- while data_left > LEVELDBLOG_DATA_LEN:
- self._write_record(
- s[data_used : data_used + LEVELDBLOG_DATA_LEN],
- LEVELDBLOG_MIDDLE,
- )
- data_used += LEVELDBLOG_DATA_LEN
- data_left -= LEVELDBLOG_DATA_LEN
- # write last and flush the entire block to disk
- self._write_record(s[data_used:], LEVELDBLOG_LAST)
- self._fp.flush()
- os.fsync(self._fp.fileno())
- self._flush_offset = self._index
- return start_offset, self._index, self._flush_offset
- def ensure_flushed(self, off: int) -> None:
- self._fp.flush() # type: ignore
- def write(self, obj: Record) -> tuple[int, int, int]:
- """Write a protocol buffer.
- Args:
- obj: Protocol buffer to write.
- Returns:
- (start_offset, end_offset, flush_offset) if successful,
- None otherwise
- """
- raw_size = obj.ByteSize()
- s = obj.SerializeToString()
- assert len(s) == raw_size, "invalid serialization"
- ret = self._write_data(s)
- return ret
- def close(self) -> None:
- if self._fp is not None:
- logger.info("close: %s", self._fname)
- self._fp.close()
|