data_logging.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. # wandb.integrations.data_logging.py
  2. #
  3. # Contains common utility functions that enable
  4. # logging datasets and predictions to wandb.
  5. from __future__ import annotations
  6. import sys
  7. from collections.abc import Sequence
  8. from typing import TYPE_CHECKING, Any, Callable
  9. import wandb
  10. if TYPE_CHECKING:
  11. from wandb.data_types import _TableIndex
  12. CAN_INFER_IMAGE_AND_VIDEO = sys.version_info.major == 3 and sys.version_info.minor >= 5
  13. class ValidationDataLogger:
  14. """Logs validation data as a wandb.Table.
  15. ValidationDataLogger is intended to be used inside of library integrations
  16. in order to facilitate the process of optionally building a validation dataset
  17. and logging periodic predictions against such validation data using WandB best
  18. practices.
  19. """
  20. validation_inputs: Sequence | dict[str, Sequence]
  21. validation_targets: Sequence | dict[str, Sequence] | None
  22. validation_indexes: list[_TableIndex]
  23. prediction_row_processor: Callable | None
  24. class_labels_table: wandb.Table | None
  25. infer_missing_processors: bool
  26. def __init__(
  27. self,
  28. inputs: Sequence | dict[str, Sequence],
  29. targets: Sequence | dict[str, Sequence] | None = None,
  30. indexes: list[_TableIndex] | None = None,
  31. validation_row_processor: Callable | None = None,
  32. prediction_row_processor: Callable | None = None,
  33. input_col_name: str = "input",
  34. target_col_name: str = "target",
  35. table_name: str = "wb_validation_data",
  36. artifact_type: str = "validation_dataset",
  37. class_labels: list[str] | None = None,
  38. infer_missing_processors: bool = True,
  39. ) -> None:
  40. """Initialize a new ValidationDataLogger.
  41. Args:
  42. inputs: A list of input vectors or dictionary of lists of input vectors
  43. (used if the model has multiple named inputs)
  44. targets: A list of target vectors or dictionary of lists of target vectors
  45. (used if the model has multiple named targets/putputs). Defaults to `None`.
  46. `targets` and `indexes` cannot both be `None`.
  47. indexes: An ordered list of `wandb.data_types._TableIndex` mapping the
  48. input items to their source table. This is most commonly retrieved by using
  49. `indexes = my_data_table.get_index()`. Defaults to `None`. `targets`
  50. and `indexes` cannot both be `None`.
  51. validation_row_processor: A function to apply to the validation data,
  52. commonly used to visualize the data. The function will receive an `ndx` (`int`)
  53. and a `row` (`dict`). If `inputs` is a list, then `row["input"]` will be the input
  54. data for the row. Else, it will be keyed based on the name of the input slot
  55. (corresponding to `inputs`). If `targets` is a list, then
  56. `row["target"]` will be the target data for the row. Else, it will
  57. be keyed based on `targets`. For example, if your input data is a
  58. single ndarray, but you wish to visualize the data as an image,
  59. then you can provide `lambda ndx, row: {"img": wandb.Image(row["input"])}`
  60. as the processor. If `None`, we will try to guess the appropriate processor.
  61. Ignored if `log_evaluation` is `False` or `val_keys` are present. Defaults to `None`.
  62. prediction_row_processor: Same as validation_row_processor, but applied to the
  63. model's output. `row["output"]` will contain the results of the model output.
  64. Defaults to `None`.
  65. input_col_name: The name to use for the input column.
  66. Defaults to `"input"`.
  67. target_col_name: The name to use for the target column.
  68. Defaults to `"target"`.
  69. table_name: The name to use for the validation table.
  70. Defaults to `"wb_validation_data"`.
  71. artifact_type: The artifact type to use for the validation data.
  72. Defaults to `"validation_dataset"`.
  73. class_labels: Optional list of labels to use in the inferred
  74. processors. If the model's `target` or `output` is inferred to be a class,
  75. we will attempt to map the class to these labels. Defaults to `None`.
  76. infer_missing_processors: Determines if processors are inferred if
  77. they are missing. Defaults to True.
  78. """
  79. class_labels_table: wandb.Table | None
  80. if isinstance(class_labels, list) and len(class_labels) > 0:
  81. class_labels_table = wandb.Table(
  82. columns=["label"], data=[[label] for label in class_labels]
  83. )
  84. else:
  85. class_labels_table = None
  86. if indexes is None:
  87. assert targets is not None
  88. local_validation_table = wandb.Table(columns=[], data=[])
  89. if isinstance(targets, dict):
  90. for col_name in targets:
  91. local_validation_table.add_column(col_name, targets[col_name])
  92. else:
  93. local_validation_table.add_column(target_col_name, targets)
  94. if isinstance(inputs, dict):
  95. for col_name in inputs:
  96. local_validation_table.add_column(col_name, inputs[col_name])
  97. else:
  98. local_validation_table.add_column(input_col_name, inputs)
  99. if validation_row_processor is None and infer_missing_processors:
  100. example_input = _make_example(inputs)
  101. example_target = _make_example(targets)
  102. if example_input is not None and example_target is not None:
  103. validation_row_processor = _infer_validation_row_processor(
  104. example_input,
  105. example_target,
  106. class_labels_table,
  107. input_col_name,
  108. target_col_name,
  109. )
  110. if validation_row_processor is not None:
  111. local_validation_table.add_computed_columns(validation_row_processor)
  112. local_validation_artifact = wandb.Artifact(table_name, artifact_type)
  113. local_validation_artifact.add(local_validation_table, "validation_data")
  114. if wandb.run:
  115. wandb.run.use_artifact(local_validation_artifact)
  116. indexes = local_validation_table.get_index()
  117. else:
  118. local_validation_artifact = None
  119. self.class_labels_table = class_labels_table
  120. self.validation_inputs = inputs
  121. self.validation_targets = targets
  122. self.validation_indexes = indexes
  123. self.prediction_row_processor = prediction_row_processor
  124. self.infer_missing_processors = infer_missing_processors
  125. self.local_validation_artifact = local_validation_artifact
  126. self.input_col_name = input_col_name
  127. def make_predictions(self, predict_fn: Callable) -> Sequence | dict[str, Sequence]:
  128. """Produce predictions by passing `validation_inputs` to `predict_fn`.
  129. Args:
  130. predict_fn (Callable): Any function which can accept `validation_inputs` and produce
  131. a list of vectors or dictionary of lists of vectors
  132. Returns:
  133. (Sequence | Dict[str, Sequence]): The returned value of predict_fn
  134. """
  135. return predict_fn(self.validation_inputs)
  136. def log_predictions(
  137. self,
  138. predictions: Sequence | dict[str, Sequence],
  139. prediction_col_name: str = "output",
  140. val_ndx_col_name: str = "val_row",
  141. table_name: str = "validation_predictions",
  142. commit: bool = True,
  143. ) -> wandb.data_types.Table:
  144. """Log a set of predictions.
  145. Intended usage:
  146. vl.log_predictions(vl.make_predictions(self.model.predict))
  147. Args:
  148. predictions (Sequence | Dict[str, Sequence]): A list of prediction vectors or dictionary
  149. of lists of prediction vectors
  150. prediction_col_name (str, optional): the name of the prediction column. Defaults to "output".
  151. val_ndx_col_name (str, optional): The name of the column linking prediction table
  152. to the validation ata table. Defaults to "val_row".
  153. table_name (str, optional): name of the prediction table. Defaults to "validation_predictions".
  154. commit (bool, optional): determines if commit should be called on the logged data. Defaults to False.
  155. """
  156. pred_table = wandb.Table(columns=[], data=[])
  157. if isinstance(predictions, dict):
  158. for col_name in predictions:
  159. pred_table.add_column(col_name, predictions[col_name])
  160. else:
  161. pred_table.add_column(prediction_col_name, predictions)
  162. pred_table.add_column(val_ndx_col_name, self.validation_indexes)
  163. if self.prediction_row_processor is None and self.infer_missing_processors:
  164. example_prediction = _make_example(predictions)
  165. example_input = _make_example(self.validation_inputs)
  166. if example_prediction is not None and example_input is not None:
  167. self.prediction_row_processor = _infer_prediction_row_processor(
  168. example_prediction,
  169. example_input,
  170. self.class_labels_table,
  171. self.input_col_name,
  172. prediction_col_name,
  173. )
  174. if self.prediction_row_processor is not None:
  175. pred_table.add_computed_columns(self.prediction_row_processor)
  176. wandb.log({table_name: pred_table}, commit=commit)
  177. return pred_table
  178. def _make_example(data: Any) -> dict | Sequence | Any | None:
  179. """Used to make an example input, target, or output."""
  180. example: dict | Sequence | Any | None
  181. if isinstance(data, dict):
  182. example = {}
  183. for key in data:
  184. example[key] = data[key][0]
  185. elif hasattr(data, "__len__"):
  186. example = data[0]
  187. else:
  188. example = None
  189. return example
  190. def _get_example_shape(example: Sequence | Any):
  191. """Get the shape of an object if applicable."""
  192. shape = []
  193. if not isinstance(example, str) and hasattr(example, "__len__"):
  194. length = len(example)
  195. shape = [length]
  196. if length > 0:
  197. shape += _get_example_shape(example[0])
  198. return shape
  199. def _bind(lambda_fn: Callable, **closure_kwargs: Any) -> Callable:
  200. """Create a closure around a lambda function by binding `closure_kwargs` to the function."""
  201. def closure(*args: Any, **kwargs: Any) -> Any:
  202. _k = {}
  203. _k.update(kwargs)
  204. _k.update(closure_kwargs)
  205. return lambda_fn(*args, **_k)
  206. return closure
  207. def _infer_single_example_keyed_processor(
  208. example: Sequence | Any,
  209. class_labels_table: wandb.Table | None = None,
  210. possible_base_example: Sequence | Any | None = None,
  211. ) -> dict[str, Callable]:
  212. """Infers a processor from a single example.
  213. Infers a processor from a single example with optional class_labels_table
  214. and base_example. Base example is useful for cases such as segmentation masks
  215. """
  216. shape = _get_example_shape(example)
  217. processors: dict[str, Callable] = {}
  218. if (
  219. class_labels_table is not None
  220. and len(shape) == 1
  221. and shape[0] == len(class_labels_table.data)
  222. ):
  223. np = wandb.util.get_module(
  224. "numpy",
  225. required="Inferring processors require numpy",
  226. )
  227. # Assume these are logits
  228. class_names = class_labels_table.get_column("label")
  229. processors["max_class"] = lambda n, d, p: class_labels_table.index_ref( # type: ignore
  230. np.argmax(d)
  231. )
  232. # TODO: Consider adding back if users ask
  233. # processors["min_class"] = lambda n, d, p: class_labels_table.index_ref( # type: ignore
  234. # np.argmin(d)
  235. # )
  236. values = np.unique(example)
  237. is_one_hot = len(values) == 2 and set(values) == {0, 1}
  238. if not is_one_hot:
  239. processors["score"] = lambda n, d, p: {
  240. class_names[i]: d[i] for i in range(shape[0])
  241. }
  242. elif (
  243. len(shape) == 1
  244. and shape[0] == 1
  245. and (
  246. isinstance(example[0], int)
  247. or (hasattr(example, "tolist") and isinstance(example.tolist()[0], int)) # type: ignore
  248. )
  249. ):
  250. # assume this is a class
  251. if class_labels_table is not None:
  252. processors["class"] = (
  253. lambda n, d, p: class_labels_table.index_ref(d[0])
  254. if d[0] < len(class_labels_table.data)
  255. else d[0]
  256. ) # type: ignore
  257. else:
  258. processors["val"] = lambda n, d, p: d[0]
  259. elif len(shape) == 1:
  260. np = wandb.util.get_module(
  261. "numpy",
  262. required="Inferring processors require numpy",
  263. )
  264. # This could be anything
  265. if shape[0] <= 10:
  266. # if less than 10, fan out the results
  267. # processors["node"] = lambda n, d, p: {i: d[i] for i in range(shape[0])}
  268. processors["node"] = lambda n, d, p: [
  269. d[i].tolist() if hasattr(d[i], "tolist") else d[i]
  270. for i in range(shape[0])
  271. ]
  272. # just report the argmax and argmin
  273. processors["argmax"] = lambda n, d, p: np.argmax(d)
  274. values = np.unique(example)
  275. is_one_hot = len(values) == 2 and set(values) == {0, 1}
  276. if not is_one_hot:
  277. processors["argmin"] = lambda n, d, p: np.argmin(d)
  278. elif len(shape) == 2 and CAN_INFER_IMAGE_AND_VIDEO:
  279. if (
  280. class_labels_table is not None
  281. and possible_base_example is not None
  282. and shape == _get_example_shape(possible_base_example)
  283. ):
  284. # consider this a segmentation mask
  285. processors["image"] = lambda n, d, p: wandb.Image(
  286. p,
  287. masks={
  288. "masks": {
  289. "mask_data": d,
  290. "class_labels": class_labels_table.get_column("label"), # type: ignore
  291. }
  292. },
  293. )
  294. else:
  295. # consider this a 2d image
  296. processors["image"] = lambda n, d, p: wandb.Image(d)
  297. elif len(shape) == 3 and CAN_INFER_IMAGE_AND_VIDEO:
  298. # consider this an image
  299. processors["image"] = lambda n, d, p: wandb.Image(d)
  300. elif len(shape) == 4 and CAN_INFER_IMAGE_AND_VIDEO:
  301. # consider this a video
  302. processors["video"] = lambda n, d, p: wandb.Video(d)
  303. return processors
  304. def _infer_validation_row_processor(
  305. example_input: dict | Sequence,
  306. example_target: dict | Sequence | Any,
  307. class_labels_table: wandb.Table | None = None,
  308. input_col_name: str = "input",
  309. target_col_name: str = "target",
  310. ) -> Callable:
  311. """Infers the composite processor for the validation data."""
  312. single_processors = {}
  313. if isinstance(example_input, dict):
  314. for key in example_input:
  315. key_processors = _infer_single_example_keyed_processor(example_input[key])
  316. for p_key in key_processors:
  317. single_processors[f"{key}:{p_key}"] = _bind(
  318. lambda ndx, row, key_processor, key: key_processor(
  319. ndx,
  320. row[key],
  321. None,
  322. ),
  323. key_processor=key_processors[p_key],
  324. key=key,
  325. )
  326. else:
  327. key = input_col_name
  328. key_processors = _infer_single_example_keyed_processor(example_input)
  329. for p_key in key_processors:
  330. single_processors[f"{key}:{p_key}"] = _bind(
  331. lambda ndx, row, key_processor, key: key_processor(
  332. ndx,
  333. row[key],
  334. None,
  335. ),
  336. key_processor=key_processors[p_key],
  337. key=key,
  338. )
  339. if isinstance(example_target, dict):
  340. for key in example_target:
  341. key_processors = _infer_single_example_keyed_processor(
  342. example_target[key], class_labels_table
  343. )
  344. for p_key in key_processors:
  345. single_processors[f"{key}:{p_key}"] = _bind(
  346. lambda ndx, row, key_processor, key: key_processor(
  347. ndx,
  348. row[key],
  349. None,
  350. ),
  351. key_processor=key_processors[p_key],
  352. key=key,
  353. )
  354. else:
  355. key = target_col_name
  356. key_processors = _infer_single_example_keyed_processor(
  357. example_target,
  358. class_labels_table,
  359. example_input if not isinstance(example_input, dict) else None,
  360. )
  361. for p_key in key_processors:
  362. single_processors[f"{key}:{p_key}"] = _bind(
  363. lambda ndx, row, key_processor, key: key_processor(
  364. ndx,
  365. row[key],
  366. row[input_col_name]
  367. if not isinstance(example_input, dict)
  368. else None,
  369. ),
  370. key_processor=key_processors[p_key],
  371. key=key,
  372. )
  373. def processor(ndx, row):
  374. return {key: single_processors[key](ndx, row) for key in single_processors}
  375. return processor
  376. def _infer_prediction_row_processor(
  377. example_prediction: dict | Sequence,
  378. example_input: dict | Sequence,
  379. class_labels_table: wandb.Table | None = None,
  380. input_col_name: str = "input",
  381. output_col_name: str = "output",
  382. ) -> Callable:
  383. """Infers the composite processor for the prediction output data."""
  384. single_processors = {}
  385. if isinstance(example_prediction, dict):
  386. for key in example_prediction:
  387. key_processors = _infer_single_example_keyed_processor(
  388. example_prediction[key], class_labels_table
  389. )
  390. for p_key in key_processors:
  391. single_processors[f"{key}:{p_key}"] = _bind(
  392. lambda ndx, row, key_processor, key: key_processor(
  393. ndx,
  394. row[key],
  395. None,
  396. ),
  397. key_processor=key_processors[p_key],
  398. key=key,
  399. )
  400. else:
  401. key = output_col_name
  402. key_processors = _infer_single_example_keyed_processor(
  403. example_prediction,
  404. class_labels_table,
  405. example_input if not isinstance(example_input, dict) else None,
  406. )
  407. for p_key in key_processors:
  408. single_processors[f"{key}:{p_key}"] = _bind(
  409. lambda ndx, row, key_processor, key: key_processor(
  410. ndx,
  411. row[key],
  412. ndx.get_row().get("val_row").get_row().get(input_col_name)
  413. if not isinstance(example_input, dict)
  414. else None,
  415. ),
  416. key_processor=key_processors[p_key],
  417. key=key,
  418. )
  419. def processor(ndx, row):
  420. return {key: single_processors[key](ndx, row) for key in single_processors}
  421. return processor