local.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from __future__ import annotations
  2. import threading
  3. from contextlib import contextmanager
  4. from typing import TYPE_CHECKING
  5. if TYPE_CHECKING:
  6. from collections.abc import Iterator
  7. # Simple dynamic scoping implementation. The name "parametrize" comes
  8. # from Racket.
  9. #
  10. # WARNING WARNING: LOOKING TO EDIT THIS FILE? Think carefully about
  11. # why you need to add a toggle to the global behavior of code
  12. # generation. The parameters here should really only be used
  13. # for "temporary" situations, where we need to temporarily change
  14. # the codegen in some cases because we cannot conveniently update
  15. # all call sites, and are slated to be eliminated once all call
  16. # sites are eliminated. If you don't have a plan for how to get there,
  17. # DON'T add a new entry here.
  18. class Locals(threading.local):
  19. use_const_ref_for_mutable_tensors: bool | None = None
  20. use_ilistref_for_tensor_lists: bool | None = None
  21. _locals = Locals()
  22. def use_const_ref_for_mutable_tensors() -> bool:
  23. if _locals.use_const_ref_for_mutable_tensors is None:
  24. raise AssertionError(
  25. "need to initialize local.use_const_ref_for_mutable_tensors with "
  26. "local.parametrize"
  27. )
  28. return _locals.use_const_ref_for_mutable_tensors
  29. def use_ilistref_for_tensor_lists() -> bool:
  30. if _locals.use_ilistref_for_tensor_lists is None:
  31. raise AssertionError(
  32. "need to initialize local.use_ilistref_for_tensor_lists with local.parametrize"
  33. )
  34. return _locals.use_ilistref_for_tensor_lists
  35. @contextmanager
  36. def parametrize(
  37. *, use_const_ref_for_mutable_tensors: bool, use_ilistref_for_tensor_lists: bool
  38. ) -> Iterator[None]:
  39. old_use_const_ref_for_mutable_tensors = _locals.use_const_ref_for_mutable_tensors
  40. old_use_ilistref_for_tensor_lists = _locals.use_ilistref_for_tensor_lists
  41. try:
  42. _locals.use_const_ref_for_mutable_tensors = use_const_ref_for_mutable_tensors
  43. _locals.use_ilistref_for_tensor_lists = use_ilistref_for_tensor_lists
  44. yield
  45. finally:
  46. _locals.use_const_ref_for_mutable_tensors = (
  47. old_use_const_ref_for_mutable_tensors
  48. )
  49. _locals.use_ilistref_for_tensor_lists = old_use_ilistref_for_tensor_lists