Registry.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #ifndef C10_UTIL_REGISTRY_H_
  3. #define C10_UTIL_REGISTRY_H_
  4. /**
  5. * Simple registry implementation that uses static variables to
  6. * register object creators during program initialization time.
  7. */
  8. // NB: This Registry works poorly when you have other namespaces.
  9. // Make all macro invocations from inside the at namespace.
  10. #include <cstdio>
  11. #include <cstdlib>
  12. #include <functional>
  13. #include <memory>
  14. #include <mutex>
  15. #include <stdexcept>
  16. #include <string>
  17. #include <unordered_map>
  18. #include <vector>
  19. #include <c10/macros/Export.h>
  20. #include <c10/macros/Macros.h>
  21. #include <c10/util/Type.h>
  22. namespace c10 {
  23. template <typename KeyType>
  24. inline std::string KeyStrRepr(const KeyType& /*key*/) {
  25. return "[key type printing not supported]";
  26. }
  27. template <>
  28. inline std::string KeyStrRepr(const std::string& key) {
  29. return key;
  30. }
  31. enum RegistryPriority {
  32. REGISTRY_FALLBACK = 1,
  33. REGISTRY_DEFAULT = 2,
  34. REGISTRY_PREFERRED = 3,
  35. };
  36. /**
  37. * @brief A template class that allows one to register classes by keys.
  38. *
  39. * The keys are usually a std::string specifying the name, but can be anything
  40. * that can be used in a std::map.
  41. *
  42. * You should most likely not use the Registry class explicitly, but use the
  43. * helper macros below to declare specific registries as well as registering
  44. * objects.
  45. */
  46. template <class SrcType, class ObjectPtrType, class... Args>
  47. class Registry {
  48. public:
  49. typedef std::function<ObjectPtrType(Args...)> Creator;
  50. Registry(bool warning = true) : registry_(), priority_(), warning_(warning) {}
  51. ~Registry() = default;
  52. void Register(
  53. const SrcType& key,
  54. Creator creator,
  55. const RegistryPriority priority = REGISTRY_DEFAULT) {
  56. std::lock_guard<std::mutex> lock(register_mutex_);
  57. // The if statement below is essentially the same as the following line:
  58. // TORCH_CHECK_EQ(registry_.count(key), 0) << "Key " << key
  59. // << " registered twice.";
  60. // However, TORCH_CHECK_EQ depends on google logging, and since registration
  61. // is carried out at static initialization time, we do not want to have an
  62. // explicit dependency on glog's initialization function.
  63. if (registry_.count(key) != 0) {
  64. auto cur_priority = priority_[key];
  65. if (priority > cur_priority) {
  66. #ifdef DEBUG
  67. std::string warn_msg =
  68. "Overwriting already registered item for key " + KeyStrRepr(key);
  69. fprintf(stderr, "%s\n", warn_msg.c_str());
  70. #endif
  71. registry_[key] = creator;
  72. priority_[key] = priority;
  73. } else if (priority == cur_priority) {
  74. std::string err_msg =
  75. "Key already registered with the same priority: " + KeyStrRepr(key);
  76. fprintf(stderr, "%s\n", err_msg.c_str());
  77. if (terminate_) {
  78. std::exit(1);
  79. } else {
  80. throw std::runtime_error(err_msg);
  81. }
  82. } else if (warning_) {
  83. std::string warn_msg =
  84. "Higher priority item already registered, skipping registration of " +
  85. KeyStrRepr(key);
  86. fprintf(stderr, "%s\n", warn_msg.c_str());
  87. }
  88. } else {
  89. registry_[key] = creator;
  90. priority_[key] = priority;
  91. }
  92. }
  93. void Register(
  94. const SrcType& key,
  95. Creator creator,
  96. const std::string& help_msg,
  97. const RegistryPriority priority = REGISTRY_DEFAULT) {
  98. Register(key, creator, priority);
  99. help_message_[key] = help_msg;
  100. }
  101. inline bool Has(const SrcType& key) {
  102. return (registry_.count(key) != 0);
  103. }
  104. ObjectPtrType Create(const SrcType& key, Args... args) {
  105. auto it = registry_.find(key);
  106. if (it == registry_.end()) {
  107. // Returns nullptr if the key is not registered.
  108. return nullptr;
  109. }
  110. return it->second(args...);
  111. }
  112. /**
  113. * Returns the keys currently registered as a std::vector.
  114. */
  115. std::vector<SrcType> Keys() const {
  116. std::vector<SrcType> keys;
  117. keys.reserve(registry_.size());
  118. for (const auto& it : registry_) {
  119. keys.push_back(it.first);
  120. }
  121. return keys;
  122. }
  123. inline const std::unordered_map<SrcType, std::string>& HelpMessage() const {
  124. return help_message_;
  125. }
  126. const char* HelpMessage(const SrcType& key) const {
  127. auto it = help_message_.find(key);
  128. if (it == help_message_.end()) {
  129. return nullptr;
  130. }
  131. return it->second.c_str();
  132. }
  133. // Used for testing, if terminate is unset, Registry throws instead of
  134. // calling std::exit
  135. void SetTerminate(bool terminate) {
  136. terminate_ = terminate;
  137. }
  138. C10_DISABLE_COPY_AND_ASSIGN(Registry);
  139. Registry(Registry&&) = delete;
  140. Registry& operator=(Registry&&) = delete;
  141. private:
  142. std::unordered_map<SrcType, Creator> registry_;
  143. std::unordered_map<SrcType, RegistryPriority> priority_;
  144. bool terminate_{true};
  145. const bool warning_;
  146. std::unordered_map<SrcType, std::string> help_message_;
  147. std::mutex register_mutex_;
  148. };
  149. template <class SrcType, class ObjectPtrType, class... Args>
  150. class Registerer {
  151. public:
  152. explicit Registerer(
  153. const SrcType& key,
  154. Registry<SrcType, ObjectPtrType, Args...>* registry,
  155. typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator,
  156. const std::string& help_msg = "") {
  157. registry->Register(key, creator, help_msg);
  158. }
  159. explicit Registerer(
  160. const SrcType& key,
  161. const RegistryPriority priority,
  162. Registry<SrcType, ObjectPtrType, Args...>* registry,
  163. typename Registry<SrcType, ObjectPtrType, Args...>::Creator creator,
  164. const std::string& help_msg = "") {
  165. registry->Register(key, creator, help_msg, priority);
  166. }
  167. template <class DerivedType>
  168. static ObjectPtrType DefaultCreator(Args... args) {
  169. return ObjectPtrType(new DerivedType(args...));
  170. }
  171. };
  172. /**
  173. * C10_DECLARE_TYPED_REGISTRY is a macro that expands to a function
  174. * declaration, as well as creating a convenient typename for its corresponding
  175. * registerer.
  176. */
  177. // Note on C10_IMPORT and C10_EXPORT below: we need to explicitly mark DECLARE
  178. // as import and DEFINE as export, because these registry macros will be used
  179. // in downstream shared libraries as well, and one cannot use *_API - the API
  180. // macro will be defined on a per-shared-library basis. Semantically, when one
  181. // declares a typed registry it is always going to be IMPORT, and when one
  182. // defines a registry (which should happen ONLY ONCE and ONLY IN SOURCE FILE),
  183. // the instantiation unit is always going to be exported.
  184. //
  185. // The only unique condition is when in the same file one does DECLARE and
  186. // DEFINE - in Windows compilers, this generates a warning that dllimport and
  187. // dllexport are mixed, but the warning is fine and linker will be properly
  188. // exporting the symbol. Same thing happens in the gflags flag declaration and
  189. // definition caes.
  190. #define C10_DECLARE_TYPED_REGISTRY( \
  191. RegistryName, SrcType, ObjectType, PtrType, ...) \
  192. C10_API ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
  193. RegistryName(); \
  194. typedef ::c10::Registerer<SrcType, PtrType<ObjectType>, ##__VA_ARGS__> \
  195. Registerer##RegistryName
  196. #define TORCH_DECLARE_TYPED_REGISTRY( \
  197. RegistryName, SrcType, ObjectType, PtrType, ...) \
  198. TORCH_API ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
  199. RegistryName(); \
  200. typedef ::c10::Registerer<SrcType, PtrType<ObjectType>, ##__VA_ARGS__> \
  201. Registerer##RegistryName
  202. #define C10_DEFINE_TYPED_REGISTRY( \
  203. RegistryName, SrcType, ObjectType, PtrType, ...) \
  204. C10_EXPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
  205. RegistryName() { \
  206. static ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
  207. registry = new ::c10:: \
  208. Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>(); \
  209. return registry; \
  210. }
  211. #define C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
  212. RegistryName, SrcType, ObjectType, PtrType, ...) \
  213. C10_EXPORT ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
  214. RegistryName() { \
  215. static ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>* \
  216. registry = \
  217. new ::c10::Registry<SrcType, PtrType<ObjectType>, ##__VA_ARGS__>( \
  218. false); \
  219. return registry; \
  220. }
  221. // Note(Yangqing): The __VA_ARGS__ below allows one to specify a templated
  222. // creator with comma in its templated arguments.
  223. #define C10_REGISTER_TYPED_CREATOR(RegistryName, key, ...) \
  224. static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
  225. key, RegistryName(), ##__VA_ARGS__);
  226. #define C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \
  227. RegistryName, key, priority, ...) \
  228. static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
  229. key, priority, RegistryName(), ##__VA_ARGS__);
  230. #define C10_REGISTER_TYPED_CLASS(RegistryName, key, ...) \
  231. static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
  232. key, \
  233. RegistryName(), \
  234. Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \
  235. ::c10::demangle_type<__VA_ARGS__>());
  236. #define C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \
  237. RegistryName, key, priority, ...) \
  238. static Registerer##RegistryName C10_ANONYMOUS_VARIABLE(g_##RegistryName)( \
  239. key, \
  240. priority, \
  241. RegistryName(), \
  242. Registerer##RegistryName::DefaultCreator<__VA_ARGS__>, \
  243. ::c10::demangle_type<__VA_ARGS__>());
  244. // C10_DECLARE_REGISTRY and C10_DEFINE_REGISTRY are hard-wired to use
  245. // std::string as the key type, because that is the most commonly used cases.
  246. #define C10_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
  247. C10_DECLARE_TYPED_REGISTRY( \
  248. RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
  249. #define TORCH_DECLARE_REGISTRY(RegistryName, ObjectType, ...) \
  250. TORCH_DECLARE_TYPED_REGISTRY( \
  251. RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
  252. #define C10_DEFINE_REGISTRY(RegistryName, ObjectType, ...) \
  253. C10_DEFINE_TYPED_REGISTRY( \
  254. RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
  255. #define C10_DEFINE_REGISTRY_WITHOUT_WARNING(RegistryName, ObjectType, ...) \
  256. C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
  257. RegistryName, std::string, ObjectType, std::unique_ptr, ##__VA_ARGS__)
  258. #define C10_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
  259. C10_DECLARE_TYPED_REGISTRY( \
  260. RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
  261. #define TORCH_DECLARE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
  262. TORCH_DECLARE_TYPED_REGISTRY( \
  263. RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
  264. #define C10_DEFINE_SHARED_REGISTRY(RegistryName, ObjectType, ...) \
  265. C10_DEFINE_TYPED_REGISTRY( \
  266. RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
  267. #define C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING( \
  268. RegistryName, ObjectType, ...) \
  269. C10_DEFINE_TYPED_REGISTRY_WITHOUT_WARNING( \
  270. RegistryName, std::string, ObjectType, std::shared_ptr, ##__VA_ARGS__)
  271. // C10_REGISTER_CREATOR and C10_REGISTER_CLASS are hard-wired to use std::string
  272. // as the key
  273. // type, because that is the most commonly used cases.
  274. #define C10_REGISTER_CREATOR(RegistryName, key, ...) \
  275. C10_REGISTER_TYPED_CREATOR(RegistryName, #key, __VA_ARGS__)
  276. #define C10_REGISTER_CREATOR_WITH_PRIORITY(RegistryName, key, priority, ...) \
  277. C10_REGISTER_TYPED_CREATOR_WITH_PRIORITY( \
  278. RegistryName, #key, priority, __VA_ARGS__)
  279. #define C10_REGISTER_CLASS(RegistryName, key, ...) \
  280. C10_REGISTER_TYPED_CLASS(RegistryName, #key, __VA_ARGS__)
  281. #define C10_REGISTER_CLASS_WITH_PRIORITY(RegistryName, key, priority, ...) \
  282. C10_REGISTER_TYPED_CLASS_WITH_PRIORITY( \
  283. RegistryName, #key, priority, __VA_ARGS__)
  284. } // namespace c10
  285. #endif // C10_UTIL_REGISTRY_H_
  286. #else
  287. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  288. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)