{ "cells": [ { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "#!/usr/bin/env python\n", "# -*- coding: utf-8 -*-\n", "# Python version: 3.6\n", "\n", "\n", "import os\n", "import copy\n", "import time\n", "import pickle\n", "import numpy as np\n", "from tqdm import tqdm\n", "\n", "import torch\n", "from tensorboardX import SummaryWriter\n", "\n", "from options import args_parser\n", "from update import LocalUpdate, test_inference\n", "from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar\n", "from utils import get_dataset, average_weights, exp_details" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Namespace(dataset='mnist', epochs=5, frac=0.1, gpu=None, iid=1, local_bs=10, local_ep=5, lr=0.01, model='mlp', num_classes=10, num_clusters=2, num_users=100, optimizer='sgd', unequal=0, verbose=1)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import argparse\n", "\n", "parser = argparse.ArgumentParser()\n", "# parser.add_argument('--name', '-n', default='foo', help='foo')\n", "parser.add_argument('--model', type=str, default='mlp', help='model name(mlp or cnn)')\n", "parser.add_argument('--dataset', type=str, default='mnist', help=\"name of dataset(mnist or cifar)\")\n", "parser.add_argument('--epochs', type=int, default=5, help=\"number of rounds of training(10)\")\n", "parser.add_argument('--num_users', type=int, default=100, help=\"number of users: K\")\n", "parser.add_argument('--num_classes', type=int, default=10, help=\"number of classes\")\n", "parser.add_argument('--gpu', default=None, help=\"To use cuda, set to a specific GPU ID. Default set to use CPU.\")\n", "parser.add_argument('--frac', type=float, default=0.1, help='the fraction of clients: C')\n", "parser.add_argument('--local_ep', type=int, default=5, help=\"the number of local epochs: E\")\n", "parser.add_argument('--local_bs', type=int, default=10, help=\"local batch size: B\")\n", "parser.add_argument('--lr', type=float, default=0.01, help='learning rate')\n", "parser.add_argument('--optimizer', type=str, default='sgd', help=\"type of optimizer\")\n", "parser.add_argument('--verbose', type=int, default=1, help='verbose')\n", "parser.add_argument('--iid', type=int, default=1, help='Default set to IID. Set to 0 for non-IID.')\n", "parser.add_argument('--unequal', type=int, default=0, help='whether to use unequal data splits for \\\n", " non-i.i.d setting (use 0 for equal splits)')\n", "# parser.add_argument('--seed', type=int, default=1, help='random seed')\n", "\n", "parser.add_argument('--num_clusters', type=int, default=2, help='verbose')\n", "\n", "\n", "# args = parser.parse_args([])\n", "args, _ = parser.parse_known_args()\n", "\n", "args\n", "# --dataset=mnist --gpu=0 --iid=0 --epochs=10" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "start time: 1571372605.758036\n" ] } ], "source": [ "# if __name__ == '__main__':\n", "start_time = time.time()\n", "print(\"start time: \", start_time)\n", "\n", "# define paths\n", "path_project = os.path.abspath('..')\n", "logger = SummaryWriter('../logs')\n", "\n", "# args = args_parser()\n", "# exp_details(args)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "if args.gpu:\n", " torch.cuda.set_device(args.gpu)\n", "device = 'cuda' if args.gpu else 'cpu'" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# load dataset and user groups\n", "train_dataset, test_dataset, user_groups = get_dataset(args)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "keys = user_groups.keys()\n", "keys" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "raw", "metadata": {}, "source": [ "l = np.arange(5,6)\n", "l\n", "{k:user_groups[k] for k in l if k in user_groups}" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "np.arange(cluster_size, dtype=int)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Each cluster size: 50.0\n", "Size of cluster 1: 50\n", "Size of cluster 2: 50\n" ] } ], "source": [ "# Splitting into clusters. FL groups\n", "cluster_size = args.num_users / args.num_clusters\n", "print(\"Each cluster size: \", cluster_size)\n", "\n", "# Cluster 1\n", "A1 = np.arange(cluster_size, dtype=int)\n", "user_groupsA = {k:user_groups[k] for k in A1 if k in user_groups}\n", "print(\"Size of cluster 1: \", len(user_groupsA))\n", "# Cluster 2\n", "B1 = np.arange(cluster_size, cluster_size+cluster_size, dtype=int)\n", "user_groupsB = {k:user_groups[k] for k in B1 if k in user_groups}\n", "print(\"Size of cluster 2: \", len(user_groupsB))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49])\n", "dict_keys([50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])\n" ] } ], "source": [ "# Check that clusters are all different\n", "print(user_groupsA.keys())\n", "print(user_groupsB.keys())" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# BUILD MODEL\n", "def build_model(args, train_dataset):\n", " if args.model == 'cnn':\n", " # Convolutional neural netork\n", " if args.dataset == 'mnist':\n", " global_model = CNNMnist(args=args)\n", " elif args.dataset == 'fmnist':\n", " global_model = CNNFashion_Mnist(args=args)\n", " elif args.dataset == 'cifar':\n", " global_model = CNNCifar(args=args)\n", "\n", " elif args.model == 'mlp':\n", " # Multi-layer preceptron\n", " img_size = train_dataset[0][0].shape\n", " len_in = 1\n", " for x in img_size:\n", " len_in *= x\n", " global_model = MLP(dim_in=len_in, dim_hidden=200,\n", " dim_out=args.num_classes)\n", " else:\n", " exit('Error: unrecognized model')\n", " \n", " return global_model" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "199210\n", "----------------------------------------------------------------\n", " Layer (type) Output Shape Param #\n", "================================================================\n", " Linear-1 [-1, 200] 157,000\n", " Dropout-2 [-1, 200] 0\n", " ReLU-3 [-1, 200] 0\n", " Linear-4 [-1, 200] 40,200\n", " Dropout-5 [-1, 200] 0\n", " ReLU-6 [-1, 200] 0\n", " Linear-7 [-1, 10] 2,010\n", " Softmax-8 [-1, 10] 0\n", "================================================================\n", "Total params: 199,210\n", "Trainable params: 199,210\n", "Non-trainable params: 0\n", "----------------------------------------------------------------\n", "Input size (MB): 0.00\n", "Forward/backward pass size (MB): 0.01\n", "Params size (MB): 0.76\n", "Estimated Total Size (MB): 0.77\n", "----------------------------------------------------------------\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# MODEL PARAM SUMMARY\n", "global_model = build_model(args, train_dataset)\n", "pytorch_total_params = sum(p.numel() for p in global_model.parameters())\n", "print(pytorch_total_params)\n", "\n", "from torchsummary import summary\n", "\n", "summary(global_model, (1, 28, 28))\n", "global_model.parameters()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MLP(\n", " (layer_input): Linear(in_features=784, out_features=200, bias=True)\n", " (relu): ReLU()\n", " (dropout): Dropout(p=0.5, inplace=False)\n", " (layer_hidden1): Linear(in_features=200, out_features=200, bias=True)\n", " (layer_hidden2): Linear(in_features=200, out_features=10, bias=True)\n", " (softmax): Softmax(dim=1)\n", ")\n" ] } ], "source": [ "# Set the model to train and send it to device.\n", "global_model.to(device)\n", "global_model.train()\n", "print(global_model)\n", "\n", "# copy weights\n", "global_weights = global_model.state_dict()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Set the cluster models to train and send it to device.\n", "cluster_modelA = build_model(args, train_dataset)\n", "cluster_modelA.to(device)\n", "cluster_modelA.train()\n", "# copy weights\n", "cluster_modelA_weights = cluster_modelA.state_dict()\n", "\n", "# Set the cluster models to train and send it to device.\n", "cluster_modelB = build_model(args, train_dataset)\n", "cluster_modelB.to(device)\n", "cluster_modelB.train()\n", "# copy weights\n", "cluster_modelB_weights = cluster_modelA.state_dict()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "raw", "metadata": {}, "source": [ "print(np.random.choice(range(args.num_users), m, replace=False))\n", "np.random.choice(A1, m, replace=False)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "len(A1)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# Defining the training function\n", "def fl_train(args, train_dataset, cluster_global_model, cluster, usergrp, epochs):\n", "# print(args)\n", " \n", "# # Set the model to train and send it to device.\n", "# cluster_global_model = build_model(args, train_dataset)\n", "# cluster_global_model.to(device)\n", "# cluster_global_model.train()\n", "\n", "# # copy weights\n", "# cluster_global_weights = cluster_global_model.state_dict()\n", " \n", " cluster_train_loss, cluster_train_accuracy = [], []\n", " cluster_val_acc_list, cluster_net_list = [], []\n", " cluster_cv_loss, cluster_cv_acc = [], []\n", " print_every = 2\n", " cluster_val_loss_pre, counter = 0, 0\n", "\n", " for epoch in tqdm(range(epochs)):\n", " cluster_local_weights, cluster_local_losses = [], []\n", " print(f'\\n | Cluster Training Round : {epoch+1} |\\n')\n", "\n", " cluster_global_model.train()\n", " m = max(int(args.frac * len(cluster)), 1)\n", " idxs_users = np.random.choice(cluster, m, replace=False)\n", "\n", "\n", " for idx in idxs_users:\n", " cluster_local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=usergrp[idx], logger=logger)\n", " cluster_w, cluster_loss = cluster_local_model.update_weights(model=copy.deepcopy(cluster_global_model), global_round=epoch)\n", " cluster_local_weights.append(copy.deepcopy(cluster_w))\n", " cluster_local_losses.append(copy.deepcopy(cluster_loss))\n", "\n", " # averaging global weights\n", " cluster_global_weights = average_weights(cluster_local_weights)\n", "\n", " # update global weights\n", " cluster_global_model.load_state_dict(cluster_global_weights)\n", "\n", " cluster_loss_avg = sum(cluster_local_losses) / len(cluster_local_losses)\n", " cluster_train_loss.append(cluster_loss_avg)\n", "\n", "# # Calculate avg training accuracy over all users at every epoch\n", "# cluster_list_acc, cluster_list_loss = [], []\n", "# cluster_global_model.eval()\n", " \n", "# for c in range(len(cluster)):\n", "# cluster_local_model = LocalUpdate(args=args, dataset=train_dataset,\n", "# idxs=usergrp[idx], logger=logger)\n", "# cluster_acc, cluster_loss = cluster_local_model.inference(model=cluster_global_model)\n", "# cluster_list_acc.append(cluster_acc)\n", "# cluster_list_loss.append(cluster_loss)\n", "# cluster_train_accuracy.append(sum(cluster_list_acc)/len(cluster_list_acc))\n", " \n", "# return cluster_global_model, cluster_train_loss #cluster_global_weights, cluster_loss_avg, cluster_train_accuracy \n", " return cluster_global_weights, cluster_loss_avg, idxs_users\n", " \n", "# A_model, A_weights, A_losses, A_trainacc = fl_train(args, train_dataset, cluster_modelA, A1, user_groupsA, 2) \n", "# A_model, A_trainloss = fl_train(args, train_dataset, cluster_modelA, A1, user_groupsA, 2)\n", "# A_weights, A_losses = fl_train(args, train_dataset, cluster_modelA, A1, user_groupsA, 2)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "# Defining the evaluation function\n", "def fl_eval(train_dataset, global_model, cluster, usergrp, idxs_users):\n", " cluster_list_acc, cluster_list_loss = [], []\n", "# for c in range(len(cluster)):\n", " for c in idxs_users:\n", " cluster_local_model = LocalUpdate(args=args, dataset=train_dataset,\n", " idxs=usergrp[c], logger=logger)\n", " acc, loss = cluster_local_model.inference(model=global_model)\n", " cluster_list_acc.append(acc)\n", " cluster_list_loss.append(loss)\n", "# train_accuracy.append(sum(list_acc)/len(list_acc))\n", "\n", " return cluster_list_acc, cluster_list_loss\n", "\n", "\n", "clusterA_list_acc, clusterA_list_loss = fl_eval(train_dataset, global_model, A1, user_groupsA, A_idxs_users)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Main training" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/5 [00:00" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# PLOTTING (optional)\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "matplotlib.use('Agg')\n", "\n", "# Plot Loss curve\n", "plt.figure()\n", "plt.title('Training Loss vs Communication rounds')\n", "plt.plot(range(len(train_loss)), train_loss, color='r')\n", "plt.ylabel('Training loss')\n", "plt.xlabel('Communication Rounds')\n", "# plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.\n", "# format(args.dataset, args.model, args.epochs, args.frac,\n", "# args.iid, args.local_ep, args.local_bs))\n", "plt.show" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Plot Average Accuracy vs Communication rounds\n", "plt.figure()\n", "plt.title('Average Accuracy vs Communication rounds')\n", "plt.plot(range(len(train_accuracy)), train_accuracy, color='k')\n", "plt.ylabel('Average Accuracy')\n", "plt.xlabel('Communication Rounds')\n", "# plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.\n", "# format(args.dataset, args.model, args.epochs, args.frac,\n", "# args.iid, args.local_ep, args.local_bs))\n", "plt.show" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "fl_pytorch", "language": "python", "name": "fl_pytorch" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }