usage.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. import collections
  2. import json
  3. import os
  4. from enum import Enum
  5. from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
  6. from ray._common.usage.usage_lib import TagKey, record_extra_usage_tag
  7. if TYPE_CHECKING:
  8. from ray.train._internal.storage import StorageContext
  9. from ray.train.trainer import BaseTrainer
  10. from ray.tune import Callback
  11. from ray.tune.schedulers import TrialScheduler
  12. from ray.tune.search import BasicVariantGenerator, Searcher
  13. AIR_TRAINERS = {
  14. "HorovodTrainer",
  15. "LightGBMTrainer",
  16. "TensorflowTrainer",
  17. "TorchTrainer",
  18. "XGBoostTrainer",
  19. }
  20. TRAIN_V2_TRAINERS = {
  21. "DataParallelTrainer",
  22. "JaxTrainer",
  23. "LightGBMTrainer",
  24. "TensorflowTrainer",
  25. "TorchTrainer",
  26. "XGBoostTrainer",
  27. }
  28. # searchers implemented by Ray Tune.
  29. TUNE_SEARCHERS = {
  30. "AxSearch",
  31. "BayesOptSearch",
  32. "TuneBOHB",
  33. "HEBOSearch",
  34. "HyperOptSearch",
  35. "NevergradSearch",
  36. "OptunaSearch",
  37. "ZOOptSearch",
  38. }
  39. # These are just wrappers around real searchers.
  40. # We don't want to double tag in this case, otherwise, the real tag
  41. # will be overwritten.
  42. TUNE_SEARCHER_WRAPPERS = {
  43. "ConcurrencyLimiter",
  44. "Repeater",
  45. }
  46. TUNE_SCHEDULERS = {
  47. "FIFOScheduler",
  48. "AsyncHyperBandScheduler",
  49. "MedianStoppingRule",
  50. "HyperBandScheduler",
  51. "HyperBandForBOHB",
  52. "PopulationBasedTraining",
  53. "PopulationBasedTrainingReplay",
  54. "PB2",
  55. "ResourceChangingScheduler",
  56. }
  57. class AirEntrypoint(Enum):
  58. TUNER = "Tuner.fit"
  59. TRAINER = "Trainer.fit"
  60. TUNE_RUN = "tune.run"
  61. TUNE_RUN_EXPERIMENTS = "tune.run_experiments"
  62. def _find_class_name(obj, allowed_module_path_prefix: str, whitelist: Set[str]):
  63. """Find the class name of the object. If the object is not
  64. under `allowed_module_path_prefix` or if its class is not in the whitelist,
  65. return "Custom".
  66. Args:
  67. obj: The object under inspection.
  68. allowed_module_path_prefix: If the `obj`'s class is not under
  69. the `allowed_module_path_prefix`, its class name will be anonymized.
  70. whitelist: If the `obj`'s class is not in the `whitelist`,
  71. it will be anonymized.
  72. Returns:
  73. The class name to be tagged with telemetry.
  74. """
  75. module_path = obj.__module__
  76. cls_name = obj.__class__.__name__
  77. if module_path.startswith(allowed_module_path_prefix) and cls_name in whitelist:
  78. return cls_name
  79. else:
  80. return "Custom"
  81. def tag_air_trainer(trainer: "BaseTrainer"):
  82. from ray.train.trainer import BaseTrainer
  83. assert isinstance(trainer, BaseTrainer)
  84. trainer_name = _find_class_name(trainer, "ray.train", AIR_TRAINERS)
  85. record_extra_usage_tag(TagKey.AIR_TRAINER, trainer_name)
  86. def tag_train_v2_trainer(trainer):
  87. from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer
  88. assert isinstance(trainer, DataParallelTrainer)
  89. trainer_name = _find_class_name(trainer, "ray.train", TRAIN_V2_TRAINERS)
  90. record_extra_usage_tag(TagKey.TRAIN_TRAINER, trainer_name)
  91. def tag_searcher(searcher: Union["BasicVariantGenerator", "Searcher"]):
  92. from ray.tune.search import BasicVariantGenerator, Searcher
  93. if isinstance(searcher, BasicVariantGenerator):
  94. # Note this could be highly inflated as all train flows are treated
  95. # as using BasicVariantGenerator.
  96. record_extra_usage_tag(TagKey.TUNE_SEARCHER, "BasicVariantGenerator")
  97. elif isinstance(searcher, Searcher):
  98. searcher_name = _find_class_name(
  99. searcher, "ray.tune.search", TUNE_SEARCHERS.union(TUNE_SEARCHER_WRAPPERS)
  100. )
  101. if searcher_name in TUNE_SEARCHER_WRAPPERS:
  102. # ignore to avoid double tagging with wrapper name.
  103. return
  104. record_extra_usage_tag(TagKey.TUNE_SEARCHER, searcher_name)
  105. else:
  106. assert False, (
  107. "Not expecting a non-BasicVariantGenerator, "
  108. "non-Searcher type passed in for `tag_searcher`."
  109. )
  110. def tag_scheduler(scheduler: "TrialScheduler"):
  111. from ray.tune.schedulers import TrialScheduler
  112. assert isinstance(scheduler, TrialScheduler)
  113. scheduler_name = _find_class_name(scheduler, "ray.tune.schedulers", TUNE_SCHEDULERS)
  114. record_extra_usage_tag(TagKey.TUNE_SCHEDULER, scheduler_name)
  115. def tag_setup_wandb():
  116. record_extra_usage_tag(TagKey.AIR_SETUP_WANDB_INTEGRATION_USED, "1")
  117. def tag_setup_mlflow():
  118. record_extra_usage_tag(TagKey.AIR_SETUP_MLFLOW_INTEGRATION_USED, "1")
  119. def _count_callbacks(callbacks: Optional[List["Callback"]]) -> Dict[str, int]:
  120. """Creates a map of callback class name -> count given a list of callbacks."""
  121. from ray.air.integrations.comet import CometLoggerCallback
  122. from ray.air.integrations.mlflow import MLflowLoggerCallback
  123. from ray.air.integrations.wandb import WandbLoggerCallback
  124. from ray.tune import Callback
  125. from ray.tune.logger import LoggerCallback
  126. from ray.tune.logger.aim import AimLoggerCallback
  127. from ray.tune.utils.callback import DEFAULT_CALLBACK_CLASSES
  128. built_in_callbacks = (
  129. WandbLoggerCallback,
  130. MLflowLoggerCallback,
  131. CometLoggerCallback,
  132. AimLoggerCallback,
  133. ) + DEFAULT_CALLBACK_CLASSES
  134. callback_names = [callback_cls.__name__ for callback_cls in built_in_callbacks]
  135. callback_counts = collections.defaultdict(int)
  136. callbacks = callbacks or []
  137. for callback in callbacks:
  138. if not isinstance(callback, Callback):
  139. # This will error later, but don't include this as custom usage.
  140. continue
  141. callback_name = callback.__class__.__name__
  142. if callback_name in callback_names:
  143. callback_counts[callback_name] += 1
  144. elif isinstance(callback, LoggerCallback):
  145. callback_counts["CustomLoggerCallback"] += 1
  146. else:
  147. callback_counts["CustomCallback"] += 1
  148. return callback_counts
  149. def tag_callbacks(callbacks: Optional[List["Callback"]]) -> bool:
  150. """Records built-in callback usage via a JSON str representing a
  151. dictionary mapping callback class name -> counts.
  152. User-defined callbacks will increment the count under the `CustomLoggerCallback`
  153. or `CustomCallback` key depending on which of the provided interfaces they subclass.
  154. NOTE: This will NOT track the name of the user-defined callback,
  155. nor its implementation.
  156. This will NOT report telemetry if no callbacks are provided by the user.
  157. Returns:
  158. bool: True if usage was recorded, False otherwise.
  159. """
  160. if not callbacks:
  161. # User didn't pass in any callbacks -> no usage recorded.
  162. return False
  163. callback_counts = _count_callbacks(callbacks)
  164. if callback_counts:
  165. callback_counts_str = json.dumps(callback_counts)
  166. record_extra_usage_tag(TagKey.AIR_CALLBACKS, callback_counts_str)
  167. def tag_storage_type(storage: "StorageContext"):
  168. """Records the storage configuration of an experiment.
  169. The storage configuration is set by `RunConfig(storage_path, storage_filesystem)`.
  170. The possible storage types (defined by `pyarrow.fs.FileSystem.type_name`) are:
  171. - 'local' = pyarrow.fs.LocalFileSystem. This includes NFS usage.
  172. - 'mock' = pyarrow.fs._MockFileSystem. This is used for testing.
  173. - ('s3', 'gcs', 'abfs', 'hdfs'): Various remote storage schemes
  174. with default implementations in pyarrow.
  175. - 'custom' = All other storage schemes, which includes ALL cases where a
  176. custom `storage_filesystem` is provided.
  177. - 'other' = catches any other cases not explicitly handled above.
  178. """
  179. whitelist = {"local", "mock", "s3", "gcs", "abfs", "hdfs"}
  180. if storage.custom_fs_provided:
  181. storage_config_tag = "custom"
  182. elif storage.storage_filesystem.type_name in whitelist:
  183. storage_config_tag = storage.storage_filesystem.type_name
  184. else:
  185. storage_config_tag = "other"
  186. record_extra_usage_tag(TagKey.AIR_STORAGE_CONFIGURATION, storage_config_tag)
  187. def tag_ray_air_env_vars() -> bool:
  188. """Records usage of environment variables exposed by the Ray AIR libraries.
  189. NOTE: This does not track the values of the environment variables, nor
  190. does this track environment variables not explicitly included in the
  191. `all_ray_air_env_vars` allow-list.
  192. Returns:
  193. bool: True if at least one environment var is supplied by the user.
  194. """
  195. from ray.air.constants import AIR_ENV_VARS
  196. from ray.train.constants import TRAIN_ENV_VARS
  197. from ray.tune.constants import TUNE_ENV_VARS
  198. all_ray_air_env_vars = sorted(
  199. set().union(AIR_ENV_VARS, TUNE_ENV_VARS, TRAIN_ENV_VARS)
  200. )
  201. user_supplied_env_vars = []
  202. for env_var in all_ray_air_env_vars:
  203. if env_var in os.environ:
  204. user_supplied_env_vars.append(env_var)
  205. if user_supplied_env_vars:
  206. env_vars_str = json.dumps(user_supplied_env_vars)
  207. record_extra_usage_tag(TagKey.AIR_ENV_VARS, env_vars_str)
  208. return True
  209. return False
  210. def tag_air_entrypoint(entrypoint: AirEntrypoint) -> None:
  211. """Records the entrypoint to an AIR training run."""
  212. assert entrypoint in AirEntrypoint
  213. record_extra_usage_tag(TagKey.AIR_ENTRYPOINT, entrypoint.value)