utils.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. __docformat__ = "restructuredtext en"
  2. __all__ = []
  3. from numpy import asanyarray, asarray, array, zeros
  4. from scipy.sparse.linalg._interface import aslinearoperator, LinearOperator, \
  5. IdentityOperator
  6. _coerce_rules = {('f','f'):'f', ('f','d'):'d', ('f','F'):'F',
  7. ('f','D'):'D', ('d','f'):'d', ('d','d'):'d',
  8. ('d','F'):'D', ('d','D'):'D', ('F','f'):'F',
  9. ('F','d'):'D', ('F','F'):'F', ('F','D'):'D',
  10. ('D','f'):'D', ('D','d'):'D', ('D','F'):'D',
  11. ('D','D'):'D'}
  12. def coerce(x,y):
  13. if x not in 'fdFD':
  14. x = 'd'
  15. if y not in 'fdFD':
  16. y = 'd'
  17. return _coerce_rules[x,y]
  18. def id(x):
  19. return x
  20. def make_system(A, M, x0, b):
  21. """Make a linear system Ax=b
  22. Parameters
  23. ----------
  24. A : LinearOperator
  25. sparse or dense matrix (or any valid input to aslinearoperator)
  26. M : {LinearOperator, Nones}
  27. preconditioner
  28. sparse or dense matrix (or any valid input to aslinearoperator)
  29. x0 : {array_like, str, None}
  30. initial guess to iterative method.
  31. ``x0 = 'Mb'`` means using the nonzero initial guess ``M @ b``.
  32. Default is `None`, which means using the zero initial guess.
  33. b : array_like
  34. right hand side
  35. Returns
  36. -------
  37. (A, M, x, b)
  38. A : LinearOperator
  39. matrix of the linear system
  40. M : LinearOperator
  41. preconditioner
  42. x : rank 1 ndarray
  43. initial guess
  44. b : rank 1 ndarray
  45. right hand side
  46. """
  47. A_ = A
  48. A = aslinearoperator(A)
  49. if A.shape[0] != A.shape[1]:
  50. raise ValueError(f'expected square matrix, but got shape={(A.shape,)}')
  51. N = A.shape[0]
  52. b = asanyarray(b)
  53. if not (b.shape == (N,1) or b.shape == (N,)):
  54. raise ValueError(f'shapes of A {A.shape} and b {b.shape} are '
  55. 'incompatible')
  56. if b.dtype.char not in 'fdFD':
  57. b = b.astype('d') # upcast non-FP types to double
  58. if hasattr(A,'dtype'):
  59. xtype = A.dtype.char
  60. else:
  61. xtype = A.matvec(b).dtype.char
  62. xtype = coerce(xtype, b.dtype.char)
  63. b = asarray(b,dtype=xtype) # make b the same type as x
  64. b = b.ravel()
  65. # process preconditioner
  66. if M is None:
  67. if hasattr(A_,'psolve'):
  68. psolve = A_.psolve
  69. else:
  70. psolve = id
  71. if hasattr(A_,'rpsolve'):
  72. rpsolve = A_.rpsolve
  73. else:
  74. rpsolve = id
  75. if psolve is id and rpsolve is id:
  76. M = IdentityOperator(shape=A.shape, dtype=A.dtype)
  77. else:
  78. M = LinearOperator(A.shape, matvec=psolve, rmatvec=rpsolve,
  79. dtype=A.dtype)
  80. else:
  81. M = aslinearoperator(M)
  82. if A.shape != M.shape:
  83. raise ValueError('matrix and preconditioner have different shapes')
  84. # set initial guess
  85. if x0 is None:
  86. x = zeros(N, dtype=xtype)
  87. elif isinstance(x0, str):
  88. if x0 == 'Mb': # use nonzero initial guess ``M @ b``
  89. bCopy = b.copy()
  90. x = M.matvec(bCopy)
  91. else:
  92. x = array(x0, dtype=xtype)
  93. if not (x.shape == (N, 1) or x.shape == (N,)):
  94. raise ValueError(f'shapes of A {A.shape} and '
  95. f'x0 {x.shape} are incompatible')
  96. x = x.ravel()
  97. return A, M, x, b