hash.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/util/Exception.h>
  4. #include <cstddef>
  5. #include <functional>
  6. #include <iomanip>
  7. #include <ios>
  8. #include <sstream>
  9. #include <string>
  10. #include <tuple>
  11. #include <type_traits>
  12. #include <utility>
  13. #include <vector>
  14. #include <c10/util/ArrayRef.h>
  15. #include <c10/util/complex.h>
  16. namespace c10 {
  17. // NOTE: hash_combine and SHA1 hashing is based on implementation from Boost
  18. //
  19. // Boost Software License - Version 1.0 - August 17th, 2003
  20. //
  21. // Permission is hereby granted, free of charge, to any person or organization
  22. // obtaining a copy of the software and accompanying documentation covered by
  23. // this license (the "Software") to use, reproduce, display, distribute,
  24. // execute, and transmit the Software, and to prepare derivative works of the
  25. // Software, and to permit third-parties to whom the Software is furnished to
  26. // do so, all subject to the following:
  27. //
  28. // The copyright notices in the Software and this entire statement, including
  29. // the above license grant, this restriction and the following disclaimer,
  30. // must be included in all copies of the Software, in whole or in part, and
  31. // all derivative works of the Software, unless such copies or derivative
  32. // works are solely in the form of machine-executable object code generated by
  33. // a source language processor.
  34. //
  35. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  36. // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  37. // FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
  38. // SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
  39. // FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
  40. // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
  41. // DEALINGS IN THE SOFTWARE.
  42. inline size_t hash_combine(size_t seed, size_t value) {
  43. return seed ^ (value + 0x9e3779b9 + (seed << 6u) + (seed >> 2u));
  44. }
  45. // Creates the SHA1 hash of a string. A 160-bit hash.
  46. // Based on the implementation in Boost (see notice above).
  47. // Note that SHA1 hashes are no longer considered cryptographically
  48. // secure, but are the standard hash for generating unique ids.
  49. // Usage:
  50. // // Let 'code' be a std::string
  51. // c10::sha1 sha1_hash{code};
  52. // const auto hash_code = sha1_hash.str();
  53. // TODO: Compare vs OpenSSL and/or CryptoPP implementations
  54. struct sha1 {
  55. typedef unsigned int(digest_type)[5];
  56. sha1(const std::string& s = "") {
  57. if (!s.empty()) {
  58. reset();
  59. process_bytes(s.c_str(), s.size());
  60. }
  61. }
  62. void reset() {
  63. h_[0] = 0x67452301;
  64. h_[1] = 0xEFCDAB89;
  65. h_[2] = 0x98BADCFE;
  66. h_[3] = 0x10325476;
  67. h_[4] = 0xC3D2E1F0;
  68. block_byte_index_ = 0;
  69. bit_count_low = 0;
  70. bit_count_high = 0;
  71. }
  72. std::string str() {
  73. unsigned int digest[5];
  74. get_digest(digest);
  75. std::ostringstream buf;
  76. for (unsigned int i : digest) {
  77. buf << std::hex << std::setfill('0') << std::setw(8) << i;
  78. }
  79. return buf.str();
  80. }
  81. private:
  82. unsigned int left_rotate(unsigned int x, std::size_t n) {
  83. return (x << n) ^ (x >> (32 - n));
  84. }
  85. void process_block_impl() {
  86. unsigned int w[80];
  87. for (std::size_t i = 0; i < 16; ++i) {
  88. w[i] = (block_[i * 4 + 0] << 24);
  89. w[i] |= (block_[i * 4 + 1] << 16);
  90. w[i] |= (block_[i * 4 + 2] << 8);
  91. w[i] |= (block_[i * 4 + 3]);
  92. }
  93. for (std::size_t i = 16; i < 80; ++i) {
  94. w[i] = left_rotate((w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16]), 1);
  95. }
  96. unsigned int a = h_[0];
  97. unsigned int b = h_[1];
  98. unsigned int c = h_[2];
  99. unsigned int d = h_[3];
  100. unsigned int e = h_[4];
  101. for (std::size_t i = 0; i < 80; ++i) {
  102. unsigned int f = 0;
  103. unsigned int k = 0;
  104. if (i < 20) {
  105. f = (b & c) | (~b & d);
  106. k = 0x5A827999;
  107. } else if (i < 40) {
  108. f = b ^ c ^ d;
  109. k = 0x6ED9EBA1;
  110. } else if (i < 60) {
  111. f = (b & c) | (b & d) | (c & d);
  112. k = 0x8F1BBCDC;
  113. } else {
  114. f = b ^ c ^ d;
  115. k = 0xCA62C1D6;
  116. }
  117. unsigned temp = left_rotate(a, 5) + f + e + k + w[i];
  118. e = d;
  119. d = c;
  120. c = left_rotate(b, 30);
  121. b = a;
  122. a = temp;
  123. }
  124. h_[0] += a;
  125. h_[1] += b;
  126. h_[2] += c;
  127. h_[3] += d;
  128. h_[4] += e;
  129. }
  130. void process_byte_impl(unsigned char byte) {
  131. block_[block_byte_index_++] = byte;
  132. if (block_byte_index_ == 64) {
  133. block_byte_index_ = 0;
  134. process_block_impl();
  135. }
  136. }
  137. void process_byte(unsigned char byte) {
  138. process_byte_impl(byte);
  139. // size_t max value = 0xFFFFFFFF
  140. // if (bit_count_low + 8 >= 0x100000000) { // would overflow
  141. // if (bit_count_low >= 0x100000000-8) {
  142. if (bit_count_low < 0xFFFFFFF8) {
  143. bit_count_low += 8;
  144. } else {
  145. bit_count_low = 0;
  146. if (bit_count_high <= 0xFFFFFFFE) {
  147. ++bit_count_high;
  148. } else {
  149. TORCH_CHECK(false, "sha1 too many bytes");
  150. }
  151. }
  152. }
  153. void process_block(void const* bytes_begin, void const* bytes_end) {
  154. unsigned char const* begin = static_cast<unsigned char const*>(bytes_begin);
  155. unsigned char const* end = static_cast<unsigned char const*>(bytes_end);
  156. for (; begin != end; ++begin) {
  157. process_byte(*begin);
  158. }
  159. }
  160. void process_bytes(void const* buffer, std::size_t byte_count) {
  161. unsigned char const* b = static_cast<unsigned char const*>(buffer);
  162. process_block(b, b + byte_count);
  163. }
  164. void get_digest(digest_type& digest) {
  165. // append the bit '1' to the message
  166. process_byte_impl(0x80);
  167. // append k bits '0', where k is the minimum number >= 0
  168. // such that the resulting message length is congruent to 56 (mod 64)
  169. // check if there is enough space for padding and bit_count
  170. if (block_byte_index_ > 56) {
  171. // finish this block
  172. while (block_byte_index_ != 0) {
  173. process_byte_impl(0);
  174. }
  175. // one more block
  176. while (block_byte_index_ < 56) {
  177. process_byte_impl(0);
  178. }
  179. } else {
  180. while (block_byte_index_ < 56) {
  181. process_byte_impl(0);
  182. }
  183. }
  184. // append length of message (before pre-processing)
  185. // as a 64-bit big-endian integer
  186. process_byte_impl(
  187. static_cast<unsigned char>((bit_count_high >> 24) & 0xFF));
  188. process_byte_impl(
  189. static_cast<unsigned char>((bit_count_high >> 16) & 0xFF));
  190. process_byte_impl(static_cast<unsigned char>((bit_count_high >> 8) & 0xFF));
  191. process_byte_impl(static_cast<unsigned char>((bit_count_high) & 0xFF));
  192. process_byte_impl(static_cast<unsigned char>((bit_count_low >> 24) & 0xFF));
  193. process_byte_impl(static_cast<unsigned char>((bit_count_low >> 16) & 0xFF));
  194. process_byte_impl(static_cast<unsigned char>((bit_count_low >> 8) & 0xFF));
  195. process_byte_impl(static_cast<unsigned char>((bit_count_low) & 0xFF));
  196. // get final digest
  197. digest[0] = h_[0];
  198. digest[1] = h_[1];
  199. digest[2] = h_[2];
  200. digest[3] = h_[3];
  201. digest[4] = h_[4];
  202. }
  203. unsigned int h_[5]{};
  204. unsigned char block_[64]{};
  205. std::size_t block_byte_index_{};
  206. std::size_t bit_count_low{};
  207. std::size_t bit_count_high{};
  208. };
  209. constexpr uint64_t twang_mix64(uint64_t key) noexcept {
  210. key = (~key) + (key << 21); // key *= (1 << 21) - 1; key -= 1;
  211. key = key ^ (key >> 24);
  212. key = key + (key << 3) + (key << 8); // key *= 1 + (1 << 3) + (1 << 8)
  213. key = key ^ (key >> 14);
  214. key = key + (key << 2) + (key << 4); // key *= 1 + (1 << 2) + (1 << 4)
  215. key = key ^ (key >> 28);
  216. key = key + (key << 31); // key *= 1 + (1 << 31)
  217. return key;
  218. }
  219. ////////////////////////////////////////////////////////////////////////////////
  220. // c10::hash implementation
  221. ////////////////////////////////////////////////////////////////////////////////
  222. namespace _hash_detail {
  223. // Use template argument deduction to shorten calls to c10::hash
  224. template <typename T>
  225. size_t simple_get_hash(const T& o);
  226. template <typename T, typename V>
  227. using type_if_not_enum = std::enable_if_t<!std::is_enum_v<T>, V>;
  228. // Use SFINAE to dispatch to std::hash if possible, cast enum types to int
  229. // automatically, and fall back to T::hash otherwise. NOTE: C++14 added support
  230. // for hashing enum types to the standard, and some compilers implement it even
  231. // when C++14 flags aren't specified. This is why we have to disable this
  232. // overload if T is an enum type (and use the one below in this case).
  233. template <typename T>
  234. auto dispatch_hash(const T& o)
  235. -> decltype(std::hash<T>()(o), type_if_not_enum<T, size_t>()) {
  236. return std::hash<T>()(o);
  237. }
  238. template <typename T>
  239. std::enable_if_t<std::is_enum_v<T>, size_t> dispatch_hash(const T& o) {
  240. using R = std::underlying_type_t<T>;
  241. return std::hash<R>()(static_cast<R>(o));
  242. }
  243. template <typename T>
  244. auto dispatch_hash(const T& o) -> decltype(T::hash(o), size_t()) {
  245. return T::hash(o);
  246. }
  247. } // namespace _hash_detail
  248. // Hasher struct
  249. template <typename T>
  250. struct hash {
  251. size_t operator()(const T& o) const {
  252. return _hash_detail::dispatch_hash(o);
  253. }
  254. };
  255. // Specialization for std::tuple
  256. template <typename... Types>
  257. struct hash<std::tuple<Types...>> {
  258. template <size_t idx, typename... Ts>
  259. struct tuple_hash {
  260. size_t operator()(const std::tuple<Ts...>& t) const {
  261. return hash_combine(
  262. _hash_detail::simple_get_hash(std::get<idx>(t)),
  263. tuple_hash<idx - 1, Ts...>()(t));
  264. }
  265. };
  266. template <typename... Ts>
  267. struct tuple_hash<0, Ts...> {
  268. size_t operator()(const std::tuple<Ts...>& t) const {
  269. return _hash_detail::simple_get_hash(std::get<0>(t));
  270. }
  271. };
  272. size_t operator()(const std::tuple<Types...>& t) const {
  273. return tuple_hash<sizeof...(Types) - 1, Types...>()(t);
  274. }
  275. };
  276. template <typename T1, typename T2>
  277. struct hash<std::pair<T1, T2>> {
  278. size_t operator()(const std::pair<T1, T2>& pair) const {
  279. std::tuple<T1, T2> tuple = std::make_tuple(pair.first, pair.second);
  280. return _hash_detail::simple_get_hash(tuple);
  281. }
  282. };
  283. template <typename T>
  284. struct hash<c10::ArrayRef<T>> {
  285. size_t operator()(c10::ArrayRef<T> v) const {
  286. size_t seed = 0;
  287. for (const auto& elem : v) {
  288. seed = hash_combine(seed, _hash_detail::simple_get_hash(elem));
  289. }
  290. return seed;
  291. }
  292. };
  293. // Specialization for std::vector
  294. template <typename T>
  295. struct hash<std::vector<T>> {
  296. size_t operator()(const std::vector<T>& v) const {
  297. return hash<c10::ArrayRef<T>>()(v);
  298. }
  299. };
  300. namespace _hash_detail {
  301. template <typename T>
  302. size_t simple_get_hash(const T& o) {
  303. return c10::hash<T>()(o);
  304. }
  305. } // namespace _hash_detail
  306. // Use this function to actually hash multiple things in one line.
  307. // Dispatches to c10::hash, so it can hash containers.
  308. // Example:
  309. //
  310. // static size_t hash(const MyStruct& s) {
  311. // return get_hash(s.member1, s.member2, s.member3);
  312. // }
  313. template <typename... Types>
  314. size_t get_hash(const Types&... args) {
  315. return c10::hash<decltype(std::tie(args...))>()(std::tie(args...));
  316. }
  317. // Specialization for c10::complex
  318. template <typename T>
  319. struct hash<c10::complex<T>> {
  320. size_t operator()(const c10::complex<T>& c) const {
  321. return get_hash(c.real(), c.imag());
  322. }
  323. };
  324. } // namespace c10
  325. #else
  326. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  327. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)