| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296 |
- """
- Module is used to infer Django model fields.
- """
- from inspect import Parameter
- from jedi import debug
- from jedi.inference.cache import inference_state_function_cache
- from jedi.inference.base_value import ValueSet, iterator_to_value_set, ValueWrapper
- from jedi.inference.filters import DictFilter, AttributeOverwrite
- from jedi.inference.names import NameWrapper, BaseTreeParamName
- from jedi.inference.compiled.value import EmptyCompiledName
- from jedi.inference.value.instance import TreeInstance
- from jedi.inference.value.klass import ClassMixin
- from jedi.inference.gradual.base import GenericClass
- from jedi.inference.gradual.generics import TupleGenericManager
- from jedi.inference.signature import AbstractSignature
- mapping = {
- 'IntegerField': (None, 'int'),
- 'BigIntegerField': (None, 'int'),
- 'PositiveIntegerField': (None, 'int'),
- 'SmallIntegerField': (None, 'int'),
- 'CharField': (None, 'str'),
- 'TextField': (None, 'str'),
- 'EmailField': (None, 'str'),
- 'GenericIPAddressField': (None, 'str'),
- 'URLField': (None, 'str'),
- 'FloatField': (None, 'float'),
- 'BinaryField': (None, 'bytes'),
- 'BooleanField': (None, 'bool'),
- 'DecimalField': ('decimal', 'Decimal'),
- 'TimeField': ('datetime', 'time'),
- 'DurationField': ('datetime', 'timedelta'),
- 'DateField': ('datetime', 'date'),
- 'DateTimeField': ('datetime', 'datetime'),
- 'UUIDField': ('uuid', 'UUID'),
- }
- _FILTER_LIKE_METHODS = ('create', 'filter', 'exclude', 'update', 'get',
- 'get_or_create', 'update_or_create')
- @inference_state_function_cache()
- def _get_deferred_attributes(inference_state):
- return inference_state.import_module(
- ('django', 'db', 'models', 'query_utils')
- ).py__getattribute__('DeferredAttribute').execute_annotation()
- def _infer_scalar_field(inference_state, field_name, field_tree_instance, is_instance):
- try:
- module_name, attribute_name = mapping[field_tree_instance.py__name__()]
- except KeyError:
- return None
- if not is_instance:
- return _get_deferred_attributes(inference_state)
- if module_name is None:
- module = inference_state.builtins_module
- else:
- module = inference_state.import_module((module_name,))
- for attribute in module.py__getattribute__(attribute_name):
- return attribute.execute_with_values()
- @iterator_to_value_set
- def _get_foreign_key_values(cls, field_tree_instance):
- if isinstance(field_tree_instance, TreeInstance):
- # TODO private access..
- argument_iterator = field_tree_instance._arguments.unpack()
- key, lazy_values = next(argument_iterator, (None, None))
- if key is None and lazy_values is not None:
- for value in lazy_values.infer():
- if value.py__name__() == 'str':
- foreign_key_class_name = value.get_safe_value()
- module = cls.get_root_context()
- for v in module.py__getattribute__(foreign_key_class_name):
- if v.is_class():
- yield v
- elif value.is_class():
- yield value
- def _infer_field(cls, field_name, is_instance):
- inference_state = cls.inference_state
- result = field_name.infer()
- for field_tree_instance in result:
- scalar_field = _infer_scalar_field(
- inference_state, field_name, field_tree_instance, is_instance)
- if scalar_field is not None:
- return scalar_field
- name = field_tree_instance.py__name__()
- is_many_to_many = name == 'ManyToManyField'
- if name in ('ForeignKey', 'OneToOneField') or is_many_to_many:
- if not is_instance:
- return _get_deferred_attributes(inference_state)
- values = _get_foreign_key_values(cls, field_tree_instance)
- if is_many_to_many:
- return ValueSet(filter(None, [
- _create_manager_for(v, 'RelatedManager') for v in values
- ]))
- else:
- return values.execute_with_values()
- debug.dbg('django plugin: fail to infer `%s` from class `%s`',
- field_name.string_name, cls.py__name__())
- return result
- class DjangoModelName(NameWrapper):
- def __init__(self, cls, name, is_instance):
- super().__init__(name)
- self._cls = cls
- self._is_instance = is_instance
- def infer(self):
- return _infer_field(self._cls, self._wrapped_name, self._is_instance)
- def _create_manager_for(cls, manager_cls='BaseManager'):
- managers = cls.inference_state.import_module(
- ('django', 'db', 'models', 'manager')
- ).py__getattribute__(manager_cls)
- for m in managers:
- if m.is_class_mixin():
- generics_manager = TupleGenericManager((ValueSet([cls]),))
- for c in GenericClass(m, generics_manager).execute_annotation():
- return c
- return None
- def _new_dict_filter(cls, is_instance):
- filters = list(cls.get_filters(
- is_instance=is_instance,
- include_metaclasses=False,
- include_type_when_class=False)
- )
- dct = {
- name.string_name: DjangoModelName(cls, name, is_instance)
- for filter_ in reversed(filters)
- for name in filter_.values()
- }
- if is_instance:
- # Replace the objects with a name that amounts to nothing when accessed
- # in an instance. This is not perfect and still completes "objects" in
- # that case, but it at least not inferes stuff like `.objects.filter`.
- # It would be nicer to do that in a better way, so that it also doesn't
- # show up in completions, but it's probably just not worth doing that
- # for the extra amount of work.
- dct['objects'] = EmptyCompiledName(cls.inference_state, 'objects')
- return DictFilter(dct)
- def is_django_model_base(value):
- return value.py__name__() == 'ModelBase' \
- and value.get_root_context().py__name__() == 'django.db.models.base'
- def get_metaclass_filters(func):
- def wrapper(cls, metaclasses, is_instance):
- for metaclass in metaclasses:
- if is_django_model_base(metaclass):
- return [_new_dict_filter(cls, is_instance)]
- return func(cls, metaclasses, is_instance)
- return wrapper
- def tree_name_to_values(func):
- def wrapper(inference_state, context, tree_name):
- result = func(inference_state, context, tree_name)
- if tree_name.value in _FILTER_LIKE_METHODS:
- # Here we try to overwrite stuff like User.objects.filter. We need
- # this to make sure that keyword param completion works on these
- # kind of methods.
- for v in result:
- if v.get_qualified_names() == ('_BaseQuerySet', tree_name.value) \
- and v.parent_context.is_module() \
- and v.parent_context.py__name__() == 'django.db.models.query':
- qs = context.get_value()
- generics = qs.get_generics()
- if len(generics) >= 1:
- return ValueSet(QuerySetMethodWrapper(v, model)
- for model in generics[0])
- elif tree_name.value == 'BaseManager' and context.is_module() \
- and context.py__name__() == 'django.db.models.manager':
- return ValueSet(ManagerWrapper(r) for r in result)
- elif tree_name.value == 'Field' and context.is_module() \
- and context.py__name__() == 'django.db.models.fields':
- return ValueSet(FieldWrapper(r) for r in result)
- return result
- return wrapper
- def _find_fields(cls):
- for name in _new_dict_filter(cls, is_instance=False).values():
- for value in name.infer():
- if value.name.get_qualified_names(include_module_names=True) \
- == ('django', 'db', 'models', 'query_utils', 'DeferredAttribute'):
- yield name
- def _get_signatures(cls):
- return [DjangoModelSignature(cls, field_names=list(_find_fields(cls)))]
- def get_metaclass_signatures(func):
- def wrapper(cls, metaclasses):
- for metaclass in metaclasses:
- if is_django_model_base(metaclass):
- return _get_signatures(cls)
- return func(cls, metaclass)
- return wrapper
- class ManagerWrapper(ValueWrapper):
- def py__getitem__(self, index_value_set, contextualized_node):
- return ValueSet(
- GenericManagerWrapper(generic)
- for generic in self._wrapped_value.py__getitem__(
- index_value_set, contextualized_node)
- )
- class GenericManagerWrapper(AttributeOverwrite, ClassMixin):
- def py__get__on_class(self, calling_instance, instance, class_value):
- return calling_instance.class_value.with_generics(
- (ValueSet({class_value}),)
- ).py__call__(calling_instance._arguments)
- def with_generics(self, generics_tuple):
- return self._wrapped_value.with_generics(generics_tuple)
- class FieldWrapper(ValueWrapper):
- def py__getitem__(self, index_value_set, contextualized_node):
- return ValueSet(
- GenericFieldWrapper(generic)
- for generic in self._wrapped_value.py__getitem__(
- index_value_set, contextualized_node)
- )
- class GenericFieldWrapper(AttributeOverwrite, ClassMixin):
- def py__get__on_class(self, calling_instance, instance, class_value):
- # This is mostly an optimization to avoid Jedi aborting inference,
- # because of too many function executions of Field.__get__.
- return ValueSet({calling_instance})
- class DjangoModelSignature(AbstractSignature):
- def __init__(self, value, field_names):
- super().__init__(value)
- self._field_names = field_names
- def get_param_names(self, resolve_stars=False):
- return [DjangoParamName(name) for name in self._field_names]
- class DjangoParamName(BaseTreeParamName):
- def __init__(self, field_name):
- super().__init__(field_name.parent_context, field_name.tree_name)
- self._field_name = field_name
- def get_kind(self):
- return Parameter.KEYWORD_ONLY
- def infer(self):
- return self._field_name.infer()
- class QuerySetMethodWrapper(ValueWrapper):
- def __init__(self, method, model_cls):
- super().__init__(method)
- self._model_cls = model_cls
- def py__get__(self, instance, class_value):
- return ValueSet({QuerySetBoundMethodWrapper(v, self._model_cls)
- for v in self._wrapped_value.py__get__(instance, class_value)})
- class QuerySetBoundMethodWrapper(ValueWrapper):
- def __init__(self, method, model_cls):
- super().__init__(method)
- self._model_cls = model_cls
- def get_signatures(self):
- return _get_signatures(self._model_cls)
|