json_compat.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. """
  2. Common validator wrapper to provide a uniform usage of other schema validation
  3. libraries.
  4. """
  5. # Copyright (c) Jupyter Development Team.
  6. # Distributed under the terms of the Modified BSD License.
  7. from __future__ import annotations
  8. import os
  9. import fastjsonschema
  10. import jsonschema
  11. from fastjsonschema import JsonSchemaException as _JsonSchemaException
  12. from jsonschema import Draft4Validator as _JsonSchemaValidator
  13. from jsonschema.exceptions import ErrorTree, ValidationError
  14. __all__ = [
  15. "ValidationError",
  16. "JsonSchemaValidator",
  17. "FastJsonSchemaValidator",
  18. "get_current_validator",
  19. "VALIDATORS",
  20. ]
  21. class JsonSchemaValidator:
  22. """A json schema validator."""
  23. name = "jsonschema"
  24. def __init__(self, schema):
  25. """Initialize the validator."""
  26. self._schema = schema
  27. self._default_validator = _JsonSchemaValidator(schema) # Default
  28. self._validator = self._default_validator
  29. def validate(self, data):
  30. """Validate incoming data."""
  31. self._default_validator.validate(data)
  32. def iter_errors(self, data, schema=None):
  33. """Iterate over errors in incoming data."""
  34. if schema is None:
  35. return self._default_validator.iter_errors(data)
  36. if hasattr(self._default_validator, "evolve"):
  37. return self._default_validator.evolve(schema=schema).iter_errors(data)
  38. return self._default_validator.iter_errors(data, schema)
  39. def error_tree(self, errors):
  40. """Create an error tree for the errors."""
  41. return ErrorTree(errors=errors)
  42. class FastJsonSchemaValidator(JsonSchemaValidator):
  43. """A schema validator using fastjsonschema."""
  44. name = "fastjsonschema"
  45. def __init__(self, schema):
  46. """Initialize the validator."""
  47. super().__init__(schema)
  48. self._validator = fastjsonschema.compile(schema)
  49. def validate(self, data):
  50. """Validate incoming data."""
  51. try:
  52. self._validator(data)
  53. except _JsonSchemaException as error:
  54. raise ValidationError(str(error), schema_path=error.path) from error
  55. def iter_errors(self, data, schema=None):
  56. """Iterate over errors in incoming data."""
  57. if schema is not None:
  58. return super().iter_errors(data, schema)
  59. errors = []
  60. validate_func = self._validator
  61. try:
  62. validate_func(data)
  63. except _JsonSchemaException as error:
  64. errors = [ValidationError(str(error), schema_path=error.path)]
  65. return errors
  66. def error_tree(self, errors):
  67. """Create an error tree for the errors."""
  68. # fastjsonschema's exceptions don't contain the same information that the jsonschema ValidationErrors
  69. # do. This method is primarily used for introspecting metadata schema failures so that we can strip
  70. # them if asked to do so in `nbformat.validate`.
  71. # Another way forward for compatibility: we could distill both validator errors into a custom collection
  72. # for this data. Since implementation details of ValidationError is used elsewhere, we would probably
  73. # just use this data for schema introspection.
  74. msg = "JSON schema error introspection not enabled for fastjsonschema"
  75. raise NotImplementedError(msg)
  76. _VALIDATOR_MAP = [
  77. ("fastjsonschema", fastjsonschema, FastJsonSchemaValidator),
  78. ("jsonschema", jsonschema, JsonSchemaValidator),
  79. ]
  80. VALIDATORS = [item[0] for item in _VALIDATOR_MAP]
  81. def _validator_for_name(validator_name):
  82. if validator_name not in VALIDATORS:
  83. msg = f"Invalid validator '{validator_name}' value!\nValid values are: {VALIDATORS}"
  84. raise ValueError(msg)
  85. for name, module, validator_cls in _VALIDATOR_MAP:
  86. if module and validator_name == name:
  87. return validator_cls
  88. # we always return something.
  89. msg = f"Missing validator for {validator_name!r}"
  90. raise ValueError(msg)
  91. def get_current_validator():
  92. """
  93. Return the default validator based on the value of an environment variable.
  94. """
  95. validator_name = os.environ.get("NBFORMAT_VALIDATOR", "fastjsonschema")
  96. return _validator_for_name(validator_name)