string_file_wrapper.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import os
  2. from typing import TextIO
  3. class StringFileWrapper:
  4. # This is a trick to simplify the code, transform the filedescriptor handling into a string handling
  5. def __init__(self, fd: TextIO, chunk_length: int) -> None:
  6. """
  7. Initialize the StringFileWrapper with a file descriptor and chunk length.
  8. Args:
  9. fd (TextIO): The file descriptor to wrap.
  10. CHUNK_LENGTH (int): The length of each chunk to read from the file.
  11. Attributes:
  12. fd (TextIO): The wrapped file descriptor.
  13. length (int): The total length of the file content.
  14. buffers (dict[int, str]): Dictionary to store chunks of file content.
  15. buffer_length (int): The length of each buffer chunk.
  16. """
  17. self.fd = fd
  18. # Buffers are chunks of text read from the file and cached to reduce disk access.
  19. self.buffers: dict[int, str] = {}
  20. if not chunk_length or chunk_length < 2:
  21. chunk_length = 1_000_000
  22. # chunk_length now refers to the number of characters per chunk.
  23. self.buffer_length = chunk_length
  24. # Keep track of the starting file position ("cookie") for each chunk so we can
  25. # seek safely without landing in the middle of a multibyte code point.
  26. self._chunk_positions: list[int] = [0]
  27. self.length: int | None = None
  28. def get_buffer(self, index: int) -> str:
  29. """
  30. Retrieve or load a buffer chunk from the file.
  31. Args:
  32. index (int): The index of the buffer chunk to retrieve.
  33. Returns:
  34. str: The buffer chunk at the specified index.
  35. """
  36. if index < 0:
  37. raise IndexError("Negative indexing is not supported")
  38. cached = self.buffers.get(index)
  39. if cached is not None:
  40. return cached
  41. self._ensure_chunk_position(index)
  42. start_pos = self._chunk_positions[index]
  43. self.fd.seek(start_pos)
  44. chunk = self.fd.read(self.buffer_length)
  45. if not chunk:
  46. raise IndexError("Chunk index out of range")
  47. end_pos = self.fd.tell()
  48. if len(self._chunk_positions) <= index + 1:
  49. self._chunk_positions.append(end_pos)
  50. if len(chunk) < self.buffer_length:
  51. self.length = index * self.buffer_length + len(chunk)
  52. self.buffers[index] = chunk
  53. # Save memory by keeping max 2MB buffer chunks and min 2 chunks
  54. max_buffers = max(2, int(2_000_000 / self.buffer_length))
  55. if len(self.buffers) > max_buffers:
  56. oldest_key = next(iter(self.buffers))
  57. if oldest_key != index:
  58. self.buffers.pop(oldest_key)
  59. return chunk
  60. def __getitem__(self, index: int | slice) -> str:
  61. """
  62. Retrieve a character or a slice of characters from the file.
  63. Args:
  64. index (Union[int, slice]): The index or slice of characters to retrieve.
  65. Returns:
  66. str: The character(s) at the specified index or slice.
  67. """
  68. # The buffer is an array that is seek like a RAM:
  69. # self.buffers[index]: the row in the array of length 1MB, index is `i` modulo CHUNK_LENGTH
  70. # self.buffures[index][j]: the column of the row that is `i` remainder CHUNK_LENGTH
  71. if isinstance(index, slice):
  72. start, stop, step = self._normalize_slice(index)
  73. if step == 0:
  74. raise ValueError("slice step cannot be zero")
  75. if step != 1:
  76. return "".join(self[i] for i in range(start, stop, step))
  77. if start >= stop:
  78. return ""
  79. return self._slice_from_buffers(start, stop)
  80. if index < 0:
  81. index += len(self)
  82. if index < 0:
  83. raise IndexError("string index out of range")
  84. buffer_index = index // self.buffer_length
  85. buffer = self.get_buffer(buffer_index)
  86. return buffer[index % self.buffer_length]
  87. def __len__(self) -> int:
  88. """
  89. Get the total length of the file.
  90. Returns:
  91. int: The total number of characters in the file.
  92. """
  93. if self.length is None:
  94. while self.length is None:
  95. chunk_index = len(self._chunk_positions)
  96. self._ensure_chunk_position(chunk_index)
  97. assert self.length is not None
  98. return self.length
  99. def _normalize_slice(self, index: slice) -> tuple[int, int, int]:
  100. total_len = len(self)
  101. start = 0 if index.start is None else index.start
  102. stop = total_len if index.stop is None else index.stop
  103. step = 1 if index.step is None else index.step
  104. if start < 0:
  105. start += total_len
  106. if stop < 0:
  107. stop += total_len
  108. start = max(start, 0)
  109. stop = min(stop, total_len)
  110. return start, stop, step
  111. def _slice_from_buffers(self, start: int, stop: int) -> str:
  112. buffer_index = start // self.buffer_length
  113. buffer_end = (stop - 1) // self.buffer_length
  114. start_mod = start % self.buffer_length
  115. stop_mod = stop % self.buffer_length
  116. if stop_mod == 0 and stop > start:
  117. stop_mod = self.buffer_length
  118. if buffer_index == buffer_end:
  119. buffer = self.get_buffer(buffer_index)
  120. return buffer[start_mod:stop_mod]
  121. start_slice = self.get_buffer(buffer_index)[start_mod:]
  122. end_slice = self.get_buffer(buffer_end)[:stop_mod]
  123. middle_slices = [self.get_buffer(i) for i in range(buffer_index + 1, buffer_end)]
  124. return start_slice + "".join(middle_slices) + end_slice
  125. def __setitem__(self, index: int | slice, value: str) -> None: # pragma: no cover
  126. """
  127. Set a character or a slice of characters in the file.
  128. Args:
  129. index (slice): The slice of characters to set.
  130. value (str): The value to set at the specified index or slice.
  131. """
  132. start = index.start or 0 if isinstance(index, slice) else index or 0
  133. if start < 0:
  134. start += len(self)
  135. current_position = self.fd.tell()
  136. self.fd.seek(start)
  137. self.fd.write(value)
  138. self.fd.seek(current_position)
  139. def _ensure_chunk_position(self, chunk_index: int) -> None:
  140. """
  141. Ensure that we know the starting file position for the given chunk index.
  142. """
  143. while len(self._chunk_positions) <= chunk_index:
  144. prev_index = len(self._chunk_positions) - 1
  145. start_pos = self._chunk_positions[-1]
  146. self.fd.seek(start_pos, os.SEEK_SET)
  147. chunk = self.fd.read(self.buffer_length)
  148. end_pos = self.fd.tell()
  149. if len(chunk) < self.buffer_length:
  150. self.length = prev_index * self.buffer_length + len(chunk)
  151. self._chunk_positions.append(end_pos)
  152. if not chunk:
  153. break
  154. if len(self._chunk_positions) <= chunk_index:
  155. raise IndexError("Chunk index out of range")