config.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. from __future__ import annotations
  2. from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup
  3. def func_name_base_str(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> str:
  4. if isinstance(g, NativeFunctionsGroup):
  5. return str(g.functional.func.name.name.base)
  6. else:
  7. return str(g.view.root_name)
  8. is_hand_written_ops_ = frozenset(
  9. (
  10. "abs",
  11. "add",
  12. "addmm",
  13. "all",
  14. "any",
  15. "argmin",
  16. "bmm",
  17. "clamp",
  18. "clamp_min",
  19. "cumsum",
  20. "div",
  21. "fmod",
  22. "index_select",
  23. "leaky_relu",
  24. "linear",
  25. "log",
  26. "matmul",
  27. "mul",
  28. "narrow_copy",
  29. "nonzero",
  30. "pow",
  31. "remainder",
  32. "sigmoid",
  33. "sign",
  34. "sub",
  35. "tanh",
  36. "detach",
  37. "expand_as",
  38. "flatten",
  39. "narrow",
  40. "reshape_as",
  41. "select",
  42. "slice",
  43. "softmax",
  44. "split",
  45. "squeeze",
  46. "transpose",
  47. "view",
  48. "where",
  49. )
  50. )
  51. def is_hand_written(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
  52. name_base = func_name_base_str(g)
  53. return name_base in is_hand_written_ops_
  54. def override_test_values(arg_map: dict[str, str], op_name: str, index: int) -> None:
  55. if index not in (0, 1):
  56. raise AssertionError(f"index must be 0 or 1, got {index}")
  57. if op_name == "addr":
  58. if index == 0:
  59. arg_map["self"] = "at::rand({6, 6})"
  60. arg_map["vec1"] = "at::rand({6})"
  61. arg_map["vec2"] = "at::rand({6})"
  62. else:
  63. arg_map["self"] = "at::rand({22, 22})"
  64. arg_map["vec1"] = "at::rand({22})"
  65. arg_map["vec2"] = "at::rand({22})"
  66. return
  67. if op_name == "mv":
  68. if index == 0:
  69. arg_map["self"] = "at::rand({6, 6})"
  70. arg_map["vec"] = "at::rand({6})"
  71. else:
  72. arg_map["self"] = "at::rand({22, 22})"
  73. arg_map["vec"] = "at::rand({22})"
  74. return
  75. if op_name == "addbmm":
  76. if index == 0:
  77. arg_map["self"] = "at::rand({6, 6})"
  78. else:
  79. arg_map["self"] = "at::rand({22, 22})"
  80. return
  81. if op_name == "cross":
  82. if index == 0:
  83. arg_map["self"] = "at::rand({3, 3, 3})"
  84. arg_map["other"] = "at::rand({3, 3, 3})"
  85. else:
  86. arg_map["self"] = "at::rand({22, 3, 22})"
  87. arg_map["other"] = "at::rand({22, 3, 22})"
  88. return
  89. if op_name == "take":
  90. if index == 0:
  91. arg_map["index"] = "at::randint(0, 216, {20}, torch::kInt64)"
  92. else:
  93. arg_map["index"] = "at::randint(0, 1000, {100}, torch::kInt64)"
  94. return
  95. if op_name == "take_along_dim":
  96. if index == 0:
  97. arg_map["indices"] = "at::argsort(self0, 1, true)"
  98. else:
  99. arg_map["indices"] = "at::argsort(self1, 1, true)"
  100. return
  101. if op_name == "masked_select":
  102. if index == 0:
  103. arg_map["mask"] = "at::randn({6, 6, 6}) > 0.5"
  104. else:
  105. arg_map["mask"] = "at::rand({22, 22, 22}) > 0.5"
  106. return
  107. if op_name == "orgqr":
  108. if index == 0:
  109. arg_map["input2"] = "at::rand({6, 6})"
  110. else:
  111. arg_map["input2"] = "at::rand({22, 22})"
  112. return
  113. if op_name == "ormqr":
  114. if index == 0:
  115. arg_map["input2"] = "at::rand({6, 6})"
  116. else:
  117. arg_map["input2"] = "at::rand({22, 22})"
  118. return
  119. if op_name == "quantile":
  120. if index == 0:
  121. arg_map["q"] = "at::rand({6})"
  122. arg_map["interpolation"] = '"linear"'
  123. else:
  124. arg_map["q"] = "at::rand({22})"
  125. arg_map["interpolation"] = '"linear"'
  126. return
  127. if op_name == "nanquantile":
  128. if index == 0:
  129. arg_map["q"] = "at::rand({6})"
  130. arg_map["interpolation"] = '"linear"'
  131. else:
  132. arg_map["q"] = "at::rand({22})"
  133. arg_map["interpolation"] = '"linear"'
  134. return
  135. if op_name == "multi_margin_loss":
  136. if index == 0:
  137. arg_map["self"] = "at::rand({6, 6})"
  138. arg_map["target"] = "at::randint(6, {6}, torch::kInt64)"
  139. arg_map["weight"] = "at::rand({6})"
  140. else:
  141. arg_map["self"] = "at::rand({22, 22})"
  142. arg_map["target"] = "at::randint(22, {22}, torch::kInt64)"
  143. arg_map["weight"] = "at::rand({22})"
  144. return
  145. if op_name == "multilabel_margin_loss":
  146. if index == 0:
  147. arg_map["self"] = "at::rand({6, 6})"
  148. arg_map["target"] = "at::randint(6, {6, 6}, torch::kInt64)"
  149. else:
  150. arg_map["self"] = "at::rand({22, 22})"
  151. arg_map["target"] = "at::randint(22, {22, 22}, torch::kInt64)"
  152. return
  153. if op_name == "nll_loss":
  154. if index == 0:
  155. arg_map["self"] = "at::rand({6, 6})"
  156. arg_map["target"] = "at::randint(6, {6}, torch::kInt64)"
  157. arg_map["weight"] = "at::rand({6})"
  158. else:
  159. arg_map["self"] = "at::rand({22, 22})"
  160. arg_map["target"] = "at::randint(22, {22}, torch::kInt64)"
  161. arg_map["weight"] = "at::rand({22})"
  162. return
  163. if op_name == "nll_loss2d":
  164. if index == 0:
  165. arg_map["self"] = "at::rand({6, 6, 6, 6})"
  166. arg_map["target"] = "at::randint(6, {6, 6, 6}, torch::kInt64)"
  167. arg_map["weight"] = "at::rand({6})"
  168. else:
  169. arg_map["self"] = "at::rand({22, 22, 22, 22})"
  170. arg_map["target"] = "at::randint(22, {22, 22, 22}, torch::kInt64)"
  171. arg_map["weight"] = "at::rand({22})"
  172. return
  173. if op_name in (
  174. "fft_fft",
  175. "fft_ifft",
  176. "fft_rfft",
  177. "fft_irfft",
  178. "fft_hfft",
  179. "fft_ihfft",
  180. ):
  181. arg_map["norm"] = '"forward"'
  182. return
  183. if op_name == "linalg_tensorinv":
  184. if index == 0:
  185. arg_map["self"] = "at::rand({6, 6, 6, 6})"
  186. arg_map["ind"] = "2"
  187. else:
  188. arg_map["self"] = "at::rand({22, 22, 22, 22})"
  189. arg_map["ind"] = "2"
  190. return
  191. if op_name == "addmv":
  192. if index == 0:
  193. arg_map["self"] = "at::rand({2})"
  194. arg_map["mat"] = "at::rand({2, 2})"
  195. arg_map["vec"] = "at::rand({2})"
  196. else:
  197. arg_map["self"] = "at::rand({35})"
  198. arg_map["mat"] = "at::rand({35, 35})"
  199. arg_map["vec"] = "at::rand({35})"
  200. return
  201. if op_name == "acosh":
  202. if index == 0:
  203. arg_map["self"] = "at::rand({2, 2, 2}) + at::ones({2, 2, 2})"
  204. else:
  205. arg_map["self"] = "at::rand({5, 5, 5}) + at::ones({5, 5, 5})"
  206. return
  207. if op_name == "adaptive_max_pool2d_backward":
  208. if index == 0:
  209. arg_map["grad_output"] = "at::rand({2, 2, 2}, at::kFloat)"
  210. arg_map["self"] = "at::rand({2, 2, 2}, at::kFloat)"
  211. arg_map["indices"] = "at::randint(0, 1, {2, 2, 2}, at::kLong)"
  212. else:
  213. arg_map["grad_output"] = "at::rand({3, 3, 3}, at::kFloat)"
  214. arg_map["self"] = "at::rand({3, 3, 3}, at::kFloat)"
  215. arg_map["indices"] = "at::randint(0, 1, {3, 3, 3}, at::kLong)"
  216. return
  217. if op_name == "adaptive_max_pool3d_backward":
  218. if index == 0:
  219. arg_map["grad_output"] = "at::rand({2, 2, 2, 2}, at::kFloat)"
  220. arg_map["self"] = "at::rand({2, 2, 2, 2}, at::kFloat)"
  221. arg_map["indices"] = "at::randint(0, 1, {2, 2, 2, 2}, at::kLong)"
  222. else:
  223. arg_map["grad_output"] = "at::rand({3, 3, 3, 3}, at::kFloat)"
  224. arg_map["self"] = "at::rand({3, 3, 3, 3}, at::kFloat)"
  225. arg_map["indices"] = "at::randint(0, 1, {3, 3, 3, 3}, at::kLong)"
  226. return
  227. if op_name == "bitwise_left_shift":
  228. if index == 0:
  229. arg_map["self"] = "at::randint(1, 1 << 4, {6, 6, 6}, at::kInt)"
  230. arg_map["other"] = "at::randint(1, 26, {6, 6, 6}, at::kInt)"
  231. else:
  232. arg_map["self"] = "at::randint(1, 1 << 4, {22, 22, 22}, at::kInt)"
  233. arg_map["other"] = "at::randint(1, 26, {22, 22, 22}, at::kInt)"
  234. return
  235. if op_name == "bitwise_right_shift":
  236. if index == 0:
  237. arg_map["self"] = "at::randint(1 << 21, 1 << 30, {6, 6, 6}, at::kInt)"
  238. arg_map["other"] = "at::randint(1, 22, {6, 6, 6}, at::kInt)"
  239. else:
  240. arg_map["self"] = "at::randint(1 << 21, 1 << 30, {22, 22, 22}, at::kInt)"
  241. arg_map["other"] = "at::randint(1, 22, {22, 22, 22}, at::kInt)"
  242. return
  243. if op_name == "gather":
  244. if index == 0:
  245. arg_map["self"] = "at::randint(1, 100, {2,2,2}, at::kInt)"
  246. arg_map["dim"] = "1"
  247. arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)"
  248. arg_map["sparse_grad"] = "false"
  249. else:
  250. arg_map["self"] = "at::randint(1, 100, {5,5,5}, at::kInt)"
  251. arg_map["dim"] = "1"
  252. arg_map["index"] = "at::randint(0, 4, {5,5,5}, torch::kInt64)"
  253. arg_map["sparse_grad"] = "false"
  254. return
  255. if op_name == "gelu":
  256. if index == 0:
  257. arg_map["self"] = "at::rand({6, 6, 6})"
  258. arg_map["approximate"] = '"tanh"'
  259. else:
  260. arg_map["self"] = "at::rand({22, 22, 22})"
  261. arg_map["approximate"] = '"tanh"'
  262. return
  263. if op_name == "gelu_backward":
  264. if index == 0:
  265. arg_map["grad_output"] = "at::rand({6, 6, 6})"
  266. arg_map["self"] = "at::rand({6, 6, 6})"
  267. arg_map["approximate"] = '"tanh"'
  268. else:
  269. arg_map["grad_output"] = "at::rand({22, 22, 22})"
  270. arg_map["self"] = "at::rand({22, 22, 22})"
  271. arg_map["approximate"] = '"tanh"'
  272. return
  273. if op_name == "index_add":
  274. if index == 0:
  275. arg_map["self"] = "at::rand({2})"
  276. arg_map["dim"] = "0"
  277. arg_map["index"] = "at::randint(0, 1, {2}, at::kInt)"
  278. arg_map["source"] = "at::rand({2})"
  279. arg_map["alpha"] = "2"
  280. else:
  281. arg_map["self"] = "at::rand({16})"
  282. arg_map["dim"] = "0"
  283. arg_map["index"] = "at::randint(0, 10, {16}, at::kInt)"
  284. arg_map["source"] = "at::rand({16})"
  285. arg_map["alpha"] = "2"
  286. return
  287. if op_name == "index_copy":
  288. if index == 0:
  289. arg_map["self"] = "at::rand({2})"
  290. arg_map["dim"] = "0"
  291. arg_map["index"] = "at::randint(0, 1, {2}, at::kLong)"
  292. arg_map["source"] = "at::rand({2})"
  293. else:
  294. arg_map["self"] = "at::rand({32})"
  295. arg_map["dim"] = "0"
  296. arg_map["index"] = "at::randint(0, 10, {32}, at::kLong)"
  297. arg_map["source"] = "at::rand({32})"
  298. return
  299. if op_name == "linalg_cross":
  300. if index == 0:
  301. arg_map["self"] = "at::rand({6, 3, 6})"
  302. arg_map["other"] = "at::rand({6, 3, 6})"
  303. arg_map["dim"] = "1"
  304. else:
  305. arg_map["self"] = "at::rand({22, 3, 22})"
  306. arg_map["other"] = "at::rand({22, 3, 22})"
  307. arg_map["dim"] = "1"
  308. return
  309. if op_name == "nll_loss_backward":
  310. if index == 0:
  311. arg_map["grad_output"] = "at::rand({})"
  312. arg_map["self"] = "at::rand({6})"
  313. arg_map["target"] = "at::randint(0, 5, {6}, torch::kInt64)"
  314. arg_map["weight"] = "at::rand({6})"
  315. arg_map["reduction"] = "1"
  316. arg_map["ignore_index"] = "1"
  317. arg_map["total_weight"] = "at::rand({})"
  318. else:
  319. arg_map["grad_output"] = "at::rand({})"
  320. arg_map["self"] = "at::rand({36})"
  321. arg_map["target"] = "at::randint(0, 11, {36}, torch::kInt64)"
  322. arg_map["weight"] = "at::rand({36})"
  323. arg_map["reduction"] = "1"
  324. arg_map["ignore_index"] = "1"
  325. arg_map["total_weight"] = "at::rand({})"
  326. return
  327. if op_name in ["scatter", "scatter_add", "_scatter_reduce"]:
  328. if index == 0:
  329. arg_map["self"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)"
  330. arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)"
  331. arg_map["src"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)"
  332. else:
  333. arg_map["self"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)"
  334. arg_map["index"] = "at::randint(0, 1, {5,5,5}, torch::kInt64)"
  335. arg_map["src"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)"
  336. if "reduce" in arg_map:
  337. arg_map["reduce"] = '"sum"' if op_name == "_scatter_reduce" else '"add"'
  338. return
  339. if op_name == "scatter_reduce":
  340. arg_map["reduce"] = '"mean"'
  341. if index == 0:
  342. arg_map["index"] = "at::randint(6, {6, 6, 6}, torch::kInt64)"
  343. else:
  344. arg_map["index"] = "at::randint(22, {22, 22, 22}, torch::kInt64)"
  345. return
  346. if op_name == "special_zeta":
  347. if index == 0:
  348. arg_map["self"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})"
  349. arg_map["other"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})"
  350. else:
  351. arg_map["self"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})"
  352. arg_map["other"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})"
  353. return
  354. if op_name == "_convert_indices_from_csr_to_coo":
  355. if index == 0:
  356. arg_map["crow_indices"] = "torch::tensor({1}, torch::kInt32)"
  357. arg_map["col_indices"] = "torch::tensor({0, 1, 0}, torch::kInt32)"
  358. arg_map["out_int32"] = "false"
  359. else:
  360. arg_map["crow_indices"] = "torch::tensor({0}, torch::kInt32)"
  361. arg_map["col_indices"] = (
  362. "torch::tensor({0, 1, 0, 2, 1, 2, 0, 1, 0, 2, 1, 2}, torch::kInt32)"
  363. )
  364. arg_map["out_int32"] = "false"
  365. return
  366. if op_name == "_convert_indices_from_coo_to_csr":
  367. if index == 0:
  368. arg_map["self"] = "at::randint(0, 3, {2}, at::kInt)"
  369. arg_map["size"] = "10"
  370. arg_map["out_int32"] = "false"
  371. else:
  372. arg_map["self"] = "at::randint(0, 3, {12}, at::kInt)"
  373. arg_map["size"] = "24"
  374. arg_map["out_int32"] = "false"
  375. return
  376. if op_name in ("diagonal", "linalg_diagonal"):
  377. arg_map["offset"] = "0"
  378. arg_map["dim1"] = "2"
  379. arg_map["dim2"] = "1"
  380. return