code_context.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. """
  2. This module provides thread-safe code context management for TorchDynamo using weak references.
  3. The CodeContextDict class maintains a mapping between Python code objects and their associated
  4. context data, using weak references to automatically clean up entries when code objects are
  5. garbage collected. This prevents memory leaks while allowing context data to be associated
  6. with code objects throughout their lifecycle.
  7. Key features:
  8. - Thread-safe context storage and retrieval
  9. - Automatic cleanup using weak references
  10. - Safe context management for Python code objects
  11. - Memory-leak prevention
  12. Example usage:
  13. code_obj = compile('x = 1', '<string>', 'exec')
  14. # Store context
  15. context = code_context.get_context(code_obj)
  16. context['metadata'] = {'optimized': True}
  17. # Retrieve context
  18. if code_context.has_context(code_obj):
  19. ctx = code_context.get_context(code_obj)
  20. # Use context data...
  21. # Remove context
  22. ctx = code_context.pop_context(code_obj)
  23. """
  24. import types
  25. from typing import Any
  26. from .utils import ExactWeakKeyDictionary
  27. class CodeContextDict:
  28. def __init__(self) -> None:
  29. self.code_context: ExactWeakKeyDictionary = ExactWeakKeyDictionary()
  30. def has_context(self, code: types.CodeType) -> bool:
  31. return code in self.code_context
  32. def get_context(self, code: types.CodeType) -> dict[str, Any]:
  33. ctx = self.code_context.get(code)
  34. if ctx is None:
  35. # pyrefly: ignore [implicit-any]
  36. ctx = {}
  37. self.code_context[code] = ctx
  38. return ctx
  39. def pop_context(self, code: types.CodeType) -> dict[str, Any]:
  40. ctx = self.get_context(code)
  41. self.code_context._remove_id(id(code))
  42. return ctx
  43. def clear(self) -> None:
  44. self.code_context.clear()
  45. code_context: CodeContextDict = CodeContextDict()