misc.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. """ Code adapted from https://github.com/ikostrikov/pytorch-a3c"""
  2. from typing import Any, List, Tuple, Union
  3. import numpy as np
  4. from ray.rllib.models.utils import get_activation_fn
  5. from ray.rllib.utils.annotations import DeveloperAPI
  6. from ray.rllib.utils.framework import try_import_torch
  7. from ray.rllib.utils.typing import TensorType
  8. torch, nn = try_import_torch()
  9. @DeveloperAPI
  10. def normc_initializer(std: float = 1.0) -> Any:
  11. def initializer(tensor):
  12. tensor.data.normal_(0, 1)
  13. tensor.data *= std / torch.sqrt(tensor.data.pow(2).sum(1, keepdim=True))
  14. return initializer
  15. @DeveloperAPI
  16. def same_padding(
  17. in_size: Tuple[int, int],
  18. filter_size: Union[int, Tuple[int, int]],
  19. stride_size: Union[int, Tuple[int, int]],
  20. ) -> (Union[int, Tuple[int, int]], Tuple[int, int]):
  21. """Note: Padding is added to match TF conv2d `same` padding.
  22. See www.tensorflow.org/versions/r0.12/api_docs/python/nn/convolution
  23. Args:
  24. in_size: Rows (Height), Column (Width) for input
  25. stride_size (Union[int,Tuple[int, int]]): Rows (Height), column (Width)
  26. for stride. If int, height == width.
  27. filter_size: Rows (Height), column (Width) for filter
  28. Returns:
  29. padding: For input into torch.nn.ZeroPad2d.
  30. output: Output shape after padding and convolution.
  31. """
  32. in_height, in_width = in_size
  33. if isinstance(filter_size, int):
  34. filter_height, filter_width = filter_size, filter_size
  35. else:
  36. filter_height, filter_width = filter_size
  37. if isinstance(stride_size, (int, float)):
  38. stride_height, stride_width = int(stride_size), int(stride_size)
  39. else:
  40. stride_height, stride_width = int(stride_size[0]), int(stride_size[1])
  41. out_height = int(np.ceil(float(in_height) / float(stride_height)))
  42. out_width = int(np.ceil(float(in_width) / float(stride_width)))
  43. pad_along_height = int((out_height - 1) * stride_height + filter_height - in_height)
  44. pad_along_width = int((out_width - 1) * stride_width + filter_width - in_width)
  45. pad_top = pad_along_height // 2
  46. pad_bottom = pad_along_height - pad_top
  47. pad_left = pad_along_width // 2
  48. pad_right = pad_along_width - pad_left
  49. padding = (pad_left, pad_right, pad_top, pad_bottom)
  50. output = (out_height, out_width)
  51. return padding, output
  52. @DeveloperAPI
  53. def same_padding_transpose_after_stride(
  54. strided_size: Tuple[int, int],
  55. kernel: Tuple[int, int],
  56. stride: Union[int, Tuple[int, int]],
  57. ) -> (Union[int, Tuple[int, int]], Tuple[int, int]):
  58. """Computes padding and output size such that TF Conv2DTranspose `same` is matched.
  59. Note that when padding="same", TensorFlow's Conv2DTranspose makes sure that
  60. 0-padding is added to the already strided image in such a way that the output image
  61. has the same size as the input image times the stride (and no matter the
  62. kernel size).
  63. For example: Input image is (4, 4, 24) (not yet strided), padding is "same",
  64. stride=2, kernel=5.
  65. First, the input image is strided (with stride=2):
  66. Input image (4x4):
  67. A B C D
  68. E F G H
  69. I J K L
  70. M N O P
  71. Stride with stride=2 -> (7x7)
  72. A 0 B 0 C 0 D
  73. 0 0 0 0 0 0 0
  74. E 0 F 0 G 0 H
  75. 0 0 0 0 0 0 0
  76. I 0 J 0 K 0 L
  77. 0 0 0 0 0 0 0
  78. M 0 N 0 O 0 P
  79. Then this strided image (strided_size=7x7) is padded (exact padding values will be
  80. output by this function):
  81. padding -> (left=3, right=2, top=3, bottom=2)
  82. 0 0 0 0 0 0 0 0 0 0 0 0
  83. 0 0 0 0 0 0 0 0 0 0 0 0
  84. 0 0 0 0 0 0 0 0 0 0 0 0
  85. 0 0 0 A 0 B 0 C 0 D 0 0
  86. 0 0 0 0 0 0 0 0 0 0 0 0
  87. 0 0 0 E 0 F 0 G 0 H 0 0
  88. 0 0 0 0 0 0 0 0 0 0 0 0
  89. 0 0 0 I 0 J 0 K 0 L 0 0
  90. 0 0 0 0 0 0 0 0 0 0 0 0
  91. 0 0 0 M 0 N 0 O 0 P 0 0
  92. 0 0 0 0 0 0 0 0 0 0 0 0
  93. 0 0 0 0 0 0 0 0 0 0 0 0
  94. Then deconvolution with kernel=5 yields an output image of 8x8 (x num output
  95. filters).
  96. Args:
  97. strided_size: The size (width x height) of the already strided image.
  98. kernel: Either width x height (tuple of ints) or - if a square kernel is used -
  99. a single int for both width and height.
  100. stride: Either stride width x stride height (tuple of ints) or - if square
  101. striding is used - a single int for both width- and height striding.
  102. Returns:
  103. Tuple consisting of 1) `padding`: A 4-tuple to pad the input after(!) striding.
  104. The values are for left, right, top, and bottom padding, individually.
  105. This 4-tuple can be used in a torch.nn.ZeroPad2d layer, and 2) the output shape
  106. after striding, padding, and the conv transpose layer.
  107. """
  108. # Solve single int (squared) inputs for kernel and/or stride.
  109. k_w, k_h = (kernel, kernel) if isinstance(kernel, int) else kernel
  110. s_w, s_h = (stride, stride) if isinstance(stride, int) else stride
  111. # Compute the total size of the 0-padding on both axes. If results are odd numbers,
  112. # the padding on e.g. left and right (or top and bottom) side will have to differ
  113. # by 1.
  114. pad_total_w, pad_total_h = k_w - 1 + s_w - 1, k_h - 1 + s_h - 1
  115. pad_right = pad_total_w // 2
  116. pad_left = pad_right + (1 if pad_total_w % 2 == 1 else 0)
  117. pad_bottom = pad_total_h // 2
  118. pad_top = pad_bottom + (1 if pad_total_h % 2 == 1 else 0)
  119. # Compute the output size.
  120. output_shape = (
  121. strided_size[0] + pad_total_w - k_w + 1,
  122. strided_size[1] + pad_total_h - k_h + 1,
  123. )
  124. # Return padding and output shape.
  125. return (pad_left, pad_right, pad_top, pad_bottom), output_shape
  126. @DeveloperAPI
  127. def valid_padding(
  128. in_size: Tuple[int, int],
  129. filter_size: Union[int, Tuple[int, int]],
  130. stride_size: Union[int, Tuple[int, int]],
  131. ) -> Tuple[int, int]:
  132. """Emulates TF Conv2DLayer "valid" padding (no padding) and computes output dims.
  133. This method, analogous to its "same" counterpart, but it only computes the output
  134. image size, since valid padding means (0, 0, 0, 0).
  135. See www.tensorflow.org/versions/r0.12/api_docs/python/nn/convolution
  136. Args:
  137. in_size: Rows (Height), Column (Width) for input
  138. stride_size (Union[int,Tuple[int, int]]): Rows (Height), column (Width)
  139. for stride. If int, height == width.
  140. filter_size: Rows (Height), column (Width) for filter
  141. Returns:
  142. The output shape after padding and convolution.
  143. """
  144. in_height, in_width = in_size
  145. if isinstance(filter_size, int):
  146. filter_height, filter_width = filter_size, filter_size
  147. else:
  148. filter_height, filter_width = filter_size
  149. if isinstance(stride_size, (int, float)):
  150. stride_height, stride_width = int(stride_size), int(stride_size)
  151. else:
  152. stride_height, stride_width = int(stride_size[0]), int(stride_size[1])
  153. out_height = int(np.ceil((in_height - filter_height + 1) / float(stride_height)))
  154. out_width = int(np.ceil((in_width - filter_width + 1) / float(stride_width)))
  155. return (out_height, out_width)
  156. @DeveloperAPI
  157. class SlimConv2d(nn.Module):
  158. """Simple mock of tf.slim Conv2d"""
  159. def __init__(
  160. self,
  161. in_channels: int,
  162. out_channels: int,
  163. kernel: Union[int, Tuple[int, int]],
  164. stride: Union[int, Tuple[int, int]],
  165. padding: Union[int, Tuple[int, int]],
  166. # Defaulting these to nn.[..] will break soft torch import.
  167. initializer: Any = "default",
  168. activation_fn: Any = "default",
  169. bias_init: float = 0,
  170. ):
  171. """Creates a standard Conv2d layer, similar to torch.nn.Conv2d
  172. Args:
  173. in_channels: Number of input channels
  174. out_channels: Number of output channels
  175. kernel: If int, the kernel is
  176. a tuple(x,x). Elsewise, the tuple can be specified
  177. stride: Controls the stride
  178. for the cross-correlation. If int, the stride is a
  179. tuple(x,x). Elsewise, the tuple can be specified
  180. padding: Controls the amount
  181. of implicit zero-paddings during the conv operation
  182. initializer: Initializer function for kernel weights
  183. activation_fn: Activation function at the end of layer
  184. bias_init: Initialize bias weights to bias_init const
  185. """
  186. super(SlimConv2d, self).__init__()
  187. layers = []
  188. # Padding layer.
  189. if padding:
  190. layers.append(nn.ZeroPad2d(padding))
  191. # Actual Conv2D layer (including correct initialization logic).
  192. conv = nn.Conv2d(in_channels, out_channels, kernel, stride)
  193. if initializer:
  194. if initializer == "default":
  195. initializer = nn.init.xavier_uniform_
  196. initializer(conv.weight)
  197. nn.init.constant_(conv.bias, bias_init)
  198. layers.append(conv)
  199. # Activation function (if any; default=ReLu).
  200. if isinstance(activation_fn, str):
  201. if activation_fn == "default":
  202. activation_fn = nn.ReLU
  203. else:
  204. activation_fn = get_activation_fn(activation_fn, "torch")
  205. if activation_fn is not None:
  206. layers.append(activation_fn())
  207. # Put everything in sequence.
  208. self._model = nn.Sequential(*layers)
  209. def forward(self, x: TensorType) -> TensorType:
  210. return self._model(x)
  211. @DeveloperAPI
  212. class SlimFC(nn.Module):
  213. """Simple PyTorch version of `linear` function"""
  214. def __init__(
  215. self,
  216. in_size: int,
  217. out_size: int,
  218. initializer: Any = None,
  219. activation_fn: Any = None,
  220. use_bias: bool = True,
  221. bias_init: float = 0.0,
  222. ):
  223. """Creates a standard FC layer, similar to torch.nn.Linear
  224. Args:
  225. in_size: Input size for FC Layer
  226. out_size: Output size for FC Layer
  227. initializer: Initializer function for FC layer weights
  228. activation_fn: Activation function at the end of layer
  229. use_bias: Whether to add bias weights or not
  230. bias_init: Initialize bias weights to bias_init const
  231. """
  232. super(SlimFC, self).__init__()
  233. layers = []
  234. # Actual nn.Linear layer (including correct initialization logic).
  235. linear = nn.Linear(in_size, out_size, bias=use_bias)
  236. if initializer is None:
  237. initializer = nn.init.xavier_uniform_
  238. initializer(linear.weight)
  239. if use_bias is True:
  240. nn.init.constant_(linear.bias, bias_init)
  241. layers.append(linear)
  242. # Activation function (if any; default=None (linear)).
  243. if isinstance(activation_fn, str):
  244. activation_fn = get_activation_fn(activation_fn, "torch")
  245. if activation_fn is not None:
  246. layers.append(activation_fn())
  247. # Put everything in sequence.
  248. self._model = nn.Sequential(*layers)
  249. def forward(self, x: TensorType) -> TensorType:
  250. return self._model(x)
  251. @DeveloperAPI
  252. class AppendBiasLayer(nn.Module):
  253. """Simple bias appending layer for free_log_std."""
  254. def __init__(self, num_bias_vars: int):
  255. super().__init__()
  256. self.log_std = torch.nn.Parameter(torch.as_tensor([0.0] * num_bias_vars))
  257. self.register_parameter("log_std", self.log_std)
  258. def forward(self, x: TensorType) -> TensorType:
  259. out = torch.cat([x, self.log_std.unsqueeze(0).repeat([len(x), 1])], axis=1)
  260. return out
  261. @DeveloperAPI
  262. class Reshape(nn.Module):
  263. """Standard module that reshapes/views a tensor"""
  264. def __init__(self, shape: List):
  265. super().__init__()
  266. self.shape = shape
  267. def forward(self, x):
  268. return x.view(*self.shape)