Source code for FedEval.config.test_config

from unittest import TestCase

from .configuration import (_D_PARTITION_KEY, _DEFAULT_D_CFG, _DEFAULT_MDL_CFG,
                            _DEFAULT_RT_CFG, ConfigurationManager, _DataConfig)
from .role import Role


[docs] class ConfigurationManagerTestCase(TestCase):
[docs] def setUp(self): self.cfg_mgr = ConfigurationManager()
[docs] def test_default_cfgs(self): self.assertDictEqual(self.cfg_mgr.data_config.inner, _DEFAULT_D_CFG) self.assertDictEqual(self.cfg_mgr.model_config.inner, _DEFAULT_MDL_CFG) self.assertDictEqual(self.cfg_mgr.runtime_config.inner, _DEFAULT_RT_CFG)
[docs] def test_cfg_write_availability(self): def set_data_cfg(): self.cfg_mgr.data_config = {} def set_model_cfg(): self.cfg_mgr.model_config = {} def set_runtime_cfg(): self.cfg_mgr.runtime_config = {} self.assertRaises(AttributeError, set_data_cfg) self.assertRaises(AttributeError, set_model_cfg) self.assertRaises(AttributeError, set_runtime_cfg)
[docs] def test_rt_cfg_accessability(self): _ = self.cfg_mgr.model_config.ml_config _ = self.cfg_mgr.model_config.strategy_config
[docs] def test_filename_setters(self): invalid_names = ['/data_config.yml', 'data\\_config.yml'] for invalid_name in invalid_names: with self.assertRaisesRegex(ValueError, 'sep'): self.cfg_mgr.data_config_filename = invalid_name self.cfg_mgr.data_config_filename = 'data_config.yml'
[docs] def test_role_setting(self): ano_mgr = ConfigurationManager() ano_mgr.role = Role.Server self.assertEqual(self.cfg_mgr.role, Role.Server)
[docs] class DataConfigPartitionTestCase(TestCase):
[docs] def setUp(self) -> None: self.rcfg = _DEFAULT_D_CFG.copy()
[docs] def test_sum_limit(self): self.rcfg[_D_PARTITION_KEY] = [0, 0, 0] with self.assertRaisesRegex(ValueError, "small"): _DataConfig(self.rcfg)
[docs] def test_no_neg(self): self.rcfg[_D_PARTITION_KEY] = [-1, 2, 0] with self.assertRaisesRegex(ValueError, 'negetive'): _DataConfig(self.rcfg)
[docs] def test_not_enough(self): self.rcfg[_D_PARTITION_KEY] = [1, 2] with self.assertRaisesRegex(ValueError, '3'): _DataConfig(self.rcfg)
[docs] def test_too_many(self): self.rcfg[_D_PARTITION_KEY] = [1, 2, 3, 4] with self.assertRaisesRegex(ValueError, '3'): _DataConfig(self.rcfg)
[docs] def test_copy_attribute(self): self.rcfg[_D_PARTITION_KEY] = [0.1, 0.1, 0.8] d_cfg = _DataConfig(self.rcfg) self.assertIsNot(d_cfg.data_partition, self.rcfg[_D_PARTITION_KEY])
[docs] def test_ok(self): partition = [0.1, 2, 3] summation = sum(partition) partition_normalized = [i / summation for i in partition] self.rcfg[_D_PARTITION_KEY] = partition d_cfg = _DataConfig(self.rcfg) _data_partition = d_cfg.data_partition self.assertIsNot(_data_partition, partition) self.assertAlmostEqual(sum(_data_partition), 1.0) for i in range(3): self.assertAlmostEqual(partition_normalized[i], _data_partition[i])
[docs] class DataConfigTestCase(TestCase):
[docs] def setUp(self): self.cfg = _DataConfig(_DEFAULT_D_CFG)
[docs] def test_inner_copy(self): inner = self.cfg.inner self.assertDictEqual(inner, _DEFAULT_D_CFG) self.assertIsNot(inner, _DEFAULT_D_CFG)
[docs] def test_sample_size(self): self.assertGreater(self.cfg.sample_size, 0) self.assertTrue(isinstance(self.cfg.sample_size, int))
[docs] def test_path_sep_in_default_dir_name(self): possible_seps = ['/', '\\', '.'] for sep in possible_seps: self.assertFalse(sep in ConfigurationManager().data_dir_name)
# TODO(fgh) add tests for config conversions and filters