microbatch.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import logging
  4. import operator
  5. from collections.abc import Sequence
  6. from typing import Any
  7. import torch
  8. from torch.fx.node import map_aggregate
  9. from torch.nn.attention.flex_attention import BlockMask
  10. from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
  11. __all__ = [
  12. "TensorChunkSpec",
  13. "split_args_kwargs_into_chunks",
  14. "merge_chunks",
  15. ]
  16. logger = logging.getLogger(__name__)
  17. """
  18. _debug_mask_minibatches specifies to send masked versions of the mini-batch
  19. through instead of micro-batch slices--this can be used for more stable
  20. numerical testing (see [A Note About Correctness Testing])
  21. """
  22. _debug_mask_minibatches = False
  23. class _CustomReducer:
  24. """
  25. Custom reducer class that can be used to specify a custom operation that
  26. reduces losses of multiple microbatches into one value.
  27. Example:
  28. >>> # xdoctest: +SKIP
  29. >>> sum_reducer = _CustomReducer(
  30. >>> torch.tensor(0.0),
  31. >>> lambda a, b: a + b
  32. >>> )
  33. """
  34. def __init__(self, init_value, reduce_fn):
  35. self.init_value = init_value
  36. self.reduce_fn = reduce_fn
  37. class _LossReducer(_CustomReducer):
  38. pass
  39. sum_reducer = _LossReducer(torch.tensor(0.0), operator.add)
  40. # Default chunking dimension is 0. This is used for the case where the user did
  41. # not specify a chunking dimension.
  42. DEFAULT_CHUNK_DIM = 0
  43. class TensorChunkSpec:
  44. """
  45. Class used to specify chunking of inputs
  46. """
  47. def __init__(self, split_dim):
  48. self.split_dim = split_dim
  49. split_dim: int
  50. def __repr__(self):
  51. return (
  52. f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})"
  53. )
  54. def __str__(self):
  55. return f"TensorChunkSpec({self.split_dim})"
  56. @staticmethod
  57. def from_tuple(
  58. chunk_dims: tuple[int, ...],
  59. ):
  60. """
  61. A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk
  62. dimensions (int's).
  63. Example:
  64. >>> # xdoctest: +SKIP
  65. >>> # There are three positional arguments to the model, and
  66. >>> # we are chunking them along dimension 0, 0 and 1, respectively
  67. >>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1))
  68. """
  69. args_chunk_spec = map_aggregate(
  70. chunk_dims,
  71. lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
  72. )
  73. return args_chunk_spec
  74. @staticmethod
  75. def from_dict(
  76. chunk_dims: dict[str, int],
  77. ):
  78. """
  79. A helper for creating a dictionary of `TensorChunkSpec` from a
  80. dictionary of chunk dimensions (int's).
  81. Example:
  82. >>> # xdoctest: +SKIP
  83. >>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument
  84. >>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1})
  85. """
  86. kwargs_chunk_spec = map_aggregate(
  87. chunk_dims,
  88. lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value]
  89. )
  90. return kwargs_chunk_spec
  91. # Class used to specify replication of inputs
  92. class _Replicate:
  93. pass
  94. def _split_block_mask(
  95. block_mask: BlockMask,
  96. num_chunks: int,
  97. ) -> list[BlockMask]:
  98. """Given a block mask, split the block mask along the batch dimension (dim0).
  99. Args:
  100. block_mask: Block mask to split
  101. num_chunks: Number of chunks to split the block mask into
  102. Returns:
  103. chunk_block_masks: List of chunked block masks
  104. """
  105. # BlockMask will broadcast if B is 1.
  106. if block_mask.kv_num_blocks.size(0) == 1:
  107. return [block_mask] * num_chunks
  108. if not block_mask.kv_num_blocks.size(0) >= num_chunks:
  109. raise AssertionError(
  110. "Block mask has fewer batch size than the number of chunks. "
  111. )
  112. batch_dim = 0
  113. kv_num_blocks_chunks = torch.tensor_split(
  114. block_mask.kv_num_blocks, num_chunks, batch_dim
  115. )
  116. kv_indices_chunks = torch.tensor_split(block_mask.kv_indices, num_chunks, batch_dim)
  117. full_kv_num_blocks_chunks = (
  118. torch.tensor_split(block_mask.full_kv_num_blocks, num_chunks, batch_dim)
  119. if block_mask.full_kv_num_blocks is not None
  120. else [None] * num_chunks
  121. )
  122. full_kv_indices_chunks = (
  123. torch.tensor_split(block_mask.full_kv_indices, num_chunks, batch_dim)
  124. if block_mask.full_kv_indices is not None
  125. else [None] * num_chunks
  126. )
  127. chunk_block_masks = []
  128. batch_offset = 0
  129. for chunk_idx in range(num_chunks):
  130. def create_mask_mod(idx):
  131. def batch_offset_mask_mod(b, h, q_idx, kv_idx):
  132. b_offset = torch.full_like(b, idx)
  133. return block_mask.mask_mod(b + b_offset, h, q_idx, kv_idx)
  134. return batch_offset_mask_mod
  135. chunk_block_masks.append(
  136. BlockMask.from_kv_blocks(
  137. kv_num_blocks=kv_num_blocks_chunks[chunk_idx],
  138. kv_indices=kv_indices_chunks[chunk_idx],
  139. full_kv_num_blocks=full_kv_num_blocks_chunks[chunk_idx],
  140. full_kv_indices=full_kv_indices_chunks[chunk_idx],
  141. BLOCK_SIZE=block_mask.BLOCK_SIZE,
  142. mask_mod=create_mask_mod(batch_offset),
  143. seq_lengths=block_mask.seq_lengths,
  144. )
  145. )
  146. batch_offset += kv_num_blocks_chunks[chunk_idx].size(0)
  147. return chunk_block_masks
  148. def _split_tensor(
  149. tensor: torch.Tensor,
  150. spec: TensorChunkSpec,
  151. num_chunks: int,
  152. ) -> Sequence[torch.Tensor]:
  153. """Given a tensor, and a chunking spec, split the tensor.
  154. Args:
  155. tensor: Tensor to split
  156. spec: Chunking spec
  157. num_chunks: Number of chunks to split the tensor into
  158. Returns:
  159. chunk_tensors: List of chunked tensors
  160. """
  161. if not tensor.size(spec.split_dim) >= num_chunks:
  162. raise AssertionError(
  163. f"Tensor size {tensor.size(spec.split_dim)} is smaller than num_chunks"
  164. )
  165. chunk_tensors = torch.tensor_split(tensor, num_chunks, spec.split_dim)
  166. if not _debug_mask_minibatches:
  167. return chunk_tensors
  168. expanded_chunks = []
  169. split_dim_idx = 0
  170. for chunk_tensor in chunk_tensors:
  171. new_val = torch.zeros_like(tensor)
  172. upper_idx = split_dim_idx + chunk_tensor.size(spec.split_dim)
  173. slice_indices = [slice(None, None, None)] * new_val.ndim
  174. slice_indices[spec.split_dim] = slice(split_dim_idx, upper_idx)
  175. new_val[slice_indices] = chunk_tensor
  176. expanded_chunks.append(new_val)
  177. split_dim_idx += chunk_tensor.size(spec.split_dim)
  178. return expanded_chunks
  179. def _shard_dict_of_args(
  180. args_dict,
  181. args_chunk_spec,
  182. num_chunks,
  183. ):
  184. """
  185. Given a dictionary of args, and a dictionary of chunking specs, shard the
  186. args according to the chunking specs.
  187. Args:
  188. args_dict: Dictionary of args
  189. args_chunk_spec: Dictionary of chunking specs
  190. num_chunks: Number of chunks to shard the args into
  191. Returns:
  192. args_split: List of sharded args
  193. """
  194. if not args_dict:
  195. return [{} for _ in range(num_chunks)]
  196. if not len(args_dict) == len(args_chunk_spec):
  197. raise AssertionError(
  198. f"args_dict.keys() = {list(args_dict.keys())} "
  199. f"args_chunk_spec.keys() = {list(args_chunk_spec.keys())}"
  200. )
  201. if args_chunk_spec is None:
  202. raise AssertionError("args_chunk_spec should have been set by caller")
  203. values, tree_spec = tree_flatten(
  204. args_dict, is_leaf=lambda x: isinstance(x, BlockMask)
  205. )
  206. chunk_specs, _ = tree_flatten(
  207. args_chunk_spec, is_leaf=lambda x: isinstance(x, BlockMask)
  208. )
  209. # First check and find the actual number of chunks
  210. split_sizes = []
  211. for v, spec in zip(values, chunk_specs, strict=True):
  212. # The original logic is "spec is _Replicate". This doesn't seem to be
  213. # correct. But we keep it for backward compatibility.
  214. if spec is _Replicate or isinstance(spec, _Replicate):
  215. split_sizes.append(num_chunks)
  216. elif isinstance(v, torch.Tensor):
  217. if not isinstance(spec, TensorChunkSpec):
  218. raise AssertionError(f"Expected TensorChunkSpec, got {type(spec)}")
  219. split_sizes.append(v.size(spec.split_dim))
  220. elif isinstance(v, BlockMask):
  221. if not isinstance(spec, TensorChunkSpec):
  222. raise AssertionError(f"Expected TensorChunkSpec, got {type(spec)}")
  223. if not spec.split_dim == 0:
  224. raise AssertionError("BlockMask only supports split_dim=0")
  225. # BlockMask will broadcast if B is 1.
  226. if v.kv_num_blocks.size(0) == 1:
  227. split_sizes.append(num_chunks)
  228. else:
  229. split_sizes.append(v.kv_num_blocks.size(0))
  230. else:
  231. raise ValueError(
  232. f"Unsupported chunk spec: {spec} and value: {v} combination."
  233. )
  234. result_num_chunks = min(*split_sizes, num_chunks)
  235. flat_split_results: list[Any] = [[] for _ in range(result_num_chunks)]
  236. for v, spec in zip(values, chunk_specs, strict=True):
  237. v_splits: Sequence[Any] = []
  238. if spec is _Replicate or isinstance(spec, _Replicate):
  239. v_splits = [v] * result_num_chunks
  240. elif isinstance(v, torch.Tensor):
  241. v_splits = _split_tensor(v, spec, result_num_chunks)
  242. elif isinstance(v, BlockMask):
  243. v_splits = _split_block_mask(v, result_num_chunks)
  244. else:
  245. raise ValueError(
  246. f"Unsupported chunk spec: {spec} and value: {v} combination."
  247. )
  248. for _flat_split_result, _v_split in zip(
  249. flat_split_results, v_splits, strict=True
  250. ):
  251. _flat_split_result.append(_v_split)
  252. return [
  253. tree_unflatten(_flat_split_result, tree_spec)
  254. for _flat_split_result in flat_split_results
  255. ]
  256. def split_args_kwargs_into_chunks(
  257. args: tuple[Any, ...],
  258. kwargs: dict[str, Any] | None,
  259. chunks: int,
  260. args_chunk_spec: tuple[TensorChunkSpec, ...] | None = None,
  261. kwargs_chunk_spec: dict[str, TensorChunkSpec] | None = None,
  262. ) -> tuple[list[tuple], list[dict]]:
  263. """
  264. Given a sequence of args and kwargs, split them into a number of chunks
  265. according to their respective chunking specs.
  266. Args:
  267. args: Tuple of args
  268. kwargs: Dict of kwargs
  269. chunks: Number of chunks to split the args and kwargs into
  270. args_chunk_spec: chunking specs for args, in same shape as args
  271. kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs
  272. Returns:
  273. args_split: List of sharded args
  274. kwargs_split: List of sharded kwargs
  275. """
  276. # Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that
  277. # the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec`
  278. # and `kwargs_chunk_spec` specifications. The steps are as follows:
  279. #
  280. # 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values.
  281. # To use a running example: suppose our inputs look like
  282. #
  283. # args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None)
  284. # (kwargs not shown but it's a similar process)
  285. #
  286. # Then for this step we would end up with
  287. #
  288. # args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None)
  289. #
  290. # 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2
  291. #
  292. # args = ([[A, A], [B, B], [C_1, C_2]], [D, D])
  293. #
  294. # 3. Rotate the nesting order such that chunks are the outer dimension
  295. #
  296. # args_chunks = [
  297. # ([A, B, C_1], D),
  298. # ([A, B, C_2], D),
  299. # ]
  300. #
  301. # 4. Unflatten each chunk according to the spec
  302. #
  303. # args_chunks = [
  304. # ([A, [B, C_1]], D),
  305. # ([A, [B, C_2]], D),
  306. # ]
  307. # TODO: _debug_mask_minibatches
  308. # Handle the case where kwargs is None
  309. if kwargs is None:
  310. kwargs = {}
  311. # If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend
  312. # their format and use default chunking along dim 0
  313. def default_spec(v):
  314. if isinstance(v, torch.Tensor | BlockMask):
  315. return TensorChunkSpec(DEFAULT_CHUNK_DIM)
  316. else:
  317. return _Replicate()
  318. if args_chunk_spec is None:
  319. args_chunk_spec = tree_map(
  320. default_spec, args, is_leaf=lambda v: isinstance(v, BlockMask)
  321. )
  322. if kwargs_chunk_spec is None:
  323. kwargs_chunk_spec = tree_map(
  324. default_spec, kwargs, is_leaf=lambda v: isinstance(v, BlockMask)
  325. )
  326. args_split_dict = _shard_dict_of_args(
  327. dict(enumerate(args)),
  328. dict(enumerate(args_chunk_spec)),
  329. chunks,
  330. )
  331. real_num_chunks = len(args_split_dict)
  332. kwargs_split = _shard_dict_of_args(
  333. kwargs,
  334. kwargs_chunk_spec,
  335. real_num_chunks,
  336. )
  337. if len(kwargs_split) < real_num_chunks:
  338. # In case kwargs are sharded into less chunks
  339. # e.g. when `args` has no tensor, just values
  340. real_num_chunks = len(kwargs_split)
  341. # Re-shard args
  342. args_split_dict = _shard_dict_of_args(
  343. dict(enumerate(args)),
  344. dict(enumerate(args_chunk_spec)),
  345. real_num_chunks,
  346. )
  347. if len(args_split_dict) != len(kwargs_split):
  348. raise RuntimeError(
  349. "args and kwargs are split into different number of chunks: "
  350. f"{len(args_split_dict)}, {len(kwargs_split)}"
  351. )
  352. args_split = [
  353. tuple(chunk_args[i] for i in range(len(chunk_args)))
  354. for chunk_args in args_split_dict
  355. ]
  356. return args_split, kwargs_split
  357. def merge_chunks(
  358. chunks: list[Any],
  359. chunk_spec,
  360. ):
  361. """
  362. Given a list of chunks, merge them into a single value according to
  363. the chunk spec.
  364. Args:
  365. chunks: list of chunks
  366. chunk_spec: Chunking spec for the chunks
  367. Returns:
  368. value: Merged value
  369. """
  370. # This is essentially the inverse of `split_args_kwargs_into_chunks`, so the
  371. # steps are similar to the steps in that function but in reverse. Given the
  372. # input values:
  373. #
  374. # chunks = [
  375. # ([A, [B, C_1]], D),
  376. # ([A, [B, C_2]], D),
  377. # ]
  378. # args_spec = ([None, [None, TensorChunkSpec]], None)
  379. #
  380. # 1. Flatten the chunks according to the chunk_spec
  381. #
  382. # chunks_flat = [
  383. # ([A, B, C_1], D),
  384. # ([A, B, C_2], D),
  385. # ]
  386. #
  387. # 2. Rotate the nesting order such that chunks are the inner dimension
  388. #
  389. # value_inner = ([A, B, [C_1, C_2]], D)
  390. #
  391. # 3. Concatenate sharded arguments
  392. #
  393. # value_combined = ([A, B, C], D)
  394. #
  395. # 4. Unflatten the combined args given the spec
  396. #
  397. # value = ([A, [B, C]], D)
  398. # Preliminary: flatten the chunk spec
  399. if chunk_spec is not None:
  400. spec_flattened, flatten_spec = tree_flatten(chunk_spec)
  401. else:
  402. # If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields
  403. # We obtain the output structure by flattening chunk 0 and generate the chunk_spec
  404. chunk0_flat, flatten_spec = tree_flatten(chunks[0])
  405. spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat)
  406. # Stage 1: flatten chunks
  407. # chunks_flattened : [num chunks, num args]
  408. chunks_flattened = []
  409. for chunk in chunks:
  410. chunk_flattened, _ = tree_flatten(chunk)
  411. if len(chunk_flattened) != len(spec_flattened):
  412. raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}")
  413. chunks_flattened.append(chunk_flattened)
  414. # Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and
  415. # concatenate sharded operands
  416. # args_flattened : [num args]
  417. args_flattened = []
  418. for arg_idx, arg in enumerate(spec_flattened):
  419. if isinstance(arg, TensorChunkSpec):
  420. partial_values = [
  421. chunks_flattened[chunk_idx][arg_idx]
  422. for chunk_idx in range(len(chunks_flattened))
  423. ]
  424. if _debug_mask_minibatches:
  425. # Infer size of individual chunks by running `tensor_split` again
  426. overall_shape = partial_values[0].shape
  427. for val in partial_values[1:]:
  428. if not val.shape == overall_shape:
  429. raise AssertionError(
  430. f"Expected shape {overall_shape}, got {val.shape}"
  431. )
  432. meta_chunks = torch.tensor_split(
  433. torch.empty(*overall_shape, device="meta"),
  434. sections=len(partial_values),
  435. dim=arg.split_dim,
  436. )
  437. values_to_cat = []
  438. chunk_start_idx = 0
  439. if not len(partial_values) == len(meta_chunks):
  440. raise AssertionError(
  441. f"Expected len(partial_values) == len(meta_chunks), got {len(partial_values)} != {len(meta_chunks)}"
  442. )
  443. for partial_value, meta_chunk in zip(
  444. partial_values, meta_chunks, strict=True
  445. ):
  446. chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim)
  447. slice_indices = [slice(None, None, None)] * partial_value.ndim
  448. slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx)
  449. sliced = partial_value[slice_indices]
  450. values_to_cat.append(sliced)
  451. chunk_start_idx = chunk_end_idx
  452. else:
  453. values_to_cat = partial_values
  454. args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim))
  455. elif isinstance(arg, _CustomReducer):
  456. reduced_val = arg.init_value
  457. for chunk_idx in range(len(chunks_flattened)):
  458. reduced_val = arg.reduce_fn(
  459. reduced_val, chunks_flattened[chunk_idx][arg_idx]
  460. )
  461. args_flattened.append(reduced_val)
  462. else:
  463. value = chunks_flattened[0][arg_idx]
  464. for chunk_idx in range(1, len(chunks_flattened)):
  465. if not chunks_flattened[chunk_idx][arg_idx] == value:
  466. raise AssertionError(
  467. f"Expected {value}, got {chunks_flattened[chunk_idx][arg_idx]}"
  468. )
  469. args_flattened.append(value)
  470. # Stage 4: Unflatten combined args
  471. return tree_unflatten(args_flattened, flatten_spec)