tests.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. import logging
  2. import tempfile
  3. import unittest
  4. import yacs.config
  5. import yaml
  6. from yacs.config import CfgNode as CN
  7. try:
  8. _ignore = unicode # noqa: F821
  9. PY2 = True
  10. except Exception as _ignore:
  11. PY2 = False
  12. class SubCN(CN):
  13. pass
  14. def get_cfg(cls=CN):
  15. cfg = cls()
  16. cfg.NUM_GPUS = 8
  17. cfg.TRAIN = cls()
  18. cfg.TRAIN.HYPERPARAMETER_1 = 0.1
  19. cfg.TRAIN.SCALES = (2, 4, 8, 16)
  20. cfg.MODEL = cls()
  21. cfg.MODEL.TYPE = "a_foo_model"
  22. # Some extra stuff to test CfgNode.__str__
  23. cfg.STR = cls()
  24. cfg.STR.KEY1 = 1
  25. cfg.STR.KEY2 = 2
  26. cfg.STR.FOO = cls()
  27. cfg.STR.FOO.KEY1 = 1
  28. cfg.STR.FOO.KEY2 = 2
  29. cfg.STR.FOO.BAR = cls()
  30. cfg.STR.FOO.BAR.KEY1 = 1
  31. cfg.STR.FOO.BAR.KEY2 = 2
  32. cfg.register_deprecated_key("FINAL_MSG")
  33. cfg.register_deprecated_key("MODEL.DILATION")
  34. cfg.register_renamed_key(
  35. "EXAMPLE.OLD.KEY",
  36. "EXAMPLE.NEW.KEY",
  37. message="Please update your config fil config file.",
  38. )
  39. cfg.KWARGS = cls(new_allowed=True)
  40. cfg.KWARGS.z = 0
  41. cfg.KWARGS.Y = cls()
  42. cfg.KWARGS.Y.X = 1
  43. return cfg
  44. class TestCfgNode(unittest.TestCase):
  45. def test_immutability(self):
  46. # Top level immutable
  47. a = CN()
  48. a.foo = 0
  49. a.freeze()
  50. with self.assertRaises(AttributeError):
  51. a.foo = 1
  52. a.bar = 1
  53. assert a.is_frozen()
  54. assert a.foo == 0
  55. a.defrost()
  56. assert not a.is_frozen()
  57. a.foo = 1
  58. assert a.foo == 1
  59. # Recursively immutable
  60. a.level1 = CN()
  61. a.level1.foo = 0
  62. a.level1.level2 = CN()
  63. a.level1.level2.foo = 0
  64. a.freeze()
  65. assert a.is_frozen()
  66. with self.assertRaises(AttributeError):
  67. a.level1.level2.foo = 1
  68. a.level1.bar = 1
  69. assert a.level1.level2.foo == 0
  70. class TestCfg(unittest.TestCase):
  71. def test_copy_cfg(self):
  72. cfg = get_cfg()
  73. cfg2 = cfg.clone()
  74. s = cfg.MODEL.TYPE
  75. cfg2.MODEL.TYPE = "dummy"
  76. assert cfg.MODEL.TYPE == s
  77. def test_merge_cfg_from_cfg(self):
  78. # Test: merge from clone
  79. cfg = get_cfg()
  80. s = "dummy0"
  81. cfg2 = cfg.clone()
  82. cfg2.MODEL.TYPE = s
  83. cfg.merge_from_other_cfg(cfg2)
  84. assert cfg.MODEL.TYPE == s
  85. # Test: merge from yaml
  86. s = "dummy1"
  87. cfg2 = CN.load_cfg(cfg.dump())
  88. cfg2.MODEL.TYPE = s
  89. cfg.merge_from_other_cfg(cfg2)
  90. assert cfg.MODEL.TYPE == s
  91. # Test: merge with a valid key
  92. s = "dummy2"
  93. cfg2 = CN()
  94. cfg2.MODEL = CN()
  95. cfg2.MODEL.TYPE = s
  96. cfg.merge_from_other_cfg(cfg2)
  97. assert cfg.MODEL.TYPE == s
  98. # Test: merge with an invalid key
  99. s = "dummy3"
  100. cfg2 = CN()
  101. cfg2.FOO = CN()
  102. cfg2.FOO.BAR = s
  103. with self.assertRaises(KeyError):
  104. cfg.merge_from_other_cfg(cfg2)
  105. # Test: merge with converted type
  106. cfg2 = CN()
  107. cfg2.TRAIN = CN()
  108. cfg2.TRAIN.SCALES = [1]
  109. cfg.merge_from_other_cfg(cfg2)
  110. assert type(cfg.TRAIN.SCALES) is tuple
  111. assert cfg.TRAIN.SCALES[0] == 1
  112. # Test str (bytes) <-> unicode conversion for py2
  113. if PY2:
  114. cfg.A_UNICODE_KEY = u"foo"
  115. cfg2 = CN()
  116. cfg2.A_UNICODE_KEY = b"bar"
  117. cfg.merge_from_other_cfg(cfg2)
  118. assert type(cfg.A_UNICODE_KEY) == unicode # noqa: F821
  119. assert cfg.A_UNICODE_KEY == u"bar"
  120. # Test: merge with invalid type
  121. cfg2 = CN()
  122. cfg2.TRAIN = CN()
  123. cfg2.TRAIN.SCALES = 1
  124. with self.assertRaises(ValueError):
  125. cfg.merge_from_other_cfg(cfg2)
  126. def test_merge_cfg_from_file(self):
  127. with tempfile.NamedTemporaryFile(mode="wt") as f:
  128. cfg = get_cfg()
  129. f.write(cfg.dump())
  130. f.flush()
  131. s = cfg.MODEL.TYPE
  132. cfg.MODEL.TYPE = "dummy"
  133. assert cfg.MODEL.TYPE != s
  134. cfg.merge_from_file(f.name)
  135. assert cfg.MODEL.TYPE == s
  136. def test_merge_cfg_from_list(self):
  137. cfg = get_cfg()
  138. opts = ["TRAIN.SCALES", "(100, )", "MODEL.TYPE", "foobar", "NUM_GPUS", 2]
  139. assert len(cfg.TRAIN.SCALES) > 0
  140. assert cfg.TRAIN.SCALES[0] != 100
  141. assert cfg.MODEL.TYPE != "foobar"
  142. assert cfg.NUM_GPUS != 2
  143. cfg.merge_from_list(opts)
  144. assert type(cfg.TRAIN.SCALES) is tuple
  145. assert len(cfg.TRAIN.SCALES) == 1
  146. assert cfg.TRAIN.SCALES[0] == 100
  147. assert cfg.MODEL.TYPE == "foobar"
  148. assert cfg.NUM_GPUS == 2
  149. def test_deprecated_key_from_list(self):
  150. # You should see logger messages like:
  151. # "Deprecated config key (ignoring): MODEL.DILATION"
  152. cfg = get_cfg()
  153. opts = ["FINAL_MSG", "foobar", "MODEL.DILATION", 2]
  154. with self.assertRaises(AttributeError):
  155. _ = cfg.FINAL_MSG # noqa
  156. with self.assertRaises(AttributeError):
  157. _ = cfg.MODEL.DILATION # noqa
  158. cfg.merge_from_list(opts)
  159. with self.assertRaises(AttributeError):
  160. _ = cfg.FINAL_MSG # noqa
  161. with self.assertRaises(AttributeError):
  162. _ = cfg.MODEL.DILATION # noqa
  163. def test_nonexistant_key_from_list(self):
  164. cfg = get_cfg()
  165. opts = ["MODEL.DOES_NOT_EXIST", "IGNORE"]
  166. with self.assertRaises(AssertionError):
  167. cfg.merge_from_list(opts)
  168. def test_load_cfg_invalid_type(self):
  169. class CustomClass(yaml.YAMLObject):
  170. """A custom class that yaml.safe_load can load."""
  171. yaml_loader = yaml.SafeLoader
  172. yaml_tag = u"!CustomClass"
  173. # FOO.BAR.QUUX will have type CustomClass, which is not allowed
  174. cfg_string = "FOO:\n BAR:\n QUUX: !CustomClass {}"
  175. with self.assertRaises(AssertionError):
  176. yacs.config.load_cfg(cfg_string)
  177. def test_deprecated_key_from_file(self):
  178. # You should see logger messages like:
  179. # "Deprecated config key (ignoring): MODEL.DILATION"
  180. cfg = get_cfg()
  181. with tempfile.NamedTemporaryFile("wt") as f:
  182. cfg2 = cfg.clone()
  183. cfg2.MODEL.DILATION = 2
  184. f.write(cfg2.dump())
  185. f.flush()
  186. with self.assertRaises(AttributeError):
  187. _ = cfg.MODEL.DILATION # noqa
  188. cfg.merge_from_file(f.name)
  189. with self.assertRaises(AttributeError):
  190. _ = cfg.MODEL.DILATION # noqa
  191. def test_renamed_key_from_list(self):
  192. cfg = get_cfg()
  193. opts = ["EXAMPLE.OLD.KEY", "foobar"]
  194. with self.assertRaises(AttributeError):
  195. _ = cfg.EXAMPLE.OLD.KEY # noqa
  196. with self.assertRaises(KeyError):
  197. cfg.merge_from_list(opts)
  198. def test_renamed_key_from_file(self):
  199. cfg = get_cfg()
  200. with tempfile.NamedTemporaryFile("wt") as f:
  201. cfg2 = cfg.clone()
  202. cfg2.EXAMPLE = CN()
  203. cfg2.EXAMPLE.RENAMED = CN()
  204. cfg2.EXAMPLE.RENAMED.KEY = "foobar"
  205. f.write(cfg2.dump())
  206. f.flush()
  207. with self.assertRaises(AttributeError):
  208. _ = cfg.EXAMPLE.RENAMED.KEY # noqa
  209. with self.assertRaises(KeyError):
  210. cfg.merge_from_file(f.name)
  211. def test_load_cfg_from_file(self):
  212. cfg = get_cfg()
  213. with tempfile.NamedTemporaryFile("wt") as f:
  214. f.write(cfg.dump())
  215. f.flush()
  216. with open(f.name, "rt") as f_read:
  217. yacs.config.load_cfg(f_read)
  218. def test_load_from_python_file(self):
  219. # Case 1: exports CfgNode
  220. cfg = get_cfg()
  221. cfg.merge_from_file("example/config_override.py")
  222. assert cfg.TRAIN.HYPERPARAMETER_1 == 0.9
  223. # Case 2: exports dict
  224. cfg = get_cfg()
  225. cfg.merge_from_file("example/config_override_from_dict.py")
  226. assert cfg.TRAIN.HYPERPARAMETER_1 == 0.9
  227. def test_invalid_type(self):
  228. cfg = get_cfg()
  229. with self.assertRaises(AssertionError):
  230. cfg.INVALID_KEY_TYPE = object()
  231. def test__str__(self):
  232. expected_str = """
  233. KWARGS:
  234. Y:
  235. X: 1
  236. z: 0
  237. MODEL:
  238. TYPE: a_foo_model
  239. NUM_GPUS: 8
  240. STR:
  241. FOO:
  242. BAR:
  243. KEY1: 1
  244. KEY2: 2
  245. KEY1: 1
  246. KEY2: 2
  247. KEY1: 1
  248. KEY2: 2
  249. TRAIN:
  250. HYPERPARAMETER_1: 0.1
  251. SCALES: (2, 4, 8, 16)
  252. """.strip()
  253. cfg = get_cfg()
  254. assert str(cfg) == expected_str
  255. def test_new_allowed(self):
  256. cfg = get_cfg()
  257. cfg.merge_from_file("example/config_new_allowed.yaml")
  258. assert cfg.KWARGS.a == 1
  259. assert cfg.KWARGS.B.c == 2
  260. assert cfg.KWARGS.B.D.e == "3"
  261. def test_new_allowed_bad(self):
  262. cfg = get_cfg()
  263. with self.assertRaises(KeyError):
  264. cfg.merge_from_file("example/config_new_allowed_bad.yaml")
  265. cfg.set_new_allowed(True)
  266. cfg.merge_from_file("example/config_new_allowed_bad.yaml")
  267. assert cfg.KWARGS.Y.f == 4
  268. class TestCfgNodeSubclass(unittest.TestCase):
  269. def test_merge_cfg_from_file(self):
  270. with tempfile.NamedTemporaryFile(mode="wt") as f:
  271. cfg = get_cfg(SubCN)
  272. f.write(cfg.dump())
  273. f.flush()
  274. s = cfg.MODEL.TYPE
  275. cfg.MODEL.TYPE = "dummy"
  276. assert cfg.MODEL.TYPE != s
  277. cfg.merge_from_file(f.name)
  278. assert cfg.MODEL.TYPE == s
  279. def test_merge_cfg_from_list(self):
  280. cfg = get_cfg(SubCN)
  281. opts = ["TRAIN.SCALES", "(100, )", "MODEL.TYPE", "foobar", "NUM_GPUS", 2]
  282. assert len(cfg.TRAIN.SCALES) > 0
  283. assert cfg.TRAIN.SCALES[0] != 100
  284. assert cfg.MODEL.TYPE != "foobar"
  285. assert cfg.NUM_GPUS != 2
  286. cfg.merge_from_list(opts)
  287. assert type(cfg.TRAIN.SCALES) is tuple
  288. assert len(cfg.TRAIN.SCALES) == 1
  289. assert cfg.TRAIN.SCALES[0] == 100
  290. assert cfg.MODEL.TYPE == "foobar"
  291. assert cfg.NUM_GPUS == 2
  292. def test_merge_cfg_from_cfg(self):
  293. cfg = get_cfg(SubCN)
  294. cfg2 = get_cfg(SubCN)
  295. s = "dummy0"
  296. cfg2.MODEL.TYPE = s
  297. cfg.merge_from_other_cfg(cfg2)
  298. assert cfg.MODEL.TYPE == s
  299. # Test: merge from yaml
  300. s = "dummy1"
  301. cfg2 = SubCN.load_cfg(cfg.dump())
  302. cfg2.MODEL.TYPE = s
  303. cfg.merge_from_other_cfg(cfg2)
  304. assert cfg.MODEL.TYPE == s
  305. if __name__ == "__main__":
  306. logging.basicConfig()
  307. yacs_logger = logging.getLogger("yacs.config")
  308. yacs_logger.setLevel(logging.DEBUG)
  309. unittest.main()