Communication-Efficient Learning of Deep Networksfrom Decentralized Data[paper]

论文解析

开山之作

笔记:

https://zhuanlan.zhihu.com/p/445458807

https://zhuanlan.zhihu.com/p/370198194

https://www.zhihu.com/question/315844487/answer/1100845525

将训练数据分布在一定设备上,通过聚合本地更新来学习共享模型。仅传递模型的更新。介绍了联邦平均算法,该算法将每个客户端的局部随机梯度下降(SGD)与执行模型平均的服务器相结合。证明了它对不平衡和非IID数据分布的鲁棒性,并可以将训练分散数据的深度网络所需的通信轮数减少几个数量级。数据上的标签可以从用户交互中自然推断出来。

数据划分

那么可能在 Federated Learning 的情况下(假设正好有 10 个 local users),就是有个用户只有 0 的图片,另外一个用户只有 8 的图片,….。虽然总量没变,但是每个用户手里的数据是 biased 的,是 Non-IID(非独立同分布) 的,这会对优化造成很大的影响。

MNIST 有两种数据分割的模式:

IID 数据分割

MNIST 总共有 60,000 训练数据,有 100 个 local device,那么每个上面就是有 600 个 sample,抽取方式是随机。

Non-IID 数据分割

我们首先根据 label 来对数据进行一个排序,然后把 60,000 个 sample 分成 200 个小块,每个小块里面有 300 个 sample。注意这里排序的目的,是人为的让每个小块里面,都是同一种数据。然后每个 local device 将随机的得到两个小块。这样分割完之后,每个 local device 都只有两种数字。

代码:

https://github.com/shaoxiongji/federated-learning

https://zhuanlan.zhihu.com/p/359060612

