| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- # ruff: noqa
- from typing import Optional
- import packaging.version
- # Pydantic is a dependency of `ray["default"]` but not the minimal installation,
- # so handle the case where it isn't installed.
- try:
- import pydantic
- PYDANTIC_INSTALLED = True
- except ImportError:
- pydantic = None
- PYDANTIC_INSTALLED = False
- PYDANTIC_MAJOR_VERSION: Optional[int] = (
- packaging.version.parse(pydantic.__version__).major
- if PYDANTIC_INSTALLED and hasattr(pydantic, "__version__")
- else None
- )
- if not PYDANTIC_INSTALLED:
- IS_PYDANTIC_2 = False
- BaseModel = None
- Extra = None
- Field = None
- NonNegativeFloat = None
- NonNegativeInt = None
- PositiveFloat = None
- PositiveInt = None
- PrivateAttr = None
- StrictInt = None
- ValidationError = None
- root_validator = None
- validator = None
- is_subclass_of_base_model = lambda obj: False
- # In pydantic <1.9.0, __version__ attribute is missing, issue ref:
- # https://github.com/pydantic/pydantic/issues/2572, so we need to check
- # the existence prior to comparison.
- elif not hasattr(pydantic, "__version__") or packaging.version.parse(
- pydantic.__version__
- ) < packaging.version.parse("2.0"):
- IS_PYDANTIC_2 = False
- from pydantic import (
- BaseModel,
- Extra,
- Field,
- NonNegativeFloat,
- NonNegativeInt,
- PositiveFloat,
- PositiveInt,
- PrivateAttr,
- StrictInt,
- ValidationError,
- root_validator,
- validator,
- )
- def is_subclass_of_base_model(obj):
- return issubclass(obj, BaseModel)
- else:
- IS_PYDANTIC_2 = True
- from pydantic.v1 import (
- BaseModel,
- Extra,
- Field,
- NonNegativeFloat,
- NonNegativeInt,
- PositiveFloat,
- PositiveInt,
- PrivateAttr,
- StrictInt,
- ValidationError,
- root_validator,
- validator,
- )
- def is_subclass_of_base_model(obj):
- from pydantic import BaseModel as BaseModelV2
- from pydantic.v1 import BaseModel as BaseModelV1
- return issubclass(obj, BaseModelV1) or issubclass(obj, BaseModelV2)
- def register_pydantic_serializers(serialization_context):
- if not PYDANTIC_INSTALLED:
- return
- if IS_PYDANTIC_2:
- # TODO(edoakes): compare against the version that has the fixes.
- from pydantic.v1.fields import ModelField
- else:
- from pydantic.fields import ModelField
- # Pydantic's Cython validators are not serializable.
- # https://github.com/cloudpipe/cloudpickle/issues/408
- serialization_context._register_cloudpickle_serializer(
- ModelField,
- custom_serializer=lambda o: {
- "name": o.name,
- # outer_type_ is the original type for ModelFields,
- # while type_ can be updated later with the nested type
- # like int for List[int].
- "type_": o.outer_type_,
- "class_validators": o.class_validators,
- "model_config": o.model_config,
- "default": o.default,
- "default_factory": o.default_factory,
- "required": o.required,
- "alias": o.alias,
- "field_info": o.field_info,
- },
- custom_deserializer=lambda kwargs: ModelField(**kwargs),
- )
|