selective_scan_interface.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. # Copyright (c) 2023, Tri Dao, Albert Gu.
  2. import torch
  3. import torch.nn.functional as F
  4. from mamba_ssm.utils.torch import custom_bwd, custom_fwd
  5. from einops import rearrange, repeat
  6. try:
  7. from causal_conv1d import causal_conv1d_fn
  8. import causal_conv1d_cuda
  9. except ImportError:
  10. causal_conv1d_fn = None
  11. causal_conv1d_cuda = None
  12. from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd
  13. import selective_scan_cuda
  14. class SelectiveScanFn(torch.autograd.Function):
  15. @staticmethod
  16. def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
  17. return_last_state=False):
  18. if u.stride(-1) != 1:
  19. u = u.contiguous()
  20. if delta.stride(-1) != 1:
  21. delta = delta.contiguous()
  22. if D is not None:
  23. D = D.contiguous()
  24. if B.stride(-1) != 1:
  25. B = B.contiguous()
  26. if C.stride(-1) != 1:
  27. C = C.contiguous()
  28. if z is not None and z.stride(-1) != 1:
  29. z = z.contiguous()
  30. if B.dim() == 3:
  31. B = rearrange(B, "b dstate l -> b 1 dstate l")
  32. ctx.squeeze_B = True
  33. if C.dim() == 3:
  34. C = rearrange(C, "b dstate l -> b 1 dstate l")
  35. ctx.squeeze_C = True
  36. out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
  37. ctx.delta_softplus = delta_softplus
  38. ctx.has_z = z is not None
  39. last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
  40. if not ctx.has_z:
  41. ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
  42. return out if not return_last_state else (out, last_state)
  43. else:
  44. ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
  45. out_z = rest[0]
  46. return out_z if not return_last_state else (out_z, last_state)
  47. @staticmethod
  48. def backward(ctx, dout, *args):
  49. if not ctx.has_z:
  50. u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
  51. z = None
  52. out = None
  53. else:
  54. u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
  55. if dout.stride(-1) != 1:
  56. dout = dout.contiguous()
  57. # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
  58. # backward of selective_scan_cuda with the backward of chunk).
  59. # Here we just pass in None and dz will be allocated in the C++ code.
  60. du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
  61. u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
  62. False # option to recompute out_z, not used here
  63. )
  64. dz = rest[0] if ctx.has_z else None
  65. dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
  66. dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
  67. return (du, ddelta, dA, dB, dC,
  68. dD if D is not None else None,
  69. dz,
  70. ddelta_bias if delta_bias is not None else None,
  71. None,
  72. None)
  73. def rms_norm_forward(
  74. x,
  75. weight,
  76. bias,
  77. eps=1e-6,
  78. is_rms_norm=True,
  79. ):
  80. # x (b l) d
  81. if x.stride(-1) != 1:
  82. x = x.contiguous()
  83. weight = weight.contiguous()
  84. if bias is not None:
  85. bias = bias.contiguous()
  86. y = _layer_norm_fwd(
  87. x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm
  88. )[0]
  89. # y (b l) d
  90. return y
  91. def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
  92. return_last_state=False):
  93. """if return_last_state is True, returns (out, last_state)
  94. last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
  95. not considered in the backward pass.
  96. """
  97. return selective_scan_ref(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
  98. def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
  99. return_last_state=False):
  100. """
  101. u: r(B D L)
  102. delta: r(B D L)
  103. A: c(D N) or r(D N)
  104. B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
  105. C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
  106. D: r(D)
  107. z: r(B D L)
  108. delta_bias: r(D), fp32
  109. out: r(B D L)
  110. last_state (optional): r(B D dstate) or c(B D dstate)
  111. """
  112. dtype_in = u.dtype
  113. u = u.float()
  114. delta = delta.float()
  115. if delta_bias is not None:
  116. delta = delta + delta_bias[..., None].float()
  117. if delta_softplus:
  118. delta = F.softplus(delta)
  119. batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
  120. is_variable_B = B.dim() >= 3
  121. is_variable_C = C.dim() >= 3
  122. if A.is_complex():
  123. if is_variable_B:
  124. B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
  125. if is_variable_C:
  126. C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
  127. else:
  128. B = B.float()
  129. C = C.float()
  130. x = A.new_zeros((batch, dim, dstate))
  131. ys = []
  132. deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
  133. if not is_variable_B:
  134. deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
  135. else:
  136. if B.dim() == 3:
  137. deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
  138. else:
  139. B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
  140. deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
  141. if is_variable_C and C.dim() == 4:
  142. C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
  143. last_state = None
  144. for i in range(u.shape[2]):
  145. x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
  146. if not is_variable_C:
  147. y = torch.einsum('bdn,dn->bd', x, C)
  148. else:
  149. if C.dim() == 3:
  150. y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
  151. else:
  152. y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
  153. if i == u.shape[2] - 1:
  154. last_state = x
  155. if y.is_complex():
  156. y = y.real * 2
  157. ys.append(y)
  158. y = torch.stack(ys, dim=2) # (batch dim L)
  159. out = y if D is None else y + u * rearrange(D, "d -> d 1")
  160. if z is not None:
  161. out = out * F.silu(z)
  162. out = out.to(dtype=dtype_in)
  163. return out if not return_last_state else (out, last_state)
  164. class MambaInnerFn(torch.autograd.Function):
  165. @staticmethod
  166. @custom_fwd
  167. def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  168. out_proj_weight, out_proj_bias,
  169. A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
  170. C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight=None, c_rms_weight= None, dt_rms_weight= None, b_c_dt_rms_eps=1e-6):
  171. """
  172. xz: (batch, dim, seqlen)
  173. """
  174. assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
  175. assert checkpoint_lvl in [0, 1]
  176. L = xz.shape[-1]
  177. delta_rank = delta_proj_weight.shape[1]
  178. d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
  179. if torch.is_autocast_enabled():
  180. x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
  181. delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
  182. out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
  183. out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
  184. if out_proj_bias is not None else None)
  185. if xz.stride(-1) != 1:
  186. xz = xz.contiguous()
  187. conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
  188. x, z = xz.chunk(2, dim=1)
  189. conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
  190. conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
  191. x, conv1d_weight, conv1d_bias, None, None, None, True
  192. )
  193. # We're being very careful here about the layout, to avoid extra transposes.
  194. # We want delta to have d as the slowest moving dimension
  195. # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
  196. x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
  197. delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
  198. ctx.is_variable_B = B is None
  199. ctx.is_variable_C = C is None
  200. ctx.B_proj_bias_is_None = B_proj_bias is None
  201. ctx.C_proj_bias_is_None = C_proj_bias is None
  202. if B is None: # variable B
  203. B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl dstate)
  204. if B_proj_bias is not None:
  205. B = B + B_proj_bias.to(dtype=B.dtype)
  206. if not A.is_complex():
  207. # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
  208. B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
  209. else:
  210. B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
  211. else:
  212. if B.stride(-1) != 1:
  213. B = B.contiguous()
  214. if C is None: # variable C
  215. C = x_dbl[:, -d_state:] # (bl dstate)
  216. if C_proj_bias is not None:
  217. C = C + C_proj_bias.to(dtype=C.dtype)
  218. if not A.is_complex():
  219. # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
  220. C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
  221. else:
  222. C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
  223. else:
  224. if C.stride(-1) != 1:
  225. C = C.contiguous()
  226. if D is not None:
  227. D = D.contiguous()
  228. if b_rms_weight is not None:
  229. B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
  230. B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
  231. B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
  232. if c_rms_weight is not None:
  233. C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
  234. C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
  235. C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
  236. if dt_rms_weight is not None:
  237. delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
  238. delta = rms_norm_forward(delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps)
  239. delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
  240. out, scan_intermediates, out_z = selective_scan_cuda.fwd(
  241. conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
  242. )
  243. ctx.delta_softplus = delta_softplus
  244. ctx.out_proj_bias_is_None = out_proj_bias is None
  245. ctx.checkpoint_lvl = checkpoint_lvl
  246. ctx.b_rms_weight = b_rms_weight
  247. ctx.c_rms_weight = c_rms_weight
  248. ctx.dt_rms_weight = dt_rms_weight
  249. ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
  250. if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
  251. conv1d_out, delta = None, None
  252. ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
  253. delta_proj_weight, out_proj_weight, conv1d_out, delta,
  254. A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out)
  255. return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
  256. @staticmethod
  257. @custom_bwd
  258. def backward(ctx, dout):
  259. # dout: (batch, seqlen, dim)
  260. assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
  261. (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
  262. conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, b_rms_weight, c_rms_weight, dt_rms_weight, out) = ctx.saved_tensors
  263. L = xz.shape[-1]
  264. delta_rank = delta_proj_weight.shape[1]
  265. d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
  266. x, z = xz.chunk(2, dim=1)
  267. if dout.stride(-1) != 1:
  268. dout = dout.contiguous()
  269. if ctx.checkpoint_lvl == 1:
  270. conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
  271. x, conv1d_weight, conv1d_bias, None, None, None, True
  272. )
  273. delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
  274. "d (b l) -> b d l", l = L)
  275. if dt_rms_weight is not None:
  276. delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
  277. delta = rms_norm_forward(delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps)
  278. delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
  279. if b_rms_weight is not None:
  280. # Recompute & RMSNorm B
  281. B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
  282. B = rms_norm_forward(
  283. B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps
  284. )
  285. B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
  286. if c_rms_weight is not None:
  287. # Recompute & RMSNorm C
  288. C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
  289. C = rms_norm_forward(
  290. C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps
  291. )
  292. C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
  293. # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
  294. # backward of selective_scan_cuda with the backward of chunk).
  295. dxz = torch.empty_like(xz) # (batch, dim, seqlen)
  296. dx, dz = dxz.chunk(2, dim=1)
  297. dout = rearrange(dout, "b l e -> e (b l)")
  298. dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
  299. dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
  300. conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
  301. ctx.delta_softplus,
  302. True # option to recompute out_z
  303. )
  304. dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
  305. dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
  306. dD = dD if D is not None else None
  307. dx_dbl = torch.empty_like(x_dbl)
  308. dB_proj_bias = None
  309. if ctx.is_variable_B:
  310. if not A.is_complex():
  311. dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
  312. else:
  313. dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
  314. dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
  315. dx_dbl[:, delta_rank:delta_rank + d_state] = dB # (bl d)
  316. dB = None
  317. dC_proj_bias = None
  318. if ctx.is_variable_C:
  319. if not A.is_complex():
  320. dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
  321. else:
  322. dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
  323. dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
  324. dx_dbl[:, -d_state:] = dC # (bl d)
  325. dC = None
  326. ddelta = rearrange(ddelta, "b d l -> d (b l)")
  327. ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
  328. dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
  329. dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
  330. dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
  331. dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
  332. dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
  333. # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
  334. # backward of conv1d with the backward of chunk).
  335. dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
  336. x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
  337. )
  338. dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
  339. dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
  340. return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
  341. dout_proj_weight, dout_proj_bias,
  342. dA, dB, dC, dD,
  343. ddelta_bias if delta_bias is not None else None,
  344. # 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
  345. dB_proj_bias, dC_proj_bias, None, None, None, None, None, None)
  346. def mamba_inner_fn(
  347. xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  348. out_proj_weight, out_proj_bias,
  349. A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
  350. C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1, b_rms_weight= None, c_rms_weight= None, dt_rms_weight= None, b_c_dt_rms_eps=1e-6
  351. ):
  352. return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  353. out_proj_weight, out_proj_bias,
  354. A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps)
  355. def mamba_inner_ref(
  356. xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  357. out_proj_weight, out_proj_bias,
  358. A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
  359. C_proj_bias=None, delta_softplus=True
  360. ):
  361. assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
  362. L = xz.shape[-1]
  363. delta_rank = delta_proj_weight.shape[1]
  364. d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
  365. x, z = xz.chunk(2, dim=1)
  366. x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu")
  367. # We're being very careful here about the layout, to avoid extra transposes.
  368. # We want delta to have d as the slowest moving dimension
  369. # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
  370. x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
  371. delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
  372. delta = rearrange(delta, "d (b l) -> b d l", l=L)
  373. if B is None: # variable B
  374. B = x_dbl[:, delta_rank:delta_rank + d_state] # (bl d)
  375. if B_proj_bias is not None:
  376. B = B + B_proj_bias.to(dtype=B.dtype)
  377. if not A.is_complex():
  378. B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
  379. else:
  380. B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
  381. if C is None: # variable B
  382. C = x_dbl[:, -d_state:] # (bl d)
  383. if C_proj_bias is not None:
  384. C = C + C_proj_bias.to(dtype=C.dtype)
  385. if not A.is_complex():
  386. C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
  387. else:
  388. C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
  389. y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
  390. return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)