| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182 |
- import os
- from typing import TextIO
- class StringFileWrapper:
- # This is a trick to simplify the code, transform the filedescriptor handling into a string handling
- def __init__(self, fd: TextIO, chunk_length: int) -> None:
- """
- Initialize the StringFileWrapper with a file descriptor and chunk length.
- Args:
- fd (TextIO): The file descriptor to wrap.
- CHUNK_LENGTH (int): The length of each chunk to read from the file.
- Attributes:
- fd (TextIO): The wrapped file descriptor.
- length (int): The total length of the file content.
- buffers (dict[int, str]): Dictionary to store chunks of file content.
- buffer_length (int): The length of each buffer chunk.
- """
- self.fd = fd
- # Buffers are chunks of text read from the file and cached to reduce disk access.
- self.buffers: dict[int, str] = {}
- if not chunk_length or chunk_length < 2:
- chunk_length = 1_000_000
- # chunk_length now refers to the number of characters per chunk.
- self.buffer_length = chunk_length
- # Keep track of the starting file position ("cookie") for each chunk so we can
- # seek safely without landing in the middle of a multibyte code point.
- self._chunk_positions: list[int] = [0]
- self.length: int | None = None
- def get_buffer(self, index: int) -> str:
- """
- Retrieve or load a buffer chunk from the file.
- Args:
- index (int): The index of the buffer chunk to retrieve.
- Returns:
- str: The buffer chunk at the specified index.
- """
- if index < 0:
- raise IndexError("Negative indexing is not supported")
- cached = self.buffers.get(index)
- if cached is not None:
- return cached
- self._ensure_chunk_position(index)
- start_pos = self._chunk_positions[index]
- self.fd.seek(start_pos)
- chunk = self.fd.read(self.buffer_length)
- if not chunk:
- raise IndexError("Chunk index out of range")
- end_pos = self.fd.tell()
- if len(self._chunk_positions) <= index + 1:
- self._chunk_positions.append(end_pos)
- if len(chunk) < self.buffer_length:
- self.length = index * self.buffer_length + len(chunk)
- self.buffers[index] = chunk
- # Save memory by keeping max 2MB buffer chunks and min 2 chunks
- max_buffers = max(2, int(2_000_000 / self.buffer_length))
- if len(self.buffers) > max_buffers:
- oldest_key = next(iter(self.buffers))
- if oldest_key != index:
- self.buffers.pop(oldest_key)
- return chunk
- def __getitem__(self, index: int | slice) -> str:
- """
- Retrieve a character or a slice of characters from the file.
- Args:
- index (Union[int, slice]): The index or slice of characters to retrieve.
- Returns:
- str: The character(s) at the specified index or slice.
- """
- # The buffer is an array that is seek like a RAM:
- # self.buffers[index]: the row in the array of length 1MB, index is `i` modulo CHUNK_LENGTH
- # self.buffures[index][j]: the column of the row that is `i` remainder CHUNK_LENGTH
- if isinstance(index, slice):
- start, stop, step = self._normalize_slice(index)
- if step == 0:
- raise ValueError("slice step cannot be zero")
- if step != 1:
- return "".join(self[i] for i in range(start, stop, step))
- if start >= stop:
- return ""
- return self._slice_from_buffers(start, stop)
- if index < 0:
- index += len(self)
- if index < 0:
- raise IndexError("string index out of range")
- buffer_index = index // self.buffer_length
- buffer = self.get_buffer(buffer_index)
- return buffer[index % self.buffer_length]
- def __len__(self) -> int:
- """
- Get the total length of the file.
- Returns:
- int: The total number of characters in the file.
- """
- if self.length is None:
- while self.length is None:
- chunk_index = len(self._chunk_positions)
- self._ensure_chunk_position(chunk_index)
- assert self.length is not None
- return self.length
- def _normalize_slice(self, index: slice) -> tuple[int, int, int]:
- total_len = len(self)
- start = 0 if index.start is None else index.start
- stop = total_len if index.stop is None else index.stop
- step = 1 if index.step is None else index.step
- if start < 0:
- start += total_len
- if stop < 0:
- stop += total_len
- start = max(start, 0)
- stop = min(stop, total_len)
- return start, stop, step
- def _slice_from_buffers(self, start: int, stop: int) -> str:
- buffer_index = start // self.buffer_length
- buffer_end = (stop - 1) // self.buffer_length
- start_mod = start % self.buffer_length
- stop_mod = stop % self.buffer_length
- if stop_mod == 0 and stop > start:
- stop_mod = self.buffer_length
- if buffer_index == buffer_end:
- buffer = self.get_buffer(buffer_index)
- return buffer[start_mod:stop_mod]
- start_slice = self.get_buffer(buffer_index)[start_mod:]
- end_slice = self.get_buffer(buffer_end)[:stop_mod]
- middle_slices = [self.get_buffer(i) for i in range(buffer_index + 1, buffer_end)]
- return start_slice + "".join(middle_slices) + end_slice
- def __setitem__(self, index: int | slice, value: str) -> None: # pragma: no cover
- """
- Set a character or a slice of characters in the file.
- Args:
- index (slice): The slice of characters to set.
- value (str): The value to set at the specified index or slice.
- """
- start = index.start or 0 if isinstance(index, slice) else index or 0
- if start < 0:
- start += len(self)
- current_position = self.fd.tell()
- self.fd.seek(start)
- self.fd.write(value)
- self.fd.seek(current_position)
- def _ensure_chunk_position(self, chunk_index: int) -> None:
- """
- Ensure that we know the starting file position for the given chunk index.
- """
- while len(self._chunk_positions) <= chunk_index:
- prev_index = len(self._chunk_positions) - 1
- start_pos = self._chunk_positions[-1]
- self.fd.seek(start_pos, os.SEEK_SET)
- chunk = self.fd.read(self.buffer_length)
- end_pos = self.fd.tell()
- if len(chunk) < self.buffer_length:
- self.length = prev_index * self.buffer_length + len(chunk)
- self._chunk_positions.append(end_pos)
- if not chunk:
- break
- if len(self._chunk_positions) <= chunk_index:
- raise IndexError("Chunk index out of range")
|