curiosity.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. from typing import Optional, Tuple, Union
  2. import numpy as np
  3. from gymnasium.spaces import Discrete, MultiDiscrete, Space
  4. from ray.rllib.models.action_dist import ActionDistribution
  5. from ray.rllib.models.catalog import ModelCatalog
  6. from ray.rllib.models.modelv2 import ModelV2
  7. from ray.rllib.models.tf.tf_action_dist import Categorical, MultiCategorical
  8. from ray.rllib.models.torch.misc import SlimFC
  9. from ray.rllib.models.torch.torch_action_dist import (
  10. TorchCategorical,
  11. TorchMultiCategorical,
  12. )
  13. from ray.rllib.models.utils import get_activation_fn
  14. from ray.rllib.policy.sample_batch import SampleBatch
  15. from ray.rllib.utils import NullContextManager
  16. from ray.rllib.utils.annotations import OldAPIStack, override
  17. from ray.rllib.utils.exploration.exploration import Exploration
  18. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  19. from ray.rllib.utils.from_config import from_config
  20. from ray.rllib.utils.tf_utils import get_placeholder, one_hot as tf_one_hot
  21. from ray.rllib.utils.torch_utils import one_hot
  22. from ray.rllib.utils.typing import FromConfigSpec, ModelConfigDict, TensorType
  23. tf1, tf, tfv = try_import_tf()
  24. torch, nn = try_import_torch()
  25. F = None
  26. if nn is not None:
  27. F = nn.functional
  28. @OldAPIStack
  29. class Curiosity(Exploration):
  30. """Implementation of:
  31. [1] Curiosity-driven Exploration by Self-supervised Prediction
  32. Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017.
  33. https://arxiv.org/pdf/1705.05363.pdf
  34. Learns a simplified model of the environment based on three networks:
  35. 1) Embedding observations into latent space ("feature" network).
  36. 2) Predicting the action, given two consecutive embedded observations
  37. ("inverse" network).
  38. 3) Predicting the next embedded obs, given an obs and action
  39. ("forward" network).
  40. The less the agent is able to predict the actually observed next feature
  41. vector, given obs and action (through the forwards network), the larger the
  42. "intrinsic reward", which will be added to the extrinsic reward.
  43. Therefore, if a state transition was unexpected, the agent becomes
  44. "curious" and will further explore this transition leading to better
  45. exploration in sparse rewards environments.
  46. """
  47. def __init__(
  48. self,
  49. action_space: Space,
  50. *,
  51. framework: str,
  52. model: ModelV2,
  53. feature_dim: int = 288,
  54. feature_net_config: Optional[ModelConfigDict] = None,
  55. inverse_net_hiddens: Tuple[int] = (256,),
  56. inverse_net_activation: str = "relu",
  57. forward_net_hiddens: Tuple[int] = (256,),
  58. forward_net_activation: str = "relu",
  59. beta: float = 0.2,
  60. eta: float = 1.0,
  61. lr: float = 1e-3,
  62. sub_exploration: Optional[FromConfigSpec] = None,
  63. **kwargs
  64. ):
  65. """Initializes a Curiosity object.
  66. Uses as defaults the hyperparameters described in [1].
  67. Args:
  68. feature_dim: The dimensionality of the feature (phi)
  69. vectors.
  70. feature_net_config: Optional model
  71. configuration for the feature network, producing feature
  72. vectors (phi) from observations. This can be used to configure
  73. fcnet- or conv_net setups to properly process any observation
  74. space.
  75. inverse_net_hiddens: Tuple of the layer sizes of the
  76. inverse (action predicting) NN head (on top of the feature
  77. outputs for phi and phi').
  78. inverse_net_activation: Activation specifier for the inverse
  79. net.
  80. forward_net_hiddens: Tuple of the layer sizes of the
  81. forward (phi' predicting) NN head.
  82. forward_net_activation: Activation specifier for the forward
  83. net.
  84. beta: Weight for the forward loss (over the inverse loss,
  85. which gets weight=1.0-beta) in the common loss term.
  86. eta: Weight for intrinsic rewards before being added to
  87. extrinsic ones.
  88. lr: The learning rate for the curiosity-specific
  89. optimizer, optimizing feature-, inverse-, and forward nets.
  90. sub_exploration: The config dict for
  91. the underlying Exploration to use (e.g. epsilon-greedy for
  92. DQN). If None, uses the FromSpecDict provided in the Policy's
  93. default config.
  94. """
  95. if not isinstance(action_space, (Discrete, MultiDiscrete)):
  96. raise ValueError(
  97. "Only (Multi)Discrete action spaces supported for Curiosity so far!"
  98. )
  99. super().__init__(action_space, model=model, framework=framework, **kwargs)
  100. if self.policy_config["num_env_runners"] != 0:
  101. raise ValueError(
  102. "Curiosity exploration currently does not support parallelism."
  103. " `num_workers` must be 0!"
  104. )
  105. self.feature_dim = feature_dim
  106. if feature_net_config is None:
  107. feature_net_config = self.policy_config["model"].copy()
  108. self.feature_net_config = feature_net_config
  109. self.inverse_net_hiddens = inverse_net_hiddens
  110. self.inverse_net_activation = inverse_net_activation
  111. self.forward_net_hiddens = forward_net_hiddens
  112. self.forward_net_activation = forward_net_activation
  113. self.action_dim = (
  114. self.action_space.n
  115. if isinstance(self.action_space, Discrete)
  116. else np.sum(self.action_space.nvec)
  117. )
  118. self.beta = beta
  119. self.eta = eta
  120. self.lr = lr
  121. # TODO: (sven) if sub_exploration is None, use Algorithm's default
  122. # Exploration config.
  123. if sub_exploration is None:
  124. raise NotImplementedError
  125. self.sub_exploration = sub_exploration
  126. # Creates modules/layers inside the actual ModelV2.
  127. self._curiosity_feature_net = ModelCatalog.get_model_v2(
  128. self.model.obs_space,
  129. self.action_space,
  130. self.feature_dim,
  131. model_config=self.feature_net_config,
  132. framework=self.framework,
  133. name="feature_net",
  134. )
  135. self._curiosity_inverse_fcnet = self._create_fc_net(
  136. [2 * self.feature_dim] + list(self.inverse_net_hiddens) + [self.action_dim],
  137. self.inverse_net_activation,
  138. name="inverse_net",
  139. )
  140. self._curiosity_forward_fcnet = self._create_fc_net(
  141. [self.feature_dim + self.action_dim]
  142. + list(self.forward_net_hiddens)
  143. + [self.feature_dim],
  144. self.forward_net_activation,
  145. name="forward_net",
  146. )
  147. # This is only used to select the correct action
  148. self.exploration_submodule = from_config(
  149. cls=Exploration,
  150. config=self.sub_exploration,
  151. action_space=self.action_space,
  152. framework=self.framework,
  153. policy_config=self.policy_config,
  154. model=self.model,
  155. num_workers=self.num_workers,
  156. worker_index=self.worker_index,
  157. )
  158. @override(Exploration)
  159. def get_exploration_action(
  160. self,
  161. *,
  162. action_distribution: ActionDistribution,
  163. timestep: Union[int, TensorType],
  164. explore: bool = True
  165. ):
  166. # Simply delegate to sub-Exploration module.
  167. return self.exploration_submodule.get_exploration_action(
  168. action_distribution=action_distribution, timestep=timestep, explore=explore
  169. )
  170. @override(Exploration)
  171. def get_exploration_optimizer(self, optimizers):
  172. # Create, but don't add Adam for curiosity NN updating to the policy.
  173. # If we added and returned it here, it would be used in the policy's
  174. # update loop, which we don't want (curiosity updating happens inside
  175. # `postprocess_trajectory`).
  176. if self.framework == "torch":
  177. feature_params = list(self._curiosity_feature_net.parameters())
  178. inverse_params = list(self._curiosity_inverse_fcnet.parameters())
  179. forward_params = list(self._curiosity_forward_fcnet.parameters())
  180. # Now that the Policy's own optimizer(s) have been created (from
  181. # the Model parameters (IMPORTANT: w/o(!) the curiosity params),
  182. # we can add our curiosity sub-modules to the Policy's Model.
  183. self.model._curiosity_feature_net = self._curiosity_feature_net.to(
  184. self.device
  185. )
  186. self.model._curiosity_inverse_fcnet = self._curiosity_inverse_fcnet.to(
  187. self.device
  188. )
  189. self.model._curiosity_forward_fcnet = self._curiosity_forward_fcnet.to(
  190. self.device
  191. )
  192. self._optimizer = torch.optim.Adam(
  193. forward_params + inverse_params + feature_params, lr=self.lr
  194. )
  195. else:
  196. self.model._curiosity_feature_net = self._curiosity_feature_net
  197. self.model._curiosity_inverse_fcnet = self._curiosity_inverse_fcnet
  198. self.model._curiosity_forward_fcnet = self._curiosity_forward_fcnet
  199. # Feature net is a RLlib ModelV2, the other 2 are keras Models.
  200. self._optimizer_var_list = (
  201. self._curiosity_feature_net.base_model.variables
  202. + self._curiosity_inverse_fcnet.variables
  203. + self._curiosity_forward_fcnet.variables
  204. )
  205. self._optimizer = tf1.train.AdamOptimizer(learning_rate=self.lr)
  206. # Create placeholders and initialize the loss.
  207. if self.framework == "tf":
  208. self._obs_ph = get_placeholder(
  209. space=self.model.obs_space, name="_curiosity_obs"
  210. )
  211. self._next_obs_ph = get_placeholder(
  212. space=self.model.obs_space, name="_curiosity_next_obs"
  213. )
  214. self._action_ph = get_placeholder(
  215. space=self.model.action_space, name="_curiosity_action"
  216. )
  217. (
  218. self._forward_l2_norm_sqared,
  219. self._update_op,
  220. ) = self._postprocess_helper_tf(
  221. self._obs_ph, self._next_obs_ph, self._action_ph
  222. )
  223. return optimizers
  224. @override(Exploration)
  225. def postprocess_trajectory(self, policy, sample_batch, tf_sess=None):
  226. """Calculates phi values (obs, obs', and predicted obs') and ri.
  227. Also calculates forward and inverse losses and updates the curiosity
  228. module on the provided batch using our optimizer.
  229. """
  230. if self.framework != "torch":
  231. self._postprocess_tf(policy, sample_batch, tf_sess)
  232. else:
  233. self._postprocess_torch(policy, sample_batch)
  234. def _postprocess_tf(self, policy, sample_batch, tf_sess):
  235. # tf1 static-graph: Perform session call on our loss and update ops.
  236. if self.framework == "tf":
  237. forward_l2_norm_sqared, _ = tf_sess.run(
  238. [self._forward_l2_norm_sqared, self._update_op],
  239. feed_dict={
  240. self._obs_ph: sample_batch[SampleBatch.OBS],
  241. self._next_obs_ph: sample_batch[SampleBatch.NEXT_OBS],
  242. self._action_ph: sample_batch[SampleBatch.ACTIONS],
  243. },
  244. )
  245. # tf-eager: Perform model calls, loss calculations, and optimizer
  246. # stepping on the fly.
  247. else:
  248. forward_l2_norm_sqared, _ = self._postprocess_helper_tf(
  249. sample_batch[SampleBatch.OBS],
  250. sample_batch[SampleBatch.NEXT_OBS],
  251. sample_batch[SampleBatch.ACTIONS],
  252. )
  253. # Scale intrinsic reward by eta hyper-parameter.
  254. sample_batch[SampleBatch.REWARDS] = (
  255. sample_batch[SampleBatch.REWARDS] + self.eta * forward_l2_norm_sqared
  256. )
  257. return sample_batch
  258. def _postprocess_helper_tf(self, obs, next_obs, actions):
  259. with (
  260. tf.GradientTape() if self.framework != "tf" else NullContextManager()
  261. ) as tape:
  262. # Push both observations through feature net to get both phis.
  263. phis, _ = self.model._curiosity_feature_net(
  264. {SampleBatch.OBS: tf.concat([obs, next_obs], axis=0)}
  265. )
  266. phi, next_phi = tf.split(phis, 2)
  267. # Predict next phi with forward model.
  268. predicted_next_phi = self.model._curiosity_forward_fcnet(
  269. tf.concat([phi, tf_one_hot(actions, self.action_space)], axis=-1)
  270. )
  271. # Forward loss term (predicted phi', given phi and action vs
  272. # actually observed phi').
  273. forward_l2_norm_sqared = 0.5 * tf.reduce_sum(
  274. tf.square(predicted_next_phi - next_phi), axis=-1
  275. )
  276. forward_loss = tf.reduce_mean(forward_l2_norm_sqared)
  277. # Inverse loss term (prediced action that led from phi to phi' vs
  278. # actual action taken).
  279. phi_cat_next_phi = tf.concat([phi, next_phi], axis=-1)
  280. dist_inputs = self.model._curiosity_inverse_fcnet(phi_cat_next_phi)
  281. action_dist = (
  282. Categorical(dist_inputs, self.model)
  283. if isinstance(self.action_space, Discrete)
  284. else MultiCategorical(dist_inputs, self.model, self.action_space.nvec)
  285. )
  286. # Neg log(p); p=probability of observed action given the inverse-NN
  287. # predicted action distribution.
  288. inverse_loss = -action_dist.logp(tf.convert_to_tensor(actions))
  289. inverse_loss = tf.reduce_mean(inverse_loss)
  290. # Calculate the ICM loss.
  291. loss = (1.0 - self.beta) * inverse_loss + self.beta * forward_loss
  292. # Step the optimizer.
  293. if self.framework != "tf":
  294. grads = tape.gradient(loss, self._optimizer_var_list)
  295. grads_and_vars = [
  296. (g, v) for g, v in zip(grads, self._optimizer_var_list) if g is not None
  297. ]
  298. update_op = self._optimizer.apply_gradients(grads_and_vars)
  299. else:
  300. update_op = self._optimizer.minimize(
  301. loss, var_list=self._optimizer_var_list
  302. )
  303. # Return the squared l2 norm and the optimizer update op.
  304. return forward_l2_norm_sqared, update_op
  305. def _postprocess_torch(self, policy, sample_batch):
  306. # Push both observations through feature net to get both phis.
  307. phis, _ = self.model._curiosity_feature_net(
  308. {
  309. SampleBatch.OBS: torch.cat(
  310. [
  311. torch.from_numpy(sample_batch[SampleBatch.OBS]).to(
  312. policy.device
  313. ),
  314. torch.from_numpy(sample_batch[SampleBatch.NEXT_OBS]).to(
  315. policy.device
  316. ),
  317. ]
  318. )
  319. }
  320. )
  321. phi, next_phi = torch.chunk(phis, 2)
  322. actions_tensor = (
  323. torch.from_numpy(sample_batch[SampleBatch.ACTIONS]).long().to(policy.device)
  324. )
  325. # Predict next phi with forward model.
  326. predicted_next_phi = self.model._curiosity_forward_fcnet(
  327. torch.cat([phi, one_hot(actions_tensor, self.action_space).float()], dim=-1)
  328. )
  329. # Forward loss term (predicted phi', given phi and action vs actually
  330. # observed phi').
  331. forward_l2_norm_sqared = 0.5 * torch.sum(
  332. torch.pow(predicted_next_phi - next_phi, 2.0), dim=-1
  333. )
  334. forward_loss = torch.mean(forward_l2_norm_sqared)
  335. # Scale intrinsic reward by eta hyper-parameter.
  336. sample_batch[SampleBatch.REWARDS] = (
  337. sample_batch[SampleBatch.REWARDS]
  338. + self.eta * forward_l2_norm_sqared.detach().cpu().numpy()
  339. )
  340. # Inverse loss term (prediced action that led from phi to phi' vs
  341. # actual action taken).
  342. phi_cat_next_phi = torch.cat([phi, next_phi], dim=-1)
  343. dist_inputs = self.model._curiosity_inverse_fcnet(phi_cat_next_phi)
  344. action_dist = (
  345. TorchCategorical(dist_inputs, self.model)
  346. if isinstance(self.action_space, Discrete)
  347. else TorchMultiCategorical(dist_inputs, self.model, self.action_space.nvec)
  348. )
  349. # Neg log(p); p=probability of observed action given the inverse-NN
  350. # predicted action distribution.
  351. inverse_loss = -action_dist.logp(actions_tensor)
  352. inverse_loss = torch.mean(inverse_loss)
  353. # Calculate the ICM loss.
  354. loss = (1.0 - self.beta) * inverse_loss + self.beta * forward_loss
  355. # Perform an optimizer step.
  356. self._optimizer.zero_grad()
  357. loss.backward()
  358. self._optimizer.step()
  359. # Return the postprocessed sample batch (with the corrected rewards).
  360. return sample_batch
  361. def _create_fc_net(self, layer_dims, activation, name=None):
  362. """Given a list of layer dimensions (incl. input-dim), creates FC-net.
  363. Args:
  364. layer_dims (Tuple[int]): Tuple of layer dims, including the input
  365. dimension.
  366. activation: An activation specifier string (e.g. "relu").
  367. Examples:
  368. If layer_dims is [4,8,6] we'll have a two layer net: 4->8 (8 nodes)
  369. and 8->6 (6 nodes), where the second layer (6 nodes) does not have
  370. an activation anymore. 4 is the input dimension.
  371. """
  372. layers = (
  373. [tf.keras.layers.Input(shape=(layer_dims[0],), name="{}_in".format(name))]
  374. if self.framework != "torch"
  375. else []
  376. )
  377. for i in range(len(layer_dims) - 1):
  378. act = activation if i < len(layer_dims) - 2 else None
  379. if self.framework == "torch":
  380. layers.append(
  381. SlimFC(
  382. in_size=layer_dims[i],
  383. out_size=layer_dims[i + 1],
  384. initializer=torch.nn.init.xavier_uniform_,
  385. activation_fn=act,
  386. )
  387. )
  388. else:
  389. layers.append(
  390. tf.keras.layers.Dense(
  391. units=layer_dims[i + 1],
  392. activation=get_activation_fn(act),
  393. name="{}_{}".format(name, i),
  394. )
  395. )
  396. if self.framework == "torch":
  397. return nn.Sequential(*layers)
  398. else:
  399. return tf.keras.Sequential(layers)