rnn.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # mypy: allow-untyped-defs
  2. import sys
  3. import torch._C
  4. import torch.cuda
  5. from torch.backends import (
  6. _get_fp32_precision_getter,
  7. _set_fp32_precision_setter,
  8. PropModule,
  9. )
  10. try:
  11. from torch._C import _cudnn
  12. except ImportError:
  13. # Uses of all the functions below should be guarded by torch.backends.cudnn.is_available(),
  14. # so it's safe to not emit any checks here.
  15. _cudnn = None # type: ignore[assignment]
  16. def get_cudnn_mode(mode):
  17. if mode == "RNN_RELU":
  18. # pyrefly: ignore [missing-attribute]
  19. return int(_cudnn.RNNMode.rnn_relu)
  20. elif mode == "RNN_TANH":
  21. # pyrefly: ignore [missing-attribute]
  22. return int(_cudnn.RNNMode.rnn_tanh)
  23. elif mode == "LSTM":
  24. # pyrefly: ignore [missing-attribute]
  25. return int(_cudnn.RNNMode.lstm)
  26. elif mode == "GRU":
  27. # pyrefly: ignore [missing-attribute]
  28. return int(_cudnn.RNNMode.gru)
  29. else:
  30. raise ValueError(f"Unknown mode: {mode}") # noqa: TRY002
  31. # NB: We don't actually need this class anymore (in fact, we could serialize the
  32. # dropout state for even better reproducibility), but it is kept for backwards
  33. # compatibility for old models.
  34. class Unserializable:
  35. def __init__(self, inner):
  36. self.inner = inner
  37. def get(self):
  38. return self.inner
  39. def __getstate__(self):
  40. # Note: can't return {}, because python2 won't call __setstate__
  41. # if the value evaluates to False
  42. return "<unserializable>"
  43. def __setstate__(self, state):
  44. self.inner = None
  45. # we would like to use ContextProp from backends here but the
  46. # frozen flags appears to be overzealous
  47. class ContextProp:
  48. def __init__(self, getter, setter):
  49. self.getter = getter
  50. self.setter = setter
  51. def __get__(self, obj, objtype):
  52. return self.getter()
  53. def __set__(self, obj, val):
  54. self.setter(val)
  55. def init_dropout_state(dropout, train, dropout_seed, dropout_state):
  56. dropout_desc_name = "desc_" + str(torch.cuda.current_device())
  57. dropout_p = dropout if train else 0
  58. if (dropout_desc_name not in dropout_state) or (
  59. dropout_state[dropout_desc_name].get() is None
  60. ):
  61. if dropout_p == 0:
  62. dropout_state[dropout_desc_name] = Unserializable(None)
  63. else:
  64. dropout_state[dropout_desc_name] = Unserializable(
  65. torch._cudnn_init_dropout_state( # type: ignore[call-arg]
  66. dropout_p,
  67. train,
  68. dropout_seed,
  69. # pyrefly: ignore [unexpected-keyword]
  70. self_ty=torch.uint8,
  71. device=torch.device("cuda"),
  72. )
  73. )
  74. dropout_ts = dropout_state[dropout_desc_name].get()
  75. return dropout_ts
  76. class CudnnRNNModule(PropModule):
  77. def __init__(self, m, name):
  78. super().__init__(m, name)
  79. self.m.Unserializable = Unserializable
  80. self.m.get_cudnn_mode = get_cudnn_mode
  81. self.m.init_dropout_state = init_dropout_state
  82. @staticmethod
  83. def init_dropout_state(dropout, train, dropout_seed, dropout_state):
  84. dropout_desc_name = "desc_" + str(torch.cuda.current_device())
  85. dropout_p = dropout if train else 0
  86. if (dropout_desc_name not in dropout_state) or (
  87. dropout_state[dropout_desc_name].get() is None
  88. ):
  89. if dropout_p == 0:
  90. dropout_state[dropout_desc_name] = Unserializable(None)
  91. else:
  92. dropout_state[dropout_desc_name] = Unserializable(
  93. torch._cudnn_init_dropout_state( # type: ignore[call-arg]
  94. dropout_p,
  95. train,
  96. dropout_seed,
  97. # pyrefly: ignore [unexpected-keyword]
  98. self_ty=torch.uint8,
  99. device=torch.device("cuda"),
  100. )
  101. )
  102. dropout_ts = dropout_state[dropout_desc_name].get()
  103. return dropout_ts
  104. fp32_precision = ContextProp(
  105. _get_fp32_precision_getter("cuda", "rnn"),
  106. _set_fp32_precision_setter("cuda", "rnn"),
  107. )
  108. sys.modules[__name__] = CudnnRNNModule(sys.modules[__name__], __name__)