__init__.py 656 B

12345678910111213141516171819202122232425262728293031
  1. from torch._functorch.apis import grad, grad_and_value, vmap
  2. from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
  3. from torch._functorch.eager_transforms import (
  4. debug_unwrap,
  5. functionalize,
  6. hessian,
  7. jacfwd,
  8. jacrev,
  9. jvp,
  10. linearize,
  11. vjp,
  12. )
  13. from torch._functorch.functional_call import functional_call, stack_module_state
  14. __all__ = [
  15. "grad",
  16. "grad_and_value",
  17. "vmap",
  18. "replace_all_batch_norm_modules_",
  19. "functionalize",
  20. "hessian",
  21. "jacfwd",
  22. "jacrev",
  23. "jvp",
  24. "linearize",
  25. "vjp",
  26. "functional_call",
  27. "stack_module_state",
  28. "debug_unwrap",
  29. ]