{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Understanding the **FL mode** in ``armlet``\n", "\n", "This tutorial allows users to understand how the **FL mode** of **ARMLET** works (i.e., the default mode that serves to run federated learning experiments).\n", "It corresponds to the function that is called when running the command `armlet exp.mode=federation` (or just `armlet` as it is the default mode).\n", "\n", "## Prerequisites (environment configuration and installation)\n", "\n", "If you have not configured the environment and installed ``armlet`` yet, you have to:\n", "\n", "1. Install [conda](https://www.anaconda.com/docs/getting-started/main) for managing the environments and run the following commands:\n", "\n", "```bash\n", "conda create -n armlet python=3.13.5\n", "conda activate armlet\n", "```\n", "\n", "2. Install **``armlet``** using `pip`:\n", "\n", "```bash\n", "cd ARMLET_DIR\n", "pip install .\n", "```\n", "\n", "Then, fill the two following variables to detail the main paths and run the command." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "ARMLET_DIR = \"../../../../\"\n", "OUTPUT_DIR = os.path.join(ARMLET_DIR, \"outputs\", \"tutorial\", \"FL_mode\")\n", "\n", "if not os.path.exists(OUTPUT_DIR):\n", " os.makedirs(OUTPUT_DIR)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prepare the configuration\n", "\n", "Before calling the [run_federation()](https://sara-bouchenak.github.io/ARMLET/api/armlet.run.html#armlet.run.run_federation) main function (in `ARMLET_DIR/armlet/run.py`), **ARMLET** uses the [Hydra](https://hydra.cc/) framework to dynamically load the configurations. For further details about **ARMLET** configuration (such as the configuration groups and values), see [Configuration](https://sara-bouchenak.github.io/ARMLET/user_guide/config/index.html) documentation page.\n", "\n", "In this tutorial, we directly prepare the configuration object in Python by creating dictionaries and using [OmegaConf](https://omegaconf.readthedocs.io/en/2.3_branch/), the YAML based hierarchical configuration system used in Hydra.\n", "\n", "The first config dictionary to create is the `cfg_paths`, which contains the main paths." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from omegaconf import OmegaConf\n", "\n", "cfg_paths = OmegaConf.create({\n", " \"data_dir\": os.path.join(ARMLET_DIR, \"datasets\"),\n", " \"log_dir\": os.path.join(ARMLET_DIR, \"logs\"),\n", " \"output_dir\": OUTPUT_DIR,\n", " \"root_dir\": ARMLET_DIR,\n", "})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, we need to provide in `cfg_data` information related to the data pipeline (dataset, data cleaning methods, train, validation, and test set splitting, distribution for the data partitioning, many other data processing, and data seed).\n", "With the following configurations, we will:\n", "(a) load the Adult dataset;\n", "(b) perform an IID data partitioning between X clients (the number of clients will be detailed after);\n", "(c) split each client dataset into a training set (80% of the client data) and a test set (20%);\n", "(d) concatenate the union of the clients test sets to form the server test set;\n", "(e) clean all data by removing missing values;\n", "and (f) perform standard preprocessing for tabular data." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "cfg_data = OmegaConf.create({\n", " \"cleaning\": {\n", " \"name\": \"default\",\n", " \"missing_values\": {\"_target_\": \"armlet.data.cleaning.missing_values.MissingValuesDataCleaningMethod\"},\n", " },\n", " \"dataset\": {\n", " \"_target_\": \"armlet.data.datasets.load_Adult_dataset\",\n", " \"dataset_name\": \"Adult\",\n", " \"path\": os.path.join(cfg_paths[\"data_dir\"], \"Adult\", \"raw_data\"),\n", " \"sensitive_attributes\": ['age', 'gender', 'race'],\n", " },\n", " \"distribution\": {\"_target_\": \"armlet.data.splitter.ArmletDataSplitter.iid\"},\n", " \"others\": {\n", " \"client_split\": 0.2,\n", " \"client_val_split\": 0.0,\n", " \"keep_test\": False,\n", " \"sampling_perc\": 1.0,\n", " \"server_split\": 0.0,\n", " \"server_test\": False,\n", " \"server_test_union\": True,\n", " \"server_val_split\": 0.0,\n", " \"uniform_test\": False,\n", " },\n", " \"processing\": {\n", " \"one_hot_encoding\": {\n", " \"_apply_directly_to_subdata_\": False,\n", " \"_target_\": \"armlet.data.processing.feature_encoding.one_hot_encoding_pipeline\",\n", " },\n", " \"conversion_to_num\": {\n", " \"_apply_directly_to_subdata_\": True,\n", " \"_target_\": \"armlet.data.processing.format_conversion.convert_bool_and_cat_to_num\",\n", " },\n", " \"normalization\": {\n", " \"_apply_directly_to_subdata_\": False,\n", " \"_target_\": \"armlet.data.processing.normalization.normalization_pipeline\",\n", " \"cols_to_exclude\": ['age', 'gender', 'race'],\n", " },\n", " \"conversion_to_tensors\": {\n", " \"_apply_directly_to_subdata_\": True,\n", " \"_target_\": \"armlet.data.processing.format_conversion.convert_dataframes_to_tensors\",\n", " \"sensitive_attributes\": ['age', 'gender', 'race'],\n", " },\n", " },\n", " \"seed\": 42,\n", "})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`cfg_eval` contains the configuration values related to the evaluation.\n", "In this tutorial, we use the [armlet.eval.evaluators.MultiCriteriaBinaryClassEval](https://sara-bouchenak.github.io/ARMLET/api/armlet.eval.evaluators.html#armlet.eval.evaluators.MultiCriteriaBinaryClassEval) evaluation class and evaluate the global model at each round (on the server side with the server test set)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "cfg_eval = OmegaConf.create({\n", " \"_target_\": \"armlet.eval.evaluators.MultiCriteriaBinaryClassEval\",\n", " \"eval_every\": 1,\n", " \"locals\": False,\n", " \"post_fit\": False,\n", " \"pre_fit\": False,\n", " \"server\": True,\n", "})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`cfg_exp` details information about how to run the experiment.\n", "Here, we choose the federation mode, enable training, specify cpu as the device, and fix an experiment seed." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "cfg_exp = OmegaConf.create({\n", " \"device\": \"cpu\",\n", " \"inmemory\": True,\n", " \"mode\": \"federation\",\n", " \"seed\": 42,\n", " \"train\": True,\n", "})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`cfg_logger` is required to specify the logging class to be used during the experiment.\n", "In this tutorial, we use [armlet.utils.log.ArmletLog](https://sara-bouchenak.github.io/ARMLET/api/armlet.utils.log.html#armlet.utils.log.ArmletLog), the minimal logger provided by **ARMLET** (it only saves the results in a JSON file)." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "cfg_logger = OmegaConf.create({\n", " \"_target_\": \"armlet.utils.log.ArmletLog\",\n", " \"json_log_dir\": cfg_paths[\"output_dir\"],\n", "})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`cfg_method` is used to specify the FL algorithm, the ML model, the client hyperparameters, and the server behavior.\n", "Here, we use [armlet.FL_pipeline.FL_algorithms.ArmletCentralized](https://sara-bouchenak.github.io/ARMLET/api/armlet.FL_pipeline.FL_algorithms.html#armlet.FL_pipeline.FL_algorithms.ArmletCentralizedFL), i.e., the standard FedAvg algorithm adapted to **ARMLET** pipeline, and consider a logistic regression model.\n", "Moreover, at each round, each client performs 10 local epochs by using a SGD optimizer with a learning rate of 0.001 and a weight decay of 0.01." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "cfg_method = OmegaConf.create({\n", " \"_target_\": \"armlet.FL_pipeline.FL_algorithms.ArmletCentralizedFL\",\n", " \"hyperparameters\": {\n", " \"client\": {\n", " \"batch_size\": 128,\n", " \"local_epochs\": 10,\n", " \"loss\": {\"_target_\": \"torch.nn.BCELoss\"},\n", " \"optimizer\": {\"lr\": 0.001, \"name\": \"SGD\", \"weight_decay\": 0.01},\n", " \"scheduler\": {\"gamma\": 1, \"name\": \"StepLR\", \"step_size\": 1},\n", " },\n", " \"model\": {\n", " \"_target_\": \"armlet.utils.net.LogRegression\",\n", " \"input_size\": None, # Automatically adjusted after data loading\n", " \"num_classes\": None, # Automatically adjusted after data loading\n", " },\n", " \"server\": {\n", " \"loss\": {\"_target_\": \"torch.nn.BCELoss\"},\n", " \"time_to_accuracy_target\": None,\n", " \"weighted\": True,\n", " },\n", " },\n", "})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`cfg_protocol` details the protocol of the FL process.\n", "With these configuration values, we run a FL process with 4 clients, for a total of 2 rounds (as a toy example).\n", "During each round, all clients participate to training." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "cfg_protocol = OmegaConf.create({\n", " \"eligible_perc\": 1.0,\n", " \"n_clients\": 4,\n", " \"n_rounds\": 2,\n", "})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we compose each of the previous configuration dictionaries. " ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from armlet.utils.configs import ArmletConfiguration\n", "\n", "cfg = OmegaConf.create({\n", " \"data\": cfg_data,\n", " \"eval\": cfg_eval,\n", " \"exp\": cfg_exp,\n", " \"logger\": cfg_logger,\n", " \"method\": cfg_method,\n", " \"paths\": cfg_paths,\n", " \"protocol\": cfg_protocol,\n", " \"save\": {},\n", "})\n", "\n", "cfg = ArmletConfiguration(cfg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we are ready to investigate the `run_federation()` function, which is called when `exp.mode=federation`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Run the data pipeline\n", "\n", "First, **ARMLET** runs the data pipeline to compute the [DataSplitter](https://makgyver.github.io/fluke/fluke.data.html#fluke.data.DataSplitter), i.e., the Fluke's Python object that contains all data after processing (launching, data cleaning, train, validation, and test sets splitting, data partitioning, normalization, etc.).\n", "As the data pipeline is quite complex, we do not detail it in this tutorial. However, you can find some information about it in the following tutorial: [Understanding the **data pipeline** in `armlet`](https://sara-bouchenak.github.io/ARMLET/getting_started/tutorials/data_pipeline.html)." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from armlet.data import data_pipeline\n", "\n", "data_splitter, val_data = data_pipeline(cfg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Adjust the configurations and save them\n", "\n", "Once we have prepared data, we setup the Fluke environment and adjust some configuration values (which are related to data, such as the number of classes and the input size of the ML model).\n", "We can now save the configuration in a YAML file for allowing reproductibility." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "from fluke import FlukeENV\n", "\n", "FlukeENV().configure(cfg)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "input_size = data_splitter.data_container.clients_tr[0].tensors[0].shape[-1]\n", "cfg.method.hyperparameters.model.input_size = input_size\n", "if data_splitter.data_container.num_classes <= 2:\n", " cfg.method.hyperparameters.model.num_classes = 1 \n", "else:\n", " cfg.method.hyperparameters.model.num_classes = data_splitter.data_container.num_classes" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "import yaml\n", "\n", "config_path = os.path.join(cfg.paths.output_dir, \"config.yaml\")\n", "cfg_to_save = cfg.to_dict()\n", "cfg_to_save[\"paths\"][\"output_dir\"] = \"${hydra:runtime.output_dir}\"\n", "if \"json_log_dir\" in cfg_to_save[\"logger\"].keys():\n", " cfg_to_save[\"logger\"][\"json_log_dir\"] = \"${paths.output_dir}\"\n", "yaml.dump(cfg_to_save, open(config_path, \"w\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Instantiate the FL algorithm, evaluator, and logger\n", "\n", "Then, we instantiate the FL algorithm we specified in the `cfg_method` dictionary by using Hydra.\n", "We do the same for the evaluator and the logger with the `cfg_eval` and `cfg_logger` dictionaries, and pass them to the Fluke environment." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "import hydra\n", "\n", "fl_algo = hydra.utils.instantiate(\n", " cfg.method,\n", " n_clients=cfg.protocol.n_clients,\n", " data_splitter=data_splitter,\n", " val_data=val_data,\n", " _convert_=\"all\",\n", " _recursive_=False,\n", ")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "evaluator = hydra.utils.instantiate(\n", " cfg.eval.exclude(\"locals\", \"post_fit\", \"pre_fit\", \"server\"), \n", " n_classes=data_splitter.data_container.num_classes,\n", " sensitive_attributes=cfg.data.dataset.sensitive_attributes,\n", ")\n", "FlukeENV().set_evaluator(evaluator)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
╭───────────────────────────────────────── Configuration ──────────────────────────────────────────╮\n",
"│ { │\n",
"│ 'save': {}, │\n",
"│ 'data': { │\n",
"│ 'cleaning': { │\n",
"│ 'name': 'default', │\n",
"│ 'missing_values': { │\n",
"│ '_target_': │\n",
"│ 'armlet.data.cleaning.missing_values.MissingValuesDataCleaningMethod' │\n",
"│ } │\n",
"│ }, │\n",
"│ 'dataset': { │\n",
"│ '_target_': 'armlet.data.datasets.load_Adult_dataset', │\n",
"│ 'dataset_name': 'Adult', │\n",
"│ 'path': '../../../../datasets/Adult/raw_data', │\n",
"│ 'sensitive_attributes': [ │\n",
"│ 'age', │\n",
"│ 'gender', │\n",
"│ 'race' │\n",
"│ ] │\n",
"│ }, │\n",
"│ 'distribution': { │\n",
"│ '_target_': 'armlet.data.splitter.ArmletDataSplitter.iid' │\n",
"│ }, │\n",
"│ 'others': { │\n",
"│ 'client_split': 0.2, │\n",
"│ 'client_val_split': 0.0, │\n",
"│ 'keep_test': False, │\n",
"│ 'sampling_perc': 1.0, │\n",
"│ 'server_split': 0.0, │\n",
"│ 'server_test': False, │\n",
"│ 'server_test_union': True, │\n",
"│ 'server_val_split': 0.0, │\n",
"│ 'uniform_test': False │\n",
"│ }, │\n",
"│ 'processing': { │\n",
"│ 'one_hot_encoding': { │\n",
"│ '_apply_directly_to_subdata_': False, │\n",
"│ '_target_': 'armlet.data.processing.feature_encoding.one_hot_encoding_pipeline' │\n",
"│ }, │\n",
"│ 'conversion_to_num': { │\n",
"│ '_apply_directly_to_subdata_': True, │\n",
"│ '_target_': │\n",
"│ 'armlet.data.processing.format_conversion.convert_bool_and_cat_to_num' │\n",
"│ }, │\n",
"│ 'normalization': { │\n",
"│ '_apply_directly_to_subdata_': False, │\n",
"│ '_target_': 'armlet.data.processing.normalization.normalization_pipeline', │\n",
"│ 'cols_to_exclude': [ │\n",
"│ 'age', │\n",
"│ 'gender', │\n",
"│ 'race' │\n",
"│ ] │\n",
"│ }, │\n",
"│ 'conversion_to_tensors': { │\n",
"│ '_apply_directly_to_subdata_': True, │\n",
"│ '_target_': │\n",
"│ 'armlet.data.processing.format_conversion.convert_dataframes_to_tensors', │\n",
"│ 'sensitive_attributes': [ │\n",
"│ 'age', │\n",
"│ 'gender', │\n",
"│ 'race' │\n",
"│ ] │\n",
"│ } │\n",
"│ }, │\n",
"│ 'seed': 42 │\n",
"│ }, │\n",
"│ 'eval': { │\n",
"│ '_target_': 'armlet.eval.evaluators.MultiCriteriaBinaryClassEval', │\n",
"│ 'eval_every': 1, │\n",
"│ 'locals': False, │\n",
"│ 'post_fit': False, │\n",
"│ 'pre_fit': False, │\n",
"│ 'server': True │\n",
"│ }, │\n",
"│ 'exp': { │\n",
"│ 'device': 'cpu', │\n",
"│ 'inmemory': True, │\n",
"│ 'mode': 'federation', │\n",
"│ 'seed': 42, │\n",
"│ 'train': True │\n",
"│ }, │\n",
"│ 'logger': { │\n",
"│ '_target_': 'armlet.utils.log.ArmletLog', │\n",
"│ 'json_log_dir': '../../../../outputs/tutorial/FL_mode' │\n",
"│ }, │\n",
"│ 'method': { │\n",
"│ '_target_': 'armlet.FL_pipeline.FL_algorithms.ArmletCentralizedFL', │\n",
"│ 'hyperparameters': { │\n",
"│ 'client': { │\n",
"│ 'batch_size': 128, │\n",
"│ 'local_epochs': 10, │\n",
"│ 'loss': { │\n",
"│ '_target_': 'torch.nn.BCELoss' │\n",
"│ }, │\n",
"│ 'optimizer': { │\n",
"│ 'lr': 0.001, │\n",
"│ 'name': 'SGD', │\n",
"│ 'weight_decay': 0.01 │\n",
"│ }, │\n",
"│ 'scheduler': { │\n",
"│ 'gamma': 1, │\n",
"│ 'name': 'StepLR', │\n",
"│ 'step_size': 1 │\n",
"│ } │\n",
"│ }, │\n",
"│ 'model': { │\n",
"│ '_target_': 'armlet.utils.net.LogRegression', │\n",
"│ 'input_size': 99, │\n",
"│ 'num_classes': 1 │\n",
"│ }, │\n",
"│ 'server': { │\n",
"│ 'loss': { │\n",
"│ '_target_': 'torch.nn.BCELoss' │\n",
"│ }, │\n",
"│ 'time_to_accuracy_target': None, │\n",
"│ 'weighted': True │\n",
"│ } │\n",
"│ } │\n",
"│ }, │\n",
"│ 'paths': { │\n",
"│ 'data_dir': '../../../../datasets', │\n",
"│ 'log_dir': '../../../../logs', │\n",
"│ 'output_dir': '../../../../outputs/tutorial/FL_mode', │\n",
"│ 'root_dir': '../../../../' │\n",
"│ }, │\n",
"│ 'protocol': { │\n",
"│ 'eligible_perc': 1.0, │\n",
"│ 'n_clients': 4, │\n",
"│ 'n_rounds': 2 │\n",
"│ }, │\n",
"│ 'exp_id': '16016ed727174cd19719c05654849004' │\n",
"│ } │\n",
"╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
"\n"
],
"text/plain": [
"╭───────────────────────────────────────── Configuration ──────────────────────────────────────────╮\n",
"│ \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'save'\u001b[0m: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'data'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'cleaning'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'name'\u001b[0m: \u001b[32m'default'\u001b[0m, │\n",
"│ \u001b[32m'missing_values'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'_target_'\u001b[0m: │\n",
"│ \u001b[32m'armlet.data.cleaning.missing_values.MissingValuesDataCleaningMethod'\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'dataset'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'_target_'\u001b[0m: \u001b[32m'armlet.data.datasets.load_Adult_dataset'\u001b[0m, │\n",
"│ \u001b[32m'dataset_name'\u001b[0m: \u001b[32m'Adult'\u001b[0m, │\n",
"│ \u001b[32m'path'\u001b[0m: \u001b[32m'../../../../datasets/Adult/raw_data'\u001b[0m, │\n",
"│ \u001b[32m'sensitive_attributes'\u001b[0m: \u001b[1m[\u001b[0m │\n",
"│ \u001b[32m'age'\u001b[0m, │\n",
"│ \u001b[32m'gender'\u001b[0m, │\n",
"│ \u001b[32m'race'\u001b[0m │\n",
"│ \u001b[1m]\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'distribution'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'_target_'\u001b[0m: \u001b[32m'armlet.data.splitter.ArmletDataSplitter.iid'\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'others'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'client_split'\u001b[0m: \u001b[1;36m0.2\u001b[0m, │\n",
"│ \u001b[32m'client_val_split'\u001b[0m: \u001b[1;36m0.0\u001b[0m, │\n",
"│ \u001b[32m'keep_test'\u001b[0m: \u001b[3;91mFalse\u001b[0m, │\n",
"│ \u001b[32m'sampling_perc'\u001b[0m: \u001b[1;36m1.0\u001b[0m, │\n",
"│ \u001b[32m'server_split'\u001b[0m: \u001b[1;36m0.0\u001b[0m, │\n",
"│ \u001b[32m'server_test'\u001b[0m: \u001b[3;91mFalse\u001b[0m, │\n",
"│ \u001b[32m'server_test_union'\u001b[0m: \u001b[3;92mTrue\u001b[0m, │\n",
"│ \u001b[32m'server_val_split'\u001b[0m: \u001b[1;36m0.0\u001b[0m, │\n",
"│ \u001b[32m'uniform_test'\u001b[0m: \u001b[3;91mFalse\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'processing'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'one_hot_encoding'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'_apply_directly_to_subdata_'\u001b[0m: \u001b[3;91mFalse\u001b[0m, │\n",
"│ \u001b[32m'_target_'\u001b[0m: \u001b[32m'armlet.data.processing.feature_encoding.one_hot_encoding_pipeline'\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'conversion_to_num'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'_apply_directly_to_subdata_'\u001b[0m: \u001b[3;92mTrue\u001b[0m, │\n",
"│ \u001b[32m'_target_'\u001b[0m: │\n",
"│ \u001b[32m'armlet.data.processing.format_conversion.convert_bool_and_cat_to_num'\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'normalization'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'_apply_directly_to_subdata_'\u001b[0m: \u001b[3;91mFalse\u001b[0m, │\n",
"│ \u001b[32m'_target_'\u001b[0m: \u001b[32m'armlet.data.processing.normalization.normalization_pipeline'\u001b[0m, │\n",
"│ \u001b[32m'cols_to_exclude'\u001b[0m: \u001b[1m[\u001b[0m │\n",
"│ \u001b[32m'age'\u001b[0m, │\n",
"│ \u001b[32m'gender'\u001b[0m, │\n",
"│ \u001b[32m'race'\u001b[0m │\n",
"│ \u001b[1m]\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'conversion_to_tensors'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'_apply_directly_to_subdata_'\u001b[0m: \u001b[3;92mTrue\u001b[0m, │\n",
"│ \u001b[32m'_target_'\u001b[0m: │\n",
"│ \u001b[32m'armlet.data.processing.format_conversion.convert_dataframes_to_tensors'\u001b[0m, │\n",
"│ \u001b[32m'sensitive_attributes'\u001b[0m: \u001b[1m[\u001b[0m │\n",
"│ \u001b[32m'age'\u001b[0m, │\n",
"│ \u001b[32m'gender'\u001b[0m, │\n",
"│ \u001b[32m'race'\u001b[0m │\n",
"│ \u001b[1m]\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'seed'\u001b[0m: \u001b[1;36m42\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'eval'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'_target_'\u001b[0m: \u001b[32m'armlet.eval.evaluators.MultiCriteriaBinaryClassEval'\u001b[0m, │\n",
"│ \u001b[32m'eval_every'\u001b[0m: \u001b[1;36m1\u001b[0m, │\n",
"│ \u001b[32m'locals'\u001b[0m: \u001b[3;91mFalse\u001b[0m, │\n",
"│ \u001b[32m'post_fit'\u001b[0m: \u001b[3;91mFalse\u001b[0m, │\n",
"│ \u001b[32m'pre_fit'\u001b[0m: \u001b[3;91mFalse\u001b[0m, │\n",
"│ \u001b[32m'server'\u001b[0m: \u001b[3;92mTrue\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'exp'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'device'\u001b[0m: \u001b[32m'cpu'\u001b[0m, │\n",
"│ \u001b[32m'inmemory'\u001b[0m: \u001b[3;92mTrue\u001b[0m, │\n",
"│ \u001b[32m'mode'\u001b[0m: \u001b[32m'federation'\u001b[0m, │\n",
"│ \u001b[32m'seed'\u001b[0m: \u001b[1;36m42\u001b[0m, │\n",
"│ \u001b[32m'train'\u001b[0m: \u001b[3;92mTrue\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'logger'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'_target_'\u001b[0m: \u001b[32m'armlet.utils.log.ArmletLog'\u001b[0m, │\n",
"│ \u001b[32m'json_log_dir'\u001b[0m: \u001b[32m'../../../../outputs/tutorial/FL_mode'\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'method'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'_target_'\u001b[0m: \u001b[32m'armlet.FL_pipeline.FL_algorithms.ArmletCentralizedFL'\u001b[0m, │\n",
"│ \u001b[32m'hyperparameters'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'client'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'batch_size'\u001b[0m: \u001b[1;36m128\u001b[0m, │\n",
"│ \u001b[32m'local_epochs'\u001b[0m: \u001b[1;36m10\u001b[0m, │\n",
"│ \u001b[32m'loss'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'_target_'\u001b[0m: \u001b[32m'torch.nn.BCELoss'\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'optimizer'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'lr'\u001b[0m: \u001b[1;36m0.001\u001b[0m, │\n",
"│ \u001b[32m'name'\u001b[0m: \u001b[32m'SGD'\u001b[0m, │\n",
"│ \u001b[32m'weight_decay'\u001b[0m: \u001b[1;36m0.01\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'scheduler'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'gamma'\u001b[0m: \u001b[1;36m1\u001b[0m, │\n",
"│ \u001b[32m'name'\u001b[0m: \u001b[32m'StepLR'\u001b[0m, │\n",
"│ \u001b[32m'step_size'\u001b[0m: \u001b[1;36m1\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'model'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'_target_'\u001b[0m: \u001b[32m'armlet.utils.net.LogRegression'\u001b[0m, │\n",
"│ \u001b[32m'input_size'\u001b[0m: \u001b[1;36m99\u001b[0m, │\n",
"│ \u001b[32m'num_classes'\u001b[0m: \u001b[1;36m1\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'server'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'loss'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'_target_'\u001b[0m: \u001b[32m'torch.nn.BCELoss'\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'time_to_accuracy_target'\u001b[0m: \u001b[3;35mNone\u001b[0m, │\n",
"│ \u001b[32m'weighted'\u001b[0m: \u001b[3;92mTrue\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'paths'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'data_dir'\u001b[0m: \u001b[32m'../../../../datasets'\u001b[0m, │\n",
"│ \u001b[32m'log_dir'\u001b[0m: \u001b[32m'../../../../logs'\u001b[0m, │\n",
"│ \u001b[32m'output_dir'\u001b[0m: \u001b[32m'../../../../outputs/tutorial/FL_mode'\u001b[0m, │\n",
"│ \u001b[32m'root_dir'\u001b[0m: \u001b[32m'../../../../'\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'protocol'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'eligible_perc'\u001b[0m: \u001b[1;36m1.0\u001b[0m, │\n",
"│ \u001b[32m'n_clients'\u001b[0m: \u001b[1;36m4\u001b[0m, │\n",
"│ \u001b[32m'n_rounds'\u001b[0m: \u001b[1;36m2\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'exp_id'\u001b[0m: \u001b[32m'16016ed727174cd19719c05654849004'\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m │\n",
"╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"log_name = f\"{fl_algo.__class__.__name__} [{fl_algo.id}]\"\n",
"log = hydra.utils.instantiate(cfg.logger, name=log_name)\n",
"log.init(**cfg, exp_id=fl_algo.id)\n",
"fl_algo.set_callbacks([log])\n",
"FlukeENV().set_logger(log)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run the FL experiment\n",
"\n",
"Finally, we call the `run()` function of the FL algorithm object and close the logger once the FL process is completed.\n",
"As training progresses, you can follow the evaluation metrics that are computed at the end of each round.\n",
"Moreover, by closing `ArmletLog`, all metrics are saved in a JSON file (`results.json`) that can be found in the `OUTPUT_DIR` folder (i.e., `ARMLET_DIR/outputs/tutorial/FL_mode` as defined in the beginning of the tutorial)."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"[UserWarning] /home/bnaline/anaconda3/envs/armlet/lib/python3.13/site-packages/rich/live.py:260\n", "install \"ipywidgets\" for Jupyter support\n", "\n" ], "text/plain": [ "\u001b[93m[UserWarning]\u001b[0m \u001b[94m/home/bnaline/anaconda3/envs/armlet/lib/python3.13/site-packages/rich/live.py:260\u001b[0m\n", "\u001b[93minstall \"ipywidgets\" for Jupyter support\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
╭──────────────────────────────────────────── Round: 1 ────────────────────────────────────────────╮\n",
"│ { │\n",
"│ 'post-fit': { │\n",
"│ 'training_loss': 0.60185, │\n",
"│ 'support': 4, │\n",
"│ 'round': 1 │\n",
"│ }, │\n",
"│ 'global': { │\n",
"│ 'accuracy': 0.78646, │\n",
"│ 'precision': 0.56932, │\n",
"│ 'recall': 0.56754, │\n",
"│ 'f1': 0.56843, │\n",
"│ 'loss': 0.5378, │\n",
"│ 'round': 1 │\n",
"│ }, │\n",
"│ 'comm_cost': 800 │\n",
"│ } │\n",
"╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
"\n"
],
"text/plain": [
"╭──────────────────────────────────────────── Round: 1 ────────────────────────────────────────────╮\n",
"│ \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'post-fit'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'training_loss'\u001b[0m: \u001b[1;36m0.60185\u001b[0m, │\n",
"│ \u001b[32m'support'\u001b[0m: \u001b[1;36m4\u001b[0m, │\n",
"│ \u001b[32m'round'\u001b[0m: \u001b[1;36m1\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'global'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'accuracy'\u001b[0m: \u001b[1;36m0.78646\u001b[0m, │\n",
"│ \u001b[32m'precision'\u001b[0m: \u001b[1;36m0.56932\u001b[0m, │\n",
"│ \u001b[32m'recall'\u001b[0m: \u001b[1;36m0.56754\u001b[0m, │\n",
"│ \u001b[32m'f1'\u001b[0m: \u001b[1;36m0.56843\u001b[0m, │\n",
"│ \u001b[32m'loss'\u001b[0m: \u001b[1;36m0.5378\u001b[0m, │\n",
"│ \u001b[32m'round'\u001b[0m: \u001b[1;36m1\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'comm_cost'\u001b[0m: \u001b[1;36m800\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m │\n",
"╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Memory usage: 588.6 MB [6.04 %]\n", "\n" ], "text/plain": [ " Memory usage: \u001b[1;36m588.6\u001b[0m MB \u001b[1m[\u001b[0m\u001b[1;36m6.04\u001b[0m %\u001b[1m]\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
╭──────────────────────────────────────────── Round: 2 ────────────────────────────────────────────╮\n",
"│ { │\n",
"│ 'post-fit': { │\n",
"│ 'training_loss': 0.49904, │\n",
"│ 'support': 4, │\n",
"│ 'round': 2 │\n",
"│ }, │\n",
"│ 'global': { │\n",
"│ 'accuracy': 0.81407, │\n",
"│ 'precision': 0.65054, │\n",
"│ 'recall': 0.53946, │\n",
"│ 'f1': 0.58981, │\n",
"│ 'loss': 0.4717, │\n",
"│ 'round': 2 │\n",
"│ }, │\n",
"│ 'comm_cost': 800 │\n",
"│ } │\n",
"╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
"\n"
],
"text/plain": [
"╭──────────────────────────────────────────── Round: 2 ────────────────────────────────────────────╮\n",
"│ \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'post-fit'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'training_loss'\u001b[0m: \u001b[1;36m0.49904\u001b[0m, │\n",
"│ \u001b[32m'support'\u001b[0m: \u001b[1;36m4\u001b[0m, │\n",
"│ \u001b[32m'round'\u001b[0m: \u001b[1;36m2\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'global'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'accuracy'\u001b[0m: \u001b[1;36m0.81407\u001b[0m, │\n",
"│ \u001b[32m'precision'\u001b[0m: \u001b[1;36m0.65054\u001b[0m, │\n",
"│ \u001b[32m'recall'\u001b[0m: \u001b[1;36m0.53946\u001b[0m, │\n",
"│ \u001b[32m'f1'\u001b[0m: \u001b[1;36m0.58981\u001b[0m, │\n",
"│ \u001b[32m'loss'\u001b[0m: \u001b[1;36m0.4717\u001b[0m, │\n",
"│ \u001b[32m'round'\u001b[0m: \u001b[1;36m2\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'comm_cost'\u001b[0m: \u001b[1;36m800\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m │\n",
"╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"Memory usage: 592.1 MB [6.06 %]\n", "\n" ], "text/plain": [ " Memory usage: \u001b[1;36m592.1\u001b[0m MB \u001b[1m[\u001b[0m\u001b[1;36m6.06\u001b[0m %\u001b[1m]\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n" ], "text/plain": [] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
╭────────────────────────────────────── Overall Performance ───────────────────────────────────────╮\n",
"│ { │\n",
"│ 'post-fit': { │\n",
"│ 'training_loss': 0.49904, │\n",
"│ 'support': 4, │\n",
"│ 'round': 2 │\n",
"│ }, │\n",
"│ 'global': { │\n",
"│ 'accuracy': 0.81407, │\n",
"│ 'precision': 0.65054, │\n",
"│ 'recall': 0.53946, │\n",
"│ 'f1': 0.58981, │\n",
"│ 'loss': 0.4717, │\n",
"│ 'round': 2 │\n",
"│ }, │\n",
"│ 'comm_costs': 2000 │\n",
"│ } │\n",
"╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
"\n"
],
"text/plain": [
"╭────────────────────────────────────── Overall Performance ───────────────────────────────────────╮\n",
"│ \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'post-fit'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'training_loss'\u001b[0m: \u001b[1;36m0.49904\u001b[0m, │\n",
"│ \u001b[32m'support'\u001b[0m: \u001b[1;36m4\u001b[0m, │\n",
"│ \u001b[32m'round'\u001b[0m: \u001b[1;36m2\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'global'\u001b[0m: \u001b[1m{\u001b[0m │\n",
"│ \u001b[32m'accuracy'\u001b[0m: \u001b[1;36m0.81407\u001b[0m, │\n",
"│ \u001b[32m'precision'\u001b[0m: \u001b[1;36m0.65054\u001b[0m, │\n",
"│ \u001b[32m'recall'\u001b[0m: \u001b[1;36m0.53946\u001b[0m, │\n",
"│ \u001b[32m'f1'\u001b[0m: \u001b[1;36m0.58981\u001b[0m, │\n",
"│ \u001b[32m'loss'\u001b[0m: \u001b[1;36m0.4717\u001b[0m, │\n",
"│ \u001b[32m'round'\u001b[0m: \u001b[1;36m2\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m, │\n",
"│ \u001b[32m'comm_costs'\u001b[0m: \u001b[1;36m2000\u001b[0m │\n",
"│ \u001b[1m}\u001b[0m │\n",
"╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"try:\n",
" fl_algo.run(cfg.protocol.n_rounds, cfg.protocol.eligible_perc)\n",
"except Exception as e:\n",
" log.log(f\"Error: {e}\")\n",
" FlukeENV().force_close()\n",
" FlukeENV.clear()\n",
" log.close()\n",
" FlukeENV().close_cache()\n",
" raise e\n",
"log.close()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "armlet",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}