memoization.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from functools import wraps
  2. def recurrence_memo(initial):
  3. """
  4. Memo decorator for sequences defined by recurrence
  5. Examples
  6. ========
  7. >>> from sympy.utilities.memoization import recurrence_memo
  8. >>> @recurrence_memo([1]) # 0! = 1
  9. ... def factorial(n, prev):
  10. ... return n * prev[-1]
  11. >>> factorial(4)
  12. 24
  13. >>> factorial(3) # use cache values
  14. 6
  15. >>> factorial.cache_length() # cache length can be obtained
  16. 5
  17. >>> factorial.fetch_item(slice(2, 4))
  18. [2, 6]
  19. """
  20. cache = initial
  21. def decorator(f):
  22. @wraps(f)
  23. def g(n):
  24. L = len(cache)
  25. if n < L:
  26. return cache[n]
  27. for i in range(L, n + 1):
  28. cache.append(f(i, cache))
  29. return cache[-1]
  30. g.cache_length = lambda: len(cache)
  31. g.fetch_item = lambda x: cache[x]
  32. return g
  33. return decorator
  34. def assoc_recurrence_memo(base_seq):
  35. """
  36. Memo decorator for associated sequences defined by recurrence starting from base
  37. base_seq(n) -- callable to get base sequence elements
  38. XXX works only for Pn0 = base_seq(0) cases
  39. XXX works only for m <= n cases
  40. """
  41. cache = []
  42. def decorator(f):
  43. @wraps(f)
  44. def g(n, m):
  45. L = len(cache)
  46. if n < L:
  47. return cache[n][m]
  48. for i in range(L, n + 1):
  49. # get base sequence
  50. F_i0 = base_seq(i)
  51. F_i_cache = [F_i0]
  52. cache.append(F_i_cache)
  53. # XXX only works for m <= n cases
  54. # generate assoc sequence
  55. for j in range(1, i + 1):
  56. F_ij = f(i, j, cache)
  57. F_i_cache.append(F_ij)
  58. return cache[n][m]
  59. return g
  60. return decorator