hipify_python.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173
  1. #!/usr/bin/env python3
  2. # mypy: allow-untyped-defs
  3. """ The Python Hipify script.
  4. ##
  5. # Copyright (c) 2015-2016 Advanced Micro Devices, Inc. All rights reserved.
  6. # 2017-2018 Advanced Micro Devices, Inc. and
  7. # Facebook Inc. All rights reserved.
  8. #
  9. # Permission is hereby granted, free of charge, to any person obtaining a copy
  10. # of this software and associated documentation files (the "Software"), to deal
  11. # in the Software without restriction, including without limitation the rights
  12. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  13. # copies of the Software, and to permit persons to whom the Software is
  14. # furnished to do so, subject to the following conditions:
  15. #
  16. # The above copyright notice and this permission notice shall be included in
  17. # all copies or substantial portions of the Software.
  18. #
  19. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  20. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  21. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  22. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  23. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  24. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  25. # THE SOFTWARE.
  26. """
  27. import argparse
  28. import fnmatch
  29. import re
  30. import shutil
  31. import sys
  32. import os
  33. import warnings
  34. from .cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS
  35. from .cuda_to_hip_mappings import MATH_TRANSPILATIONS
  36. from .cuda_to_hip_mappings import CAFFE2_PATH_MAPPINGS
  37. from collections.abc import Iterator
  38. from collections.abc import Mapping, Iterable
  39. from enum import Enum
  40. import functools
  41. import hashlib
  42. def _deprecated(name):
  43. warnings.warn(f"hipify version 2.0.0 no longer uses function {name}", FutureWarning, stacklevel=2)
  44. class CurrentState(Enum):
  45. INITIALIZED = 1
  46. DONE = 2
  47. class HipifyResult:
  48. def __init__(self, current_state, hipified_path) -> None:
  49. self.current_state = current_state
  50. self.hipified_path = hipified_path
  51. self.status = ""
  52. def __str__(self) -> str:
  53. return (f"HipifyResult:: current_state: {self.current_state}, hipified_path : {self.hipified_path}, status: {self.status}")
  54. HipifyFinalResult = dict[str, HipifyResult]
  55. HIPIFY_C_BREADCRUMB = "// !!! This is a file automatically generated by hipify!!!\n"
  56. HIPIFY_FINAL_RESULT: HipifyFinalResult = {}
  57. # Hardcode the PyTorch template map
  58. """This dictionary provides the mapping from PyTorch kernel template types
  59. to their actual types."""
  60. PYTORCH_TEMPLATE_MAP = {"Dtype": "scalar_t", "T": "scalar_t"}
  61. __all__ = ['InputError', 'openf', 'bcolors', 'GeneratedFileCleaner', 'match_extensions', 'matched_files_iter',
  62. 'preprocess_file_and_save_result', 'compute_stats', 'add_dim3', 'processKernelLaunches', 'find_closure_group',
  63. 'find_bracket_group', 'find_parentheses_group', 'replace_math_functions', 'hip_header_magic', 'replace_extern_shared',
  64. 'get_hip_file_path', 'is_out_of_place', 'is_pytorch_file', 'is_cusparse_file', 'is_special_file', 'is_caffe2_gpu_file',
  65. 'Trie', 'preprocessor', 'file_specific_replacement', 'file_add_header',
  66. 'fix_static_global_kernels', 'extract_arguments', 'str2bool', 'CurrentState', 'HipifyResult', 'hipify']
  67. class InputError(Exception):
  68. # Exception raised for errors in the input.
  69. def __init__(self, message) -> None:
  70. super().__init__(message)
  71. self.message = message
  72. def __str__(self) -> str:
  73. return f"Input error: {self.message}"
  74. def openf(filename, mode):
  75. return open(filename, mode, errors='ignore')
  76. # Color coding for printing
  77. class bcolors:
  78. HEADER = '\033[95m'
  79. OKBLUE = '\033[94m'
  80. OKGREEN = '\033[92m'
  81. WARNING = '\033[93m'
  82. FAIL = '\033[91m'
  83. ENDC = '\033[0m'
  84. BOLD = '\033[1m'
  85. UNDERLINE = '\033[4m'
  86. # To the programmer, the output of hipify most likely are intermediates.
  87. # This class allows users of hipify to ask for a cleanup by running the
  88. # hipify and compilation in a with instantiating this context manager class
  89. # with keep_intermediates=False.
  90. # The main usecase is the cpp_extensions, specifically the load method.
  91. # It is a good idea to keep intermediates (in case of errors or to
  92. # not recompile unchanged files), but in cases where you don't want to
  93. # keep them (e.g. in the CI), this can be used to remove files.
  94. class GeneratedFileCleaner:
  95. """Context Manager to clean up generated files"""
  96. def __init__(self, keep_intermediates=False) -> None:
  97. self.keep_intermediates = keep_intermediates
  98. self.files_to_clean = set()
  99. self.dirs_to_clean = []
  100. def __enter__(self):
  101. return self
  102. def open(self, fn, *args, **kwargs):
  103. if not os.path.exists(fn):
  104. self.files_to_clean.add(os.path.abspath(fn))
  105. return open(fn, *args, **kwargs)
  106. def makedirs(self, dn, exist_ok=False) -> None:
  107. parent, n = os.path.split(dn)
  108. if not n:
  109. parent, n = os.path.split(parent)
  110. if parent and n and not os.path.exists(parent):
  111. self.makedirs(parent, exist_ok=True)
  112. if not os.path.isdir(dn) or not exist_ok:
  113. os.mkdir(dn)
  114. self.dirs_to_clean.append(os.path.abspath(dn))
  115. def __exit__(self, type, value, traceback):
  116. if not self.keep_intermediates:
  117. for f in self.files_to_clean:
  118. os.unlink(f)
  119. for d in self.dirs_to_clean[::-1]:
  120. os.rmdir(d)
  121. # Follow UNIX convention for paths to use '/' instead of '\\' on Windows
  122. def _to_unix_path(path: str) -> str:
  123. return path.replace(os.sep, '/')
  124. def match_extensions(filename: str, extensions: Iterable) -> bool:
  125. """Helper method to see if filename ends with certain extension"""
  126. return any(filename.endswith(e) for e in extensions)
  127. def _fnmatch(filepath, patterns):
  128. return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns)
  129. def matched_files_iter(
  130. root_path: str,
  131. includes: Iterable = (),
  132. ignores: Iterable = (),
  133. extensions: Iterable = (),
  134. out_of_place_only: bool = False,
  135. is_pytorch_extension: bool = False) -> Iterator[str]:
  136. exact_matches = set(includes)
  137. # This is a very rough heuristic; really, we want to avoid scanning
  138. # any file which is not checked into source control, but this script
  139. # needs to work even if you're in a Git or Hg checkout, so easier to
  140. # just block the biggest time sinks that won't matter in the
  141. # end.
  142. for (abs_dirpath, dirs, filenames) in os.walk(root_path, topdown=True):
  143. rel_dirpath = os.path.relpath(abs_dirpath, root_path)
  144. if rel_dirpath == '.':
  145. # Blah blah blah O(n) blah blah
  146. if ".git" in dirs:
  147. dirs.remove(".git")
  148. if "build" in dirs:
  149. dirs.remove("build")
  150. if "third_party" in dirs:
  151. dirs.remove("third_party")
  152. dirs.append("third_party/nvfuser")
  153. for filename in filenames:
  154. filepath = _to_unix_path(os.path.join(abs_dirpath, filename))
  155. # We respect extensions, UNLESS you wrote the entire
  156. # filename verbatim, in which case we always accept it
  157. if (
  158. _fnmatch(filepath, includes)
  159. and (not _fnmatch(filepath, ignores))
  160. and (match_extensions(filepath, extensions) or filepath in exact_matches)
  161. ):
  162. yield filepath
  163. def preprocess_file_and_save_result(
  164. output_directory: str,
  165. filepath: str,
  166. all_files: Iterable,
  167. header_include_dirs: Iterable,
  168. stats: dict[str, list],
  169. hip_clang_launch: bool,
  170. is_pytorch_extension: bool,
  171. clean_ctx: GeneratedFileCleaner,
  172. show_progress: bool) -> None:
  173. fin_path = os.path.abspath(os.path.join(output_directory, filepath))
  174. hipify_result = HipifyResult(current_state=CurrentState.INITIALIZED, hipified_path=fin_path)
  175. HIPIFY_FINAL_RESULT[fin_path] = hipify_result
  176. result = preprocessor(output_directory, filepath, all_files, header_include_dirs, stats,
  177. hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)
  178. # Show what happened
  179. if show_progress and "ignored" not in result.status:
  180. print(
  181. fin_path, "->",
  182. result.hipified_path, result.status, flush=True)
  183. HIPIFY_FINAL_RESULT[fin_path] = result
  184. def compute_stats(stats) -> None:
  185. unsupported_calls = {cuda_call for (cuda_call, _filepath) in stats["unsupported_calls"]}
  186. # Print the number of unsupported calls
  187. print(f"Total number of unsupported CUDA function calls: {len(unsupported_calls):d}")
  188. # Print the list of unsupported calls
  189. print(", ".join(unsupported_calls))
  190. # Print the number of kernel launches
  191. print(f"\nTotal number of replaced kernel launches: {len(stats['kernel_launches']):d}")
  192. def add_dim3(kernel_string, cuda_kernel):
  193. '''adds dim3() to the second and third arguments in the kernel launch'''
  194. count = 0
  195. closure = 0
  196. kernel_string = kernel_string.replace("<<<", "").replace(">>>", "")
  197. arg_locs: list[dict[str, int]] = [{} for _ in range(2)]
  198. arg_locs[count]['start'] = 0
  199. for ind, c in enumerate(kernel_string):
  200. if count > 1:
  201. break
  202. if c == "(":
  203. closure += 1
  204. elif c == ")":
  205. closure -= 1
  206. if (c == "," or ind == len(kernel_string) - 1) and closure == 0:
  207. arg_locs[count]['end'] = ind + (c != ",")
  208. count += 1
  209. if count < 2:
  210. arg_locs[count]['start'] = ind + 1
  211. first_arg_raw = kernel_string[arg_locs[0]['start']:arg_locs[0]['end'] + 1]
  212. second_arg_raw = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']]
  213. first_arg_clean = kernel_string[arg_locs[0]['start']:arg_locs[0]['end']].replace("\n", "").strip(" ")
  214. second_arg_clean = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']].replace("\n", "").strip(" ")
  215. first_arg_dim3 = f"dim3({first_arg_clean})"
  216. second_arg_dim3 = f"dim3({second_arg_clean})"
  217. first_arg_raw_dim3 = first_arg_raw.replace(first_arg_clean, first_arg_dim3)
  218. second_arg_raw_dim3 = second_arg_raw.replace(second_arg_clean, second_arg_dim3)
  219. cuda_kernel = cuda_kernel.replace(first_arg_raw + second_arg_raw, first_arg_raw_dim3 + second_arg_raw_dim3)
  220. return cuda_kernel
  221. RE_KERNEL_LAUNCH = re.compile(r'([ ]+)(detail?)::[ ]+\\\n[ ]+')
  222. def processKernelLaunches(string, stats):
  223. """ Replace the CUDA style Kernel launches with the HIP style kernel launches."""
  224. # Concat the namespace with the kernel names. (Find cleaner way of doing this later).
  225. string = RE_KERNEL_LAUNCH.sub(lambda inp: f"{inp.group(1)}{inp.group(2)}::", string)
  226. def grab_method_and_template(in_kernel):
  227. # The positions for relevant kernel components.
  228. pos = {
  229. "kernel_launch": {"start": in_kernel["start"], "end": in_kernel["end"]},
  230. "kernel_name": {"start": -1, "end": -1},
  231. "template": {"start": -1, "end": -1}
  232. }
  233. # Count for balancing template
  234. count = {"<>": 0}
  235. # Status for whether we are parsing a certain item.
  236. START = 0
  237. AT_TEMPLATE = 1
  238. AFTER_TEMPLATE = 2
  239. AT_KERNEL_NAME = 3
  240. status = START
  241. # Parse the string character by character
  242. for i in range(pos["kernel_launch"]["start"] - 1, -1, -1):
  243. char = string[i]
  244. # Handle Templating Arguments
  245. if status in (START, AT_TEMPLATE):
  246. if char == ">":
  247. if status == START:
  248. status = AT_TEMPLATE
  249. pos["template"]["end"] = i
  250. count["<>"] += 1
  251. if char == "<":
  252. count["<>"] -= 1
  253. if count["<>"] == 0 and (status == AT_TEMPLATE):
  254. pos["template"]["start"] = i
  255. status = AFTER_TEMPLATE
  256. # Handle Kernel Name
  257. if status != AT_TEMPLATE:
  258. if string[i].isalnum() or string[i] in {'(', ')', '_', ':', '#'}:
  259. if status != AT_KERNEL_NAME:
  260. status = AT_KERNEL_NAME
  261. pos["kernel_name"]["end"] = i
  262. # Case: Kernel name starts the string.
  263. if i == 0:
  264. pos["kernel_name"]["start"] = 0
  265. # Finished
  266. return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])]
  267. else:
  268. # Potential ending point if we're already traversing a kernel's name.
  269. if status == AT_KERNEL_NAME:
  270. pos["kernel_name"]["start"] = i
  271. # Finished
  272. return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])]
  273. def find_kernel_bounds(string):
  274. """Finds the starting and ending points for all kernel launches in the string."""
  275. kernel_end = 0
  276. kernel_positions = []
  277. # Continue until we cannot find any more kernels anymore.
  278. while string.find("<<<", kernel_end) != -1:
  279. # Get kernel starting position (starting from the previous ending point)
  280. kernel_start = string.find("<<<", kernel_end)
  281. # Get kernel ending position (adjust end point past the >>>)
  282. kernel_end = string.find(">>>", kernel_start) + 3
  283. if kernel_end <= 0:
  284. raise InputError("no kernel end found")
  285. # Add to list of traversed kernels
  286. kernel_positions.append({"start": kernel_start, "end": kernel_end,
  287. "group": string[kernel_start: kernel_end]})
  288. return kernel_positions
  289. # Replace comments and string literals from the code so that find_kernel_bounds does not
  290. # wrongly capture kernels in comments and string literals.
  291. # This function replaces them with "x" to keep positions.
  292. def mask_comments(string):
  293. in_comment = ''
  294. prev_c = ''
  295. new_string = ''
  296. for c in string:
  297. if in_comment == '':
  298. # Outside comments
  299. if c == '/' and prev_c == '/':
  300. in_comment = '//'
  301. elif c == '*' and prev_c == '/':
  302. in_comment = '/*'
  303. elif c == '"' and prev_c != '\\' and prev_c != "'":
  304. in_comment = '"'
  305. elif in_comment == '//':
  306. # In // xxx
  307. if c == '\r' or c == '\n':
  308. in_comment = ''
  309. elif in_comment == '/*':
  310. # In /* xxx */
  311. if c == '/' and prev_c == '*':
  312. in_comment = ''
  313. elif in_comment == '"':
  314. # In ""
  315. if c == '"' and prev_c != '\\':
  316. in_comment = ''
  317. prev_c = c
  318. if in_comment == '':
  319. new_string += c
  320. else:
  321. new_string += 'x'
  322. return new_string
  323. # Grab positional ranges of all kernel launches
  324. get_kernel_positions = list(find_kernel_bounds(mask_comments(string)))
  325. output_string = string
  326. # Replace each CUDA kernel with a HIP kernel.
  327. for kernel in get_kernel_positions:
  328. # Get kernel components
  329. params = grab_method_and_template(kernel)
  330. # Find parenthesis after kernel launch
  331. parenthesis = string.find("(", kernel["end"])
  332. # Extract cuda kernel
  333. cuda_kernel = string[params[0]["start"]:parenthesis + 1]
  334. kernel_string = string[kernel['start']:kernel['end']]
  335. end_param_index = 0 if params[1]['end'] == -1 else 1
  336. kernel_name_with_template = string[params[0]['start']:params[end_param_index]['end'] + 1]
  337. cuda_kernel_dim3 = add_dim3(kernel_string, cuda_kernel)
  338. # Keep number of kernel launch params consistent (grid dims, group dims, stream, dynamic shared size)
  339. num_klp = len(extract_arguments(0, kernel["group"].replace("<<<", "(").replace(">>>", ")")))
  340. hip_kernel = "hipLaunchKernelGGL(" + cuda_kernel_dim3[0:-1].replace(
  341. ">>>", ", 0" * (4 - num_klp) + ">>>").replace("<<<", ", ").replace(
  342. ">>>", ", ").replace(kernel_name_with_template, "(" + kernel_name_with_template + ")")
  343. # Replace cuda kernel with hip kernel
  344. output_string = output_string.replace(cuda_kernel, hip_kernel)
  345. # Update the statistics
  346. stats["kernel_launches"].append(hip_kernel)
  347. return output_string
  348. def find_closure_group(input_string, start, group):
  349. """Generalization for finding a balancing closure group
  350. if group = ["(", ")"], then finds the first balanced parentheses.
  351. if group = ["{", "}"], then finds the first balanced bracket.
  352. Given an input string, a starting position in the input string, and the group type,
  353. find_closure_group returns the positions of group[0] and group[1] as a tuple.
  354. Example:
  355. >>> find_closure_group("(hi)", 0, ["(", ")"])
  356. (0, 3)
  357. """
  358. inside_parenthesis = False
  359. parens = 0
  360. pos = start
  361. p_start, p_end = -1, -1
  362. while pos < len(input_string):
  363. if input_string[pos] == group[0]:
  364. if inside_parenthesis is False:
  365. inside_parenthesis = True
  366. parens = 1
  367. p_start = pos
  368. else:
  369. parens += 1
  370. elif input_string[pos] == group[1] and inside_parenthesis:
  371. parens -= 1
  372. if parens == 0:
  373. p_end = pos
  374. return p_start, p_end
  375. pos += 1
  376. return None, None
  377. def find_bracket_group(input_string, start):
  378. """Finds the first balanced parentheses."""
  379. return find_closure_group(input_string, start, group=["{", "}"])
  380. def find_parentheses_group(input_string, start):
  381. """Finds the first balanced bracket."""
  382. return find_closure_group(input_string, start, group=["(", ")"])
  383. RE_ASSERT = re.compile(r"\bassert[ ]*\(")
  384. def replace_math_functions(input_string):
  385. """FIXME: Temporarily replace std:: invocations of math functions
  386. with non-std:: versions to prevent linker errors NOTE: This
  387. can lead to correctness issues when running tests, since the
  388. correct version of the math function (exp/expf) might not get
  389. called. Plan is to remove this function once HIP supports
  390. std:: math function calls inside device code
  391. """
  392. output_string = input_string
  393. for func in MATH_TRANSPILATIONS:
  394. output_string = output_string.replace(fr'{func}(', f'{MATH_TRANSPILATIONS[func]}(')
  395. return output_string
  396. RE_SYNCTHREADS = re.compile(r":?:?\b(__syncthreads)\b(\w*\()")
  397. def hip_header_magic(input_string):
  398. """If the file makes kernel builtin calls and does not include the cuda_runtime.h header,
  399. then automatically add an #include to match the "magic" includes provided by NVCC.
  400. TODO:
  401. Update logic to ignore cases where the cuda_runtime.h is included by another file.
  402. """
  403. # Copy the input.
  404. output_string = input_string
  405. # Check if one of the following headers is already included.
  406. headers = ["hip/hip_runtime.h", "hip/hip_runtime_api.h"]
  407. if any(re.search(fr'#include ("{ext}"|<{ext}>)', output_string) for ext in headers):
  408. return output_string
  409. # Rough logic to detect if we're inside device code
  410. hasDeviceLogic: int
  411. hasDeviceLogic = "hipLaunchKernelGGL" in output_string
  412. hasDeviceLogic += "__global__" in output_string
  413. hasDeviceLogic += "__shared__" in output_string
  414. hasDeviceLogic += RE_SYNCTHREADS.search(output_string) is not None
  415. # If device logic found, provide the necessary header.
  416. if hasDeviceLogic:
  417. output_string = '#include "hip/hip_runtime.h"\n' + input_string
  418. return output_string
  419. RE_EXTERN_SHARED = re.compile(r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+)\s+(\w+)\s*\[\s*\]\s*;")
  420. def replace_extern_shared(input_string):
  421. """
  422. Match 'extern __shared__ type foo[];' syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
  423. See: https://github.com/ROCm/hip/blob/master/docs/markdown/hip_kernel_language.md#__shared__
  424. Examples:
  425. "extern __shared__ char smemChar[];"
  426. => "HIP_DYNAMIC_SHARED( char, smemChar)"
  427. "extern __shared__ unsigned char smem[];"
  428. => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
  429. """
  430. output_string = input_string
  431. output_string = RE_EXTERN_SHARED.sub(
  432. lambda inp: f"HIP_DYNAMIC_SHARED({inp.group(1) or ''} {inp.group(2)}, {inp.group(3)})", output_string)
  433. return output_string
  434. def get_hip_file_path(rel_filepath, is_pytorch_extension=False):
  435. """
  436. Returns the new name of the hipified file
  437. """
  438. # At the moment, some PyTorch source files are HIPified in place. The predicate
  439. # is_out_of_place tells us if this is the case or not.
  440. if os.path.isabs(rel_filepath):
  441. raise AssertionError("rel_filepath must be a relative path")
  442. if not is_pytorch_extension and not is_out_of_place(rel_filepath):
  443. return rel_filepath
  444. dirpath, filename = os.path.split(rel_filepath)
  445. root, ext = os.path.splitext(filename)
  446. # Here's the plan:
  447. #
  448. # In general, we need to disambiguate the HIPified filename so that
  449. # it gets a different name from the original filename, so
  450. # that we don't overwrite the original file
  451. #
  452. # There's a lot of different naming conventions across PyTorch,
  453. # but the general recipe is to convert occurrences
  454. # of cuda/gpu to hip, and add hip if there are no occurrences
  455. # of cuda/gpu anywhere.
  456. #
  457. # Concretely, we do the following:
  458. #
  459. # - If there is a directory component named "cuda", replace
  460. # it with "hip", AND
  461. #
  462. # - If the file name contains "CUDA", replace it with "HIP", AND
  463. #
  464. # - ALWAYS replace '.cu' with '.hip', because those files
  465. # contain CUDA kernels that needs to be hipified and processed with
  466. # hip compiler
  467. #
  468. # - If we are not hipifying a PyTorch extension, and the parent
  469. # directory name did not change as a result of the above
  470. # transformations, insert "hip" in the file path
  471. # as the direct parent folder of the file
  472. #
  473. # - If we are hipifying a PyTorch extension, and the parent directory
  474. # name as well as the filename (incl. extension) did not change as
  475. # a result of the above transformations, insert "_hip" in the filename
  476. #
  477. # This isn't set in stone; we might adjust this to support other
  478. # naming conventions.
  479. if ext == '.cu':
  480. ext = '.hip'
  481. orig_filename = filename
  482. orig_dirpath = dirpath
  483. dirpath = dirpath.replace('cuda', 'hip')
  484. dirpath = dirpath.replace('CUDA', 'HIP')
  485. dirpath = dirpath.replace('THC', 'THH')
  486. root = root.replace('cuda', 'hip')
  487. root = root.replace('CUDA', 'HIP')
  488. # Special case to handle caffe2/core/THCCachingAllocator
  489. if dirpath != "caffe2/core":
  490. root = root.replace('THC', 'THH')
  491. if not is_pytorch_extension and dirpath == orig_dirpath:
  492. dirpath = os.path.join(dirpath, 'hip')
  493. if is_pytorch_extension and dirpath == orig_dirpath and (root + ext) == orig_filename:
  494. root = root + "_hip"
  495. return os.path.join(dirpath, root + ext)
  496. def is_out_of_place(rel_filepath) -> bool:
  497. if os.path.isabs(rel_filepath):
  498. raise AssertionError("rel_filepath must be a relative path")
  499. if rel_filepath.startswith("torch/"):
  500. return False
  501. if rel_filepath.startswith("third_party/nvfuser/"):
  502. return False
  503. if rel_filepath.startswith("tools/autograd/templates/"):
  504. return False
  505. return True
  506. # Keep this synchronized with includes/ignores in build_amd.py
  507. def is_pytorch_file(rel_filepath) -> bool:
  508. _deprecated("is_pytorch_file")
  509. if os.path.isabs(rel_filepath):
  510. raise AssertionError("rel_filepath must be a relative path")
  511. if rel_filepath.startswith("aten/"):
  512. if rel_filepath.startswith("aten/src/ATen/core/"):
  513. return False
  514. return True
  515. if rel_filepath.startswith("torch/"):
  516. return True
  517. if rel_filepath.startswith("third_party/nvfuser/"):
  518. return True
  519. if rel_filepath.startswith("third_party/fbgemm/"):
  520. return True
  521. if rel_filepath.startswith("third_party/mslk/"):
  522. return True
  523. if rel_filepath.startswith("tools/autograd/templates/"):
  524. return True
  525. return False
  526. def is_cusparse_file(rel_filepath):
  527. _deprecated("is_cusparse_file")
  528. if is_pytorch_file(rel_filepath):
  529. return "sparse" in rel_filepath.lower()
  530. return False
  531. def is_special_file(rel_filepath) -> bool:
  532. _deprecated("is_special_file")
  533. if is_pytorch_file(rel_filepath):
  534. if "sparse" in rel_filepath.lower():
  535. return True
  536. elif "linalg" in rel_filepath.lower():
  537. if "batchlinearalgebralibblas" in rel_filepath.lower():
  538. return False # don't use "special" mappings for this specific linalg cublas file
  539. return True
  540. return False
  541. def is_caffe2_gpu_file(rel_filepath):
  542. _deprecated("is_caffe2_gpu_file")
  543. if os.path.isabs(rel_filepath):
  544. raise AssertionError("rel_filepath must be a relative path")
  545. if rel_filepath.startswith("c10/cuda"):
  546. return True
  547. filename = os.path.basename(rel_filepath)
  548. _, ext = os.path.splitext(filename)
  549. return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename)
  550. class TrieNode:
  551. """A Trie node whose children are represented as a directory of char: TrieNode.
  552. A special char '' represents end of word
  553. """
  554. def __init__(self) -> None:
  555. self.children = {}
  556. class Trie:
  557. """Creates a Trie out of a list of words. The trie can be exported to a Regex pattern.
  558. The corresponding Regex should match much faster than a simple Regex union."""
  559. def __init__(self) -> None:
  560. """Initialize the trie with an empty root node."""
  561. self.root = TrieNode()
  562. self._hash = hashlib.md5(usedforsecurity=False)
  563. self._digest = self._hash.digest()
  564. def add(self, word) -> None:
  565. """Add a word to the Trie. """
  566. self._hash.update(word.encode())
  567. self._digest = self._hash.digest()
  568. node = self.root
  569. for char in word:
  570. node.children.setdefault(char, TrieNode())
  571. node = node.children[char]
  572. node.children[''] = True # Mark the end of the word
  573. def dump(self):
  574. """Return the root node of Trie. """
  575. return self.root
  576. def quote(self, char):
  577. """ Escape a char for regex. """
  578. return re.escape(char)
  579. def search(self, word):
  580. """Search whether word is present in the Trie.
  581. Returns True if yes, else return False"""
  582. node = self.root
  583. for char in word:
  584. if char in node.children:
  585. node = node.children[char]
  586. else:
  587. return False
  588. # make sure to check the end-of-word marker present
  589. return '' in node.children
  590. @functools.lru_cache # noqa: B019
  591. def _pattern(self, root, digest):
  592. """Convert a Trie into a regular expression pattern
  593. Memoized on the hash digest of the trie, which is built incrementally
  594. during add().
  595. """
  596. node = root
  597. if "" in node.children and len(node.children.keys()) == 1:
  598. return None
  599. alt = [] # store alternative patterns
  600. cc = [] # store char to char classes
  601. q = 0 # for node representing the end of word
  602. for char in sorted(node.children.keys()):
  603. if isinstance(node.children[char], TrieNode):
  604. try:
  605. recurse = self._pattern(node.children[char], self._digest)
  606. alt.append(self.quote(char) + recurse)
  607. except Exception:
  608. cc.append(self.quote(char))
  609. else:
  610. q = 1
  611. cconly = not len(alt) > 0
  612. if len(cc) > 0:
  613. if len(cc) == 1:
  614. alt.append(cc[0])
  615. else:
  616. alt.append('[' + ''.join(cc) + ']')
  617. if len(alt) == 1:
  618. result = alt[0]
  619. else:
  620. result = "(?:" + "|".join(alt) + ")"
  621. if q:
  622. if cconly:
  623. result += "?"
  624. else:
  625. result = f"(?:{result})?"
  626. return result
  627. def pattern(self):
  628. """Export the Trie to a regex pattern."""
  629. return self._pattern(self.root, self._digest)
  630. def export_to_regex(self):
  631. """Export the Trie to a regex pattern."""
  632. return self._pattern(self.root, self._digest)
  633. PYTORCH_TRIE = Trie()
  634. PYTORCH_MAP: dict[str, object] = {}
  635. for mapping in CUDA_TO_HIP_MAPPINGS:
  636. if not isinstance(mapping, Mapping):
  637. raise TypeError("Expected each mapping in CUDA_TO_HIP_MAPPINGS to be a Mapping")
  638. for src, dst in mapping.items():
  639. PYTORCH_TRIE.add(src)
  640. PYTORCH_MAP[src] = dst
  641. RE_PYTORCH_PREPROCESSOR = re.compile(fr'(?<=\W)({PYTORCH_TRIE.export_to_regex()})(?=\W)')
  642. RE_QUOTE_HEADER = re.compile(r'#include "([^"]+)"')
  643. RE_ANGLE_HEADER = re.compile(r'#include <([^>]+)>')
  644. RE_THC_GENERIC_FILE = re.compile(r'#define THC_GENERIC_FILE "([^"]+)"')
  645. RE_CU_SUFFIX = re.compile(r'\.cu\b') # be careful not to pick up .cuh
  646. """
  647. Returns a HipifyResult object with the following details:
  648. "hipified_path" : absolute path of hipified source file
  649. "status" : "ok" if hipified file was written out
  650. "skipped" if an identical hipified file already existed or hipified file couldn't be written out
  651. "ignored" if the source file was a hipified file itself or not meant to be hipified
  652. "current_state" : CurrentState.INITIALIZED if source file is first ready to be hipified
  653. CurrentState.DONE if source file is done with hipification process
  654. """
  655. def preprocessor(
  656. output_directory: str,
  657. filepath: str,
  658. all_files: Iterable,
  659. header_include_dirs: Iterable,
  660. stats: dict[str, list],
  661. hip_clang_launch: bool,
  662. is_pytorch_extension: bool,
  663. clean_ctx: GeneratedFileCleaner,
  664. show_progress: bool) -> HipifyResult:
  665. """ Executes the CUDA -> HIP conversion on the specified file. """
  666. fin_path = os.path.abspath(os.path.join(output_directory, filepath))
  667. filepath = _to_unix_path(filepath)
  668. hipify_result = HIPIFY_FINAL_RESULT[fin_path]
  669. if filepath not in all_files:
  670. hipify_result.hipified_path = None
  671. hipify_result.status = "[ignored, not to be hipified]"
  672. hipify_result.current_state = CurrentState.DONE
  673. return hipify_result
  674. rel_filepath = _to_unix_path(os.path.relpath(filepath, output_directory))
  675. with open(fin_path, encoding='utf-8') as fin:
  676. if fin.readline() == HIPIFY_C_BREADCRUMB:
  677. hipify_result.hipified_path = None
  678. hipify_result.status = "[ignored, input is hipified output]"
  679. hipify_result.current_state = CurrentState.DONE
  680. return hipify_result
  681. fin.seek(0)
  682. output_source = fin.read()
  683. orig_output_source = output_source
  684. # get_hip_file_path needs a relative path to work correctly
  685. fout_path = os.path.abspath(os.path.join(output_directory, get_hip_file_path(rel_filepath, is_pytorch_extension)))
  686. if not os.path.exists(os.path.dirname(fout_path)):
  687. clean_ctx.makedirs(os.path.dirname(fout_path))
  688. # unsupported_calls statistics reporting is broken atm
  689. def pt_repl(m):
  690. return PYTORCH_MAP[m.group(0)]
  691. output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
  692. # TODO: Remove CAFFE2_PATH_MAPPINGS. They were necessary for Meta-internal builds.
  693. # Apply CAFFE2 path mappings (simple string replacement for paths containing slashes)
  694. # Need to be careful to avoid double-transformations when source file has #ifdef blocks
  695. # with HIP-specific paths already in them (e.g., caffe2/core/hip/context_gpu.h)
  696. for cuda_path, hip_path in CAFFE2_PATH_MAPPINGS.items():
  697. # Use regex to ensure we don't match paths that already have been hipified
  698. # We need to avoid transforming "caffe2/core/hip/context_gpu.h" when looking for "caffe2/core/context_gpu.h"
  699. # The key insight: if hip_path contains /hip/ and cuda_path doesn't, we need to be careful
  700. if "/hip/" in hip_path and "/hip/" not in cuda_path:
  701. # Only replace cuda_path if it's not preceded by "/hip/"
  702. # Use negative lookbehind to prevent matching already-hipified paths
  703. # The pattern checks that the cuda_path is not immediately preceded by "/hip/"
  704. pattern = r'(?<!/hip/)' + re.escape(cuda_path)
  705. output_source = re.sub(pattern, hip_path, output_source)
  706. else:
  707. # Simple replacement when no /hip/ involved or both have it
  708. output_source = output_source.replace(cuda_path, hip_path)
  709. # Header rewrites
  710. def mk_repl(templ, include_current_dir=True):
  711. def repl(m):
  712. f = m.group(1)
  713. filename = os.path.basename(f)
  714. if (
  715. f.startswith(("ATen/cuda",
  716. "ATen/native/cuda",
  717. "ATen/native/nested/cuda",
  718. "ATen/native/quantized/cuda",
  719. "ATen/native/sparse/cuda",
  720. "ATen/native/transformers/cuda",
  721. "THC/")) or
  722. (f.startswith("THC") and not f.startswith("THCP"))
  723. ):
  724. return templ.format(get_hip_file_path(m.group(1), is_pytorch_extension))
  725. # if filename is one of the files being hipified for this extension
  726. if (is_pytorch_extension and any(s.endswith(filename) for s in all_files)):
  727. header_dir = None
  728. header_filepath = None
  729. # If include_current_dir True, look first in same dir as the including source file
  730. if include_current_dir:
  731. header_dir_to_check = os.path.dirname(fin_path)
  732. header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f))
  733. if os.path.exists(header_path_to_check):
  734. header_dir = header_dir_to_check
  735. header_filepath = header_path_to_check
  736. # If not found, look in include dirs one by one and first match wins
  737. if header_filepath is None:
  738. for header_include_dir in header_include_dirs:
  739. header_dir_to_check = os.path.join(output_directory, header_include_dir)
  740. header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f))
  741. if os.path.exists(header_path_to_check):
  742. header_dir = header_dir_to_check
  743. header_filepath = header_path_to_check
  744. # If header file not found, keep as is
  745. if header_filepath is None:
  746. return m.group(0)
  747. # Hipify header file first if needed
  748. if header_filepath not in HIPIFY_FINAL_RESULT:
  749. preprocess_file_and_save_result(output_directory,
  750. header_filepath,
  751. all_files, header_include_dirs, stats, hip_clang_launch,
  752. is_pytorch_extension, clean_ctx, show_progress)
  753. elif header_filepath in HIPIFY_FINAL_RESULT:
  754. header_result = HIPIFY_FINAL_RESULT[header_filepath]
  755. if header_result.current_state == CurrentState.INITIALIZED:
  756. # get_hip_file_path needs a relative path to work correctly
  757. header_rel_path = os.path.relpath(header_filepath, output_directory)
  758. header_fout_path = os.path.abspath(os.path.join(output_directory,
  759. get_hip_file_path(header_rel_path, is_pytorch_extension)))
  760. header_result.hipified_path = header_fout_path
  761. HIPIFY_FINAL_RESULT[header_filepath] = header_result
  762. return templ.format(os.path.relpath(header_fout_path if header_fout_path is not None
  763. else header_filepath, header_dir))
  764. hipified_header_filepath = HIPIFY_FINAL_RESULT[header_filepath].hipified_path
  765. return templ.format(_to_unix_path(os.path.relpath(hipified_header_filepath if hipified_header_filepath is not None
  766. else header_filepath, header_dir)))
  767. return m.group(0)
  768. return repl
  769. output_source = RE_QUOTE_HEADER.sub(mk_repl('#include "{0}"', True), output_source)
  770. output_source = RE_ANGLE_HEADER.sub(mk_repl('#include <{0}>', False), output_source)
  771. output_source = RE_THC_GENERIC_FILE.sub(mk_repl('#define THC_GENERIC_FILE "{0}"'), output_source)
  772. # CMakeLists.txt rewrites
  773. if filepath.endswith('CMakeLists.txt'):
  774. output_source = output_source.replace('CUDA', 'HIP')
  775. output_source = output_source.replace('THC', 'THH')
  776. output_source = RE_CU_SUFFIX.sub('.hip', output_source)
  777. # Perform Kernel Launch Replacements
  778. if not hip_clang_launch:
  779. output_source = processKernelLaunches(output_source, stats)
  780. # Replace std:: with non-std:: versions
  781. if (filepath.endswith((".cu", ".cuh"))) and "PowKernel" not in filepath:
  782. output_source = replace_math_functions(output_source)
  783. # Include header if device code is contained.
  784. output_source = hip_header_magic(output_source)
  785. # Replace the extern __shared__
  786. # NOTE: No longer needed after transition from hcc to hipclang.
  787. # output_source = replace_extern_shared(output_source)
  788. # Don't write out identical hipified files for extensions if dirpath has not changed
  789. if (
  790. is_pytorch_extension
  791. and orig_output_source == output_source
  792. and os.path.dirname(fin_path) == os.path.dirname(fout_path)
  793. ):
  794. hipify_result.hipified_path = fin_path
  795. hipify_result.status = "[skipped, no changes]"
  796. hipify_result.current_state = CurrentState.DONE
  797. return hipify_result
  798. # Add hipify breadcrumb for C-style files to avoid re-hipification
  799. if fin_path != fout_path and match_extensions(fin_path, (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".hpp")):
  800. output_source = HIPIFY_C_BREADCRUMB + output_source
  801. do_write = True
  802. if os.path.exists(fout_path):
  803. with open(fout_path, encoding='utf-8') as fout_old:
  804. do_write = fout_old.read() != output_source
  805. if do_write:
  806. try:
  807. with clean_ctx.open(fout_path, 'w', encoding='utf-8') as fout:
  808. fout.write(output_source)
  809. hipify_result.hipified_path = fout_path
  810. hipify_result.status = "[ok]"
  811. hipify_result.current_state = CurrentState.DONE
  812. return hipify_result
  813. except OSError as e:
  814. print(f'{bcolors.WARNING}Failed to save {fout_path} with "{e.strerror}", leaving {fin_path} unchanged.{bcolors.ENDC}',
  815. file=sys.stderr)
  816. hipify_result.hipified_path = fin_path
  817. hipify_result.status = "[skipped, no permissions]"
  818. hipify_result.current_state = CurrentState.DONE
  819. return hipify_result
  820. else:
  821. hipify_result.hipified_path = fout_path
  822. hipify_result.status = "[skipped, already hipified]"
  823. hipify_result.current_state = CurrentState.DONE
  824. return hipify_result
  825. def file_specific_replacement(filepath, search_string, replace_string, strict=False) -> None:
  826. with openf(filepath, "r+") as f:
  827. contents = f.read()
  828. if strict:
  829. contents = re.sub(fr'\b({re.escape(search_string)})\b', lambda x: replace_string, contents)
  830. else:
  831. contents = contents.replace(search_string, replace_string)
  832. f.seek(0)
  833. f.write(contents)
  834. f.truncate()
  835. def file_add_header(filepath, header) -> None:
  836. with openf(filepath, "r+") as f:
  837. contents = f.read()
  838. if header[0] != "<" and header[-1] != ">":
  839. header = f'"{header}"'
  840. contents = (f'#include {header} \n') + contents
  841. f.seek(0)
  842. f.write(contents)
  843. f.truncate()
  844. def fix_static_global_kernels(in_txt):
  845. """Static global kernels in HIP results in a compilation error."""
  846. in_txt = in_txt.replace(" __global__ static", "__global__")
  847. return in_txt
  848. RE_INCLUDE = re.compile(r"#include .*\n")
  849. def extract_arguments(start, string):
  850. """
  851. Return the list of arguments in the upcoming function parameter closure.
  852. Example:
  853. string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))'
  854. arguments (output): [{'start': 1, 'end': 7}, {'start': 8, 'end': 16}, \
  855. {'start': 17, 'end': 19}, {'start': 20, 'end': 53}]
  856. """
  857. arguments = []
  858. closures = {
  859. "<": 0,
  860. "(": 0
  861. }
  862. current_position = start
  863. argument_start_pos = current_position + 1
  864. # Search for final parenthesis
  865. while current_position < len(string):
  866. if string[current_position] == "(":
  867. closures["("] += 1
  868. elif string[current_position] == ")":
  869. closures["("] -= 1
  870. elif string[current_position] == "<":
  871. closures["<"] += 1
  872. elif string[current_position] == ">" and string[current_position - 1] != "-" and closures["<"] > 0:
  873. closures["<"] -= 1
  874. # Finished all arguments
  875. if closures["("] == 0 and closures["<"] == 0:
  876. # Add final argument
  877. arguments.append({"start": argument_start_pos, "end": current_position})
  878. break
  879. # Finished current argument
  880. if closures["("] == 1 and closures["<"] == 0 and string[current_position] == ",":
  881. arguments.append({"start": argument_start_pos, "end": current_position})
  882. argument_start_pos = current_position + 1
  883. current_position += 1
  884. return arguments
  885. def str2bool(v : str) -> bool:
  886. """ArgumentParser doesn't support type=bool. Thus, this helper method will convert
  887. from possible string types to True / False."""
  888. if v.lower() in ('yes', 'true', 't', 'y', '1'):
  889. return True
  890. elif v.lower() in ('no', 'false', 'f', 'n', '0'):
  891. return False
  892. else:
  893. raise argparse.ArgumentTypeError('Boolean value expected.')
  894. def hipify(
  895. project_directory: str,
  896. show_detailed: bool = False,
  897. extensions: Iterable = (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"),
  898. header_extensions: Iterable = (".cuh", ".h", ".hpp"),
  899. output_directory: str = "",
  900. header_include_dirs: Iterable = (),
  901. includes: Iterable = ('*',),
  902. extra_files: Iterable = (),
  903. out_of_place_only: bool = False,
  904. ignores: Iterable = (),
  905. show_progress: bool = True,
  906. hip_clang_launch: bool = False,
  907. is_pytorch_extension: bool = False,
  908. hipify_extra_files_only: bool = False,
  909. clean_ctx: GeneratedFileCleaner | None = None
  910. ) -> HipifyFinalResult:
  911. if project_directory == "":
  912. project_directory = os.getcwd()
  913. # Verify the project directory exists.
  914. if not os.path.exists(project_directory):
  915. print("The project folder specified does not exist.")
  916. sys.exit(1)
  917. # If no output directory, provide a default one.
  918. if not output_directory:
  919. project_directory.rstrip("/")
  920. output_directory = project_directory + "_amd"
  921. if project_directory != output_directory:
  922. includes = [include.replace(project_directory, output_directory) for include in includes]
  923. ignores = [ignore.replace(project_directory, output_directory) for ignore in ignores]
  924. # Copy from project directory to output directory if not done already.
  925. if not os.path.exists(output_directory):
  926. shutil.copytree(project_directory, output_directory)
  927. includes = list(map(_to_unix_path, includes))
  928. ignores = list(map(_to_unix_path, ignores))
  929. all_files = list(matched_files_iter(output_directory, includes=includes,
  930. ignores=ignores, extensions=extensions,
  931. out_of_place_only=out_of_place_only,
  932. is_pytorch_extension=is_pytorch_extension))
  933. all_files_set = set(all_files)
  934. for f in extra_files:
  935. if not os.path.isabs(f):
  936. f = os.path.join(output_directory, f)
  937. if f not in all_files_set:
  938. all_files.append(f)
  939. # List all files in header_include_paths to ensure they are hipified
  940. from pathlib import Path
  941. for header_include_dir in header_include_dirs:
  942. if os.path.isabs(header_include_dir):
  943. header_include_dir_path = Path(header_include_dir)
  944. else:
  945. header_include_dir_path = Path(os.path.join(output_directory, header_include_dir))
  946. all_files.extend(
  947. str(path) for path in header_include_dir_path.rglob('*') if path.is_file()
  948. and _fnmatch(str(path), includes)
  949. and (not _fnmatch(str(path), ignores))
  950. and match_extensions(path.name, header_extensions)
  951. )
  952. if clean_ctx is None:
  953. clean_ctx = GeneratedFileCleaner(keep_intermediates=True)
  954. # Preprocessing statistics.
  955. stats: dict[str, list] = {"unsupported_calls": [], "kernel_launches": []}
  956. for filepath in (all_files if not hipify_extra_files_only else extra_files):
  957. preprocess_file_and_save_result(output_directory, filepath, all_files, header_include_dirs,
  958. stats, hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)
  959. print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr)
  960. # Show detailed summary
  961. if show_detailed:
  962. compute_stats(stats)
  963. return HIPIFY_FINAL_RESULT