import json
import os
import datetime
import hashlib
from abc import ABC, abstractmethod, abstractproperty
from copy import deepcopy
from enum import Enum
from threading import RLock
from typing import (Any, Dict, List, Mapping, Optional, Sequence, TextIO, Tuple,
Union)
import yaml
from .filename_checker import check_filename
from .role import Role
from .singleton import Singleton
RawConfigurationDict = Dict[str, Any]
DEFAULT_D_CFG_FILENAME_YAML = '1_data_config.yml'
DEFAULT_MDL_CFG_FILENAME_YAML = '2_model_config.yml'
DEFAULT_RT_CFG_FILENAME_YAML = '3_runtime_config.yml'
DEFAULT_D_CFG_FILENAME_JSON = '1_data_config.yml'
DEFAULT_MDL_CFG_FILENAME_JSON = '2_model_config.yml'
DEFAULT_RT_CFG_FILENAME_JSON = '3_runtime_config.yml'
# default data configurations
_D_DIR_KEY = 'data_dir'
_D_NAME_KEY = 'dataset'
_D_NI_ENABLE_KEY = 'non-iid'
_D_NI_CLASS_KEY = 'non-iid-class'
_D_NI_STRATEGY_KEY = 'non-iid-strategy'
_D_NORMALIZE_KEY = 'normalize'
_D_SAMPLE_SIZE_KEY = 'sample_size'
_D_PARTITION_KEY = 'train_val_test'
_D_FEATURE_SIZE = 'feature_size'
_D_RANDOM_SEED = 'random_seed'
_DEFAULT_D_CFG: RawConfigurationDict = {
_D_DIR_KEY: 'data',
_D_NAME_KEY: 'mnist',
_D_NI_ENABLE_KEY: False,
_D_NI_CLASS_KEY: 1,
_D_NI_STRATEGY_KEY: 'average',
_D_NORMALIZE_KEY: True,
_D_SAMPLE_SIZE_KEY: 300,
_D_PARTITION_KEY: [0.8, 0.1, 0.1],
_D_FEATURE_SIZE: 1000,
_D_RANDOM_SEED: 100
}
# default model configurations
_STRATEGY_KEY = 'FedModel'
_STRATEGY_NAME_KEY = 'name'
_STRATEGY_ETA_KEY = 'eta'
_STRATEGY_B_KEY = 'B'
_STRATEGY_C_KEY = 'C'
_STRATEGY_E_KEY = 'E'
_STRATEGY_E_RATIO = 'evaluate_ratio'
_STRATEGY_E_DISTRIBUTE = 'distributed_evaluate'
_STRATEGY_MAX_ROUND_NUM_KEY = 'max_rounds'
_STRATEGY_TOLERANCE_NUM_KEY = 'num_tolerance'
_STRATEGY_NUM_ROUNDS_BETWEEN_VAL_KEY = 'rounds_between_val'
_STRATEGY_FEDSTC_SPARSITY_KEY = 'sparsity'
_STRATEGY_FEDPROX_MU_KEY = 'mu'
_STRATEGY_FEDOPT_TAU_KEY = 'tau'
_STRATEGY_FEDOPT_BETA1_KEY = 'beta1'
_STRATEGY_FEDOPT_BETA2_KEY = 'beta2'
_STRATEGY_FEDOPT_NAME_KEY = 'opt_name'
_STRATEGY_FETCHSGD_COL_NUM_KEY = 'num_col'
_STRATEGY_FETCHSGD_ROW_NUM_KEY = 'num_row'
_STRATEGY_FETCHSGD_BLOCK_NUM_KEY = 'num_block'
_STRATEGY_FETCHSGD_TOP_K_KEY = 'top_k'
_STRATEGY_FEDSVD_BLOCK = 'block_size'
_STRATEGY_FEDSVD_MODE = 'fedsvd_mode'
_STRATEGY_FEDSVD_TOPK = 'fedsvd_top_k'
_STRATEGY_FEDSVD_L2 = 'fedsvd_lr_l2'
_STRATEGY_FEDSVD_OPT_1 = 'fedsvd_opt_1'
_STRATEGY_FEDSVD_OPT_2 = 'fedsvd_opt_2'
_STRATEGY_FEDSVD_EVALUATE = 'fedsvd_debug_evaluate'
_ML_KEY = 'MLModel'
_ML_NAME_KEY = 'name'
_ML_ACTIVATION_KEY = 'activation'
_ML_DROPOUT_RATIO_KEY = 'dropout'
_ML_UNITS_SIZE_KEY = 'units'
_ML_OPTIMIZER_KEY = 'optimizer'
_ML_OPTIMIZER_NAME_KEY = 'name'
_ML_OPTIMIZER_LEARNING_RATE_KEY = 'lr'
_ML_OPTIMIZER_MOMENTUM_KEY = 'momentum'
_ML_LOSS_CALC_METHODS_KEY = 'loss'
_ML_METRICS_KEY = 'metrics'
_ML_DEFAULT_METRICS = ['accuracy']
_DEFAULT_MDL_CFG: RawConfigurationDict = {
_STRATEGY_KEY: {
_STRATEGY_NAME_KEY: 'FedAvg',
# shared params
_STRATEGY_B_KEY: 32,
_STRATEGY_C_KEY: 0.1,
_STRATEGY_E_KEY: 1,
_STRATEGY_E_RATIO: 1.0,
_STRATEGY_E_DISTRIBUTE: True,
_STRATEGY_MAX_ROUND_NUM_KEY: 3000,
_STRATEGY_TOLERANCE_NUM_KEY: 100,
_STRATEGY_NUM_ROUNDS_BETWEEN_VAL_KEY: 1,
# FedSTC
_STRATEGY_FEDSTC_SPARSITY_KEY: 0.01,
# FedProx
_STRATEGY_FEDPROX_MU_KEY: 0.01,
# FedOpt
_STRATEGY_FEDOPT_TAU_KEY: 1e-4,
_STRATEGY_FEDOPT_BETA1_KEY: 0.9,
_STRATEGY_FEDOPT_BETA2_KEY: 0.99,
_STRATEGY_FEDOPT_NAME_KEY: 'fedyogi',
# server-side learning rate, used by FedSCA and FedOpt
_STRATEGY_ETA_KEY: 1.0,
# FetchSGD
_STRATEGY_FETCHSGD_COL_NUM_KEY: 5,
_STRATEGY_FETCHSGD_ROW_NUM_KEY: 1e4,
_STRATEGY_FETCHSGD_BLOCK_NUM_KEY: 10,
_STRATEGY_FETCHSGD_TOP_K_KEY: 0.1,
},
_ML_KEY: {
_ML_NAME_KEY: 'MLP',
_ML_ACTIVATION_KEY: 'relu',
_ML_DROPOUT_RATIO_KEY: 0.2,
_ML_UNITS_SIZE_KEY: [512, 512],
_ML_OPTIMIZER_KEY: {
_ML_OPTIMIZER_NAME_KEY: 'sgd',
_ML_OPTIMIZER_LEARNING_RATE_KEY: 0.1,
_ML_OPTIMIZER_MOMENTUM_KEY: 0,
# _ML_OPTIMIZER_MOMENTUM_KEY: 0.9, # FetchSGD
},
_ML_LOSS_CALC_METHODS_KEY: 'categorical_crossentropy',
_ML_METRICS_KEY: _ML_DEFAULT_METRICS,
},
}
# default runtime configurations
# _RT_CLIENTS_KEY = 'clients'
# _RT_C_BANDWIDTH_KEY = 'bandwidth'
_RT_SERVER_KEY = 'server'
_RT_S_HOST_KEY = 'host'
_RT_S_LISTEN_KEY = 'listen'
_RT_S_PORT_KEY = 'port'
_RT_S_CLIENTS_NUM_KEY = 'num_clients'
_RT_S_SECRET_KEY = 'secret_key'
_RT_DOCKER_KEY = 'docker'
_RT_D_IMAGE_LABEL_KEY = 'image'
_RT_D_CONTAINER_NUM_KEY = 'num_containers'
_RT_D_GPU_ENABLE_KEY = 'enable_gpu'
_RT_D_GPU_NUM_KEY = 'num_gpu'
_RT_MACHINES_KEY = 'machines'
_RT_M_ADDRESS_KEY = 'host'
_RT_M_PORT_KEY = 'port'
_RT_M_USERNAME_KEY = 'username'
_RT_M_WORK_DIR_KEY = 'dir'
_RT_M_SK_FILENAME_KEY = 'key'
_RT_M_CAPACITY_KEY = 'capacity'
_RT_M_SERVER_NAME = 'server'
_RT_LOG_KEY = 'log'
_RT_L_BASE_LEVEL_KEY = 'base_level'
_RT_L_FILE_LEVEL_KEY = 'file_log_level'
_RT_L_CONSOLE_LEVEL_KEY = 'console_log_level'
_RT_L_DIR_PATH_KEY = 'log_dir'
_RT_COMMUNICATION_KEY = 'communication'
_RT_COMM_METHOD_KEY = 'method'
_RT_COMM_PORT_KEY = 'port'
_RT_COMM_LIMIT_FLAG_KEY = 'limit_network_resource'
_RT_COMM_BANDWIDTH_UP_KEY = 'bandwidth_upload'
_RT_COMM_BANDWIDTH_DOWN_KEY = 'bandwidth_download'
_RT_COMM_LATENCY_KEY = 'latency'
_RT_COMM_FAST_MODE = 'fast_mode'
_DEFAULT_RT_CFG: RawConfigurationDict = {
_RT_COMMUNICATION_KEY: {
_RT_COMM_METHOD_KEY: 'SocketIO',
_RT_COMM_PORT_KEY: 8000,
_RT_COMM_LIMIT_FLAG_KEY: True,
_RT_COMM_BANDWIDTH_UP_KEY: '30Mbit',
_RT_COMM_BANDWIDTH_DOWN_KEY: '10Mbit',
_RT_COMM_LATENCY_KEY: '50ms',
_RT_COMM_FAST_MODE: False
},
_RT_LOG_KEY: {
_RT_L_DIR_PATH_KEY: 'log/quickstart',
_RT_L_BASE_LEVEL_KEY: 'INFO',
_RT_L_FILE_LEVEL_KEY: 'INFO',
_RT_L_CONSOLE_LEVEL_KEY: 'ERROR',
},
_RT_DOCKER_KEY: {
_RT_D_IMAGE_LABEL_KEY: 'fedeval:sdfsdf',
_RT_D_CONTAINER_NUM_KEY: 10,
_RT_D_GPU_ENABLE_KEY: False,
_RT_D_GPU_NUM_KEY: 0,
},
# _RT_CLIENTS_KEY: {
# _RT_C_BANDWIDTH_KEY: '100Mbit',
# },
_RT_SERVER_KEY: {
_RT_S_HOST_KEY: 'server',
_RT_S_LISTEN_KEY: 'server',
_RT_S_CLIENTS_NUM_KEY: 10,
_RT_S_PORT_KEY: 8000,
_RT_S_SECRET_KEY: 'secret!',
},
_RT_MACHINES_KEY: {
_RT_M_SERVER_NAME: {
_RT_M_ADDRESS_KEY: '10.173.1.22',
_RT_M_PORT_KEY: 22,
_RT_M_USERNAME_KEY: 'ubuntu',
_RT_M_WORK_DIR_KEY: '/ldisk/chaidi/FedEval',
_RT_M_SK_FILENAME_KEY: 'id_rsa',
},
'm1': {
_RT_M_ADDRESS_KEY: '10.173.1.22',
_RT_M_PORT_KEY: 22,
_RT_M_USERNAME_KEY: 'ubuntu',
_RT_M_WORK_DIR_KEY: '/ldisk/chaidi/FedEval',
_RT_M_SK_FILENAME_KEY: 'id_rsa',
_RT_M_CAPACITY_KEY: 100,
},
},
}
# --- Configuration Entities ---
[docs]
class _Configuraiton(object):
def __init__(self, config: RawConfigurationDict) -> None:
self._inner: RawConfigurationDict = self._config_filter(config)
@property
def inner(self) -> RawConfigurationDict:
"""return a deep copy of its inner configuraiton data, presented as a dict.
Noticed that modifications on the returned object will NOT affect the original
configuration.
Returns:
RawConfigurationDict: a deep copy of the inner data representaiton
of this config object.
"""
return deepcopy(self._inner)
[docs]
@staticmethod
def _config_filter(config: RawConfigurationDict) -> RawConfigurationDict:
# No filter by default
return config
[docs]
class _DataConfig(_Configuraiton):
_IID_EXCEPTiON_CONTENT = 'The dataset is configured as iid.'
def __init__(self, data_config: RawConfigurationDict = _DEFAULT_D_CFG) -> None:
super().__init__(data_config)
# non-iid
self._non_iid: bool = self._inner.get(_D_NI_ENABLE_KEY, False)
if self._non_iid:
self._non_iid_strategy_name: str = self._inner.get(
_D_NI_STRATEGY_KEY, 'average')
if self._non_iid_strategy_name != 'natural':
self._non_iid_class_num: int = int(
self._inner.get(_D_NI_CLASS_KEY, 1))
# partition
partition = self._inner[_D_PARTITION_KEY].copy()
if len(partition) != 3:
raise ValueError(
f'there should be 3 values in {_D_PARTITION_KEY}.')
for i in partition:
if i < 0:
raise ValueError(
f'values in {_D_PARTITION_KEY} should not be negetive.')
summation = sum(partition)
if summation <= 1e-6:
raise ValueError(f'values in {_D_PARTITION_KEY} are too small.')
partition = [i / summation for i in partition]
self._partition = partition
[docs]
@staticmethod
def _config_filter(config: RawConfigurationDict) -> RawConfigurationDict:
if not config[_D_NI_ENABLE_KEY]:
config[_D_NI_CLASS_KEY] = None
config[_D_NI_STRATEGY_KEY] = None
return config
@property
def dataset_name(self) -> str:
"""the name of the dataset, chosen from mnist, cifar10, cifar100, femnist, and mnist.
Returns:
str: the name of chosen dataset.
"""
return self._inner[_D_NAME_KEY]
@property
def iid(self) -> bool:
"""if the dataset would be used in an i.i.d. manner.
Returns:
bool: True if the dataset is sampled in an i.i.d. manner; otherwise, False.
"""
return not self._non_iid
@property
def non_iid_class_num(self) -> int:
"""return the number of classes hold by each client.
Only avaliable when the dataset is sampled in a non-i.i.d. form.
Raises:
AttributeError: raised when called without non-i.i.d. setting.
Returns:
int: the number of classes hold by each client.
"""
if self._non_iid:
return self._non_iid_class_num
else:
raise AttributeError(_DataConfig._IID_EXCEPTiON_CONTENT)
@property
def non_iid_strategy_name(self) -> str:
"""return the name of non-i.i.d. data partition strategy.
Two choices are given:
1. "natural" strategy for femnist and celebA dataset
2. "average" for mnist, cifar10 and cifar100
Raises:
AttributeError: raised when called without non-i.i.d. setting.
Returns:
str: the name of non-i.i.d. data partition strategy.
"""
if self._non_iid:
if not self._non_iid_strategy_name_check():
raise AttributeError(
f'unregistered non-iid data partition strategy name: {self._non_iid_strategy_name}')
return self._non_iid_strategy_name
else:
raise AttributeError(_DataConfig._IID_EXCEPTiON_CONTENT)
[docs]
def _non_iid_strategy_name_check(self) -> bool:
"""check if the non-i.i.d. data partition strategy is known.
Returns:
bool: True if the data partition strategy name is registered as followed; otherwise, False.
"""
return self._non_iid_strategy_name in ['natural', 'average']
@property
def normalized(self) -> bool:
"""whether the image pixel data point will be normalized to [0, 1].
Returns:
bool: True if data points would be normalized; otherwise, False.
"""
return self._inner[_D_NORMALIZE_KEY]
@property
def sample_size(self) -> int:
"""return the number of samples owned by each client."""
if self._inner[_D_SAMPLE_SIZE_KEY] is None:
return None
return int(self._inner[_D_SAMPLE_SIZE_KEY])
@property
def data_partition(self) -> Sequence[float]:
"""get the data partition proportion, ordered as
[train data ratio, test data ration, validation data ration].
Constraints met by the return value:
1. all the ratios in the returned list sum up to 1.
2. all the ratios in the returned list are non-negative.
Returns:
Sequence[float]: [train data ratio, test data ration, validation data ration]
"""
return self._partition
@property
def feature_size(self):
# TODO(Di): Add constraints in the future
# if self.dataset_name != 'synthetic_matrix_horizontal' and \
# self.dataset_name != 'synthetic_matrix_vertical':
# raise AttributeError
return self._inner[_D_FEATURE_SIZE]
@property
def random_seed(self):
return int(self._inner[_D_RANDOM_SEED])
[docs]
class _ModelConfig(_Configuraiton):
def __init__(self, model_config: RawConfigurationDict = _DEFAULT_MDL_CFG) -> None:
_ModelConfig.__check_raw_config(model_config)
super().__init__(model_config)
self._strategy_cfg = model_config[_STRATEGY_KEY]
# The model config could be empty, e.g., in FedSVD
self._ml_cfg = model_config[_ML_KEY] or {}
[docs]
@staticmethod
def _config_filter(config: RawConfigurationDict) -> RawConfigurationDict:
# Fed Model filters
if config[_STRATEGY_KEY][_STRATEGY_NAME_KEY] != 'FedSTC':
config[_STRATEGY_KEY][_STRATEGY_FEDSTC_SPARSITY_KEY] = None
if config[_STRATEGY_KEY][_STRATEGY_NAME_KEY] != 'FedProx':
config[_STRATEGY_KEY][_STRATEGY_FEDPROX_MU_KEY] = None
if config[_STRATEGY_KEY][_STRATEGY_NAME_KEY] != 'FedOpt':
config[_STRATEGY_KEY][_STRATEGY_FEDOPT_TAU_KEY] = None
config[_STRATEGY_KEY][_STRATEGY_FEDOPT_NAME_KEY] = None
config[_STRATEGY_KEY][_STRATEGY_FEDOPT_BETA1_KEY] = None
config[_STRATEGY_KEY][_STRATEGY_FEDOPT_BETA2_KEY] = None
if config[_STRATEGY_KEY][_STRATEGY_NAME_KEY] != 'FedOpt' and \
config[_STRATEGY_KEY][_STRATEGY_NAME_KEY] != 'FedSCA':
config[_STRATEGY_KEY][_STRATEGY_ETA_KEY] = None
return config
@staticmethod
def __check_raw_config(config: RawConfigurationDict) -> None:
_ModelConfig.__check_runtime_config_shallow_structure(config)
_ModelConfig.__check_ML_model_params(config.get(_ML_KEY))
@staticmethod
def __check_runtime_config_shallow_structure(config: RawConfigurationDict) -> None:
# assert config.get(
# _ML_KEY) != None, f'model_config should have `{_ML_KEY}`'
assert config.get(
_STRATEGY_KEY) != None, f'model_config should have `{_STRATEGY_KEY}`'
@staticmethod
def __check_ML_model_params(ml_config: RawConfigurationDict) -> None:
if ml_config:
dropout_ratio = ml_config.get(_ML_DROPOUT_RATIO_KEY)
if dropout_ratio:
assert dropout_ratio >= 0 and dropout_ratio <= 1, 'dropout ration out of range.'
@property
def strategy_config(self) -> RawConfigurationDict:
"""a variant of inner method: return a copy of inner strategy raw dict.
Returns:
RawConfigurationDict: a deep copy of the strategy-related configuration dict.
"""
return deepcopy(self._strategy_cfg)
@property
def ml_config(self) -> RawConfigurationDict:
"""a variant of inner method: return a copy of inner machine learning raw dict.
Returns:
RawConfigurationDict: a deep copy of the ML model-related configuration dict.
"""
return deepcopy(self._ml_cfg)
@property
def strategy_name(self) -> str:
"""get the class name of the federated strategy (i.e., the main controller of federated
process). Notice that the strategy class with this name (case sensitive and whole word
matching) should have been implemented in this library (specifically, in strategy module),
otherwise a TypeNotFound exception would be raised in the following steps.
Returns:
str: the classname/typename of the federated strategy.
"""
return self._strategy_cfg[_STRATEGY_NAME_KEY]
@property
def ml_method_name(self) -> str:
"""get the class name of the machine learning model (i.e., the kernel of the whole
calculation process). Notice that the strategy class with this name (case sensitive
and whole word matching) should have been implemented in this library (specifically,
in model module), otherwise a TypeNotFound exception would be raised in the
following steps.
Returns:
str: the classname/typename of the inner machine learning model.
"""
return self._ml_cfg.get(_ML_NAME_KEY)
@property
def server_learning_rate(self) -> float:
"""get the learning rate on the server side.
Only available in FedOpt and FedSCA.
Raises:
AttributeError: called in a in proper federated strategy.
Returns:
float: the learning rate on the server side.
"""
if self.strategy_name != 'FedOpt' and self.strategy_name != 'FedSCA':
raise AttributeError
return float(self._strategy_cfg[_STRATEGY_ETA_KEY])
@property
def B(self) -> int:
"""the local minibatch size used for the updates on the client side."""
return int(self._strategy_cfg[_STRATEGY_B_KEY])
@property
def C(self) -> float:
"""the fraction of clients that perform computation in each round.
Examples:
if there are 100 available clients in a test network with a C of 0.2,
then there should be (100*0.2=)20 clients in each round of iterations.
"""
return float(self._strategy_cfg[_STRATEGY_C_KEY])
@property
def E(self) -> int:
"""the number of training passes that each client makes over its local dataset
in each round.
"""
return int(self._strategy_cfg[_STRATEGY_E_KEY])
@property
def evaluate_ratio(self):
return float(self._strategy_cfg[_STRATEGY_E_RATIO])
@property
def distributed_evaluate(self):
return bool(self._strategy_cfg[_STRATEGY_E_DISTRIBUTE])
@property
def max_round_num(self) -> int:
"""the total/maximum number of the iteration rounds."""
return int(self._strategy_cfg[_STRATEGY_MAX_ROUND_NUM_KEY])
@property
def tolerance_num(self) -> int:
"""the patience for early stopping"""
return int(self._strategy_cfg[_STRATEGY_TOLERANCE_NUM_KEY])
@property
def num_of_rounds_between_val(self) -> int:
"""the number of rounds between test or validation"""
return int(self._strategy_cfg[_STRATEGY_NUM_ROUNDS_BETWEEN_VAL_KEY])
@property
def stc_sparsity(self) -> float:
"""TODO(fgh): the origin of FedSTC"""
return float(self._strategy_cfg[_STRATEGY_FEDSTC_SPARSITY_KEY])
@property
def prox_mu(self) -> float:
"""the /mu parameter in FedProx, a scaler that measures the approximation
between the local model and the global model.
More info available in Federated Optimization in Heterogeneous Networks(arXiv:1812.06127).
"""
return float(self._strategy_cfg[_STRATEGY_FEDPROX_MU_KEY])
@property
def opt_tau(self) -> float:
# TODO(fgh) can not find a corresponding variable in FedOpt.
return float(self._strategy_cfg[_STRATEGY_FEDOPT_TAU_KEY])
@property
def opt_beta_1(self) -> float:
# TODO(fgh) can not find a corresponding variable in FedOpt.
return float(self._strategy_cfg[_STRATEGY_FEDOPT_BETA1_KEY])
@property
def opt_beta_2(self) -> float:
# TODO(fgh) can not find a corresponding variable in FedOpt.
return float(self._strategy_cfg[_STRATEGY_FEDOPT_BETA2_KEY])
@property
def activation(self) -> str:
"""the name of activation mechanism in tensorflow layers.
More info available in https://tensorflow.google.cn/api_docs/python/tf/keras/activations.
"""
return self._ml_cfg[_ML_ACTIVATION_KEY]
@property
def dropout(self) -> float:
"""the dropout fraction of Dropout layer in the DL model."""
return float(self._ml_cfg[_ML_DROPOUT_RATIO_KEY])
@property
def unit_size(self) -> Sequence[int]:
"""the size of sequential neural network components.
Returns:
Sequence[int]: the size of network components
(ordered the same with data flow direction)
"""
return [
int(i) for i in self._ml_cfg[_ML_UNITS_SIZE_KEY]].copy()
@property
def optimizer_name(self) -> str:
"""the name of the optimizer in tensorflow network.
More info available in https://tensorflow.google.cn/api_docs/python/tf/keras/optimizers.
"""
return self._ml_cfg[_ML_OPTIMIZER_KEY][_ML_OPTIMIZER_NAME_KEY]
@property
def learning_rate(self) -> float:
"""the learning rate of model training in tensorlflow."""
return float(self._ml_cfg[_ML_OPTIMIZER_KEY][_ML_OPTIMIZER_LEARNING_RATE_KEY])
@property
def momentum(self) -> float:
"""the momentum of the optimizer."""
return float(self._ml_cfg[_ML_OPTIMIZER_KEY][_ML_OPTIMIZER_MOMENTUM_KEY])
@property
def loss_calc_method(self) -> str:
"""the identifier of a loss function in tensorflow.
More info available in https://tensorflow.google.cn/api_docs/python/tf/keras/losses.
Returns:
str: the string name of the loss function during model training.
"""
return self._ml_cfg[_ML_LOSS_CALC_METHODS_KEY]
@property
def metrics(self) -> Sequence[str]:
"""names of the metrics used in model training and validation in tensorflow.
More info in https://tensorflow.google.cn/api_docs/python/tf/keras/metrics.
Returns:
Sequence[str]: a copy of metric names.
"""
return self._ml_cfg[_ML_METRICS_KEY].copy()
@property
def col_num(self) -> int:
"""the number of columns in FetchSGD.
More info available at https://export.arxiv.org/abs/2007.07682.
"""
return int(self._ml_cfg[_STRATEGY_KEY][_STRATEGY_FETCHSGD_COL_NUM_KEY])
@property
def row_num(self) -> int:
"""the number of rows in FetchSGD.
More info available at https://export.arxiv.org/abs/2007.07682.
"""
return int(self._ml_cfg[_STRATEGY_KEY][_STRATEGY_FETCHSGD_ROW_NUM_KEY])
@property
def block_num(self) -> int:
"""the number of blocks in FetchSGD.
More info available at https://export.arxiv.org/abs/2007.07682.
"""
return int(self._ml_cfg[_STRATEGY_KEY][_STRATEGY_FETCHSGD_BLOCK_NUM_KEY])
@property
def top_k(self) -> int:
"""the number of top items in FetchSGD.
More info available at https://export.arxiv.org/abs/2007.07682.
"""
return int(self._ml_cfg[_STRATEGY_KEY][_STRATEGY_FETCHSGD_TOP_K_KEY])
@property
def block_size(self) -> int:
"""
block size of FedSVD
"""
if self.strategy_name != 'FedSVD':
raise AttributeError
return int(self._strategy_cfg[_STRATEGY_FEDSVD_BLOCK])
@property
def svd_mode(self) -> str:
"""
block size of FedSVD
"""
if self.strategy_name != 'FedSVD':
raise AttributeError
assert self._strategy_cfg[_STRATEGY_FEDSVD_MODE] in ['svd', 'pca', 'lr'], \
f'Unknown FedSVD Mode: {self._strategy_cfg[_STRATEGY_FEDSVD_MODE]}, ' \
f'should be either svd or pca'
return str(self._strategy_cfg[_STRATEGY_FEDSVD_MODE])
@property
def svd_top_k(self) -> int:
"""
block size of FedSVD
"""
if self.strategy_name != 'FedSVD':
raise AttributeError
return int(self._strategy_cfg[_STRATEGY_FEDSVD_TOPK])
@property
def svd_lr_l2(self):
"""
L2 penalize of FedSVD
"""
if self.strategy_name != 'FedSVD':
raise AttributeError
return float(self._strategy_cfg[_STRATEGY_FEDSVD_L2])
@property
def svd_opt_1(self):
if self.strategy_name != 'FedSVD':
raise AttributeError
return str(self._strategy_cfg[_STRATEGY_FEDSVD_OPT_1]).lower() == 'true'
@property
def svd_opt_2(self):
if self.strategy_name != 'FedSVD':
raise AttributeError
return str(self._strategy_cfg[_STRATEGY_FEDSVD_OPT_2]).lower() == 'true'
@property
def svd_evaluate(self):
if self.strategy_name != 'FedSVD':
raise AttributeError
return str(self._strategy_cfg[_STRATEGY_FEDSVD_EVALUATE]).lower() == 'true'
[docs]
class _RT_Machine(_Configuraiton):
__ITEM_CHECK_VALUE_ERROR_PATTERN = 'machine configuraitons should have {}.'
def __init__(self, machine_config: RawConfigurationDict, is_server: bool = False) -> None:
_RT_Machine.__check_items(machine_config, is_server)
super().__init__(machine_config)
self._is_server = is_server
@staticmethod
def __check_items(config: RawConfigurationDict, is_server: bool = False) -> None:
required_keys = [_RT_M_ADDRESS_KEY, _RT_M_WORK_DIR_KEY,
_RT_M_PORT_KEY, _RT_M_USERNAME_KEY, _RT_M_SK_FILENAME_KEY]
for k in required_keys:
assert k in config, ValueError(
_RT_Machine.__ITEM_CHECK_VALUE_ERROR_PATTERN.format(k))
if not is_server:
assert _RT_M_CAPACITY_KEY in config, ValueError(
_RT_Machine.__ITEM_CHECK_VALUE_ERROR_PATTERN.format(_RT_M_CAPACITY_KEY))
@property
def is_server(self) -> bool:
"""if the machine is a central server."""
return self._is_server
@property
def addr(self) -> str:
"""the IP address of this machine or the name of this container in docker."""
return self._inner[_RT_M_ADDRESS_KEY]
@property
def port(self) -> int:
"""the port of this virtual machine on the physical machine."""
return int(self._inner[_RT_M_PORT_KEY])
@property
def username(self) -> str:
"""the username of this machine."""
return self._inner[_RT_M_USERNAME_KEY]
@property
def work_dir_path(self) -> str:
"""the path of this machine's working diretory."""
return self._inner[_RT_M_WORK_DIR_KEY]
@property
def key_filename(self) -> str:
"""the name of ssh connection secret key file."""
return self._inner[_RT_M_SK_FILENAME_KEY]
@property
def capacity(self) -> int:
"""the number of container that this machine can handle.
Only available on the client side.
Raises:
AttributeError: called from the server side.
"""
if self._is_server:
raise AttributeError(
'capacity is inaccessible for the server side.')
return int(self._inner[_RT_M_CAPACITY_KEY])
[docs]
class _RuntimeConfig(_Configuraiton):
__ITEM_CHECK_VALUE_ERROR_PATTERN = 'runtime configurations should have {}.'
__AVAILABLE_LOGGING_LEVELS = {
'NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'}
def __init__(self, runtime_config: RawConfigurationDict = _DEFAULT_RT_CFG) -> None:
_RuntimeConfig.__check_items(runtime_config)
super().__init__(runtime_config)
self.__init_machines()
[docs]
@staticmethod
def _config_filter(config: RawConfigurationDict) -> RawConfigurationDict:
if not config[_RT_COMMUNICATION_KEY][_RT_COMM_LIMIT_FLAG_KEY]:
config[_RT_COMMUNICATION_KEY][_RT_COMM_BANDWIDTH_UP_KEY] = None
config[_RT_COMMUNICATION_KEY][_RT_COMM_BANDWIDTH_DOWN_KEY] = None
config[_RT_COMMUNICATION_KEY][_RT_COMM_LATENCY_KEY] = None
return config
@staticmethod
def __check_items(config: RawConfigurationDict) -> None:
required_keys = [_RT_DOCKER_KEY, _RT_SERVER_KEY,
_RT_COMMUNICATION_KEY, _RT_LOG_KEY]
for k in required_keys:
assert k in config, ValueError(
_RuntimeConfig.__ITEM_CHECK_VALUE_ERROR_PATTERN.format(k))
[docs]
def _has_machines(self) -> bool:
return _RT_MACHINES_KEY in self._inner
def __init_machines(self) -> bool:
if not self._has_machines():
return False
self._machines: Dict[str, _RT_Machine] = dict()
for name in self._inner[_RT_MACHINES_KEY]:
self._machines[name] = _RT_Machine(
self._inner[_RT_MACHINES_KEY][name], name == _RT_M_SERVER_NAME)
return True
@property
def machines(self) -> Optional[Mapping[str, _RT_Machine]]:
"""return a deep copy of all the machines in the configuration.
Returns:
Optional[Mapping[str, _RT_Machine]]: None if there is no machine setting.
"""
if not self._has_machines():
return None
return deepcopy(self._machines)
@property
def client_machines(self) -> Optional[Mapping[str, _RT_Machine]]:
"""return a deep copy of all the client machines in the configuration.
Returns:
Optional[Mapping[str, _RT_Machine]]: None if there is no client machine setting.
"""
if not self._has_machines():
return None
return deepcopy({name: v for name, v in self._machines.items() if not v.is_server})
@property
def server_machine(self):
if not self._has_machines():
return None
server = [v for _, v in self._machines.items() if v.is_server]
assert len(server) == 1, 'The system requires one server'
return deepcopy(server[0])
@property
def limit_network_resource(self) -> bool:
"""whether limit the network resource"""
return bool(self._inner[_RT_COMMUNICATION_KEY][_RT_COMM_LIMIT_FLAG_KEY])
@property
def bandwidth_upload(self) -> str:
"""the bandwidth of each container."""
if not self.limit_network_resource:
raise AttributeError
return self._inner[_RT_COMMUNICATION_KEY][_RT_COMM_BANDWIDTH_UP_KEY]
@property
def bandwidth_download(self) -> str:
"""the bandwidth of each container."""
if not self.limit_network_resource:
raise AttributeError
return self._inner[_RT_COMMUNICATION_KEY][_RT_COMM_BANDWIDTH_DOWN_KEY]
@property
def latency(self) -> str:
"""the latency of each container."""
if not self.limit_network_resource:
raise AttributeError
return self._inner[_RT_COMMUNICATION_KEY][_RT_COMM_LATENCY_KEY]
@property
def image_label(self) -> str:
"""the label of the docker image used in this experiment."""
return self._inner[_RT_DOCKER_KEY][_RT_D_IMAGE_LABEL_KEY]
@property
def container_num(self) -> int:
"""the number of total docker containers in this experiment."""
return int(self._inner[_RT_DOCKER_KEY][_RT_D_CONTAINER_NUM_KEY])
@property
def central_server_addr(self) -> str:
"""the IP address of the central server."""
return self._inner[_RT_SERVER_KEY][_RT_S_HOST_KEY]
@property
def central_server_listen_at(self) -> str:
"""the listening IP address of the flask services on the cetral server side."""
return self._inner[_RT_SERVER_KEY][_RT_S_LISTEN_KEY]
@property
def central_server_port(self) -> int:
"""the port that the central server occupies."""
return int(self._inner[_RT_SERVER_KEY][_RT_S_PORT_KEY])
@property
def client_num(self) -> int:
"""the total number of the clients."""
return int(self._inner[_RT_SERVER_KEY][_RT_S_CLIENTS_NUM_KEY])
[docs]
@staticmethod
def _check_log_level_validity(level: str) -> None:
"""make sure the given string is one of the logging levels.
Args:
level (str): a string representation of a logging level.
Raises:
ValueError: the given string is not a valid logging level.
"""
if level not in _RuntimeConfig.__AVAILABLE_LOGGING_LEVELS:
raise ValueError(
f'invalid logging level, available choices: {_RuntimeConfig.__AVAILABLE_LOGGING_LEVELS}')
@property
def base_log_level(self) -> str:
"""the base logging level of all the loggers."""
lvl = self._inner[_RT_LOG_KEY][_RT_L_BASE_LEVEL_KEY]
_RuntimeConfig._check_log_level_validity(lvl)
return lvl
@property
def file_log_level(self) -> str:
"""the logging level in the log files."""
lvl = self._inner[_RT_LOG_KEY][_RT_L_FILE_LEVEL_KEY]
_RuntimeConfig._check_log_level_validity(lvl)
return lvl
@property
def console_log_level(self) -> str:
"""the logging level in consoles."""
lvl = self._inner[_RT_LOG_KEY][_RT_L_CONSOLE_LEVEL_KEY]
_RuntimeConfig._check_log_level_validity(lvl)
return lvl
@property
def secret_key(self) -> str:
"""the secret key of the flask service on the central server side.
Returns:
str: the secret key as a string.
"""
return self._inner[_RT_SERVER_KEY][_RT_S_SECRET_KEY]
@property
def gpu_enabled(self) -> bool:
"""whether the GPU is enabled in this experiment."""
return bool(self._inner[_RT_DOCKER_KEY][_RT_D_GPU_ENABLE_KEY])
@property
def gpu_num(self) -> int:
"""the number of GPUs.
Raises:
AttributeError: called without GPUs enabled.
"""
if not self.gpu_enabled:
raise AttributeError('GPU is not enabled.')
return int(self._inner[_RT_DOCKER_KEY][_RT_D_GPU_NUM_KEY])
@property
def comm_method(self) -> str:
"""the method/technique used for mechaine-wise communication in the experiment."""
return self._inner[_RT_COMMUNICATION_KEY][_RT_COMM_METHOD_KEY]
@property
def comm_port(self) -> int:
"""the port for communication on the server side."""
return int(self._inner[_RT_COMMUNICATION_KEY][_RT_COMM_PORT_KEY])
@property
def comm_fast_mode(self) -> bool:
"""
In fast mode, all the clients in one container will only download the parameters once
to improve the efficiency, e.g., when tuning the parameters.
Turn off the fast_mode if you are benchmarking the communication and time
"""
return bool(self._inner[_RT_COMMUNICATION_KEY][_RT_COMM_FAST_MODE])
# --- Configuration Manager Interfaces ---
[docs]
class ConfigurationManagerInterface(ABC):
@abstractproperty
def data_config_filename(self) -> str:
raise NotImplementedError
@abstractproperty
def model_config_filename(self) -> str:
raise NotImplementedError
@abstractproperty
def runtime_config_filename(self) -> str:
raise NotImplementedError
@abstractproperty
def data_config(self) -> RawConfigurationDict:
raise NotImplementedError
@abstractproperty
def model_config(self) -> _ModelConfig:
raise NotImplementedError
@abstractproperty
def runtime_config(self) -> RawConfigurationDict:
raise NotImplementedError
@abstractproperty
def job_id(self):
raise NotImplementedError
[docs]
class ClientConfigurationManagerInterface(ABC):
"""an interface of ConfigurationManager from the client side,
regulating the essential functions as clients.
Raises:
NotImplementedError: called without implementation.
"""
pass
[docs]
class ServerConfigurationManagerInterface(ABC):
"""an interface of ConfigurationManager from the central server side,
regulating the essential functions as clients.
Raises:
NotImplementedError: called without implementation.
"""
@abstractproperty
def num_of_train_clients_contacted_per_round(self) -> int:
raise NotImplementedError
# --- Configuration Serilizer Interfaces ---
_DEFAULT_ENCODING = 'utf-8'
_Stream = Union[str, bytes, TextIO]
[docs]
class _CfgYamlInterface(ABC):
"""an interface that regulates the methods used to serialize
and deserialize configuraitons in YAML.
"""
[docs]
@staticmethod
def load_configs(
src_path,
data_config_filename: str = DEFAULT_D_CFG_FILENAME_YAML,
model_config_filename: str = DEFAULT_MDL_CFG_FILENAME_YAML,
runtime_config_filename: str = DEFAULT_RT_CFG_FILENAME_YAML,
encoding=_DEFAULT_ENCODING
) -> Tuple[RawConfigurationDict, RawConfigurationDict, RawConfigurationDict]:
_d_cfg_path = os.path.join(src_path, data_config_filename)
_mdl_cfg_path = os.path.join(src_path, model_config_filename)
_rt_cfg_path = os.path.join(src_path, runtime_config_filename)
with open(_d_cfg_path, 'r', encoding=encoding) as f:
d_cfg = yaml.safe_load(f)
with open(_mdl_cfg_path, 'r', encoding=encoding) as f:
mdl_cfg = yaml.safe_load(f)
with open(_rt_cfg_path, 'r', encoding=encoding) as f:
rt_cfg = yaml.safe_load(f)
return d_cfg, mdl_cfg, rt_cfg
[docs]
@staticmethod
def save_configs(
data_cfg: RawConfigurationDict,
model_cfg: RawConfigurationDict,
runtime_cfg: RawConfigurationDict,
dst_path,
data_config_filename: str = DEFAULT_D_CFG_FILENAME_YAML,
model_config_filename: str = DEFAULT_MDL_CFG_FILENAME_YAML,
runtime_config_filename: str = DEFAULT_RT_CFG_FILENAME_YAML,
encoding=_DEFAULT_ENCODING
) -> None:
os.makedirs(dst_path, exist_ok=True)
_d_cfg_path = os.path.join(dst_path, data_config_filename)
_mdl_cfg_path = os.path.join(dst_path, model_config_filename)
_rt_cfg_path = os.path.join(dst_path, runtime_config_filename)
with open(_d_cfg_path, 'w', encoding=encoding) as f:
yaml.dump(data_cfg, f)
with open(_mdl_cfg_path, 'w', encoding=encoding) as f:
yaml.dump(model_cfg, f)
with open(_rt_cfg_path, 'w', encoding=encoding) as f:
yaml.dump(runtime_cfg, f)
[docs]
class _CfgJsonInterface(ABC):
"""an interface that regulates the methods used to serialize
and deserialize configuraitons in JSON.
"""
[docs]
@staticmethod
def load_configs(
src_path,
data_config_filename: str = DEFAULT_D_CFG_FILENAME_JSON,
model_config_filename: str = DEFAULT_MDL_CFG_FILENAME_JSON,
runtime_config_filename: str = DEFAULT_RT_CFG_FILENAME_JSON,
encoding=_DEFAULT_ENCODING
) -> Tuple[RawConfigurationDict, RawConfigurationDict, RawConfigurationDict]:
_d_cfg_path = os.path.join(src_path, data_config_filename)
_mdl_cfg_path = os.path.join(src_path, model_config_filename)
_rt_cfg_path = os.path.join(src_path, runtime_config_filename)
with open(_d_cfg_path, 'r', encoding=encoding) as f:
d_cfg = json.load(f)
with open(_mdl_cfg_path, 'r', encoding=encoding) as f:
mdl_cfg = json.load(f)
with open(_rt_cfg_path, 'r', encoding=encoding) as f:
rt_cfg = json.load(f)
return d_cfg, mdl_cfg, rt_cfg
[docs]
@staticmethod
def save_configs(
data_cfg: RawConfigurationDict,
model_cfg: RawConfigurationDict,
runtime_cfg: RawConfigurationDict,
dst_path,
data_config_filename: str = DEFAULT_D_CFG_FILENAME_YAML,
model_config_filename: str = DEFAULT_MDL_CFG_FILENAME_YAML,
runtime_config_filename: str = DEFAULT_RT_CFG_FILENAME_YAML,
encoding=_DEFAULT_ENCODING
) -> None:
os.makedirs(dst_path, exist_ok=True)
_d_cfg_path = os.path.join(dst_path, data_config_filename)
_mdl_cfg_path = os.path.join(dst_path, model_config_filename)
_rt_cfg_path = os.path.join(dst_path, runtime_config_filename)
with open(_d_cfg_path, 'w', encoding=encoding) as f:
json.dump(data_cfg, f)
with open(_mdl_cfg_path, 'w', encoding=encoding) as f:
json.dump(model_cfg, f)
with open(_rt_cfg_path, 'w', encoding=encoding) as f:
json.dump(runtime_cfg, f)
[docs]
class _CfgSerializer(Enum):
"""types of serializer for configurations."""
YAML = 'yaml'
JSON = 'json'
[docs]
class _CfgFileInterface(ABC):
"""an interface that regulates the methods used to serialize
and deserialize configuraitons from the file system.
"""
[docs]
@staticmethod
@abstractmethod
def from_files(from_config_path: str,
serializer: Union[str, _CfgSerializer] = _CfgSerializer.YAML,
encoding=_DEFAULT_ENCODING) -> ConfigurationManagerInterface:
raise NotImplementedError
[docs]
@abstractmethod
def to_files(self,
dst_dir_path: str,
serializer: Union[str, _CfgSerializer] = _CfgSerializer.YAML,
encoding: Optional[str] = None) -> None:
raise NotImplementedError
[docs]
@staticmethod
def serializer2enum(serializer: Union[str, _CfgSerializer]) -> _CfgSerializer:
"""convert serializer name(string) into enum type"""
if isinstance(serializer, str):
try:
serializer = _CfgSerializer(serializer)
except ValueError:
raise ValueError(f'{serializer} is not supported currently.')
if not isinstance(serializer, _CfgSerializer):
raise ValueError(f'invalid serializer type: {serializer.__class__.__name__}.')
return serializer
# --- Role-related Configuration Interface ---
[docs]
class _RoledConfigurationInterface(ABC):
@abstractproperty
def role(self) -> Role:
raise NotImplementedError
# --- Configuration Manager ---
[docs]
class ConfigurationManager(Singleton,
ConfigurationManagerInterface,
ClientConfigurationManagerInterface,
ServerConfigurationManagerInterface,
_CfgYamlInterface,
_CfgJsonInterface,
_CfgFileInterface,
_RoledConfigurationInterface):
__init_once_lock = RLock() # thread lock for __initiated
__initiated = False # whether this class has been initiated
def __init__(self,
data_config: RawConfigurationDict = _DEFAULT_D_CFG,
model_config: RawConfigurationDict = _DEFAULT_MDL_CFG,
runtime_config: RawConfigurationDict = _DEFAULT_RT_CFG,
thread_safe: bool = False) -> None:
with ConfigurationManager.__init_once_lock:
if not ConfigurationManager.__initiated:
super().__init__(thread_safe)
self._d_cfg: _DataConfig = _DataConfig(data_config)
self._mdl_cfg: _ModelConfig = _ModelConfig(model_config)
self._rt_cfg: _RuntimeConfig = _RuntimeConfig(runtime_config)
self._job_time = os.environ.get('UNIFIED_JOB_TIME', datetime.datetime.now().strftime('%Y_%m%d_%H%M%S'))
self._init_file_names()
self._encoding = _DEFAULT_ENCODING
self.__init_role()
# Set random seeds
import tensorflow as tf
import numpy as np
tf.random.set_seed(self._d_cfg.random_seed)
np.random.seed(self._d_cfg.random_seed)
ConfigurationManager.__initiated = True
[docs]
def _init_file_names(self,
data_config_filename: str = DEFAULT_D_CFG_FILENAME_YAML,
model_config_filename: str = DEFAULT_MDL_CFG_FILENAME_YAML,
runtime_config_filename: str = DEFAULT_RT_CFG_FILENAME_YAML) -> None:
self._d_cfg_filename = data_config_filename
self._mdl_cfg_filename = model_config_filename
self._rt_cfg_filename = runtime_config_filename
# TODO(fgh) add unit tests for this method in test_config.py
@property
def data_unique_id(self):
# Collect the configs that determine the datasets,
# and generate a unique ID
data_unique_configs = [f'{key}={value}' for key, value in self._d_cfg.inner.items()]
data_unique_configs += [
f'ml_model={self._mdl_cfg.ml_method_name}',
f'n_clients={self._rt_cfg.client_num}',
]
data_unique_configs = sorted(data_unique_configs)
return self._get_md5(','.join(data_unique_configs))
@property
def config_unique_id(self):
return self.generate_unique_id(self._d_cfg.inner, self._mdl_cfg.inner, self._rt_cfg.inner)
[docs]
@classmethod
def generate_unique_id(cls, data_config: dict, model_config: dict, runtime_config: dict):
unique_configs = [
json.dumps(data_config, sort_keys=True),
json.dumps(model_config, sort_keys=True),
json.dumps(runtime_config, sort_keys=True)
]
return cls._get_md5(','.join(unique_configs))
[docs]
@staticmethod
def _get_md5(config_string):
# Creat the hash code
hl = hashlib.md5()
hl.update(config_string.encode(encoding='utf-8'))
return hl.hexdigest()
@property
def data_dir_name(self) -> str:
"""The output directory of the clients' data.
Returns:
str: the name of the data directory.
"""
return os.path.join(self._d_cfg.inner[_D_DIR_KEY], f'{self._d_cfg.dataset_name}_{self.data_unique_id[:10]}')
@property
def log_dir_path(self) -> str:
"""the path of the base of log directory."""
return os.path.join(
self._rt_cfg.inner[_RT_LOG_KEY][_RT_L_DIR_PATH_KEY],
self._job_time + '_' + self.config_unique_id
)
@property
def history_record_path(self) -> str:
"""the path of the history record."""
return self._rt_cfg.inner[_RT_LOG_KEY][_RT_L_DIR_PATH_KEY]
@property
@Singleton.thread_safe_ensurance
def job_id(self) -> str:
return str(self._job_time)
@property
@Singleton.thread_safe_ensurance
def encoding(self) -> str:
"""the encoding scheme during (de)serialization."""
return self._encoding
@encoding.setter
@Singleton.thread_safe_ensurance
def encoding(self, encoding):
self._encoding = encoding
@property
@Singleton.thread_safe_ensurance
def data_config_filename(self) -> str:
return self._d_cfg_filename
@data_config_filename.setter
@Singleton.thread_safe_ensurance
@check_filename(1)
def data_config_filename(self, filename: str):
self._d_cfg_filename = filename
@property
@Singleton.thread_safe_ensurance
def model_config_filename(self) -> str:
return self._mdl_cfg_filename
@model_config_filename.setter
@Singleton.thread_safe_ensurance
@check_filename(1)
def model_config_filename(self, filename: str) -> None:
self._mdl_cfg_filename = filename
@property
@Singleton.thread_safe_ensurance
def runtime_config_filename(self) -> str:
return self._rt_cfg_filename
@runtime_config_filename.setter
@Singleton.thread_safe_ensurance
@check_filename(1)
def runtime_config_filename(self, filename: str) -> None:
self._rt_cfg_filename = filename
@property
def data_config(self) -> _DataConfig:
return self._d_cfg
@property
def model_config(self) -> _ModelConfig:
return self._mdl_cfg
@property
def runtime_config(self) -> _RuntimeConfig:
return self._rt_cfg
@property
def num_of_train_clients_contacted_per_round(self) -> int:
"""the number of clients selected to participate the main
federated process in each round.
"""
return max(1, int(self._rt_cfg.client_num * self._mdl_cfg.C))
@property
def num_of_eval_clients_contacted_per_round(self) -> int:
"""the number of clients selected to participate the main
federated process in each round.
"""
return max(1, int(self._rt_cfg.client_num * self._mdl_cfg.evaluate_ratio))
[docs]
@staticmethod
def load_configs(
src_path,
serializer: Union[str, _CfgSerializer] = _CfgSerializer.YAML,
data_config_filename: str = DEFAULT_D_CFG_FILENAME_YAML,
model_config_filename: str = DEFAULT_MDL_CFG_FILENAME_YAML,
runtime_config_filename: str = DEFAULT_RT_CFG_FILENAME_YAML,
encoding=_DEFAULT_ENCODING
) -> Tuple[RawConfigurationDict, RawConfigurationDict, RawConfigurationDict]:
serializer = _CfgFileInterface.serializer2enum(serializer)
if serializer == _CfgSerializer.YAML:
return _CfgYamlInterface.load_configs(
src_path, encoding=encoding,
data_config_filename=data_config_filename,
model_config_filename=model_config_filename,
runtime_config_filename=runtime_config_filename
)
elif serializer == _CfgSerializer.JSON:
return _CfgYamlInterface.load_configs(
src_path, encoding=encoding,
data_config_filename=data_config_filename,
model_config_filename=model_config_filename,
runtime_config_filename=runtime_config_filename
)
else:
raise NotImplementedError(f'Invalid serializer {serializer}')
[docs]
@staticmethod
def save_configs(
data_cfg: RawConfigurationDict,
model_cfg: RawConfigurationDict,
runtime_cfg: RawConfigurationDict,
dst_path,
data_config_filename: str = DEFAULT_D_CFG_FILENAME_YAML,
model_config_filename: str = DEFAULT_MDL_CFG_FILENAME_YAML,
runtime_config_filename: str = DEFAULT_RT_CFG_FILENAME_YAML,
encoding=_DEFAULT_ENCODING,
serializer: Union[str, _CfgSerializer] = _CfgSerializer.YAML,
):
serializer = _CfgFileInterface.serializer2enum(serializer)
if serializer == _CfgSerializer.YAML:
return _CfgYamlInterface.save_configs(
data_cfg=data_cfg, model_cfg=model_cfg, runtime_cfg=runtime_cfg,
dst_path=dst_path, encoding=encoding,
data_config_filename=data_config_filename,
model_config_filename=model_config_filename,
runtime_config_filename=runtime_config_filename
)
elif serializer == _CfgSerializer.JSON:
return _CfgJsonInterface.save_configs(
data_cfg=data_cfg, model_cfg=model_cfg, runtime_cfg=runtime_cfg,
dst_path=dst_path, encoding=encoding,
data_config_filename=data_config_filename,
model_config_filename=model_config_filename,
runtime_config_filename=runtime_config_filename
)
else:
raise NotImplementedError(f'Invalid serializer {serializer}')
[docs]
@staticmethod
def from_files(
src_path: str,
data_config_filename: str = DEFAULT_D_CFG_FILENAME_YAML,
model_config_filename: str = DEFAULT_MDL_CFG_FILENAME_YAML,
runtime_config_filename: str = DEFAULT_RT_CFG_FILENAME_YAML,
serializer: Union[str, _CfgSerializer] = _CfgSerializer.YAML,
encoding=_DEFAULT_ENCODING
):
d_cfg, m_cfg, r_cfg = ConfigurationManager.load_configs(
src_path=src_path, encoding=encoding, serializer=serializer,
data_config_filename=data_config_filename,
model_config_filename=model_config_filename,
runtime_config_filename=runtime_config_filename
)
return ConfigurationManager(d_cfg, m_cfg, r_cfg)
[docs]
def to_files(
self,
dst_dir_path: str,
serializer: Union[str, _CfgSerializer] = _CfgSerializer.YAML,
encoding: Optional[str] = None
) -> None:
serializer = _CfgFileInterface.serializer2enum(serializer)
d_cfg = self.data_config.inner
mdl_cfg = self.model_config.inner
rt_cfg = self.runtime_config.inner
d_filename = self.data_config_filename
mdl_filename = self.model_config_filename
rt_filename = self.runtime_config_filename
encoding = encoding or self.encoding
if serializer == _CfgSerializer.YAML:
return _CfgYamlInterface.save_configs(
d_cfg, mdl_cfg, rt_cfg,
dst_dir_path,
d_filename, mdl_filename, rt_filename,
encoding=encoding)
elif serializer == _CfgSerializer.JSON:
return _CfgJsonInterface.save_configs(
d_cfg, mdl_cfg, rt_cfg,
dst_dir_path,
d_filename, mdl_filename, rt_filename,
encoding=encoding)
else:
raise NotImplementedError(f'Invalid serializer {serializer}')
def __init_role(self) -> None:
self._role: Optional[Role] = None
@property
@Singleton.thread_safe_ensurance
def role(self) -> Role:
"""return the role of this runtime entity.
Raises:
AttributeError: called without role configured.
Returns:
Role: the role of this runtime entity.
"""
if self._role is None:
raise AttributeError('the role of this node has not been set yet.')
return self._role
@role.setter
@Singleton.thread_safe_ensurance
def role(self, role: Role):
"""set the role of this runtime entity.
This method should be called only once.
It is recommoned to be set as soon as the role of this runtime could be known.
Args:
role (Role): the role which this entity should be.
Raises:
AttributeError: called more than once.
"""
if self._role is not None:
raise AttributeError('the role of a node can only be set once.')
self._role = role