{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Understanding the **FL mode** in ``armlet``\n", "\n", "This tutorial allows users to understand how the **FL mode** of **ARMLET** works (i.e., the default mode that serves to run federated learning experiments).\n", "It corresponds to the function that is called when running the command `armlet exp.mode=federation` (or just `armlet` as it is the default mode).\n", "\n", "## Prerequisites (environment configuration and installation)\n", "\n", "If you have not configured the environment and installed ``armlet`` yet, you have to:\n", "\n", "1. Install [conda](https://www.anaconda.com/docs/getting-started/main) for managing the environments and run the following commands:\n", "\n", "```bash\n", "conda create -n armlet python=3.13.5\n", "conda activate armlet\n", "```\n", "\n", "2. Install **``armlet``** using `pip`:\n", "\n", "```bash\n", "cd ARMLET_DIR\n", "pip install .\n", "```\n", "\n", "Then, fill the two following variables to detail the main paths and run the command." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "ARMLET_DIR = \"../../../../\"\n", "OUTPUT_DIR = os.path.join(ARMLET_DIR, \"outputs\", \"tutorial\", \"FL_mode\")\n", "\n", "if not os.path.exists(OUTPUT_DIR):\n", " os.makedirs(OUTPUT_DIR)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prepare the configuration\n", "\n", "Before calling the [run_federation()](https://sara-bouchenak.github.io/ARMLET/api/armlet.run.html#armlet.run.run_federation) main function (in `ARMLET_DIR/armlet/run.py`), **ARMLET** uses the [Hydra](https://hydra.cc/) framework to dynamically load the configurations. For further details about **ARMLET** configuration (such as the configuration groups and values), see [Configuration](https://sara-bouchenak.github.io/ARMLET/user_guide/config/index.html) documentation page.\n", "\n", "In this tutorial, we directly prepare the configuration object in Python by creating dictionaries and using [OmegaConf](https://omegaconf.readthedocs.io/en/2.3_branch/), the YAML based hierarchical configuration system used in Hydra.\n", "\n", "The first config dictionary to create is the `cfg_paths`, which contains the main paths." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from omegaconf import OmegaConf\n", "\n", "cfg_paths = OmegaConf.create({\n", " \"data_dir\": os.path.join(ARMLET_DIR, \"datasets\"),\n", " \"log_dir\": os.path.join(ARMLET_DIR, \"logs\"),\n", " \"output_dir\": OUTPUT_DIR,\n", " \"root_dir\": ARMLET_DIR,\n", "})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, we need to provide in `cfg_data` information related to the data pipeline (dataset, data cleaning methods, train, validation, and test set splitting, distribution for the data partitioning, many other data processing, and data seed).\n", "With the following configurations, we will:\n", "(a) load the Adult dataset;\n", "(b) perform an IID data partitioning between X clients (the number of clients will be detailed after);\n", "(c) split each client dataset into a training set (80% of the client data) and a test set (20%);\n", "(d) concatenate the union of the clients test sets to form the server test set;\n", "(e) clean all data by removing missing values;\n", "and (f) perform standard preprocessing for tabular data." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "cfg_data = OmegaConf.create({\n", " \"cleaning\": {\n", " \"name\": \"default\",\n", " \"missing_values\": {\"_target_\": \"armlet.data.cleaning.missing_values.MissingValuesDataCleaningMethod\"},\n", " },\n", " \"dataset\": {\n", " \"_target_\": \"armlet.data.datasets.load_Adult_dataset\",\n", " \"dataset_name\": \"Adult\",\n", " \"path\": os.path.join(cfg_paths[\"data_dir\"], \"Adult\", \"raw_data\"),\n", " \"sensitive_attributes\": ['age', 'gender', 'race'],\n", " },\n", " \"distribution\": {\"_target_\": \"armlet.data.splitter.ArmletDataSplitter.iid\"},\n", " \"others\": {\n", " \"client_split\": 0.2,\n", " \"client_val_split\": 0.0,\n", " \"keep_test\": False,\n", " \"sampling_perc\": 1.0,\n", " \"server_split\": 0.0,\n", " \"server_test\": False,\n", " \"server_test_union\": True,\n", " \"server_val_split\": 0.0,\n", " \"uniform_test\": False,\n", " },\n", " \"processing\": {\n", " \"one_hot_encoding\": {\n", " \"_apply_directly_to_subdata_\": False,\n", " \"_target_\": \"armlet.data.processing.feature_encoding.one_hot_encoding_pipeline\",\n", " },\n", " \"conversion_to_num\": {\n", " \"_apply_directly_to_subdata_\": True,\n", " \"_target_\": \"armlet.data.processing.format_conversion.convert_bool_and_cat_to_num\",\n", " },\n", " \"normalization\": {\n", " \"_apply_directly_to_subdata_\": False,\n", " \"_target_\": \"armlet.data.processing.normalization.normalization_pipeline\",\n", " \"cols_to_exclude\": ['age', 'gender', 'race'],\n", " },\n", " \"conversion_to_tensors\": {\n", " \"_apply_directly_to_subdata_\": True,\n", " \"_target_\": \"armlet.data.processing.format_conversion.convert_dataframes_to_tensors\",\n", " \"sensitive_attributes\": ['age', 'gender', 'race'],\n", " },\n", " },\n", " \"seed\": 42,\n", "})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`cfg_eval` contains the configuration values related to the evaluation.\n", "In this tutorial, we use the [armlet.eval.evaluators.MultiCriteriaBinaryClassEval](https://sara-bouchenak.github.io/ARMLET/api/armlet.eval.evaluators.html#armlet.eval.evaluators.MultiCriteriaBinaryClassEval) evaluation class and evaluate the global model at each round (on the server side with the server test set)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "cfg_eval = OmegaConf.create({\n", " \"_target_\": \"armlet.eval.evaluators.MultiCriteriaBinaryClassEval\",\n", " \"eval_every\": 1,\n", " \"locals\": False,\n", " \"post_fit\": False,\n", " \"pre_fit\": False,\n", " \"server\": True,\n", "})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`cfg_exp` details information about how to run the experiment.\n", "Here, we choose the federation mode, enable training, specify cpu as the device, and fix an experiment seed." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "cfg_exp = OmegaConf.create({\n", " \"device\": \"cpu\",\n", " \"inmemory\": True,\n", " \"mode\": \"federation\",\n", " \"seed\": 42,\n", " \"train\": True,\n", "})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`cfg_logger` is required to specify the logging class to be used during the experiment.\n", "In this tutorial, we use [armlet.utils.log.ArmletLog](https://sara-bouchenak.github.io/ARMLET/api/armlet.utils.log.html#armlet.utils.log.ArmletLog), the minimal logger provided by **ARMLET** (it only saves the results in a JSON file)." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "cfg_logger = OmegaConf.create({\n", " \"_target_\": \"armlet.utils.log.ArmletLog\",\n", " \"json_log_dir\": cfg_paths[\"output_dir\"],\n", "})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`cfg_method` is used to specify the FL algorithm, the ML model, the client hyperparameters, and the server behavior.\n", "Here, we use [armlet.FL_pipeline.FL_algorithms.ArmletCentralized](https://sara-bouchenak.github.io/ARMLET/api/armlet.FL_pipeline.FL_algorithms.html#armlet.FL_pipeline.FL_algorithms.ArmletCentralizedFL), i.e., the standard FedAvg algorithm adapted to **ARMLET** pipeline, and consider a logistic regression model.\n", "Moreover, at each round, each client performs 10 local epochs by using a SGD optimizer with a learning rate of 0.001 and a weight decay of 0.01." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "cfg_method = OmegaConf.create({\n", " \"_target_\": \"armlet.FL_pipeline.FL_algorithms.ArmletCentralizedFL\",\n", " \"hyperparameters\": {\n", " \"client\": {\n", " \"batch_size\": 128,\n", " \"local_epochs\": 10,\n", " \"loss\": {\"_target_\": \"torch.nn.BCELoss\"},\n", " \"optimizer\": {\"lr\": 0.001, \"name\": \"SGD\", \"weight_decay\": 0.01},\n", " \"scheduler\": {\"gamma\": 1, \"name\": \"StepLR\", \"step_size\": 1},\n", " },\n", " \"model\": {\n", " \"_target_\": \"armlet.utils.net.LogRegression\",\n", " \"input_size\": None, # Automatically adjusted after data loading\n", " \"num_classes\": None, # Automatically adjusted after data loading\n", " },\n", " \"server\": {\n", " \"loss\": {\"_target_\": \"torch.nn.BCELoss\"},\n", " \"time_to_accuracy_target\": None,\n", " \"weighted\": True,\n", " },\n", " },\n", "})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`cfg_protocol` details the protocol of the FL process.\n", "With these configuration values, we run a FL process with 4 clients, for a total of 2 rounds (as a toy example).\n", "During each round, all clients participate to training." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "cfg_protocol = OmegaConf.create({\n", " \"eligible_perc\": 1.0,\n", " \"n_clients\": 4,\n", " \"n_rounds\": 2,\n", "})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we compose each of the previous configuration dictionaries. " ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from armlet.utils.configs import ArmletConfiguration\n", "\n", "cfg = OmegaConf.create({\n", " \"data\": cfg_data,\n", " \"eval\": cfg_eval,\n", " \"exp\": cfg_exp,\n", " \"logger\": cfg_logger,\n", " \"method\": cfg_method,\n", " \"paths\": cfg_paths,\n", " \"protocol\": cfg_protocol,\n", " \"save\": {},\n", "})\n", "\n", "cfg = ArmletConfiguration(cfg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we are ready to investigate the `run_federation()` function, which is called when `exp.mode=federation`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run the data pipeline\n", "\n", "First, **ARMLET** runs the data pipeline to compute the [DataSplitter](https://makgyver.github.io/fluke/fluke.data.html#fluke.data.DataSplitter), i.e., the Fluke's Python object that contains all data after processing (launching, data cleaning, train, validation, and test sets splitting, data partitioning, normalization, etc.).\n", "As the data pipeline is quite complex, we do not detail it in this tutorial. However, you can find some information about it in the following tutorial: [Understanding the **data pipeline** in `armlet`](https://sara-bouchenak.github.io/ARMLET/getting_started/tutorials/data_pipeline.html)." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from armlet.data import data_pipeline\n", "\n", "data_splitter, val_data = data_pipeline(cfg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Adjust the configurations and save them\n", "\n", "Once we have prepared data, we setup the Fluke environment and adjust some configuration values (which are related to data, such as the number of classes and the input size of the ML model).\n", "We can now save the configuration in a YAML file for allowing reproductibility." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "from fluke import FlukeENV\n", "\n", "FlukeENV().configure(cfg)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "input_size = data_splitter.data_container.clients_tr[0].tensors[0].shape[-1]\n", "cfg.method.hyperparameters.model.input_size = input_size\n", "if data_splitter.data_container.num_classes <= 2:\n", " cfg.method.hyperparameters.model.num_classes = 1 \n", "else:\n", " cfg.method.hyperparameters.model.num_classes = data_splitter.data_container.num_classes" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "import yaml\n", "\n", "config_path = os.path.join(cfg.paths.output_dir, \"config.yaml\")\n", "cfg_to_save = cfg.to_dict()\n", "cfg_to_save[\"paths\"][\"output_dir\"] = \"${hydra:runtime.output_dir}\"\n", "if \"json_log_dir\" in cfg_to_save[\"logger\"].keys():\n", " cfg_to_save[\"logger\"][\"json_log_dir\"] = \"${paths.output_dir}\"\n", "yaml.dump(cfg_to_save, open(config_path, \"w\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Instantiate the FL algorithm, evaluator, and logger\n", "\n", "Then, we instantiate the FL algorithm we specified in the `cfg_method` dictionary by using Hydra.\n", "We do the same for the evaluator and the logger with the `cfg_eval` and `cfg_logger` dictionaries, and pass them to the Fluke environment." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "import hydra\n", "\n", "fl_algo = hydra.utils.instantiate(\n", " cfg.method,\n", " n_clients=cfg.protocol.n_clients,\n", " data_splitter=data_splitter,\n", " val_data=val_data,\n", " _convert_=\"all\",\n", " _recursive_=False,\n", ")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "evaluator = hydra.utils.instantiate(\n", " cfg.eval.exclude(\"locals\", \"post_fit\", \"pre_fit\", \"server\"), \n", " n_classes=data_splitter.data_container.num_classes,\n", " sensitive_attributes=cfg.data.dataset.sensitive_attributes,\n", ")\n", "FlukeENV().set_evaluator(evaluator)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
╭───────────────────────────────────────── Configuration ──────────────────────────────────────────╮\n",
       "│ {                                                                                                │\n",
       "│     'save': {},                                                                                  │\n",
       "│     'data': {                                                                                    │\n",
       "│         'cleaning': {                                                                            │\n",
       "│             'name': 'default',                                                                   │\n",
       "│             'missing_values': {                                                                  │\n",
       "│                 '_target_':                                                                      │\n",
       "│ 'armlet.data.cleaning.missing_values.MissingValuesDataCleaningMethod'                            │\n",
       "│             }                                                                                    │\n",
       "│         },                                                                                       │\n",
       "│         'dataset': {                                                                             │\n",
       "│             '_target_': 'armlet.data.datasets.load_Adult_dataset',                               │\n",
       "│             'dataset_name': 'Adult',                                                             │\n",
       "│             'path': '../../../../datasets/Adult/raw_data',                                       │\n",
       "│             'sensitive_attributes': [                                                            │\n",
       "│                 'age',                                                                           │\n",
       "│                 'gender',                                                                        │\n",
       "│                 'race'                                                                           │\n",
       "│             ]                                                                                    │\n",
       "│         },                                                                                       │\n",
       "│         'distribution': {                                                                        │\n",
       "│             '_target_': 'armlet.data.splitter.ArmletDataSplitter.iid'                            │\n",
       "│         },                                                                                       │\n",
       "│         'others': {                                                                              │\n",
       "│             'client_split': 0.2,                                                                 │\n",
       "│             'client_val_split': 0.0,                                                             │\n",
       "│             'keep_test': False,                                                                  │\n",
       "│             'sampling_perc': 1.0,                                                                │\n",
       "│             'server_split': 0.0,                                                                 │\n",
       "│             'server_test': False,                                                                │\n",
       "│             'server_test_union': True,                                                           │\n",
       "│             'server_val_split': 0.0,                                                             │\n",
       "│             'uniform_test': False                                                                │\n",
       "│         },                                                                                       │\n",
       "│         'processing': {                                                                          │\n",
       "│             'one_hot_encoding': {                                                                │\n",
       "│                 '_apply_directly_to_subdata_': False,                                            │\n",
       "│                 '_target_': 'armlet.data.processing.feature_encoding.one_hot_encoding_pipeline'  │\n",
       "│             },                                                                                   │\n",
       "│             'conversion_to_num': {                                                               │\n",
       "│                 '_apply_directly_to_subdata_': True,                                             │\n",
       "│                 '_target_':                                                                      │\n",
       "│ 'armlet.data.processing.format_conversion.convert_bool_and_cat_to_num'                           │\n",
       "│             },                                                                                   │\n",
       "│             'normalization': {                                                                   │\n",
       "│                 '_apply_directly_to_subdata_': False,                                            │\n",
       "│                 '_target_': 'armlet.data.processing.normalization.normalization_pipeline',       │\n",
       "│                 'cols_to_exclude': [                                                             │\n",
       "│                     'age',                                                                       │\n",
       "│                     'gender',                                                                    │\n",
       "│                     'race'                                                                       │\n",
       "│                 ]                                                                                │\n",
       "│             },                                                                                   │\n",
       "│             'conversion_to_tensors': {                                                           │\n",
       "│                 '_apply_directly_to_subdata_': True,                                             │\n",
       "│                 '_target_':                                                                      │\n",
       "│ 'armlet.data.processing.format_conversion.convert_dataframes_to_tensors',                        │\n",
       "│                 'sensitive_attributes': [                                                        │\n",
       "│                     'age',                                                                       │\n",
       "│                     'gender',                                                                    │\n",
       "│                     'race'                                                                       │\n",
       "│                 ]                                                                                │\n",
       "│             }                                                                                    │\n",
       "│         },                                                                                       │\n",
       "│         'seed': 42                                                                               │\n",
       "│     },                                                                                           │\n",
       "│     'eval': {                                                                                    │\n",
       "│         '_target_': 'armlet.eval.evaluators.MultiCriteriaBinaryClassEval',                       │\n",
       "│         'eval_every': 1,                                                                         │\n",
       "│         'locals': False,                                                                         │\n",
       "│         'post_fit': False,                                                                       │\n",
       "│         'pre_fit': False,                                                                        │\n",
       "│         'server': True                                                                           │\n",
       "│     },                                                                                           │\n",
       "│     'exp': {                                                                                     │\n",
       "│         'device': 'cpu',                                                                         │\n",
       "│         'inmemory': True,                                                                        │\n",
       "│         'mode': 'federation',                                                                    │\n",
       "│         'seed': 42,                                                                              │\n",
       "│         'train': True                                                                            │\n",
       "│     },                                                                                           │\n",
       "│     'logger': {                                                                                  │\n",
       "│         '_target_': 'armlet.utils.log.ArmletLog',                                                │\n",
       "│         'json_log_dir': '../../../../outputs/tutorial/FL_mode'                                   │\n",
       "│     },                                                                                           │\n",
       "│     'method': {                                                                                  │\n",
       "│         '_target_': 'armlet.FL_pipeline.FL_algorithms.ArmletCentralizedFL',                      │\n",
       "│         'hyperparameters': {                                                                     │\n",
       "│             'client': {                                                                          │\n",
       "│                 'batch_size': 128,                                                               │\n",
       "│                 'local_epochs': 10,                                                              │\n",
       "│                 'loss': {                                                                        │\n",
       "│                     '_target_': 'torch.nn.BCELoss'                                               │\n",
       "│                 },                                                                               │\n",
       "│                 'optimizer': {                                                                   │\n",
       "│                     'lr': 0.001,                                                                 │\n",
       "│                     'name': 'SGD',                                                               │\n",
       "│                     'weight_decay': 0.01                                                         │\n",
       "│                 },                                                                               │\n",
       "│                 'scheduler': {                                                                   │\n",
       "│                     'gamma': 1,                                                                  │\n",
       "│                     'name': 'StepLR',                                                            │\n",
       "│                     'step_size': 1                                                               │\n",
       "│                 }                                                                                │\n",
       "│             },                                                                                   │\n",
       "│             'model': {                                                                           │\n",
       "│                 '_target_': 'armlet.utils.net.LogRegression',                                    │\n",
       "│                 'input_size': 99,                                                                │\n",
       "│                 'num_classes': 1                                                                 │\n",
       "│             },                                                                                   │\n",
       "│             'server': {                                                                          │\n",
       "│                 'loss': {                                                                        │\n",
       "│                     '_target_': 'torch.nn.BCELoss'                                               │\n",
       "│                 },                                                                               │\n",
       "│                 'time_to_accuracy_target': None,                                                 │\n",
       "│                 'weighted': True                                                                 │\n",
       "│             }                                                                                    │\n",
       "│         }                                                                                        │\n",
       "│     },                                                                                           │\n",
       "│     'paths': {                                                                                   │\n",
       "│         'data_dir': '../../../../datasets',                                                      │\n",
       "│         'log_dir': '../../../../logs',                                                           │\n",
       "│         'output_dir': '../../../../outputs/tutorial/FL_mode',                                    │\n",
       "│         'root_dir': '../../../../'                                                               │\n",
       "│     },                                                                                           │\n",
       "│     'protocol': {                                                                                │\n",
       "│         'eligible_perc': 1.0,                                                                    │\n",
       "│         'n_clients': 4,                                                                          │\n",
       "│         'n_rounds': 2                                                                            │\n",
       "│     },                                                                                           │\n",
       "│     'exp_id': '16016ed727174cd19719c05654849004'                                                 │\n",
       "│ }                                                                                                │\n",
       "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
       "
\n" ], "text/plain": [ "╭───────────────────────────────────────── Configuration ──────────────────────────────────────────╮\n", "│ \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'save'\u001b[0m: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'data'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'cleaning'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'name'\u001b[0m: \u001b[32m'default'\u001b[0m, │\n", "│ \u001b[32m'missing_values'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'_target_'\u001b[0m: │\n", "│ \u001b[32m'armlet.data.cleaning.missing_values.MissingValuesDataCleaningMethod'\u001b[0m │\n", "│ \u001b[1m}\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'dataset'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'_target_'\u001b[0m: \u001b[32m'armlet.data.datasets.load_Adult_dataset'\u001b[0m, │\n", "│ \u001b[32m'dataset_name'\u001b[0m: \u001b[32m'Adult'\u001b[0m, │\n", "│ \u001b[32m'path'\u001b[0m: \u001b[32m'../../../../datasets/Adult/raw_data'\u001b[0m, │\n", "│ \u001b[32m'sensitive_attributes'\u001b[0m: \u001b[1m[\u001b[0m │\n", "│ \u001b[32m'age'\u001b[0m, │\n", "│ \u001b[32m'gender'\u001b[0m, │\n", "│ \u001b[32m'race'\u001b[0m │\n", "│ \u001b[1m]\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'distribution'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'_target_'\u001b[0m: \u001b[32m'armlet.data.splitter.ArmletDataSplitter.iid'\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'others'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'client_split'\u001b[0m: \u001b[1;36m0.2\u001b[0m, │\n", "│ \u001b[32m'client_val_split'\u001b[0m: \u001b[1;36m0.0\u001b[0m, │\n", "│ \u001b[32m'keep_test'\u001b[0m: \u001b[3;91mFalse\u001b[0m, │\n", "│ \u001b[32m'sampling_perc'\u001b[0m: \u001b[1;36m1.0\u001b[0m, │\n", "│ \u001b[32m'server_split'\u001b[0m: \u001b[1;36m0.0\u001b[0m, │\n", "│ \u001b[32m'server_test'\u001b[0m: \u001b[3;91mFalse\u001b[0m, │\n", "│ \u001b[32m'server_test_union'\u001b[0m: \u001b[3;92mTrue\u001b[0m, │\n", "│ \u001b[32m'server_val_split'\u001b[0m: \u001b[1;36m0.0\u001b[0m, │\n", "│ \u001b[32m'uniform_test'\u001b[0m: \u001b[3;91mFalse\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'processing'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'one_hot_encoding'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'_apply_directly_to_subdata_'\u001b[0m: \u001b[3;91mFalse\u001b[0m, │\n", "│ \u001b[32m'_target_'\u001b[0m: \u001b[32m'armlet.data.processing.feature_encoding.one_hot_encoding_pipeline'\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'conversion_to_num'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'_apply_directly_to_subdata_'\u001b[0m: \u001b[3;92mTrue\u001b[0m, │\n", "│ \u001b[32m'_target_'\u001b[0m: │\n", "│ \u001b[32m'armlet.data.processing.format_conversion.convert_bool_and_cat_to_num'\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'normalization'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'_apply_directly_to_subdata_'\u001b[0m: \u001b[3;91mFalse\u001b[0m, │\n", "│ \u001b[32m'_target_'\u001b[0m: \u001b[32m'armlet.data.processing.normalization.normalization_pipeline'\u001b[0m, │\n", "│ \u001b[32m'cols_to_exclude'\u001b[0m: \u001b[1m[\u001b[0m │\n", "│ \u001b[32m'age'\u001b[0m, │\n", "│ \u001b[32m'gender'\u001b[0m, │\n", "│ \u001b[32m'race'\u001b[0m │\n", "│ \u001b[1m]\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'conversion_to_tensors'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'_apply_directly_to_subdata_'\u001b[0m: \u001b[3;92mTrue\u001b[0m, │\n", "│ \u001b[32m'_target_'\u001b[0m: │\n", "│ \u001b[32m'armlet.data.processing.format_conversion.convert_dataframes_to_tensors'\u001b[0m, │\n", "│ \u001b[32m'sensitive_attributes'\u001b[0m: \u001b[1m[\u001b[0m │\n", "│ \u001b[32m'age'\u001b[0m, │\n", "│ \u001b[32m'gender'\u001b[0m, │\n", "│ \u001b[32m'race'\u001b[0m │\n", "│ \u001b[1m]\u001b[0m │\n", "│ \u001b[1m}\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'seed'\u001b[0m: \u001b[1;36m42\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'eval'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'_target_'\u001b[0m: \u001b[32m'armlet.eval.evaluators.MultiCriteriaBinaryClassEval'\u001b[0m, │\n", "│ \u001b[32m'eval_every'\u001b[0m: \u001b[1;36m1\u001b[0m, │\n", "│ \u001b[32m'locals'\u001b[0m: \u001b[3;91mFalse\u001b[0m, │\n", "│ \u001b[32m'post_fit'\u001b[0m: \u001b[3;91mFalse\u001b[0m, │\n", "│ \u001b[32m'pre_fit'\u001b[0m: \u001b[3;91mFalse\u001b[0m, │\n", "│ \u001b[32m'server'\u001b[0m: \u001b[3;92mTrue\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'exp'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'device'\u001b[0m: \u001b[32m'cpu'\u001b[0m, │\n", "│ \u001b[32m'inmemory'\u001b[0m: \u001b[3;92mTrue\u001b[0m, │\n", "│ \u001b[32m'mode'\u001b[0m: \u001b[32m'federation'\u001b[0m, │\n", "│ \u001b[32m'seed'\u001b[0m: \u001b[1;36m42\u001b[0m, │\n", "│ \u001b[32m'train'\u001b[0m: \u001b[3;92mTrue\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'logger'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'_target_'\u001b[0m: \u001b[32m'armlet.utils.log.ArmletLog'\u001b[0m, │\n", "│ \u001b[32m'json_log_dir'\u001b[0m: \u001b[32m'../../../../outputs/tutorial/FL_mode'\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'method'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'_target_'\u001b[0m: \u001b[32m'armlet.FL_pipeline.FL_algorithms.ArmletCentralizedFL'\u001b[0m, │\n", "│ \u001b[32m'hyperparameters'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'client'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'batch_size'\u001b[0m: \u001b[1;36m128\u001b[0m, │\n", "│ \u001b[32m'local_epochs'\u001b[0m: \u001b[1;36m10\u001b[0m, │\n", "│ \u001b[32m'loss'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'_target_'\u001b[0m: \u001b[32m'torch.nn.BCELoss'\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'optimizer'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'lr'\u001b[0m: \u001b[1;36m0.001\u001b[0m, │\n", "│ \u001b[32m'name'\u001b[0m: \u001b[32m'SGD'\u001b[0m, │\n", "│ \u001b[32m'weight_decay'\u001b[0m: \u001b[1;36m0.01\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'scheduler'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'gamma'\u001b[0m: \u001b[1;36m1\u001b[0m, │\n", "│ \u001b[32m'name'\u001b[0m: \u001b[32m'StepLR'\u001b[0m, │\n", "│ \u001b[32m'step_size'\u001b[0m: \u001b[1;36m1\u001b[0m │\n", "│ \u001b[1m}\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'model'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'_target_'\u001b[0m: \u001b[32m'armlet.utils.net.LogRegression'\u001b[0m, │\n", "│ \u001b[32m'input_size'\u001b[0m: \u001b[1;36m99\u001b[0m, │\n", "│ \u001b[32m'num_classes'\u001b[0m: \u001b[1;36m1\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'server'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'loss'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'_target_'\u001b[0m: \u001b[32m'torch.nn.BCELoss'\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'time_to_accuracy_target'\u001b[0m: \u001b[3;35mNone\u001b[0m, │\n", "│ \u001b[32m'weighted'\u001b[0m: \u001b[3;92mTrue\u001b[0m │\n", "│ \u001b[1m}\u001b[0m │\n", "│ \u001b[1m}\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'paths'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'data_dir'\u001b[0m: \u001b[32m'../../../../datasets'\u001b[0m, │\n", "│ \u001b[32m'log_dir'\u001b[0m: \u001b[32m'../../../../logs'\u001b[0m, │\n", "│ \u001b[32m'output_dir'\u001b[0m: \u001b[32m'../../../../outputs/tutorial/FL_mode'\u001b[0m, │\n", "│ \u001b[32m'root_dir'\u001b[0m: \u001b[32m'../../../../'\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'protocol'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'eligible_perc'\u001b[0m: \u001b[1;36m1.0\u001b[0m, │\n", "│ \u001b[32m'n_clients'\u001b[0m: \u001b[1;36m4\u001b[0m, │\n", "│ \u001b[32m'n_rounds'\u001b[0m: \u001b[1;36m2\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'exp_id'\u001b[0m: \u001b[32m'16016ed727174cd19719c05654849004'\u001b[0m │\n", "│ \u001b[1m}\u001b[0m │\n", "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "log_name = f\"{fl_algo.__class__.__name__} [{fl_algo.id}]\"\n", "log = hydra.utils.instantiate(cfg.logger, name=log_name)\n", "log.init(**cfg, exp_id=fl_algo.id)\n", "fl_algo.set_callbacks([log])\n", "FlukeENV().set_logger(log)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run the FL experiment\n", "\n", "Finally, we call the `run()` function of the FL algorithm object and close the logger once the FL process is completed.\n", "As training progresses, you can follow the evaluation metrics that are computed at the end of each round.\n", "Moreover, by closing `ArmletLog`, all metrics are saved in a JSON file (`results.json`) that can be found in the `OUTPUT_DIR` folder (i.e., `ARMLET_DIR/outputs/tutorial/FL_mode` as defined in the beginning of the tutorial)." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
[UserWarning] /home/bnaline/anaconda3/envs/armlet/lib/python3.13/site-packages/rich/live.py:260\n",
       "install \"ipywidgets\" for Jupyter support\n",
       "
\n" ], "text/plain": [ "\u001b[93m[UserWarning]\u001b[0m \u001b[94m/home/bnaline/anaconda3/envs/armlet/lib/python3.13/site-packages/rich/live.py:260\u001b[0m\n", "\u001b[93minstall \"ipywidgets\" for Jupyter support\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
╭──────────────────────────────────────────── Round: 1 ────────────────────────────────────────────╮\n",
       "│ {                                                                                                │\n",
       "│     'post-fit': {                                                                                │\n",
       "│         'training_loss': 0.60185,                                                                │\n",
       "│         'support': 4,                                                                            │\n",
       "│         'round': 1                                                                               │\n",
       "│     },                                                                                           │\n",
       "│     'global': {                                                                                  │\n",
       "│         'accuracy': 0.78646,                                                                     │\n",
       "│         'precision': 0.56932,                                                                    │\n",
       "│         'recall': 0.56754,                                                                       │\n",
       "│         'f1': 0.56843,                                                                           │\n",
       "│         'loss': 0.5378,                                                                          │\n",
       "│         'round': 1                                                                               │\n",
       "│     },                                                                                           │\n",
       "│     'comm_cost': 800                                                                             │\n",
       "│ }                                                                                                │\n",
       "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
       "
\n" ], "text/plain": [ "╭──────────────────────────────────────────── Round: 1 ────────────────────────────────────────────╮\n", "│ \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'post-fit'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'training_loss'\u001b[0m: \u001b[1;36m0.60185\u001b[0m, │\n", "│ \u001b[32m'support'\u001b[0m: \u001b[1;36m4\u001b[0m, │\n", "│ \u001b[32m'round'\u001b[0m: \u001b[1;36m1\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'global'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'accuracy'\u001b[0m: \u001b[1;36m0.78646\u001b[0m, │\n", "│ \u001b[32m'precision'\u001b[0m: \u001b[1;36m0.56932\u001b[0m, │\n", "│ \u001b[32m'recall'\u001b[0m: \u001b[1;36m0.56754\u001b[0m, │\n", "│ \u001b[32m'f1'\u001b[0m: \u001b[1;36m0.56843\u001b[0m, │\n", "│ \u001b[32m'loss'\u001b[0m: \u001b[1;36m0.5378\u001b[0m, │\n", "│ \u001b[32m'round'\u001b[0m: \u001b[1;36m1\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'comm_cost'\u001b[0m: \u001b[1;36m800\u001b[0m │\n", "│ \u001b[1m}\u001b[0m │\n", "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
  Memory usage: 588.6 MB [6.04 %]\n",
       "
\n" ], "text/plain": [ " Memory usage: \u001b[1;36m588.6\u001b[0m MB \u001b[1m[\u001b[0m\u001b[1;36m6.04\u001b[0m %\u001b[1m]\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
╭──────────────────────────────────────────── Round: 2 ────────────────────────────────────────────╮\n",
       "│ {                                                                                                │\n",
       "│     'post-fit': {                                                                                │\n",
       "│         'training_loss': 0.49904,                                                                │\n",
       "│         'support': 4,                                                                            │\n",
       "│         'round': 2                                                                               │\n",
       "│     },                                                                                           │\n",
       "│     'global': {                                                                                  │\n",
       "│         'accuracy': 0.81407,                                                                     │\n",
       "│         'precision': 0.65054,                                                                    │\n",
       "│         'recall': 0.53946,                                                                       │\n",
       "│         'f1': 0.58981,                                                                           │\n",
       "│         'loss': 0.4717,                                                                          │\n",
       "│         'round': 2                                                                               │\n",
       "│     },                                                                                           │\n",
       "│     'comm_cost': 800                                                                             │\n",
       "│ }                                                                                                │\n",
       "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
       "
\n" ], "text/plain": [ "╭──────────────────────────────────────────── Round: 2 ────────────────────────────────────────────╮\n", "│ \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'post-fit'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'training_loss'\u001b[0m: \u001b[1;36m0.49904\u001b[0m, │\n", "│ \u001b[32m'support'\u001b[0m: \u001b[1;36m4\u001b[0m, │\n", "│ \u001b[32m'round'\u001b[0m: \u001b[1;36m2\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'global'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'accuracy'\u001b[0m: \u001b[1;36m0.81407\u001b[0m, │\n", "│ \u001b[32m'precision'\u001b[0m: \u001b[1;36m0.65054\u001b[0m, │\n", "│ \u001b[32m'recall'\u001b[0m: \u001b[1;36m0.53946\u001b[0m, │\n", "│ \u001b[32m'f1'\u001b[0m: \u001b[1;36m0.58981\u001b[0m, │\n", "│ \u001b[32m'loss'\u001b[0m: \u001b[1;36m0.4717\u001b[0m, │\n", "│ \u001b[32m'round'\u001b[0m: \u001b[1;36m2\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'comm_cost'\u001b[0m: \u001b[1;36m800\u001b[0m │\n", "│ \u001b[1m}\u001b[0m │\n", "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
  Memory usage: 592.1 MB [6.06 %]\n",
       "
\n" ], "text/plain": [ " Memory usage: \u001b[1;36m592.1\u001b[0m MB \u001b[1m[\u001b[0m\u001b[1;36m6.06\u001b[0m %\u001b[1m]\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "
╭────────────────────────────────────── Overall Performance ───────────────────────────────────────╮\n",
       "│ {                                                                                                │\n",
       "│     'post-fit': {                                                                                │\n",
       "│         'training_loss': 0.49904,                                                                │\n",
       "│         'support': 4,                                                                            │\n",
       "│         'round': 2                                                                               │\n",
       "│     },                                                                                           │\n",
       "│     'global': {                                                                                  │\n",
       "│         'accuracy': 0.81407,                                                                     │\n",
       "│         'precision': 0.65054,                                                                    │\n",
       "│         'recall': 0.53946,                                                                       │\n",
       "│         'f1': 0.58981,                                                                           │\n",
       "│         'loss': 0.4717,                                                                          │\n",
       "│         'round': 2                                                                               │\n",
       "│     },                                                                                           │\n",
       "│     'comm_costs': 2000                                                                           │\n",
       "│ }                                                                                                │\n",
       "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
       "
\n" ], "text/plain": [ "╭────────────────────────────────────── Overall Performance ───────────────────────────────────────╮\n", "│ \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'post-fit'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'training_loss'\u001b[0m: \u001b[1;36m0.49904\u001b[0m, │\n", "│ \u001b[32m'support'\u001b[0m: \u001b[1;36m4\u001b[0m, │\n", "│ \u001b[32m'round'\u001b[0m: \u001b[1;36m2\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'global'\u001b[0m: \u001b[1m{\u001b[0m │\n", "│ \u001b[32m'accuracy'\u001b[0m: \u001b[1;36m0.81407\u001b[0m, │\n", "│ \u001b[32m'precision'\u001b[0m: \u001b[1;36m0.65054\u001b[0m, │\n", "│ \u001b[32m'recall'\u001b[0m: \u001b[1;36m0.53946\u001b[0m, │\n", "│ \u001b[32m'f1'\u001b[0m: \u001b[1;36m0.58981\u001b[0m, │\n", "│ \u001b[32m'loss'\u001b[0m: \u001b[1;36m0.4717\u001b[0m, │\n", "│ \u001b[32m'round'\u001b[0m: \u001b[1;36m2\u001b[0m │\n", "│ \u001b[1m}\u001b[0m, │\n", "│ \u001b[32m'comm_costs'\u001b[0m: \u001b[1;36m2000\u001b[0m │\n", "│ \u001b[1m}\u001b[0m │\n", "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "try:\n", " fl_algo.run(cfg.protocol.n_rounds, cfg.protocol.eligible_perc)\n", "except Exception as e:\n", " log.log(f\"Error: {e}\")\n", " FlukeENV().force_close()\n", " FlukeENV.clear()\n", " log.close()\n", " FlukeENV().close_cache()\n", " raise e\n", "log.close()" ] } ], "metadata": { "kernelspec": { "display_name": "armlet", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.5" } }, "nbformat": 4, "nbformat_minor": 2 }