django.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. """
  2. Module is used to infer Django model fields.
  3. """
  4. from inspect import Parameter
  5. from jedi import debug
  6. from jedi.inference.cache import inference_state_function_cache
  7. from jedi.inference.base_value import ValueSet, iterator_to_value_set, ValueWrapper
  8. from jedi.inference.filters import DictFilter, AttributeOverwrite
  9. from jedi.inference.names import NameWrapper, BaseTreeParamName
  10. from jedi.inference.compiled.value import EmptyCompiledName
  11. from jedi.inference.value.instance import TreeInstance
  12. from jedi.inference.value.klass import ClassMixin
  13. from jedi.inference.gradual.base import GenericClass
  14. from jedi.inference.gradual.generics import TupleGenericManager
  15. from jedi.inference.signature import AbstractSignature
  16. mapping = {
  17. 'IntegerField': (None, 'int'),
  18. 'BigIntegerField': (None, 'int'),
  19. 'PositiveIntegerField': (None, 'int'),
  20. 'SmallIntegerField': (None, 'int'),
  21. 'CharField': (None, 'str'),
  22. 'TextField': (None, 'str'),
  23. 'EmailField': (None, 'str'),
  24. 'GenericIPAddressField': (None, 'str'),
  25. 'URLField': (None, 'str'),
  26. 'FloatField': (None, 'float'),
  27. 'BinaryField': (None, 'bytes'),
  28. 'BooleanField': (None, 'bool'),
  29. 'DecimalField': ('decimal', 'Decimal'),
  30. 'TimeField': ('datetime', 'time'),
  31. 'DurationField': ('datetime', 'timedelta'),
  32. 'DateField': ('datetime', 'date'),
  33. 'DateTimeField': ('datetime', 'datetime'),
  34. 'UUIDField': ('uuid', 'UUID'),
  35. }
  36. _FILTER_LIKE_METHODS = ('create', 'filter', 'exclude', 'update', 'get',
  37. 'get_or_create', 'update_or_create')
  38. @inference_state_function_cache()
  39. def _get_deferred_attributes(inference_state):
  40. return inference_state.import_module(
  41. ('django', 'db', 'models', 'query_utils')
  42. ).py__getattribute__('DeferredAttribute').execute_annotation()
  43. def _infer_scalar_field(inference_state, field_name, field_tree_instance, is_instance):
  44. try:
  45. module_name, attribute_name = mapping[field_tree_instance.py__name__()]
  46. except KeyError:
  47. return None
  48. if not is_instance:
  49. return _get_deferred_attributes(inference_state)
  50. if module_name is None:
  51. module = inference_state.builtins_module
  52. else:
  53. module = inference_state.import_module((module_name,))
  54. for attribute in module.py__getattribute__(attribute_name):
  55. return attribute.execute_with_values()
  56. @iterator_to_value_set
  57. def _get_foreign_key_values(cls, field_tree_instance):
  58. if isinstance(field_tree_instance, TreeInstance):
  59. # TODO private access..
  60. argument_iterator = field_tree_instance._arguments.unpack()
  61. key, lazy_values = next(argument_iterator, (None, None))
  62. if key is None and lazy_values is not None:
  63. for value in lazy_values.infer():
  64. if value.py__name__() == 'str':
  65. foreign_key_class_name = value.get_safe_value()
  66. module = cls.get_root_context()
  67. for v in module.py__getattribute__(foreign_key_class_name):
  68. if v.is_class():
  69. yield v
  70. elif value.is_class():
  71. yield value
  72. def _infer_field(cls, field_name, is_instance):
  73. inference_state = cls.inference_state
  74. result = field_name.infer()
  75. for field_tree_instance in result:
  76. scalar_field = _infer_scalar_field(
  77. inference_state, field_name, field_tree_instance, is_instance)
  78. if scalar_field is not None:
  79. return scalar_field
  80. name = field_tree_instance.py__name__()
  81. is_many_to_many = name == 'ManyToManyField'
  82. if name in ('ForeignKey', 'OneToOneField') or is_many_to_many:
  83. if not is_instance:
  84. return _get_deferred_attributes(inference_state)
  85. values = _get_foreign_key_values(cls, field_tree_instance)
  86. if is_many_to_many:
  87. return ValueSet(filter(None, [
  88. _create_manager_for(v, 'RelatedManager') for v in values
  89. ]))
  90. else:
  91. return values.execute_with_values()
  92. debug.dbg('django plugin: fail to infer `%s` from class `%s`',
  93. field_name.string_name, cls.py__name__())
  94. return result
  95. class DjangoModelName(NameWrapper):
  96. def __init__(self, cls, name, is_instance):
  97. super().__init__(name)
  98. self._cls = cls
  99. self._is_instance = is_instance
  100. def infer(self):
  101. return _infer_field(self._cls, self._wrapped_name, self._is_instance)
  102. def _create_manager_for(cls, manager_cls='BaseManager'):
  103. managers = cls.inference_state.import_module(
  104. ('django', 'db', 'models', 'manager')
  105. ).py__getattribute__(manager_cls)
  106. for m in managers:
  107. if m.is_class_mixin():
  108. generics_manager = TupleGenericManager((ValueSet([cls]),))
  109. for c in GenericClass(m, generics_manager).execute_annotation():
  110. return c
  111. return None
  112. def _new_dict_filter(cls, is_instance):
  113. filters = list(cls.get_filters(
  114. is_instance=is_instance,
  115. include_metaclasses=False,
  116. include_type_when_class=False)
  117. )
  118. dct = {
  119. name.string_name: DjangoModelName(cls, name, is_instance)
  120. for filter_ in reversed(filters)
  121. for name in filter_.values()
  122. }
  123. if is_instance:
  124. # Replace the objects with a name that amounts to nothing when accessed
  125. # in an instance. This is not perfect and still completes "objects" in
  126. # that case, but it at least not inferes stuff like `.objects.filter`.
  127. # It would be nicer to do that in a better way, so that it also doesn't
  128. # show up in completions, but it's probably just not worth doing that
  129. # for the extra amount of work.
  130. dct['objects'] = EmptyCompiledName(cls.inference_state, 'objects')
  131. return DictFilter(dct)
  132. def is_django_model_base(value):
  133. return value.py__name__() == 'ModelBase' \
  134. and value.get_root_context().py__name__() == 'django.db.models.base'
  135. def get_metaclass_filters(func):
  136. def wrapper(cls, metaclasses, is_instance):
  137. for metaclass in metaclasses:
  138. if is_django_model_base(metaclass):
  139. return [_new_dict_filter(cls, is_instance)]
  140. return func(cls, metaclasses, is_instance)
  141. return wrapper
  142. def tree_name_to_values(func):
  143. def wrapper(inference_state, context, tree_name):
  144. result = func(inference_state, context, tree_name)
  145. if tree_name.value in _FILTER_LIKE_METHODS:
  146. # Here we try to overwrite stuff like User.objects.filter. We need
  147. # this to make sure that keyword param completion works on these
  148. # kind of methods.
  149. for v in result:
  150. if v.get_qualified_names() == ('_BaseQuerySet', tree_name.value) \
  151. and v.parent_context.is_module() \
  152. and v.parent_context.py__name__() == 'django.db.models.query':
  153. qs = context.get_value()
  154. generics = qs.get_generics()
  155. if len(generics) >= 1:
  156. return ValueSet(QuerySetMethodWrapper(v, model)
  157. for model in generics[0])
  158. elif tree_name.value == 'BaseManager' and context.is_module() \
  159. and context.py__name__() == 'django.db.models.manager':
  160. return ValueSet(ManagerWrapper(r) for r in result)
  161. elif tree_name.value == 'Field' and context.is_module() \
  162. and context.py__name__() == 'django.db.models.fields':
  163. return ValueSet(FieldWrapper(r) for r in result)
  164. return result
  165. return wrapper
  166. def _find_fields(cls):
  167. for name in _new_dict_filter(cls, is_instance=False).values():
  168. for value in name.infer():
  169. if value.name.get_qualified_names(include_module_names=True) \
  170. == ('django', 'db', 'models', 'query_utils', 'DeferredAttribute'):
  171. yield name
  172. def _get_signatures(cls):
  173. return [DjangoModelSignature(cls, field_names=list(_find_fields(cls)))]
  174. def get_metaclass_signatures(func):
  175. def wrapper(cls, metaclasses):
  176. for metaclass in metaclasses:
  177. if is_django_model_base(metaclass):
  178. return _get_signatures(cls)
  179. return func(cls, metaclass)
  180. return wrapper
  181. class ManagerWrapper(ValueWrapper):
  182. def py__getitem__(self, index_value_set, contextualized_node):
  183. return ValueSet(
  184. GenericManagerWrapper(generic)
  185. for generic in self._wrapped_value.py__getitem__(
  186. index_value_set, contextualized_node)
  187. )
  188. class GenericManagerWrapper(AttributeOverwrite, ClassMixin):
  189. def py__get__on_class(self, calling_instance, instance, class_value):
  190. return calling_instance.class_value.with_generics(
  191. (ValueSet({class_value}),)
  192. ).py__call__(calling_instance._arguments)
  193. def with_generics(self, generics_tuple):
  194. return self._wrapped_value.with_generics(generics_tuple)
  195. class FieldWrapper(ValueWrapper):
  196. def py__getitem__(self, index_value_set, contextualized_node):
  197. return ValueSet(
  198. GenericFieldWrapper(generic)
  199. for generic in self._wrapped_value.py__getitem__(
  200. index_value_set, contextualized_node)
  201. )
  202. class GenericFieldWrapper(AttributeOverwrite, ClassMixin):
  203. def py__get__on_class(self, calling_instance, instance, class_value):
  204. # This is mostly an optimization to avoid Jedi aborting inference,
  205. # because of too many function executions of Field.__get__.
  206. return ValueSet({calling_instance})
  207. class DjangoModelSignature(AbstractSignature):
  208. def __init__(self, value, field_names):
  209. super().__init__(value)
  210. self._field_names = field_names
  211. def get_param_names(self, resolve_stars=False):
  212. return [DjangoParamName(name) for name in self._field_names]
  213. class DjangoParamName(BaseTreeParamName):
  214. def __init__(self, field_name):
  215. super().__init__(field_name.parent_context, field_name.tree_name)
  216. self._field_name = field_name
  217. def get_kind(self):
  218. return Parameter.KEYWORD_ONLY
  219. def infer(self):
  220. return self._field_name.infer()
  221. class QuerySetMethodWrapper(ValueWrapper):
  222. def __init__(self, method, model_cls):
  223. super().__init__(method)
  224. self._model_cls = model_cls
  225. def py__get__(self, instance, class_value):
  226. return ValueSet({QuerySetBoundMethodWrapper(v, self._model_cls)
  227. for v in self._wrapped_value.py__get__(instance, class_value)})
  228. class QuerySetBoundMethodWrapper(ValueWrapper):
  229. def __init__(self, method, model_cls):
  230. super().__init__(method)
  231. self._model_cls = model_cls
  232. def get_signatures(self):
  233. return _get_signatures(self._model_cls)