gen_autograd.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. """
  2. To run this file by hand from the root of the PyTorch
  3. repository, run:
  4. python -m tools.autograd.gen_autograd \
  5. aten/src/ATen/native/native_functions.yaml \
  6. aten/src/ATen/native/tags.yaml \
  7. $OUTPUT_DIR \
  8. tools/autograd
  9. Where $OUTPUT_DIR is where you would like the files to be
  10. generated. In the full build system, OUTPUT_DIR is
  11. torch/csrc/autograd/generated/
  12. """
  13. # gen_autograd.py generates C++ autograd functions and Python bindings.
  14. #
  15. # It delegates to the following scripts:
  16. #
  17. # gen_autograd_functions.py: generates subclasses of torch::autograd::Node
  18. # gen_variable_type.py: generates VariableType.h which contains all tensor methods
  19. # gen_python_functions.py: generates Python bindings to THPVariable
  20. #
  21. from __future__ import annotations
  22. import argparse
  23. import os
  24. from torchgen.api import cpp
  25. from torchgen.api.autograd import (
  26. match_differentiability_info,
  27. NativeFunctionWithDifferentiabilityInfo,
  28. )
  29. from torchgen.gen import parse_native_yaml
  30. from torchgen.selective_build.selector import SelectiveBuilder
  31. from . import gen_python_functions
  32. from .gen_autograd_functions import (
  33. gen_autograd_functions_lib,
  34. gen_autograd_functions_python,
  35. )
  36. from .gen_inplace_or_view_type import gen_inplace_or_view_type
  37. from .gen_trace_type import gen_trace_type
  38. from .gen_variable_factories import gen_variable_factories
  39. from .gen_variable_type import gen_variable_type
  40. from .gen_view_funcs import gen_view_funcs
  41. from .load_derivatives import load_derivatives
  42. def gen_autograd(
  43. native_functions_path: str,
  44. tags_path: str,
  45. out: str,
  46. autograd_dir: str,
  47. operator_selector: SelectiveBuilder,
  48. disable_autograd: bool = False,
  49. ) -> None:
  50. # Parse and load derivatives.yaml
  51. differentiability_infos, used_dispatch_keys = load_derivatives(
  52. os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path
  53. )
  54. template_path = os.path.join(autograd_dir, "templates")
  55. native_funcs = parse_native_yaml(native_functions_path, tags_path).native_functions
  56. fns = sorted(
  57. filter(
  58. operator_selector.is_native_function_selected_for_training, native_funcs
  59. ),
  60. key=lambda f: cpp.name(f.func),
  61. )
  62. fns_with_diff_infos: list[NativeFunctionWithDifferentiabilityInfo] = (
  63. match_differentiability_info(fns, differentiability_infos)
  64. )
  65. # Generate VariableType.h/cpp
  66. if not disable_autograd:
  67. gen_variable_type(
  68. out,
  69. native_functions_path,
  70. tags_path,
  71. fns_with_diff_infos,
  72. template_path,
  73. used_dispatch_keys,
  74. )
  75. gen_inplace_or_view_type(
  76. out, native_functions_path, tags_path, fns_with_diff_infos, template_path
  77. )
  78. # operator filter not applied as tracing sources are excluded in selective build
  79. gen_trace_type(out, native_funcs, template_path)
  80. # Generate Functions.h/cpp
  81. gen_autograd_functions_lib(out, differentiability_infos, template_path)
  82. # Generate variable_factories.h
  83. gen_variable_factories(out, native_functions_path, tags_path, template_path)
  84. # Generate ViewFuncs.h/cpp
  85. gen_view_funcs(out, fns_with_diff_infos, template_path)
  86. def gen_autograd_python(
  87. native_functions_path: str,
  88. tags_path: str,
  89. out: str,
  90. autograd_dir: str,
  91. ) -> None:
  92. differentiability_infos, _ = load_derivatives(
  93. os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path
  94. )
  95. template_path = os.path.join(autograd_dir, "templates")
  96. # Generate Functions.h/cpp
  97. gen_autograd_functions_python(out, differentiability_infos, template_path)
  98. # Generate Python bindings
  99. deprecated_path = os.path.join(autograd_dir, "deprecated.yaml")
  100. gen_python_functions.gen(
  101. out, native_functions_path, tags_path, deprecated_path, template_path
  102. )
  103. def main() -> None:
  104. parser = argparse.ArgumentParser(description="Generate autograd C++ files script")
  105. parser.add_argument(
  106. "native_functions", metavar="NATIVE", help="path to native_functions.yaml"
  107. )
  108. parser.add_argument("tags", metavar="NATIVE", help="path to tags.yaml")
  109. parser.add_argument("out", metavar="OUT", help="path to output directory")
  110. parser.add_argument(
  111. "autograd", metavar="AUTOGRAD", help="path to autograd directory"
  112. )
  113. args = parser.parse_args()
  114. gen_autograd(
  115. args.native_functions,
  116. args.tags,
  117. args.out,
  118. args.autograd,
  119. SelectiveBuilder.get_nop_selector(),
  120. )
  121. if __name__ == "__main__":
  122. main()