pt_utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. import numpy as np
  2. import torch
  3. from torch.utils.data import Dataset, IterableDataset
  4. from ..utils.generic import ModelOutput
  5. class PipelineDataset(Dataset):
  6. def __init__(self, dataset, process, params):
  7. self.dataset = dataset
  8. self.process = process
  9. self.params = params
  10. def __len__(self):
  11. return len(self.dataset)
  12. def __getitem__(self, i):
  13. item = self.dataset[i]
  14. processed = self.process(item, **self.params)
  15. return processed
  16. class PipelineIterator(IterableDataset):
  17. def __init__(self, loader, infer, params, loader_batch_size=None):
  18. """
  19. Roughly equivalent to
  20. ```
  21. for item in loader:
  22. yield infer(item, **params)
  23. ```
  24. Arguments:
  25. loader (`torch.utils.data.DataLoader` or `Iterable`):
  26. The iterator that will be used to apply `infer` on.
  27. infer (any function):
  28. The function to apply of each element of `loader`.
  29. params (`dict`):
  30. The parameters passed to `infer` along with every item
  31. loader_batch_size (`int`, *optional*):
  32. If specified, the items of `loader` are supposed to come as batch, and are loader_batched here
  33. making it roughly behave as
  34. ```
  35. for items in loader:
  36. for i in loader_batch_size:
  37. item = items[i]
  38. yield infer(item, **params)
  39. ```"""
  40. self.loader = loader
  41. self.infer = infer
  42. self.params = params
  43. if loader_batch_size == 1:
  44. # Let's spare some time by deactivating altogether
  45. loader_batch_size = None
  46. self.loader_batch_size = loader_batch_size
  47. # Internal bookkeeping
  48. self._loader_batch_index = None
  49. self._loader_batch_data = None
  50. def __len__(self):
  51. return len(self.loader)
  52. def __iter__(self):
  53. self.iterator = iter(self.loader)
  54. return self
  55. def loader_batch_item(self):
  56. """
  57. Return item located at `loader_batch_index` within the current `loader_batch_data`.
  58. """
  59. if isinstance(self._loader_batch_data, torch.Tensor):
  60. # Batch data is simple tensor, just fetch the slice
  61. result = self._loader_batch_data[self._loader_batch_index].unsqueeze(0)
  62. else:
  63. # Batch data is assumed to be BaseModelOutput (or dict)
  64. loader_batched = {}
  65. for k, element in self._loader_batch_data.items():
  66. if isinstance(element, ModelOutput):
  67. # Convert ModelOutput to tuple first
  68. element = element.to_tuple()
  69. if isinstance(element[0], torch.Tensor):
  70. loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element)
  71. elif isinstance(element[0], np.ndarray):
  72. loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element)
  73. continue
  74. if k in {"hidden_states", "attentions"} and isinstance(element, tuple):
  75. # Those are stored as lists of tensors so need specific unbatching.
  76. if isinstance(element[0], torch.Tensor):
  77. loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element)
  78. elif isinstance(element[0], np.ndarray):
  79. loader_batched[k] = tuple(np.expand_dims(el[self._loader_batch_index], 0) for el in element)
  80. continue
  81. if k == "past_key_values":
  82. continue
  83. if element is None:
  84. # This can happen for optional data that get passed around
  85. loader_batched[k] = None
  86. elif isinstance(element[self._loader_batch_index], torch.Tensor):
  87. # Take correct batch data, but make it looked like batch_size=1
  88. # For compatibility with other methods within transformers
  89. loader_batched[k] = element[self._loader_batch_index].unsqueeze(0)
  90. elif isinstance(element[self._loader_batch_index], np.ndarray):
  91. # Take correct batch data, but make it looked like batch_size=1
  92. # For compatibility with other methods within transformers
  93. loader_batched[k] = np.expand_dims(element[self._loader_batch_index], 0)
  94. else:
  95. # This is typically a list, so no need to `unsqueeze`.
  96. loader_batched[k] = element[self._loader_batch_index]
  97. # Recreate the element by reusing the original class to make it look
  98. # batch_size=1
  99. result = self._loader_batch_data.__class__(loader_batched)
  100. self._loader_batch_index += 1
  101. return result
  102. def __next__(self):
  103. if self._loader_batch_index is not None and self._loader_batch_index < self.loader_batch_size:
  104. # We are currently unrolling a batch so we just need to return
  105. # the current item within a batch
  106. return self.loader_batch_item()
  107. # We're out of items within a batch
  108. item = next(self.iterator)
  109. processed = self.infer(item, **self.params)
  110. # We now have a batch of "inferred things".
  111. if self.loader_batch_size is not None:
  112. # Try to infer the size of the batch
  113. if isinstance(processed, torch.Tensor):
  114. first_tensor = processed
  115. elif isinstance(processed, tuple):
  116. first_tensor = processed[0]
  117. else:
  118. key = list(processed.keys())[0]
  119. first_tensor = processed[key]
  120. if isinstance(first_tensor, list):
  121. observed_batch_size = len(first_tensor)
  122. else:
  123. observed_batch_size = first_tensor.shape[0]
  124. if 0 < observed_batch_size < self.loader_batch_size:
  125. # could be last batch so we can't unroll as many
  126. # elements.
  127. self.loader_batch_size = observed_batch_size
  128. # Setting internal index to unwrap the batch
  129. self._loader_batch_data = processed[0] if isinstance(processed, tuple) else processed
  130. self._loader_batch_index = 0
  131. return self.loader_batch_item()
  132. else:
  133. # We're not unrolling batches
  134. return processed
  135. class PipelineChunkIterator(PipelineIterator):
  136. def __init__(self, loader, infer, params, loader_batch_size=None):
  137. """
  138. Roughly equivalent to
  139. ```
  140. for iterator in loader:
  141. for item in iterator:
  142. yield infer(item, **params)
  143. ```
  144. Arguments:
  145. loader (`torch.utils.data.DataLoader` or `Iterable`):
  146. The iterator that will be used to apply `infer` on.
  147. infer (any function):
  148. The function to apply of each element of `loader`.
  149. params (`dict`):
  150. The parameters passed to `infer` along with every item
  151. """
  152. super().__init__(loader, infer, params)
  153. def __iter__(self):
  154. self.iterator = iter(self.loader)
  155. self.subiterator = None
  156. return self
  157. def __next__(self):
  158. if self.subiterator is None:
  159. "Subiterator None means we haven't started a `preprocess` iterator. so start it"
  160. self.subiterator = self.infer(next(self.iterator), **self.params)
  161. try:
  162. # Try to return next item
  163. processed = next(self.subiterator)
  164. except StopIteration:
  165. # When a preprocess iterator ends, we can start looking at the next item
  166. # ChunkIterator will keep feeding until ALL elements of iterator
  167. # all have created their subiterator and have been iterating against.
  168. #
  169. # Another way to look at it, is we're basically flattening lists of lists
  170. # into a single list, but with generators
  171. self.subiterator = self.infer(next(self.iterator), **self.params)
  172. processed = next(self.subiterator)
  173. return processed
  174. class PipelinePackIterator(PipelineIterator):
  175. """
  176. Roughly equivalent to
  177. ```
  178. packed = []
  179. for item in loader:
  180. packed.append(item)
  181. if item["is_last"]:
  182. yield packed
  183. packed = []
  184. ```
  185. but it also handles cases where `item` are batched (meaning it's a dict of Tensor with first dimension > 1. In
  186. that case it does
  187. ```
  188. packed = []
  189. for batch in loader:
  190. # item is batched
  191. for item in batch:
  192. packed.append(item)
  193. if item["is_last"]:
  194. yield packed
  195. packed = []
  196. ```
  197. Arguments:
  198. loader (`torch.utils.data.DataLoader` or `Iterable`):
  199. The iterator that will be used to apply `infer` on.
  200. infer (any function):
  201. The function to apply of each element of `loader`.
  202. params (`dict`):
  203. The parameters passed to `infer` along with every item
  204. loader_batch_size (`int`, *optional*):
  205. If specified, the items of `loader` are supposed to come as batch, and are loader_batched here making
  206. it roughly behave as
  207. ```
  208. for items in loader:
  209. for i in loader_batch_size:
  210. item = items[i]
  211. yield infer(item, **params)
  212. ```"""
  213. def __iter__(self):
  214. self.iterator = iter(self.loader)
  215. return self
  216. def __next__(self):
  217. # Extremely similar to PipelineIterator in its unpacking mechanism
  218. # BUT, we have an extra required item which is the presence of `is_last`
  219. # That is because everything is flattened by `PipelineChunkIterator` we
  220. # need to keep track of how to regroup here in the original `process`
  221. # boundaries so that `process` and `postprocess` see the same data.
  222. # This iterator accumulates items (possibly while unbatching) until it
  223. # its a `is_last` and then just passes it on to the caller.
  224. is_last = False
  225. accumulator = []
  226. if self._loader_batch_index is not None and self._loader_batch_index < self.loader_batch_size:
  227. while self._loader_batch_index < self.loader_batch_size:
  228. item = self.loader_batch_item()
  229. is_last = item.pop("is_last")
  230. accumulator.append(item)
  231. if is_last:
  232. return accumulator
  233. while not is_last:
  234. processed = self.infer(next(self.iterator), **self.params)
  235. if self.loader_batch_size is not None:
  236. if isinstance(processed, torch.Tensor):
  237. first_tensor = processed
  238. else:
  239. key = list(processed.keys())[0]
  240. first_tensor = processed[key]
  241. if isinstance(first_tensor, list):
  242. observed_batch_size = len(first_tensor)
  243. else:
  244. observed_batch_size = first_tensor.shape[0]
  245. if 0 < observed_batch_size < self.loader_batch_size:
  246. # could be last batch so we can't unroll as many
  247. # elements.
  248. self.loader_batch_size = observed_batch_size
  249. self._loader_batch_data = processed
  250. self._loader_batch_index = 0
  251. while self._loader_batch_index < self.loader_batch_size:
  252. item = self.loader_batch_item()
  253. is_last = item.pop("is_last")
  254. accumulator.append(item)
  255. if is_last:
  256. return accumulator
  257. else:
  258. item = processed
  259. is_last = item.pop("is_last")
  260. accumulator.append(item)
  261. return accumulator
  262. class KeyDataset(Dataset):
  263. def __init__(self, dataset: Dataset, key: str):
  264. self.dataset = dataset
  265. self.key = key
  266. def __len__(self):
  267. return len(self.dataset)
  268. def __getitem__(self, i):
  269. return self.dataset[i][self.key]
  270. class KeyPairDataset(Dataset):
  271. def __init__(self, dataset: Dataset, key1: str, key2: str):
  272. self.dataset = dataset
  273. self.key1 = key1
  274. self.key2 = key2
  275. def __len__(self):
  276. return len(self.dataset)
  277. def __getitem__(self, i):
  278. return {"text": self.dataset[i][self.key1], "text_pair": self.dataset[i][self.key2]}