import logging import tempfile import unittest import yacs.config import yaml from yacs.config import CfgNode as CN try: _ignore = unicode # noqa: F821 PY2 = True except Exception as _ignore: PY2 = False class SubCN(CN): pass def get_cfg(cls=CN): cfg = cls() cfg.NUM_GPUS = 8 cfg.TRAIN = cls() cfg.TRAIN.HYPERPARAMETER_1 = 0.1 cfg.TRAIN.SCALES = (2, 4, 8, 16) cfg.MODEL = cls() cfg.MODEL.TYPE = "a_foo_model" # Some extra stuff to test CfgNode.__str__ cfg.STR = cls() cfg.STR.KEY1 = 1 cfg.STR.KEY2 = 2 cfg.STR.FOO = cls() cfg.STR.FOO.KEY1 = 1 cfg.STR.FOO.KEY2 = 2 cfg.STR.FOO.BAR = cls() cfg.STR.FOO.BAR.KEY1 = 1 cfg.STR.FOO.BAR.KEY2 = 2 cfg.register_deprecated_key("FINAL_MSG") cfg.register_deprecated_key("MODEL.DILATION") cfg.register_renamed_key( "EXAMPLE.OLD.KEY", "EXAMPLE.NEW.KEY", message="Please update your config fil config file.", ) cfg.KWARGS = cls(new_allowed=True) cfg.KWARGS.z = 0 cfg.KWARGS.Y = cls() cfg.KWARGS.Y.X = 1 return cfg class TestCfgNode(unittest.TestCase): def test_immutability(self): # Top level immutable a = CN() a.foo = 0 a.freeze() with self.assertRaises(AttributeError): a.foo = 1 a.bar = 1 assert a.is_frozen() assert a.foo == 0 a.defrost() assert not a.is_frozen() a.foo = 1 assert a.foo == 1 # Recursively immutable a.level1 = CN() a.level1.foo = 0 a.level1.level2 = CN() a.level1.level2.foo = 0 a.freeze() assert a.is_frozen() with self.assertRaises(AttributeError): a.level1.level2.foo = 1 a.level1.bar = 1 assert a.level1.level2.foo == 0 class TestCfg(unittest.TestCase): def test_copy_cfg(self): cfg = get_cfg() cfg2 = cfg.clone() s = cfg.MODEL.TYPE cfg2.MODEL.TYPE = "dummy" assert cfg.MODEL.TYPE == s def test_merge_cfg_from_cfg(self): # Test: merge from clone cfg = get_cfg() s = "dummy0" cfg2 = cfg.clone() cfg2.MODEL.TYPE = s cfg.merge_from_other_cfg(cfg2) assert cfg.MODEL.TYPE == s # Test: merge from yaml s = "dummy1" cfg2 = CN.load_cfg(cfg.dump()) cfg2.MODEL.TYPE = s cfg.merge_from_other_cfg(cfg2) assert cfg.MODEL.TYPE == s # Test: merge with a valid key s = "dummy2" cfg2 = CN() cfg2.MODEL = CN() cfg2.MODEL.TYPE = s cfg.merge_from_other_cfg(cfg2) assert cfg.MODEL.TYPE == s # Test: merge with an invalid key s = "dummy3" cfg2 = CN() cfg2.FOO = CN() cfg2.FOO.BAR = s with self.assertRaises(KeyError): cfg.merge_from_other_cfg(cfg2) # Test: merge with converted type cfg2 = CN() cfg2.TRAIN = CN() cfg2.TRAIN.SCALES = [1] cfg.merge_from_other_cfg(cfg2) assert type(cfg.TRAIN.SCALES) is tuple assert cfg.TRAIN.SCALES[0] == 1 # Test str (bytes) <-> unicode conversion for py2 if PY2: cfg.A_UNICODE_KEY = u"foo" cfg2 = CN() cfg2.A_UNICODE_KEY = b"bar" cfg.merge_from_other_cfg(cfg2) assert type(cfg.A_UNICODE_KEY) == unicode # noqa: F821 assert cfg.A_UNICODE_KEY == u"bar" # Test: merge with invalid type cfg2 = CN() cfg2.TRAIN = CN() cfg2.TRAIN.SCALES = 1 with self.assertRaises(ValueError): cfg.merge_from_other_cfg(cfg2) def test_merge_cfg_from_file(self): with tempfile.NamedTemporaryFile(mode="wt") as f: cfg = get_cfg() f.write(cfg.dump()) f.flush() s = cfg.MODEL.TYPE cfg.MODEL.TYPE = "dummy" assert cfg.MODEL.TYPE != s cfg.merge_from_file(f.name) assert cfg.MODEL.TYPE == s def test_merge_cfg_from_list(self): cfg = get_cfg() opts = ["TRAIN.SCALES", "(100, )", "MODEL.TYPE", "foobar", "NUM_GPUS", 2] assert len(cfg.TRAIN.SCALES) > 0 assert cfg.TRAIN.SCALES[0] != 100 assert cfg.MODEL.TYPE != "foobar" assert cfg.NUM_GPUS != 2 cfg.merge_from_list(opts) assert type(cfg.TRAIN.SCALES) is tuple assert len(cfg.TRAIN.SCALES) == 1 assert cfg.TRAIN.SCALES[0] == 100 assert cfg.MODEL.TYPE == "foobar" assert cfg.NUM_GPUS == 2 def test_deprecated_key_from_list(self): # You should see logger messages like: # "Deprecated config key (ignoring): MODEL.DILATION" cfg = get_cfg() opts = ["FINAL_MSG", "foobar", "MODEL.DILATION", 2] with self.assertRaises(AttributeError): _ = cfg.FINAL_MSG # noqa with self.assertRaises(AttributeError): _ = cfg.MODEL.DILATION # noqa cfg.merge_from_list(opts) with self.assertRaises(AttributeError): _ = cfg.FINAL_MSG # noqa with self.assertRaises(AttributeError): _ = cfg.MODEL.DILATION # noqa def test_nonexistant_key_from_list(self): cfg = get_cfg() opts = ["MODEL.DOES_NOT_EXIST", "IGNORE"] with self.assertRaises(AssertionError): cfg.merge_from_list(opts) def test_load_cfg_invalid_type(self): class CustomClass(yaml.YAMLObject): """A custom class that yaml.safe_load can load.""" yaml_loader = yaml.SafeLoader yaml_tag = u"!CustomClass" # FOO.BAR.QUUX will have type CustomClass, which is not allowed cfg_string = "FOO:\n BAR:\n QUUX: !CustomClass {}" with self.assertRaises(AssertionError): yacs.config.load_cfg(cfg_string) def test_deprecated_key_from_file(self): # You should see logger messages like: # "Deprecated config key (ignoring): MODEL.DILATION" cfg = get_cfg() with tempfile.NamedTemporaryFile("wt") as f: cfg2 = cfg.clone() cfg2.MODEL.DILATION = 2 f.write(cfg2.dump()) f.flush() with self.assertRaises(AttributeError): _ = cfg.MODEL.DILATION # noqa cfg.merge_from_file(f.name) with self.assertRaises(AttributeError): _ = cfg.MODEL.DILATION # noqa def test_renamed_key_from_list(self): cfg = get_cfg() opts = ["EXAMPLE.OLD.KEY", "foobar"] with self.assertRaises(AttributeError): _ = cfg.EXAMPLE.OLD.KEY # noqa with self.assertRaises(KeyError): cfg.merge_from_list(opts) def test_renamed_key_from_file(self): cfg = get_cfg() with tempfile.NamedTemporaryFile("wt") as f: cfg2 = cfg.clone() cfg2.EXAMPLE = CN() cfg2.EXAMPLE.RENAMED = CN() cfg2.EXAMPLE.RENAMED.KEY = "foobar" f.write(cfg2.dump()) f.flush() with self.assertRaises(AttributeError): _ = cfg.EXAMPLE.RENAMED.KEY # noqa with self.assertRaises(KeyError): cfg.merge_from_file(f.name) def test_load_cfg_from_file(self): cfg = get_cfg() with tempfile.NamedTemporaryFile("wt") as f: f.write(cfg.dump()) f.flush() with open(f.name, "rt") as f_read: yacs.config.load_cfg(f_read) def test_load_from_python_file(self): # Case 1: exports CfgNode cfg = get_cfg() cfg.merge_from_file("example/config_override.py") assert cfg.TRAIN.HYPERPARAMETER_1 == 0.9 # Case 2: exports dict cfg = get_cfg() cfg.merge_from_file("example/config_override_from_dict.py") assert cfg.TRAIN.HYPERPARAMETER_1 == 0.9 def test_invalid_type(self): cfg = get_cfg() with self.assertRaises(AssertionError): cfg.INVALID_KEY_TYPE = object() def test__str__(self): expected_str = """ KWARGS: Y: X: 1 z: 0 MODEL: TYPE: a_foo_model NUM_GPUS: 8 STR: FOO: BAR: KEY1: 1 KEY2: 2 KEY1: 1 KEY2: 2 KEY1: 1 KEY2: 2 TRAIN: HYPERPARAMETER_1: 0.1 SCALES: (2, 4, 8, 16) """.strip() cfg = get_cfg() assert str(cfg) == expected_str def test_new_allowed(self): cfg = get_cfg() cfg.merge_from_file("example/config_new_allowed.yaml") assert cfg.KWARGS.a == 1 assert cfg.KWARGS.B.c == 2 assert cfg.KWARGS.B.D.e == "3" def test_new_allowed_bad(self): cfg = get_cfg() with self.assertRaises(KeyError): cfg.merge_from_file("example/config_new_allowed_bad.yaml") cfg.set_new_allowed(True) cfg.merge_from_file("example/config_new_allowed_bad.yaml") assert cfg.KWARGS.Y.f == 4 class TestCfgNodeSubclass(unittest.TestCase): def test_merge_cfg_from_file(self): with tempfile.NamedTemporaryFile(mode="wt") as f: cfg = get_cfg(SubCN) f.write(cfg.dump()) f.flush() s = cfg.MODEL.TYPE cfg.MODEL.TYPE = "dummy" assert cfg.MODEL.TYPE != s cfg.merge_from_file(f.name) assert cfg.MODEL.TYPE == s def test_merge_cfg_from_list(self): cfg = get_cfg(SubCN) opts = ["TRAIN.SCALES", "(100, )", "MODEL.TYPE", "foobar", "NUM_GPUS", 2] assert len(cfg.TRAIN.SCALES) > 0 assert cfg.TRAIN.SCALES[0] != 100 assert cfg.MODEL.TYPE != "foobar" assert cfg.NUM_GPUS != 2 cfg.merge_from_list(opts) assert type(cfg.TRAIN.SCALES) is tuple assert len(cfg.TRAIN.SCALES) == 1 assert cfg.TRAIN.SCALES[0] == 100 assert cfg.MODEL.TYPE == "foobar" assert cfg.NUM_GPUS == 2 def test_merge_cfg_from_cfg(self): cfg = get_cfg(SubCN) cfg2 = get_cfg(SubCN) s = "dummy0" cfg2.MODEL.TYPE = s cfg.merge_from_other_cfg(cfg2) assert cfg.MODEL.TYPE == s # Test: merge from yaml s = "dummy1" cfg2 = SubCN.load_cfg(cfg.dump()) cfg2.MODEL.TYPE = s cfg.merge_from_other_cfg(cfg2) assert cfg.MODEL.TYPE == s if __name__ == "__main__": logging.basicConfig() yacs_logger = logging.getLogger("yacs.config") yacs_logger.setLevel(logging.DEBUG) unittest.main()