json_util.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. from __future__ import annotations
  2. import json
  3. import logging
  4. import os
  5. from typing import Any
  6. from wandb import env
  7. logger = logging.getLogger(__name__)
  8. try:
  9. import wandb.vendor.wandb_orjson as orjson
  10. # Allow disabling orjson for compatibility and safety.
  11. if not os.environ.get(env.DISABLE_ORJSON):
  12. def dumps(obj: Any, **kwargs: Any) -> str:
  13. """Wrapper for <json|orjson>.dumps."""
  14. cls = kwargs.pop("cls", None)
  15. try:
  16. _kwargs = kwargs.copy()
  17. if cls:
  18. _kwargs["default"] = cls.default
  19. encoded = orjson.dumps(
  20. obj,
  21. option=orjson.OPT_NON_STR_KEYS | orjson.OPT_FAIL_ON_INVALID_FLOAT, # type: ignore[attr-defined]
  22. **_kwargs,
  23. ).decode()
  24. except Exception:
  25. logger.exception("Error using orjson.dumps")
  26. if cls:
  27. kwargs["cls"] = cls
  28. encoded = json.dumps(obj, **kwargs)
  29. return encoded # type: ignore[no-any-return]
  30. def dump(obj: Any, fp: Any, **kwargs: Any) -> None:
  31. """Wrapper for <json|orjson>.dump."""
  32. cls = kwargs.pop("cls", None)
  33. try:
  34. _kwargs = kwargs.copy()
  35. if cls:
  36. _kwargs["default"] = cls.default
  37. encoded = orjson.dumps(
  38. obj,
  39. option=orjson.OPT_NON_STR_KEYS | orjson.OPT_FAIL_ON_INVALID_FLOAT, # type: ignore[attr-defined]
  40. **_kwargs,
  41. )
  42. fp.write(encoded.decode())
  43. except Exception:
  44. logger.exception("Error using orjson.dump")
  45. if cls:
  46. kwargs["cls"] = cls
  47. json.dump(obj, fp, **kwargs)
  48. def loads(obj: str | bytes) -> Any:
  49. """Wrapper for orjson.loads."""
  50. try:
  51. decoded = orjson.loads(obj)
  52. except Exception:
  53. logger.exception("Error using orjson.loads")
  54. decoded = json.loads(obj)
  55. return decoded
  56. def load(fp: Any) -> Any:
  57. """Wrapper for orjson.load."""
  58. try:
  59. decoded = orjson.loads(fp.read())
  60. except Exception:
  61. logger.exception("Error using orjson.load")
  62. decoded = json.load(fp)
  63. return decoded
  64. else:
  65. from json import dump, dumps, load, loads # type: ignore[assignment]
  66. except ImportError:
  67. from json import dump, dumps, load, loads # type: ignore[assignment] # noqa: F401