pydantic_compat.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # ruff: noqa
  2. from typing import Optional
  3. import packaging.version
  4. # Pydantic is a dependency of `ray["default"]` but not the minimal installation,
  5. # so handle the case where it isn't installed.
  6. try:
  7. import pydantic
  8. PYDANTIC_INSTALLED = True
  9. except ImportError:
  10. pydantic = None
  11. PYDANTIC_INSTALLED = False
  12. PYDANTIC_MAJOR_VERSION: Optional[int] = (
  13. packaging.version.parse(pydantic.__version__).major
  14. if PYDANTIC_INSTALLED and hasattr(pydantic, "__version__")
  15. else None
  16. )
  17. if not PYDANTIC_INSTALLED:
  18. IS_PYDANTIC_2 = False
  19. BaseModel = None
  20. Extra = None
  21. Field = None
  22. NonNegativeFloat = None
  23. NonNegativeInt = None
  24. PositiveFloat = None
  25. PositiveInt = None
  26. PrivateAttr = None
  27. StrictInt = None
  28. ValidationError = None
  29. root_validator = None
  30. validator = None
  31. is_subclass_of_base_model = lambda obj: False
  32. # In pydantic <1.9.0, __version__ attribute is missing, issue ref:
  33. # https://github.com/pydantic/pydantic/issues/2572, so we need to check
  34. # the existence prior to comparison.
  35. elif not hasattr(pydantic, "__version__") or packaging.version.parse(
  36. pydantic.__version__
  37. ) < packaging.version.parse("2.0"):
  38. IS_PYDANTIC_2 = False
  39. from pydantic import (
  40. BaseModel,
  41. Extra,
  42. Field,
  43. NonNegativeFloat,
  44. NonNegativeInt,
  45. PositiveFloat,
  46. PositiveInt,
  47. PrivateAttr,
  48. StrictInt,
  49. ValidationError,
  50. root_validator,
  51. validator,
  52. )
  53. def is_subclass_of_base_model(obj):
  54. return issubclass(obj, BaseModel)
  55. else:
  56. IS_PYDANTIC_2 = True
  57. from pydantic.v1 import (
  58. BaseModel,
  59. Extra,
  60. Field,
  61. NonNegativeFloat,
  62. NonNegativeInt,
  63. PositiveFloat,
  64. PositiveInt,
  65. PrivateAttr,
  66. StrictInt,
  67. ValidationError,
  68. root_validator,
  69. validator,
  70. )
  71. def is_subclass_of_base_model(obj):
  72. from pydantic import BaseModel as BaseModelV2
  73. from pydantic.v1 import BaseModel as BaseModelV1
  74. return issubclass(obj, BaseModelV1) or issubclass(obj, BaseModelV2)
  75. def register_pydantic_serializers(serialization_context):
  76. if not PYDANTIC_INSTALLED:
  77. return
  78. if IS_PYDANTIC_2:
  79. # TODO(edoakes): compare against the version that has the fixes.
  80. from pydantic.v1.fields import ModelField
  81. else:
  82. from pydantic.fields import ModelField
  83. # Pydantic's Cython validators are not serializable.
  84. # https://github.com/cloudpipe/cloudpickle/issues/408
  85. serialization_context._register_cloudpickle_serializer(
  86. ModelField,
  87. custom_serializer=lambda o: {
  88. "name": o.name,
  89. # outer_type_ is the original type for ModelFields,
  90. # while type_ can be updated later with the nested type
  91. # like int for List[int].
  92. "type_": o.outer_type_,
  93. "class_validators": o.class_validators,
  94. "model_config": o.model_config,
  95. "default": o.default,
  96. "default_factory": o.default_factory,
  97. "required": o.required,
  98. "alias": o.alias,
  99. "field_info": o.field_info,
  100. },
  101. custom_deserializer=lambda kwargs: ModelField(**kwargs),
  102. )