https://zhuanlan.zhihu.com/p/438065296

  • python main_fed.py

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    
    > d:\python\pycharmprojects\fedavg\federated-learning\main_fed.py(65)<module>()
    -> img_size = dataset_train[0][0].shape
    (Pdb) c
    MLP(
      (layer_input): Linear(in_features=784, out_features=200, bias=True)
      (relu): ReLU()
      (dropout): Dropout(p=0.5, inplace=False)
      (layer_hidden): Linear(in_features=200, out_features=10, bias=True)
    )
    Round   0, Average loss 0.081
    Round   1, Average loss 0.093
    Round   2, Average loss 0.075
    Round   3, Average loss 0.072
    Round   4, Average loss 0.053
    Round   5, Average loss 0.065
    Round   6, Average loss 0.045
    Round   7, Average loss 0.053
    Round   8, Average loss 0.063
    Round   9, Average loss 0.075
    Training accuracy: 69.21
    Testing accuracy: 69.20
    
  • python main_fed.py –dataset mnist –iid –num_channels 1 –model cnn –epochs 50 –gpu 0

     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    
    D:\Python\PycharmProjects\FedAvg\federated-learning>python main_fed.py --dataset mnist --iid --num_channels 1 --model cnn --epochs 50 --gpu 0
    CNNMnist(
      (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
      (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
      (conv2_drop): Dropout2d(p=0.5, inplace=False)
      (fc1): Linear(in_features=320, out_features=50, bias=True)
      (fc2): Linear(in_features=50, out_features=10, bias=True)
    )
    Round   0, Average loss 1.630
    Round   1, Average loss 0.841
    Round   2, Average loss 0.600
    Round   3, Average loss 0.497
    Round   4, Average loss 0.428
    Round   5, Average loss 0.377
    Round   6, Average loss 0.337
    Round   7, Average loss 0.324
    Round   8, Average loss 0.307
    Round   9, Average loss 0.289
    Round  10, Average loss 0.273
    Round  11, Average loss 0.273
    Round  12, Average loss 0.268
    Round  13, Average loss 0.253
    Round  14, Average loss 0.248
    Round  15, Average loss 0.235
    Round  16, Average loss 0.239
    Round  17, Average loss 0.215
    Round  18, Average loss 0.209
    Round  19, Average loss 0.204
    Round  20, Average loss 0.206
    Round  21, Average loss 0.213
    Round  22, Average loss 0.193
    Round  23, Average loss 0.203
    Round  24, Average loss 0.204
    Round  25, Average loss 0.187
    Round  26, Average loss 0.198
    Round  27, Average loss 0.182
    Round  28, Average loss 0.180
    Round  29, Average loss 0.175
    Round  30, Average loss 0.173
    Round  31, Average loss 0.192
    Round  32, Average loss 0.174
    Round  33, Average loss 0.167
    Round  34, Average loss 0.178
    Round  35, Average loss 0.163
    Round  36, Average loss 0.150
    Round  37, Average loss 0.177
    Round  38, Average loss 0.172
    Round  39, Average loss 0.157
    Round  40, Average loss 0.154
    Round  41, Average loss 0.146
    Round  42, Average loss 0.144
    Round  43, Average loss 0.141
    Round  44, Average loss 0.145
    Round  45, Average loss 0.151
    Round  46, Average loss 0.136
    Round  47, Average loss 0.155
    Round  48, Average loss 0.150
    Round  49, Average loss 0.142
    Training accuracy: 98.42
    Testing accuracy: 98.48
    
  • python main_nn.py

      1
      2
      3
      4
      5
      6
      7
      8
      9
     10
     11
     12
     13
     14
     15
     16
     17
     18
     19
     20
     21
     22
     23
     24
     25
     26
     27
     28
     29
     30
     31
     32
     33
     34
     35
     36
     37
     38
     39
     40
     41
     42
     43
     44
     45
     46
     47
     48
     49
     50
     51
     52
     53
     54
     55
     56
     57
     58
     59
     60
     61
     62
     63
     64
     65
     66
     67
     68
     69
     70
     71
     72
     73
     74
     75
     76
     77
     78
     79
     80
     81
     82
     83
     84
     85
     86
     87
     88
     89
     90
     91
     92
     93
     94
     95
     96
     97
     98
     99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    
    D:\Python\PycharmProjects\FedAvg\federated-learning>python main_nn.py
    MLP(
      (layer_input): Linear(in_features=784, out_features=64, bias=True)
      (relu): ReLU()
      (dropout): Dropout(p=0.5, inplace=False)
      (layer_hidden): Linear(in_features=64, out_features=10, bias=True)
    )
    Train Epoch: 0 [0/60000 (0%)]   Loss: 2.382788
    Train Epoch: 0 [3200/60000 (5%)]        Loss: 1.377547
    Train Epoch: 0 [6400/60000 (11%)]       Loss: 1.040499
    Train Epoch: 0 [9600/60000 (16%)]       Loss: 0.801196
    Train Epoch: 0 [12800/60000 (21%)]      Loss: 0.649839
    Train Epoch: 0 [16000/60000 (27%)]      Loss: 0.579395
    Train Epoch: 0 [19200/60000 (32%)]      Loss: 0.774160
    Train Epoch: 0 [22400/60000 (37%)]      Loss: 0.453726
    Train Epoch: 0 [25600/60000 (43%)]      Loss: 0.668951
    Train Epoch: 0 [28800/60000 (48%)]      Loss: 0.530153
    Train Epoch: 0 [32000/60000 (53%)]      Loss: 0.632806
    Train Epoch: 0 [35200/60000 (59%)]      Loss: 0.585639
    Train Epoch: 0 [38400/60000 (64%)]      Loss: 0.397279
    Train Epoch: 0 [41600/60000 (69%)]      Loss: 0.631374
    Train Epoch: 0 [44800/60000 (75%)]      Loss: 0.422581
    Train Epoch: 0 [48000/60000 (80%)]      Loss: 0.423022
    Train Epoch: 0 [51200/60000 (85%)]      Loss: 0.381796
    Train Epoch: 0 [54400/60000 (91%)]      Loss: 0.760502
    Train Epoch: 0 [57600/60000 (96%)]      Loss: 0.506489
    
    Train loss: 0.6550482136608441
    Train Epoch: 1 [0/60000 (0%)]   Loss: 0.318453
    Train Epoch: 1 [3200/60000 (5%)]        Loss: 0.364851
    Train Epoch: 1 [6400/60000 (11%)]       Loss: 0.367190
    Train Epoch: 1 [9600/60000 (16%)]       Loss: 0.386182
    Train Epoch: 1 [12800/60000 (21%)]      Loss: 0.467612
    Train Epoch: 1 [16000/60000 (27%)]      Loss: 0.438048
    Train Epoch: 1 [19200/60000 (32%)]      Loss: 0.446184
    Train Epoch: 1 [22400/60000 (37%)]      Loss: 0.440449
    Train Epoch: 1 [25600/60000 (43%)]      Loss: 0.380581
    Train Epoch: 1 [28800/60000 (48%)]      Loss: 0.673389
    Train Epoch: 1 [32000/60000 (53%)]      Loss: 0.562654
    Train Epoch: 1 [35200/60000 (59%)]      Loss: 0.361667
    Train Epoch: 1 [38400/60000 (64%)]      Loss: 0.357094
    Train Epoch: 1 [41600/60000 (69%)]      Loss: 0.370722
    Train Epoch: 1 [44800/60000 (75%)]      Loss: 0.482470
    Train Epoch: 1 [48000/60000 (80%)]      Loss: 0.365681
    Train Epoch: 1 [51200/60000 (85%)]      Loss: 0.346680
    Train Epoch: 1 [54400/60000 (91%)]      Loss: 0.554997
    Train Epoch: 1 [57600/60000 (96%)]      Loss: 0.373696
    
    Train loss: 0.39213755361433983
    Train Epoch: 2 [0/60000 (0%)]   Loss: 0.269020
    Train Epoch: 2 [3200/60000 (5%)]        Loss: 0.424938
    Train Epoch: 2 [6400/60000 (11%)]       Loss: 0.501854
    Train Epoch: 2 [9600/60000 (16%)]       Loss: 0.444985
    Train Epoch: 2 [12800/60000 (21%)]      Loss: 0.395568
    Train Epoch: 2 [16000/60000 (27%)]      Loss: 0.240778
    Train Epoch: 2 [19200/60000 (32%)]      Loss: 0.279499
    Train Epoch: 2 [22400/60000 (37%)]      Loss: 0.126053
    Train Epoch: 2 [25600/60000 (43%)]      Loss: 0.307477
    Train Epoch: 2 [28800/60000 (48%)]      Loss: 0.387789
    Train Epoch: 2 [32000/60000 (53%)]      Loss: 0.261300
    Train Epoch: 2 [35200/60000 (59%)]      Loss: 0.183429
    Train Epoch: 2 [38400/60000 (64%)]      Loss: 0.237269
    Train Epoch: 2 [41600/60000 (69%)]      Loss: 0.151103
    Train Epoch: 2 [44800/60000 (75%)]      Loss: 0.258133
    Train Epoch: 2 [48000/60000 (80%)]      Loss: 0.192494
    Train Epoch: 2 [51200/60000 (85%)]      Loss: 0.302599
    Train Epoch: 2 [54400/60000 (91%)]      Loss: 0.331346
    Train Epoch: 2 [57600/60000 (96%)]      Loss: 0.295655
    
    Train loss: 0.3434041797526991
    Train Epoch: 3 [0/60000 (0%)]   Loss: 0.249048
    Train Epoch: 3 [3200/60000 (5%)]        Loss: 0.352690
    Train Epoch: 3 [6400/60000 (11%)]       Loss: 0.207267
    Train Epoch: 3 [9600/60000 (16%)]       Loss: 0.284693
    Train Epoch: 3 [12800/60000 (21%)]      Loss: 0.431070
    Train Epoch: 3 [16000/60000 (27%)]      Loss: 0.441519
    Train Epoch: 3 [19200/60000 (32%)]      Loss: 0.423299
    Train Epoch: 3 [22400/60000 (37%)]      Loss: 0.311703
    Train Epoch: 3 [25600/60000 (43%)]      Loss: 0.207815
    Train Epoch: 3 [28800/60000 (48%)]      Loss: 0.316640
    Train Epoch: 3 [32000/60000 (53%)]      Loss: 0.371618
    Train Epoch: 3 [35200/60000 (59%)]      Loss: 0.245004
    Train Epoch: 3 [38400/60000 (64%)]      Loss: 0.292778
    Train Epoch: 3 [41600/60000 (69%)]      Loss: 0.345176
    Train Epoch: 3 [44800/60000 (75%)]      Loss: 0.246085
    Train Epoch: 3 [48000/60000 (80%)]      Loss: 0.280428
    Train Epoch: 3 [51200/60000 (85%)]      Loss: 0.479609
    Train Epoch: 3 [54400/60000 (91%)]      Loss: 0.264962
    Train Epoch: 3 [57600/60000 (96%)]      Loss: 0.250311
    
    Train loss: 0.312608087558482
    Train Epoch: 4 [0/60000 (0%)]   Loss: 0.310402
    Train Epoch: 4 [3200/60000 (5%)]        Loss: 0.336302
    Train Epoch: 4 [6400/60000 (11%)]       Loss: 0.153214
    Train Epoch: 4 [9600/60000 (16%)]       Loss: 0.277832
    Train Epoch: 4 [12800/60000 (21%)]      Loss: 0.130915
    Train Epoch: 4 [16000/60000 (27%)]      Loss: 0.318101
    Train Epoch: 4 [19200/60000 (32%)]      Loss: 0.425775
    Train Epoch: 4 [22400/60000 (37%)]      Loss: 0.365087
    Train Epoch: 4 [25600/60000 (43%)]      Loss: 0.243406
    Train Epoch: 4 [28800/60000 (48%)]      Loss: 0.159906
    Train Epoch: 4 [32000/60000 (53%)]      Loss: 0.313192
    Train Epoch: 4 [35200/60000 (59%)]      Loss: 0.270967
    Train Epoch: 4 [38400/60000 (64%)]      Loss: 0.548697
    Train Epoch: 4 [41600/60000 (69%)]      Loss: 0.366319
    Train Epoch: 4 [44800/60000 (75%)]      Loss: 0.265159
    Train Epoch: 4 [48000/60000 (80%)]      Loss: 0.322257
    Train Epoch: 4 [51200/60000 (85%)]      Loss: 0.225261
    Train Epoch: 4 [54400/60000 (91%)]      Loss: 0.214001
    Train Epoch: 4 [57600/60000 (96%)]      Loss: 0.318979
    
    Train loss: 0.29650619769814424
    Train Epoch: 5 [0/60000 (0%)]   Loss: 0.310624
    Train Epoch: 5 [3200/60000 (5%)]        Loss: 0.341441
    Train Epoch: 5 [6400/60000 (11%)]       Loss: 0.239370
    Train Epoch: 5 [9600/60000 (16%)]       Loss: 0.278191
    Train Epoch: 5 [12800/60000 (21%)]      Loss: 0.287788
    Train Epoch: 5 [16000/60000 (27%)]      Loss: 0.173155
    Train Epoch: 5 [19200/60000 (32%)]      Loss: 0.240729
    Train Epoch: 5 [22400/60000 (37%)]      Loss: 0.161884
    Train Epoch: 5 [25600/60000 (43%)]      Loss: 0.232861
    Train Epoch: 5 [28800/60000 (48%)]      Loss: 0.295267
    Train Epoch: 5 [32000/60000 (53%)]      Loss: 0.227393
    Train Epoch: 5 [35200/60000 (59%)]      Loss: 0.294346
    Train Epoch: 5 [38400/60000 (64%)]      Loss: 0.202981
    Train Epoch: 5 [41600/60000 (69%)]      Loss: 0.251558
    Train Epoch: 5 [44800/60000 (75%)]      Loss: 0.244267
    Train Epoch: 5 [48000/60000 (80%)]      Loss: 0.383483
    Train Epoch: 5 [51200/60000 (85%)]      Loss: 0.298123
    Train Epoch: 5 [54400/60000 (91%)]      Loss: 0.308985
    Train Epoch: 5 [57600/60000 (96%)]      Loss: 0.473712
    
    Train loss: 0.2828775059058468
    Train Epoch: 6 [0/60000 (0%)]   Loss: 0.243838
    Train Epoch: 6 [3200/60000 (5%)]        Loss: 0.220742
    Train Epoch: 6 [6400/60000 (11%)]       Loss: 0.089961
    Train Epoch: 6 [9600/60000 (16%)]       Loss: 0.205233
    Train Epoch: 6 [12800/60000 (21%)]      Loss: 0.389221
    Train Epoch: 6 [16000/60000 (27%)]      Loss: 0.318107
    Train Epoch: 6 [19200/60000 (32%)]      Loss: 0.378846
    Train Epoch: 6 [22400/60000 (37%)]      Loss: 0.307503
    Train Epoch: 6 [25600/60000 (43%)]      Loss: 0.119020
    Train Epoch: 6 [28800/60000 (48%)]      Loss: 0.360318
    Train Epoch: 6 [32000/60000 (53%)]      Loss: 0.213947
    Train Epoch: 6 [35200/60000 (59%)]      Loss: 0.343688
    Train Epoch: 6 [38400/60000 (64%)]      Loss: 0.199512
    Train Epoch: 6 [41600/60000 (69%)]      Loss: 0.279146
    Train Epoch: 6 [44800/60000 (75%)]      Loss: 0.331717
    Train Epoch: 6 [48000/60000 (80%)]      Loss: 0.281463
    Train Epoch: 6 [51200/60000 (85%)]      Loss: 0.216211
    Train Epoch: 6 [54400/60000 (91%)]      Loss: 0.350641
    Train Epoch: 6 [57600/60000 (96%)]      Loss: 0.333605
    
    Train loss: 0.26997971273005517
    Train Epoch: 7 [0/60000 (0%)]   Loss: 0.385539
    Train Epoch: 7 [3200/60000 (5%)]        Loss: 0.350820
    Train Epoch: 7 [6400/60000 (11%)]       Loss: 0.166582
    Train Epoch: 7 [9600/60000 (16%)]       Loss: 0.358578
    Train Epoch: 7 [12800/60000 (21%)]      Loss: 0.277666
    Train Epoch: 7 [16000/60000 (27%)]      Loss: 0.141554
    Train Epoch: 7 [19200/60000 (32%)]      Loss: 0.280924
    Train Epoch: 7 [22400/60000 (37%)]      Loss: 0.248621
    Train Epoch: 7 [25600/60000 (43%)]      Loss: 0.340090
    Train Epoch: 7 [28800/60000 (48%)]      Loss: 0.310846
    Train Epoch: 7 [32000/60000 (53%)]      Loss: 0.326671
    Train Epoch: 7 [35200/60000 (59%)]      Loss: 0.232063
    Train Epoch: 7 [38400/60000 (64%)]      Loss: 0.254690
    Train Epoch: 7 [41600/60000 (69%)]      Loss: 0.566192
    Train Epoch: 7 [44800/60000 (75%)]      Loss: 0.225977
    Train Epoch: 7 [48000/60000 (80%)]      Loss: 0.504349
    Train Epoch: 7 [51200/60000 (85%)]      Loss: 0.269759
    Train Epoch: 7 [54400/60000 (91%)]      Loss: 0.156826
    Train Epoch: 7 [57600/60000 (96%)]      Loss: 0.184405
    
    Train loss: 0.2583980952451097
    Train Epoch: 8 [0/60000 (0%)]   Loss: 0.393707
    Train Epoch: 8 [3200/60000 (5%)]        Loss: 0.296386
    Train Epoch: 8 [6400/60000 (11%)]       Loss: 0.220201
    Train Epoch: 8 [9600/60000 (16%)]       Loss: 0.275108
    Train Epoch: 8 [12800/60000 (21%)]      Loss: 0.303884
    Train Epoch: 8 [16000/60000 (27%)]      Loss: 0.264850
    Train Epoch: 8 [19200/60000 (32%)]      Loss: 0.172773
    Train Epoch: 8 [22400/60000 (37%)]      Loss: 0.258944
    Train Epoch: 8 [25600/60000 (43%)]      Loss: 0.175459
    Train Epoch: 8 [28800/60000 (48%)]      Loss: 0.353995
    Train Epoch: 8 [32000/60000 (53%)]      Loss: 0.238126
    Train Epoch: 8 [35200/60000 (59%)]      Loss: 0.429915
    Train Epoch: 8 [38400/60000 (64%)]      Loss: 0.123635
    Train Epoch: 8 [41600/60000 (69%)]      Loss: 0.271555
    Train Epoch: 8 [44800/60000 (75%)]      Loss: 0.190366
    Train Epoch: 8 [48000/60000 (80%)]      Loss: 0.188578
    Train Epoch: 8 [51200/60000 (85%)]      Loss: 0.317133
    Train Epoch: 8 [54400/60000 (91%)]      Loss: 0.251918
    Train Epoch: 8 [57600/60000 (96%)]      Loss: 0.260449
    
    Train loss: 0.2549204707685818
    Train Epoch: 9 [0/60000 (0%)]   Loss: 0.259996
    Train Epoch: 9 [3200/60000 (5%)]        Loss: 0.113057
    Train Epoch: 9 [6400/60000 (11%)]       Loss: 0.137881
    Train Epoch: 9 [9600/60000 (16%)]       Loss: 0.180240
    Train Epoch: 9 [12800/60000 (21%)]      Loss: 0.278744
    Train Epoch: 9 [16000/60000 (27%)]      Loss: 0.293762
    Train Epoch: 9 [19200/60000 (32%)]      Loss: 0.385209
    Train Epoch: 9 [22400/60000 (37%)]      Loss: 0.225448
    Train Epoch: 9 [25600/60000 (43%)]      Loss: 0.207232
    Train Epoch: 9 [28800/60000 (48%)]      Loss: 0.151411
    Train Epoch: 9 [32000/60000 (53%)]      Loss: 0.144787
    Train Epoch: 9 [35200/60000 (59%)]      Loss: 0.428082
    Train Epoch: 9 [38400/60000 (64%)]      Loss: 0.221369
    Train Epoch: 9 [41600/60000 (69%)]      Loss: 0.220435
    Train Epoch: 9 [44800/60000 (75%)]      Loss: 0.206271
    Train Epoch: 9 [48000/60000 (80%)]      Loss: 0.275636
    Train Epoch: 9 [51200/60000 (85%)]      Loss: 0.222783
    Train Epoch: 9 [54400/60000 (91%)]      Loss: 0.409510
    Train Epoch: 9 [57600/60000 (96%)]      Loss: 0.325869
    
    Train loss: 0.24870677012751605
    test on 10000 samples
    
    Test set: Average loss: 0.0001
    Accuracy: 9568/10000 (95.68%)
    

pdb

pdb:一个断点工具。是python自带的一个包,为python提供了一种交互的源代码调试功能

主要特征是包括设置断点,单步调试,进入函数模式,查看当前代码,查找栈片段,动态改变变量的值等。

基本命令

  • 查看训练集的信息,数量,位置,格式,归一化
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
(Pdb) dataset_train
Dataset MNIST
    Number of datapoints: 60000
    Split: train
    Root Location: ../data/mnist/
    Transforms (if any): Compose(
                             ToTensor()
                             Normalize(mean=(0.1307,), std=(0.3081,))
                         )
    Target Transforms (if any): None
  • 打印一个图片 p:print。图片的像素是28*28的。数字大小是归一化后的数字大小,第一行就是白色的部分,对应于不同的灰度。tensor(5):这个图片的标签是5。
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
(Pdb) p dataset_train[0]
(tensor([[[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.3860, -0.1951,
          -0.1951, -0.1951,  1.1795,  1.3068,  1.8032, -0.0933,  1.6887,
           2.8215,  2.7197,  1.1923, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.0424,  0.0340,  0.7722,  1.5359,  1.7396,  2.7960,
           2.7960,  2.7960,  2.7960,  2.7960,  2.4396,  1.7650,  2.7960,
           2.6560,  2.0578,  0.3904, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
           0.1995,  2.6051,  2.7960,  2.7960,  2.7960,  2.7960,  2.7960,
           2.7960,  2.7960,  2.7960,  2.7706,  0.7595,  0.6195,  0.6195,
           0.2886,  0.0722, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.1951,  2.3633,  2.7960,  2.7960,  2.7960,  2.7960,  2.7960,
           2.0960,  1.8923,  2.7197,  2.6433, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242,  0.5940,  1.5614,  0.9377,  2.7960,  2.7960,  2.1851,
          -0.2842, -0.4242,  0.1231,  1.5359, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.2460, -0.4115,  1.5359,  2.7960,  0.7213,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242,  1.3450,  2.7960,  1.9942,
          -0.3988, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.2842,  1.9942,  2.7960,
           0.4668, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,  0.0213,  2.6433,
           2.4396,  1.6123,  0.9504, -0.4115, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,  0.6068,
           2.6306,  2.7960,  2.7960,  1.0904, -0.1060, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
           0.1486,  1.9432,  2.7960,  2.7960,  1.4850, -0.0806, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.2206,  0.7595,  2.7833,  2.7960,  1.9560, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242,  2.7451,  2.7960,  2.7451,  0.3904,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
           0.1613,  1.2305,  1.9051,  2.7960,  2.7960,  2.2105, -0.3988,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,  0.0722,  1.4596,
           2.4906,  2.7960,  2.7960,  2.7960,  2.7578,  1.8923, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.1187,  1.0268,  2.3887,  2.7960,
           2.7960,  2.7960,  2.7960,  2.1342,  0.5686, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.1315,  0.4159,  2.2869,  2.7960,  2.7960,  2.7960,
           2.7960,  2.0960,  0.6068, -0.3988, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.1951,
           1.7523,  2.3633,  2.7960,  2.7960,  2.7960,  2.7960,  2.0578,
           0.5940, -0.3097, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242,  0.2758,  1.7650,  2.4524,
           2.7960,  2.7960,  2.7960,  2.7960,  2.6815,  1.2686, -0.2842,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242,  1.3068,  2.7960,  2.7960,
           2.7960,  2.2742,  1.2941,  1.2559, -0.2206, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242]]]), tensor(5))
(Pdb) p dataset_train[0][1]
tensor(5)
  • 训练集分发给用户后,查看第二个用户的数据集。共600个数字,对应于图片在数据集中处于的位置。shape是对矩阵形状的描述,600行1列。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
(Pdb) p dict_users[1]
array([ 8097, 28228,  8102, 11346, 50756, 11344, 38547, 28218, 37053,
       28214, 28212, 50802, 28210, 38542,  8133, 11319, 53391, 53390,
       11339,  8078, 50753, 53434, 53473, 28268, 50695, 28266, 26340,
       50700, 26341, 11420, 50705,  8043, 37015, 50716, 37020, 53455,
       26348, 50720, 50722,  8058, 11394, 26355, 50733, 26357, 50736,
       11384, 11380, 50744, 50748, 53388, 50812, 50814, 37063, 50875,
       53343, 38509, 53340, 11230,  8226, 26431, 26434,  8229, 37103,
       28156, 28155, 28154, 50897, 50898, 50899, 50901, 26447, 11203,
       53323, 53322, 50913,  8251, 26454, 11193,  8257, 37120,  8212,
       28140,  8211, 11243, 11306, 53377,  8152, 11299, 11298,  8154,
       38530, 50826, 28200, 53373, 11286, 53367, 53366, 38527, 28189,
       11271, 28185, 50840, 37083, 37084, 11257,  8201,  8202, 11251,
       11250,  8204, 28169, 11242, 10472, 38359, 28021, 26817, 27880,
        8707, 26821, 27878, 37360, 10629, 27874, 51379, 10620, 10616,
       38171, 52938, 38170,  8727, 27864, 10603, 37367, 27860, 51398,
       51399,  8740,  8742, 26850,  8743, 52898, 27869, 38195, 27883,
       38197, 51340, 52971, 10688, 10687,  8670, 38216, 26779, 10680,
       26782, 27898, 27897, 10673, 26788, 38208, 52959, 52956, 52953,
       26794,  8691, 51356, 51357, 10657, 51359, 37347, 26800, 27888,
       52942, 52894, 51418, 10583, 10582, 38127, 10514, 52851, 38124,
       52841, 52840,  8810, 38120, 10502, 52839, 26899, 27810, 10497,
       52835, 52834, 26907, 52829, 37412, 51506,  8831, 38108, 38107,
       27797, 51517,  8833,  8834, 10477, 26888, 38221, 37403, 10526,
        8749, 27852, 51419,  8752, 52889, 10574, 38148, 10570, 10567,
       26863, 27844, 38143, 52874, 10554, 51435, 51443, 38140, 27834,
       51448, 52869, 26875, 51458, 10535,  8784, 37386, 10528, 26883,
        8793,  8474, 26770, 52975, 37241, 37243, 26646, 53072, 27989,
       38326, 26649, 51171, 38319, 26658,  8552, 37258, 38330, 53056,
        8564,  8565, 26669, 38307, 27957, 37274,  8573, 10825, 10824,
       51212, 51214, 51216, 10846, 53076, 51151, 37234, 38356, 51117,
       53121, 51118, 10931, 51119,  8480, 26613,  8490, 38354, 10920,
        8493, 28013, 37210, 53105, 37214, 26624, 51131, 26626, 26630,
       10898, 53089, 26632, 38339, 26635, 53088, 26638, 38288,  8582,
       26688, 26689, 53019,  5168,  3900, 33498, 58991, 32465, 35587,
       28153,  8529, 29974,   993, 37110, 32462, 30667,  1230,  8530,
       33600,  1197, 55846, 55832,  5123,  1123, 35552, 53233, 30007,
       35538, 59074, 59071, 56772, 28056, 32392, 33580, 53176, 33577,
       59044, 56779, 55819,  8351,  3963, 53162,  3962, 35554,  8348,
       56757, 32358, 32365,  1115,  8422, 59085, 28090,  1118, 53168,
        8418, 33586,  8423, 53177, 32367,  1131,  8394,  1148, 53206,
       59060, 32378,  3955, 53194, 33567,  5107,  8379, 56764,  8388,
        3954, 35000, 28069,  8382, 53199, 59041, 55825, 35001, 53178,
        1132, 53180, 30695, 37182, 33556, 37179,  3956,  5114, 59052,
        8402, 28063, 33558,  8373, 53187,  1155, 56768, 27961, 37155,
        3966, 34990, 59114, 59033, 34992, 35527,  8452, 35558,  8451,
        8319, 28035,  1080,  1081,  5090, 59112, 56744, 37200, 59028,
        1192,  8456,  5073, 33516,  1060, 56734,  5075, 56737, 30684,
        1084, 37201, 33596, 53135, 30031, 53273,  1069, 28110, 59116,
        1065, 56783,  8330,  8331, 37191,  1181, 37190,  4089,  5132,
       30711,  3971, 59102, 56750, 35008,  5128,  1174, 35006,  8430,
       53160, 32406, 28094, 37154, 32411, 37196,  8332,  8445, 28100,
        3926, 53140,  1085, 28038, 32413, 53143, 28099, 33535,  1095,
       35009, 28098, 32353, 53147, 59036,  8218, 37232, 32535,  4972,
       33646,   871, 55768, 59247, 33440, 59248,  8595, 58935,  8615,
        4971, 29968,   920,  8177,   870,  8652,   919,  3889,   891,
       52984,  8160,  1293,  8161,  4049,  4970,  5012,  8203,  8588,
        4064, 32496,  8589, 32530,  1259, 37330, 33456, 59233, 30737,
       33480, 56660, 30089,  8176, 37290, 55784, 35052,  8162,  8660,
       33658,  8633, 55931, 59289,  1287, 33654, 52990, 29957, 33462,
        8189, 53361, 58963,  8610, 33458, 53375, 35600, 56656, 27933,
        4976,  3885, 32540, 29961, 30639, 33457, 27920,  8185, 28184,
        4998,  4978, 27913, 37081,  4979, 58944, 56668,  5006,  5004,
       33467, 35596,  8648, 32543, 30649, 28187,   905, 37305,  8180,
       33441, 59266,  8602,  5189,  8156, 30742, 27924, 35048, 58965,
        8194, 35496, 32528, 52967,  5204,  3878, 53040, 32534, 55901,
        8578, 53394, 55786, 56834,   951, 59297, 53016, 53395, 33433,
        4070, 37337, 58953,   961, 28167,  8664], dtype=int64)
(Pdb) p dict_users[1].shape
(600,)
  • 默认数据是非独立同分布的
1
2
(Pdb) p args.iid
False
  • n:调试跳转下一部,打印img_size,大小是28*28
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
(Pdb) n
> d:\python\pycharmprojects\fedavg\federated-learning\main_fed.py(68)<module>()
-> if args.model == 'cnn' and args.dataset == 'cifar':
(Pdb) n
> d:\python\pycharmprojects\fedavg\federated-learning\main_fed.py(70)<module>()
-> elif args.model == 'cnn' and args.dataset == 'mnist':
(Pdb) n
> d:\python\pycharmprojects\fedavg\federated-learning\main_fed.py(72)<module>()
-> elif args.model == 'mlp':
(Pdb) p img_size
torch.Size([1, 28, 28])

代码解析

视频

nets中mlp的网络结构:

输入层是一维的,将28*28二维的拉成一维的

隐藏层定位为200个维度

输出层是10,0-9这十个数字

激励函数:

softmax:归一化的操作,将数值转化为百分比,哪个占比最高相当于我们训练的模型对于这个图片的判断是几。

训练过程: