视频

PFLlib_code

fedbn

ditto

fedala

fedb

PFLLIB

数据集处理

dataset等三个文件夹mark directory as sources root–import下列的文件时才可以识别到

system

​ flscore:

​ client

​ server

​ optimizers

​ until:

generate_mnist.py:生成数据集mnist在dirpath下

rawdata:原始数据集60000trian10000test

test、train:20个npz文件

jason串格式化ctrl+alt+l(所有代码格式化都可用,和网易云冲突但是)


For the label skew scenario :标签偏移

In non-IID scenario, 2 situations exist. The first one is the pathological non-IID scenario, the second one is practical non-IID scenario.

第一种是病理性非 IID 场景,第二种是实际的非 IID 场景。例如,在病理性非 IID 场景中,每个客户机上的数据只包含特定数量的标签(可能只有2个标签) ,尽管所有客户机上的数据包含10个标签,比如 MNIST 数据集。在实际的非内部 ID 场景中,使用了狄利克雷分布。

划分数据的不同场景

1、python generate_mnist.py iid balance - # for iid and balanced scenario

IID,平衡–每个client包含所有类,且每个client分到的当前类的样本数相同

2、python generate_mnist.py iid - - # for iid and unbalanced scenario

IID,不平衡场景–每个client包含所有类,但数据量当前类对应的

3、python generate_mnist.py noniid - pat # for pathological noniid and unbalanced scenario

4、python generate_mnist.py noniid - dir # for practical noniid and unbalanced scenario

“balance” 通常表示在训练过程中平衡各个参与方(如客户端)的贡献或资源分配。这可以包括确保每个参与方有足够的训练样本、计算资源或者对模型更新的贡献,以实现公平性和效率性。

联邦学习的不平衡场景可能包括以下情况: 1. 数据分布不均:参与方(如客户端)拥有的数据量差异很大,有些参与方拥有的样本数量非常少,而有些参与方拥有的样本数量非常多。 2. 计算资源不均:参与方的计算能力不同,有些参与方的设备性能较低,而有些参与方的设备性能较高。 3. 数据标签不均:在分类任务中,不同类别的样本数量差异很大,导致一些类别的数据在训练过程中被较少地考虑。 这些不平衡的场景可能会对联邦学习的模型训练和整体性能产生影响,需要采取相应的策略来解决。

1、iid and balanced scenario:将每个类(10)划分到所有客户端,将当前类对应的数据平均分给每个client(20)

2、iid and unbalanced scenario:


下载数据,将数据的训练集测试集揉在一起,然后划分到每个client,然后分为训练集和测试集。然后将划分的参数和每个client的数据情况保存到json串中。

pfl每个client都有测试集,揉和后划分之后再区分–保证客户端的训练集和测试集的分布一致,有利于pfl的准确率

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# idxs:7000个数据对应的下标数组
(Pdb) p idxs
array([    0,     1,     2, ..., 69997, 69998, 69999])
(Pdb) p len(idxs)
70000

# idx_for_each_class 对应了数据的每个类别的下标
(Pdb) p idx_for_each_class
[array([    1,    21,    34, ..., 69964, 69983, 69993]), array([    3,     6,     8, ..., 69978, 69984, 69994]), array([    5,    16,    25, ..., 69980, 69985, 69995]), array([    7,
  10,    12, ..., 69975, 69986, 69996]), array([    2,     9,    20, ..., 69977, 69987, 69997]), array([    0,    11,    35, ..., 69982, 69988, 69998]), array([   13,    18,    32, ...
, 69981, 69989, 69999]), array([   15,    29,    38, ..., 69968, 69979, 69990]), array([   17,    31,    41, ..., 69959, 69967, 69991]), array([    4,    19,    22, ..., 69945, 69973,
69992])]
(Pdb) p len(idx_for_each_class)
10
# 每个client的种类数量--平衡都是10,不平衡默认2
(Pdb) p class_num_per_client
[10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]
# 对于每一个class类别i,选择应该划分该类别的client
# 选取client的比例=客户端种类的采样比例,不平衡len=2,选两个客户端
(Pdb) p selected_clients
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
# 类别i的图片数量
(Pdb) p num_all_samples
6903
# 类别i被划分的客户端数目
(Pdb) p num_selected_clients
20
# 每个类别被分到的类别i对应的数量
(Pdb) p num_per
345.15
# 每个客户端对应类别i的图片数量,前n-1个取整num_per,剩余的图片全放在第n个client
# 不平衡的话,以num_per为中间值高斯分布
(Pdb) p num_samples
[345, 345, 345, 345, 345, 345, 345, 345, 345, 345, 345, 345, 345, 345, 345, 345, 345, 345, 345, 348]

# 对于当前class,每个client分到的数据下标
(Pdb) p dataidx_map
{0: array([   1,   21,   34,   37,   51,   56,   63,   68,   69,   75,   81,
         88,   95,  108,  114,  118,  119,  121,  156,  169,  192,  206,
        209,  210,  216,  229,  232,  234,  246,  249,  260,  283,  293,
        296,  303,  320,  326,  359,  399,  427,  429,  435,  440,  451,
        453,  458,  462,  464,  473,  489,  519,  524,  526,  527,  542,
        577,  582,  596,  603,  612,  633,  639,  656,  662,  666,  667,
        668,  669,  689,  702,  709,  712,  733,  743,  745,  776,  781,
        787,  790,  818,  825,  849,  859,  860,  869,  872,  889,  903,
        927,  943,  949,  952,  957,  965,  979,  984,  997, 1000, 1015,
       1018, 1028, 1029, 1040, 1046, 1049, 1076, 1078, 1090, 1093, 1102,
       1107, 1128, 1137, 1152, 1168, 1179, 1195, 1209, 1268, 1304, 1310,
       1346, 1349, 1359, 1363, 1367, 1368, 1371, 1372, 1377, 1386, 1387,
       1403, 1423, 1443, 1454, 1471, 1479, 1489, 1495, 1501, 1502, 1512,
       1517, 1530, 1532, 1571, 1578, 1590, 1596, 1600, 1605, 1606, 1625,
       1626, 1645, 1664, 1678, 1682, 1701, 1709, 1712, 1723, 1725, 1729,
       1742, 1769, 1771, 1775, 1796, 1797, 1798, 1819, 1828, 1837, 1857,
       1868, 1876, 1877, 1897, 1904, 1907, 1916, 1926, 1927, 1930, 1956,
       1963, 1969, 1995, 1999, 2009, 2051, 2058, 2066, 2079, 2081, 2082,
       2084, 2100, 2101, 2112, 2121, 2144, 2147, 2160, 2191, 2192, 2195,
       2218, 2245, 2253, 2257, 2269, 2278, 2298, 2310, 2325, 2327, 2333,
       2340, 2352, 2353, 2373, 2396, 2403, 2404, 2411, 2435, 2436, 2440,
       2473, 2483, 2493, 2500, 2525, 2526, 2528, 2532, 2538, 2539, 2557,
       2561, 2581, 2582, 2584, 2586, 2597, 2606, 2615, 2617, 2621, 2624,
       2629, 2642, 2709, 2718, 2729, 2736, 2745, 2746, 2765, 2770, 2782,
       2806, 2817, 2826, 2839, 2854, 2864, 2873, 2882, 2890, 2897, 2899,
       2904, 2914, 2919, 2935, 2944, 2952, 2955, 2974, 2975, 2996, 3001,
       3012, 3015, 3016, 3021, 3024, 3035, 3049, 3067, 3106, 3107, 3128,
       3131, 3135, 3143, 3160, 3175, 3195, 3198, 3213, 3231, 3241, 3247,
       3248, 3259, 3262, 3269, 3286, 3309, 3328, 3337, 3367, 3369, 3376,
       3377, 3391, 3396, 3409, 3410, 3429, 3434, 3441, 3443, 3461, 3479,
       3490, 3492, 3514, 3516, 3529, 3534, 3541, 3562, 3565, 3568, 3585,
       3590, 3603, 3610, 3612])}
{0: array([   1,   21,   34,   37,   51,   56,   63,   68,   69,   75,   81,
         88,   95,  108,  114,  118,  119,  121,  156,  169,  192,  206,
        209,  210,  216,  229,  232,  234,  246,  249,  260,  283,  293,
        296,  303,  320,  326,  359,  399,  427,  429,  435,  440,  451,
        453,  458,  462,  464,  473,  489,  519,  524,  526,  527,  542,
        577,  582,  596,  603,  612,  633,  639,  656,  662,  666,  667,
        668,  669,  689,  702,  709,  712,  733,  743,  745,  776,  781,
        787,  790,  818,  825,  849,  859,  860,  869,  872,  889,  903,
        927,  943,  949,  952,  957,  965,  979,  984,  997, 1000, 1015,
       1018, 1028, 1029, 1040, 1046, 1049, 1076, 1078, 1090, 1093, 1102,
       1107, 1128, 1137, 1152, 1168, 1179, 1195, 1209, 1268, 1304, 1310,
       1346, 1349, 1359, 1363, 1367, 1368, 1371, 1372, 1377, 1386, 1387,
       1403, 1423, 1443, 1454, 1471, 1479, 1489, 1495, 1501, 1502, 1512,
       1517, 1530, 1532, 1571, 1578, 1590, 1596, 1600, 1605, 1606, 1625,
       1626, 1645, 1664, 1678, 1682, 1701, 1709, 1712, 1723, 1725, 1729,
       1742, 1769, 1771, 1775, 1796, 1797, 1798, 1819, 1828, 1837, 1857,
       1868, 1876, 1877, 1897, 1904, 1907, 1916, 1926, 1927, 1930, 1956,
       1963, 1969, 1995, 1999, 2009, 2051, 2058, 2066, 2079, 2081, 2082,
       2084, 2100, 2101, 2112, 2121, 2144, 2147, 2160, 2191, 2192, 2195,
       2218, 2245, 2253, 2257, 2269, 2278, 2298, 2310, 2325, 2327, 2333,
       2340, 2352, 2353, 2373, 2396, 2403, 2404, 2411, 2435, 2436, 2440,
       2473, 2483, 2493, 2500, 2525, 2526, 2528, 2532, 2538, 2539, 2557,
       2561, 2581, 2582, 2584, 2586, 2597, 2606, 2615, 2617, 2621, 2624,
       2629, 2642, 2709, 2718, 2729, 2736, 2745, 2746, 2765, 2770, 2782,
       2806, 2817, 2826, 2839, 2854, 2864, 2873, 2882, 2890, 2897, 2899,
       2904, 2914, 2919, 2935, 2944, 2952, 2955, 2974, 2975, 2996, 3001,
       3012, 3015, 3016, 3021, 3024, 3035, 3049, 3067, 3106, 3107, 3128,
       3131, 3135, 3143, 3160, 3175, 3195, 3198, 3213, 3231, 3241, 3247,
       3248, 3259, 3262, 3269, 3286, 3309, 3328, 3337, 3367, 3369, 3376,
       3377, 3391, 3396, 3409, 3410, 3429, 3434, 3441, 3443, 3461, 3479,
       3490, 3492, 3514, 3516, 3529, 3534, 3541, 3562, 3565, 3568, 3585,
       3590, 3603, 3610, 3612]), 1: array([3622, 3632, 3661, 3667, 3673, 3677, 3691, 3693, 3698, 3702, 3709,
       3715, 3720, 3734, 3743, 3776, 3777, 3778, 3806, 3809, 3829, 3849,
       3876, 3881, 3882, 3884, 3888, 3896, 3906, 3929, 3937, 3943, 3953,
       3964, 4002, 4008, 4009, 4023, 4027, 4044, 4047, 4067, 4082, 4106,
       4108, 4115, 4123, 4142, 4145, 4159, 4181, 4201, 4203, 4216, 4218,
       4220, 4239, 4245, 4265, 4270, 4279, 4284, 4288, 4289, 4310, 4316,
       4321, 4325, 4340, 4343, 4344, 4347, 4356, 4374, 4389, 4415, 4444,
       4453, 4455, 4465, 4485, 4488, 4496, 4500, 4505, 4527, 4532, 4535,
       4539, 4556, 4563, 4565, 4588, 4597, 4607, 4608, 4624, 4629, 4642,
       4653, 4655, 4656, 4682, 4686, 4688, 4713, 4718, 4729, 4738, 4744,
       4748, 4749, 4756, 4773, 4776, 4793, 4804, 4849, 4852, 4854, 4855,
       4870, 4889, 4892, 4906, 4911, 4918, 4926, 4931, 4951, 4962, 4981,
       4985, 4994, 5007, 5010, 5015, 5019, 5048, 5052, 5053, 5072, 5074,
       5082, 5083, 5096, 5108, 5115, 5120, 5131, 5133, 5144, 5147, 5154,
       5167, 5187, 5192, 5194, 5196, 5202, 5203, 5228, 5244, 5246, 5249,
       5266, 5268, 5272, 5275, 5285, 5286, 5288, 5298, 5317, 5319, 5321,
       5327, 5330, 5340, 5342, 5344, 5353, 5368, 5373, 5393, 5398, 5461,
       5462, 5464, 5468, 5469, 5470, 5475, 5487, 5488, 5498, 5502, 5505,
       5512, 5514, 5525, 5541, 5560, 5563, 5583, 5585, 5604, 5608, 5612,
       5615, 5622, 5627, 5641, 5646, 5662, 5665, 5674, 5690, 5692, 5694,
       5697, 5698, 5711, 5729, 5730, 5747, 5767, 5773, 5801, 5808, 5809,
       5830, 5844, 5881, 5883, 5884, 5886, 5888, 5905, 5921, 5927, 5929,
       5932, 5938, 5961, 5978, 5981, 6001, 6008, 6023, 6034, 6045, 6065,
       6076, 6094, 6097, 6103, 6113, 6117, 6133, 6137, 6146, 6179, 6207,
       6227, 6228, 6245, 6279, 6292, 6296, 6302, 6316, 6320, 6325, 6326,
       6330, 6331, 6332, 6337, 6341, 6357, 6367, 6394, 6400, 6406, 6409,
       6421, 6422, 6441, 6444, 6446, 6461, 6468, 6496, 6502, 6538, 6591,
       6593, 6603, 6607, 6611, 6614, 6615, 6619, 6629, 6630, 6645, 6648,
       6663, 6676, 6681, 6684, 6700, 6704, 6708, 6723, 6730, 6737, 6738,
       6757, 6758, 6763, 6768, 6791, 6802, 6819, 6823, 6832, 6847, 6857,
       6875, 6883, 6917, 6923, 6928, 6935, 6958, 6960, 6967, 6968, 6970,
       6972, 6973, 6979, 6987])}
(Pdb) p dataidx_map
{0: array([   1,   21,   34, ..., 3510, 3511, 3519]), 1: array([3622, 3632, 3661, ..., 6964, 6969, 6993]), 2: array([ 6991,  7003,  7004, ..., 10576, 10577, 10588]), 3: array([10276, 1
0283, 10323, ..., 13897, 13898, 13904]), 4: array([13799, 13807, 13808, ..., 17317, 17319, 17323]), 5: array([17360, 17365, 17376, ..., 20554, 20562, 20567]), 6: array([20816, 20827, 2
0840, ..., 24185, 24193, 24194]), 7: array([24419, 24429, 24440, ..., 27815, 27817, 27839]), 8: array([28102, 28105, 28108, ..., 31115, 31120, 31128]), 9: array([31425, 31430, 31445, .
.., 34489, 34491, 34496]), 10: array([35192, 35209, 35217, ..., 37867, 37873, 37877]), 11: array([38729, 38744, 38746, ..., 41502, 41503, 41510]), 12: array([41987, 41989, 41999, ...,
44882, 44884, 44890]), 13: array([45544, 45557, 45577, ..., 48616, 48618, 48633]), 14: array([49065, 49075, 49085, ..., 52232, 52236, 52239]), 15: array([52393, 52396, 52397, ..., 5585
9, 55884, 55889]), 16: array([55890, 55892, 55907, ..., 59458, 59464, 59486]), 17: array([59418, 59420, 59421, ..., 62983, 63001, 63005]), 18: array([63210, 63215, 63219, ..., 66214, 6
6221, 66228]), 19: array([66633, 66643, 66651, ..., 69945, 69973, 69992])}
# 划分好之后
(Pdb) p dataidx_map[0]
array([   1,   21,   34, ..., 3510, 3511, 3519])
(Pdb) p len(dataidx_map[client])
3495
(Pdb) p y[client]
array([0, 0, 0, ..., 9, 9, 9], dtype=int64)
(Pdb) p y[client].shape
(3495,)

# 从给定的数组中提取出唯一元素,对提取的唯一元素进行排序
(Pdb) p np.unique(y[client])
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int64)
(Pdb) p statistic[client]
[(0, 345)]
(Pdb) p statistic
[[(0, 345)], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []]
(Pdb) p statistic[client]
[(0, 348), (1, 410), (2, 359), (3, 358), (4, 345), (5, 328), (6, 359), (7, 377), (8, 346), (9, 365)]
(Pdb) p len(statistic[client])
10
(Pdb) p len(statistic)
20

# 划分训练集和测试集
(Pdb) p len(X)
20
(Pdb) p len(X[i])
3495
(Pdb) p len(X_train)
2621
(Pdb) p len(X_test)
874
# 划分之后 train_data[0]是一个字典,包含两个键值对
# 'x'为client0对应的数据,'y'为标签
(Pdb) p len(train_data)
20
(Pdb) p len(train_data[0])
2
(Pdb) p type(train_data[0])
<class 'dict'>
(Pdb) p train_data[0]['x'].shape
(2621, 1, 28, 28)
(Pdb) p train_data[0]['y']
array([4, 7, 5, ..., 3, 0, 2], dtype=int64)
(Pdb) p train_data[0]['y'].shape
(2621,)

# 划分后每个客户端训练集和测试集的数量
(Pdb) p num_samples
{'train': [2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2696], 'test': [874, 874, 874, 874, 874, 874, 874, 874, 874
, 874, 874, 874, 874, 874, 874, 874, 874, 874, 874, 899]}

运行

1
2
3
rm ../dataset/mnist/config.json
cd ../dataset/
nohup python -u generate_mnist.py noniid - dir > mnist_dataset.out 2>&1

这是一个后台运行Python脚本的命令,使用了**nohup**命令来让命令在后台持续运行,同时将输出重定向到mnist_dataset.out文件中。**-u**选项在Python命令行中表示使用无缓冲的输出。它的作用是强制Python在标准输出和标准错误流中不进行缓冲,而是立即输出到终端。这对于实时查看脚本的输出或日志很有用,尤其是在脚本需要长时间运行或需要实时监控输出时。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
python generate_mnist.py iid balance - # for iid and balanced scenario

Client 0	 Size of data: 3495	 Labels:  [0 1 2 3 4 5 6 7 8 9]
		 Samples of labels:  [(0, 345), (1, 393), (2, 349), (3, 357), (4, 341), (5, 315), (6, 343), (7, 364), (8, 341), (9, 347)]
--------------------------------------------------
Client 1	 Size of data: 3495	 Labels:  [0 1 2 3 4 5 6 7 8 9]
		 Samples of labels:  [(0, 345), (1, 393), (2, 349), (3, 357), (4, 341), (5, 315), (6, 343), (7, 364), (8, 341), (9, 347)]
--------------------------------------------------
Client 2	 Size of data: 3495	 Labels:  [0 1 2 3 4 5 6 7 8 9]
		 Samples of labels:  [(0, 345), (1, 393), (2, 349), (3, 357), (4, 341), (5, 315), (6, 343), (7, 364), (8, 341), (9, 347)]
--------------------------------------------------
Client 3	 Size of data: 3495	 Labels:  [0 1 2 3 4 5 6 7 8 9]
		 Samples of labels:  [(0, 345), (1, 393), (2, 349), (3, 357), (4, 341), (5, 315), (6, 343), (7, 364), (8, 341), (9, 347)]
--------------------------------------------------
Client 4	 Size of data: 3495	 Labels:  [0 1 2 3 4 5 6 7 8 9]
		 Samples of labels:  [(0, 345), (1, 393), (2, 349), (3, 357), (4, 341), (5, 315), (6, 343), (7, 364), (8, 341), (9, 347)]
--------------------------------------------------
Client 5	 Size of data: 3495	 Labels:  [0 1 2 3 4 5 6 7 8 9]
		 Samples of labels:  [(0, 345), (1, 393), (2, 349), (3, 357), (4, 341), (5, 315), (6, 343), (7, 364), (8, 341), (9, 347)]
--------------------------------------------------
Client 6	 Size of data: 3495	 Labels:  [0 1 2 3 4 5 6 7 8 9]
		 Samples of labels:  [(0, 345), (1, 393), (2, 349), (3, 357), (4, 341), (5, 315), (6, 343), (7, 364), (8, 341), (9, 347)]
--------------------------------------------------
Client 7	 Size of data: 3495	 Labels:  [0 1 2 3 4 5 6 7 8 9]
		 Samples of labels:  [(0, 345), (1, 393), (2, 349), (3, 357), (4, 341), (5, 315), (6, 343), (7, 364), (8, 341), (9, 347)]
--------------------------------------------------
Client 8	 Size of data: 3495	 Labels:  [0 1 2 3 4 5 6 7 8 9]
		 Samples of labels:  [(0, 345), (1, 393), (2, 349), (3, 357), (4, 341), (5, 315), (6, 343), (7, 364), (8, 341), (9, 347)]
--------------------------------------------------
Client 9	 Size of data: 3495	 Labels:  [0 1 2 3 4 5 6 7 8 9]
		 Samples of labels:  [(0, 345), (1, 393), (2, 349), (3, 357), (4, 341), (5, 315), (6, 343), (7, 364), (8, 341), (9, 347)]
--------------------------------------------------
Client 10	 Size of data: 3495	 Labels:  [0 1 2 3 4 5 6 7 8 9]
		 Samples of labels:  [(0, 345), (1, 393), (2, 349), (3, 357), (4, 341), (5, 315), (6, 343), (7, 364), (8, 341), (9, 347)]
--------------------------------------------------
Client 11	 Size of data: 3495	 Labels:  [0 1 2 3 4 5 6 7 8 9]
		 Samples of labels:  [(0, 345), (1, 393), (2, 349), (3, 357), (4, 341), (5, 315), (6, 343), (7, 364), (8, 341), (9, 347)]
--------------------------------------------------
Client 12	 Size of data: 3495	 Labels:  [0 1 2 3 4 5 6 7 8 9]
		 Samples of labels:  [(0, 345), (1, 393), (2, 349), (3, 357), (4, 341), (5, 315), (6, 343), (7, 364), (8, 341), (9, 347)]
--------------------------------------------------
Client 13	 Size of data: 3495	 Labels:  [0 1 2 3 4 5 6 7 8 9]
		 Samples of labels:  [(0, 345), (1, 393), (2, 349), (3, 357), (4, 341), (5, 315), (6, 343), (7, 364), (8, 341), (9, 347)]
--------------------------------------------------
Client 14	 Size of data: 3495	 Labels:  [0 1 2 3 4 5 6 7 8 9]
		 Samples of labels:  [(0, 345), (1, 393), (2, 349), (3, 357), (4, 341), (5, 315), (6, 343), (7, 364), (8, 341), (9, 347)]
--------------------------------------------------
Client 15	 Size of data: 3495	 Labels:  [0 1 2 3 4 5 6 7 8 9]
		 Samples of labels:  [(0, 345), (1, 393), (2, 349), (3, 357), (4, 341), (5, 315), (6, 343), (7, 364), (8, 341), (9, 347)]
--------------------------------------------------
Client 16	 Size of data: 3495	 Labels:  [0 1 2 3 4 5 6 7 8 9]
		 Samples of labels:  [(0, 345), (1, 393), (2, 349), (3, 357), (4, 341), (5, 315), (6, 343), (7, 364), (8, 341), (9, 347)]
--------------------------------------------------
Client 17	 Size of data: 3495	 Labels:  [0 1 2 3 4 5 6 7 8 9]
		 Samples of labels:  [(0, 345), (1, 393), (2, 349), (3, 357), (4, 341), (5, 315), (6, 343), (7, 364), (8, 341), (9, 347)]
--------------------------------------------------
Client 18	 Size of data: 3495	 Labels:  [0 1 2 3 4 5 6 7 8 9]
		 Samples of labels:  [(0, 345), (1, 393), (2, 349), (3, 357), (4, 341), (5, 315), (6, 343), (7, 364), (8, 341), (9, 347)]
--------------------------------------------------
Client 19	 Size of data: 3595	 Labels:  [0 1 2 3 4 5 6 7 8 9]
		 Samples of labels:  [(0, 348), (1, 410), (2, 359), (3, 358), (4, 345), (5, 328), (6, 359), (7, 377), (8, 346), (9, 365)]
--------------------------------------------------

Total number of samples: 70000
The number of train samples: [2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2621, 2696]
The number of test samples: [874, 874, 874, 874, 874, 874, 874, 874, 874, 874, 874, 874, 874, 874, 874, 874, 874, 874, 874, 899]

Saving to disk.

Finish generating dataset.

# IID 不平衡python generate_mnist.py iid - -
# 对于每个i,每个client分到的样本数会重新划分
(Pdb) p num_samples
[71, 269, 106, 289, 237, 167, 178, 163, 105, 271, 315, 212, 310, 288, 286, 190, 84, 102, 249, 3011]
(Pdb) p num_samples
[280, 391, 125, 180, 46, 358, 356, 61, 352, 40, 355, 248, 303, 255, 180, 154, 160, 69, 110, 3854]

Client 0         Size of data: 1882      Labels:  [0 1 2 3 4 5 6 7 8 9]
                 Samples of labels:  [(0, 71), (1, 280), (2, 83), (3, 300), (4, 336), (5, 67), (6, 137), (7, 253), (8, 111), (9, 244)]
--------------------------------------------------
Client 1         Size of data: 2922      Labels:  [0 1 2 3 4 5 6 7 8 9]
                 Samples of labels:  [(0, 269), (1, 391), (2, 347), (3, 230), (4, 250), (5, 294), (6, 300), (7, 267), (8, 253), (9, 321)]
--------------------------------------------------
Client 2         Size of data: 1323      Labels:  [0 1 2 3 4 5 6 7 8 9]
                 Samples of labels:  [(0, 106), (1, 125), (2, 37), (3, 161), (4, 185), (5, 88), (6, 216), (7, 132), (8, 211), (9, 62)]

client处理

learning_rate_decay:收敛性分析的时候要求学习率必须衰减,不衰减的话就没办法收敛,理论上是这样,实际上不衰减也能收敛,能达到一个比较高的准确率,北大的 On the Convergence of FedAvg on Non-IID Data,blog

local epochs:每次聚合的时候本地训练次数

每个本地 epoch 内执行多通常情况下,在每个本地 epoch 中,模型会根据本地训练数据进行一次参数更新,然后进行一次模型评估。然而,当采用多个更新步骤时,模型将在同一本地 epoch 内执行多次参数更新,而不是仅执行一次。这种方法可能会提高模型的收敛速度,特别是在使用大型模型和数据集进行训练时。然而,需要注意的是,多次更新步骤可能会增加过拟合的风险,因此需要仔细权衡。个参数更新步骤。

join_ratio:客户端每轮参加的比例,用于衡量client dift的程度

eval_gap:几轮评估test一次模型

copy.deepcopy 是用于创建对象的深度拷贝,这意味着它会递归地复制对象及其包含的所有对象,而不仅仅是复制对象本身。这样做可以确保 self.modelargs.model 的完整和独立的副本,而不是对原始对象的引用。

nn.BatchNorm2d 批标准化被用于加速神经网络的训练,并且有助于处理梯度消失/爆炸问题。用于在卷积神经网络的卷积层后应用批标准化。它通过规范化每个输入通道的输出,然后应用缩放和偏移,以使模型更容易训练并提高泛化能力。在使用深度学习框架构建卷积神经网络时,批标准化通常被认为是一种标准的正则化技术,并且已被证明在许多情况下能够提高模型的性能和训练速度。


准确率和AUC是两个不同的指标,用于评估分类模型的性能,它们分别从整体准确性和类别排序能力的角度来衡量模型的表现。ROC 曲线下面积(ROC AUC)是一种用于衡量分类模型性能的指标,它表示分类模型在不同阈值下真正例率(True Positive Rate)与假正例率(False Positive Rate)之间的权衡。ROC 曲线是以真正例率为纵轴,假正例率为横轴所绘制的曲线,ROC AUC 则是 ROC 曲线下方的面积,取值范围在 0 到 1 之间。ROC AUC 值越接近 1,表示模型性能越好;越接近 0.5,表示模型性能越一般;小于 0.5,则表示模型性能不如随机猜测。ROC 曲线和 ROC AUC 可以帮助我们评估分类模型在不同阈值下的整体性能和稳定性。

AUC(Area Under the Curve)和准确率是两种不同的性能评估指标:

  1. AUC(Area Under the Curve):用于衡量分类模型在不同阈值下真正例率(True Positive Rate)与假正例率(False Positive Rate)之间的权衡。ROC 曲线下面积(ROC AUC)是对整个 ROC 曲线的一个总体性能指标,表示分类模型对正负样本的排序能力,范围在 0 到 1 之间。
  2. 准确率(Accuracy):表示模型在所有预测样本中正确分类的比例,是最常见的分类模型性能指标之一。但准确率不能很好地处理样本不均衡的情况,当正负样本比例严重失衡时,准确率并不是一个很好的评价指标。

总的来说,AUC 是评估分类模型排序能力的指标,而准确率是评估模型分类能力的指标。 AUC 更适用于样本不均衡的情况,而准确率则更适用于样本均衡的情况。

  • Epoch(周期) 是指整个训练数据集被送入神经网络进行了一次正向传播和反向传播的过程。一个 epoch 表示神经网络已经学习完整个训练数据集的过程。
  • Batch(批次) 是指将整个训练数据集分成若干批次,每个批次包含若干个样本。在每个批次中,模型根据批次内的样本进行一次正向传播和反向传播,然后更新参数。
  • 在使用时,可以将整个训练数据集划分为多个批次,每个批次包含固定数量的样本。在每个 epoch 中,神经网络会依次处理每个批次的数据,并进行参数更新。这样可以加速训练过程,减少内存占用,并且有助于模型的收敛。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
# ClientAVG继承了client类
# 定义了clientAVG,它是 `Client` 类的子类。
# 该类具有一个构造函数 `__init__`,它接受 `args`、`id`、`train_samples`、`test_samples` 等参数
# 并通过调用 `super()` 函数来调用父类 `Client` 的构造函数,并将参数传递给父类构造函数。
from flcore.clients.clientbase import Client

class clientAVG(Client):
    def __init__(self, args, id, train_samples, test_samples, **kwargs):
        # 调用父类的构造函数 
        super().__init__(args, id, train_samples, test_samples, **kwargs)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import argparse  
  
# 创建 ArgumentParser 对象  
parser = argparse.ArgumentParser(description='Process some integers.')  
  
# 添加命令行参数  
parser.add_argument('integers', metavar='N', type=int, nargs='+', help='an integer for the accumulator')  
  
# 解析命令行参数  
args = parser.parse_args()  
  
# 访问解析后的参数  
print(args.integers)  

# 添加参数直接add,允许用户在命令行中指定程序的行为,而不需要修改代码。
# 不用在函数中添加参数,再修改调用的代码
# 管理超参数的一种方式

client端先通过DataLoader分好batch,用batch的每个样本做梯度下降

clientavg.py

一个epoch把所有训练集跑一遍,一个batch是一个batchsize的数据

client.train()

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
# client0 训练
(Pdb) p len(trainloader)
141
# 本地训练的epoch
max_local_epochs = self.local_epochs=1
for i, (x, y) in enumerate(trainloader)#按batch进行
(Pdb) p x.shape
torch.Size([10, 1, 28, 28])
(Pdb) p y.shape
torch.Size([10])
(Pdb) p loss
tensor(2.3381, device='cuda:0', grad_fn=<NllLossBackward0>)
(Pdb) p self.train_time_cost# 每个client独立计算
{'num_rounds': 1, 'total_cost': 657.0169236660004}
# 每个client单独记录,不会被清空
(Pdb) p self.selected_clients[0].train_time_cost
{'num_rounds': 1, 'total_cost': 657.0169236660004}
(Pdb) p self.selected_clients[1].train_time_cost
{'num_rounds': 1, 'total_cost': 26.226897716522217}

server处理

main.py

初始化超参数,生成模型和选择算法,使用 FedAvg 类来执行 FedAvg 算法,设置client(数据集和实例化client),server.train()服务器端训练,计算测试结果

for i in range(args.prev, args.times):# (0,1)–实验运行的次数,打印信息用的,并不是模型聚合的次数

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
	if args.algorithm == "FedAvg":
        	# 备份模型的全连接层 (fc) 
            args.head = copy.deepcopy(args.model.fc)
            # 将模型的全连接层 (fc) 替换为 nn.Identity(),
            # 这样做的目的是为了在聚合时只考虑模型的特征提取部分,而不包括全连接层的参数
            
            # 用于创建一个恒等映射(identity mapping)的模块。
            #在神经网络中,它通常被用作一个占位符,不对输入进行任何变换,直接将输入作为输出返回。
            args.model.fc = nn.Identity() 
        
            # 将模型和备份的全连接层重新组合,这可能是为了在聚合后重新还原模型结构。
            args.model = BaseHeadSplit(args.model, args.head)
            
            # FedAvg 类来执行 FedAvg 算法
            # 传入参数 args 和 i,进行模型参数的聚合和更新。
            server = FedAvg(args, i)
            
# split an original model into a base and a head
# 可以灵活地将原始模型拆分为基础部分和头部部分,
# 以便在需要的时候进行个性化定制或模型结构的重新组合。
class BaseHeadSplit(nn.Module):
    def __init__(self, base, head):
        super(BaseHeadSplit, self).__init__()

        self.base = base
        self.head = head
        
    def forward(self, x):
        out = self.base(x)
        out = self.head(out)

        return out         
# 备份全连接层的作用是为了在聚合过程中保留原始模型的全连接层参数。通过备份全连接层,可以确保在模型参数聚合后,还能够恢复原始的全连接层结构和参数,以便保持模型的完整性和性能。

# 在代码中,备份全连接层的操作使得在模型参数聚合过程中,只聚合模型的特征提取部分,而不包括全连接层的参数。这有助于保持模型结构的一致性,并且在分布式学习中更容易实现模型参数的共享和聚合。

# 在联邦学习中,是否需要聚合全连接层取决于具体的联邦学习任务和模型架构。一般来说,如果在联邦学习中全连接层的参数在不同设备上是相同的,那么可能不需要聚合全连接层;如果全连接层的参数需要在不同设备上进行调整,那么可能需要聚合全连接层。

# 全连接层(Fully Connected Layer)通常位于神经网络的末尾,负责将前面的特征表示映射到最终的输出。在分类任务中,全连接层通常将特征映射到不同类别的得分或概率;在其他任务中,全连接层也扮演着类似的作用,将中间特征映射到最终输出。
(Pdb) p args.head
Linear(in_features=512, out_features=10, bias=True)

(Pdb) p args.model
FedAvgCNN(
  (conv1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): ReLU(inplace=True)
  )
  (fc): Identity()
)
           
(Pdb) p args.model
BaseHeadSplit(
  (base): FedAvgCNN(
    (conv1): Sequential(
      (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    )
    (conv2): Sequential(
      (0): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    )
    (fc1): Sequential(
      (0): Linear(in_features=1024, out_features=512, bias=True)
      (1): ReLU(inplace=True)
    )
    (fc): Identity()
  )
  (head): Linear(in_features=512, out_features=10, bias=True)
)
(Pdb) 
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# 使用 FedAvg 类来执行 FedAvg 算法
server = FedAvg(args, i)
# 设置client
self.set_clients(clientAVG)
(Pdb) p self.num_clients
2

# 读取数据 mnist,0(client_id),train
def read_client_data(dataset, idx, is_train=True)
(Pdb) p train_file
'../dataset\\mnist\\train/0.npz'
# train_data读取数据
(Pdb) p train_data['x'].shape
(1411, 1, 28, 28)
(Pdb) p type(train_data['x'])
<class 'numpy.ndarray'>
# 转为tensor张量
(Pdb) p X_train.shape
torch.Size([1411, 1, 28, 28])
(Pdb) p type(X_train)
<class 'torch.Tensor'>
(Pdb) p y_train.shape
torch.Size([1411])

# 打包为元组的列表
# 1411个二元组组成的列表
(Pdb) p type(train_data)
<class 'list'>
(Pdb) p len(train_data)
1411
(Pdb) p type(train_data[0])
<class 'tuple'>
(Pdb) p train_data[0][0].shape
torch.Size([1, 28, 28])
(Pdb) p train_data[0][1].shape
torch.Size([])
(Pdb)  p train_data[0][1]
tensor(4)

# 训练集和测试集 client0 1882
(Pdb) p len(train_data)
1411
(Pdb) p len(test_data)
471
# client1 2922
(Pdb) p len(train_data)
2191
(Pdb) p len(test_data)
731

# 实例化client类
self.set_clients(clientAVG)
client = clientObj(self.args, 
                            id=i, 
                            train_samples=len(train_data), 
                            test_samples=len(test_data), 
                            train_slow=train_slow, 
                            send_slow=send_slow)
(Pdb) p client
<flcore.clients.clientavg.clientAVG object at 0x0000022DCD9EAC08>

serveravg.py

server.train():

for i in range(self.global_rounds+1)#2000+1–模型聚合的次数2000次

下发全局模型(server->client),评估全局模型(acc,auc,loss),每个client进行训练(client.train()),接收client模型,聚合模型,保存结果,保存全局模型

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
(Pdb) p self.selected_clients
[<flcore.clients.clientavg.clientAVG object at 0x0000022DC7683948>, <flcore.clients.clientavg.clientAVG object at 0x0000022DCD9EAC08>]

self.global_model = copy.deepcopy(args.model)
# server->client下发模型
self.send_models()
self.global_model = copy.deepcopy(args.model)
client.set_parameters(self.global_model)# 全局模型set到cliet模型上
# new_param 是 model 的参数。在 PyTorch 中,模型的参数通常是指模型中的权重和偏置等可学习的参数。这些参数以张量(tensor)的形式存在,用于表示神经网络中不同层之间的连接权重和偏置。new_param 是一个模型的参数张量,其格式可以是任意形状的张量,取决于具体模型的结构和层的数量。
(Pdb) p new_param
Parameter containing:
tensor([[[[-1.4974e-03,  1.0729e-01, -1.6461e-01, -1.4719e-01, -7.7031e-02],
          [ 5.3631e-02, -3.9626e-03,  1.5858e-01, -1.7749e-02,  5.2923e-02],
          [-6.0443e-02, -3.9313e-02, -1.9107e-01, -1.3246e-01, -8.2445e-02],
          [ 7.4087e-03,  7.9067e-02,  1.2000e-01, -1.3559e-01, -8.7093e-02],
          [ 7.2643e-02,  1.6608e-01, -4.1160e-02,  1.4966e-01, -3.2237e-02]]],


        [[[ 2.1163e-02,  1.8110e-01, -1.8553e-01, -1.2591e-01, -5.0633e-02],
          [-7.7960e-02,  1.7280e-01, -1.2964e-01, -9.2067e-02, -1.3973e-01],
          [-1.8731e-01, -1.1675e-01,  1.7192e-01,  8.9244e-02,  9.6935e-02],
          [ 1.0518e-02, -1.0254e-01,  3.3837e-02, -1.8674e-01, -1.4451e-01],
          [-1.0311e-01,  1.2619e-01,  1.1726e-01, -8.8699e-02, -7.2165e-03]]],


        [[[ 1.2791e-01,  1.9883e-01,  7.9376e-02,  2.7019e-02,  1.3410e-01],
          [-1.1776e-01,  3.7269e-02, -1.5506e-01, -1.3862e-01, -1.0332e-01],
          [ 9.0495e-02,  8.0432e-02, -1.1847e-01,  6.0421e-02,  1.0979e-01],
          [-2.5243e-02,  7.6363e-03,  4.6341e-02,  1.2408e-01,  1.9204e-01],
          [-1.5412e-01, -7.3294e-02,  7.8602e-02,  1.6571e-01,  1.7404e-01]]],


        [[[ 1.7647e-01,  3.9803e-02, -1.7392e-01,  1.8398e-02, -1.2512e-01],
          [-1.8639e-01,  1.7770e-01,  1.5207e-01, -1.9951e-01,  3.7434e-02],
          [-3.3692e-02, -3.2912e-02, -9.1551e-02,  7.6911e-02, -1.1846e-01],
          [ 7.3318e-02,  1.0114e-01,  1.4317e-01,  7.4782e-02, -1.9795e-01],
          [-1.2974e-01,  9.9863e-02,  4.1860e-02, -1.5602e-01, -1.1516e-01]]],


        [[[ 1.8815e-01,  1.3476e-01, -8.7205e-02, -5.0337e-02, -1.9052e-01],
          [-3.5948e-03, -1.5061e-01, -1.5427e-01, -1.1020e-02,  3.0029e-02],
          [-8.1906e-02,  1.1868e-01, -1.2171e-01,  1.8147e-01,  1.3706e-01],
          [-1.6866e-01, -4.9777e-02,  9.0245e-03,  2.9180e-02,  4.7435e-02],
          [ 7.8486e-02,  1.1980e-02, -9.7586e-02,  9.4638e-02, -1.9185e-01]]],


        [[[-1.1854e-01, -5.0066e-02, -9.7423e-02, -6.9967e-02, -1.6392e-01],
          [-4.2543e-02,  4.2751e-02, -1.3029e-01, -1.0264e-02,  1.4317e-01],
          [-2.0560e-02,  5.5584e-03, -1.7254e-02,  4.0476e-02,  1.2717e-01],
          [ 1.8945e-01,  1.2701e-01,  1.8988e-01, -1.4464e-02, -1.7966e-01],
          [-9.4815e-02,  1.3618e-01, -1.2965e-03, -9.9409e-02, -1.5326e-01]]],


        [[[-1.8717e-01, -1.6880e-01, -4.0567e-02,  1.0968e-01,  1.0813e-01],
          [-1.9289e-01,  1.2476e-01, -1.5650e-01, -4.2282e-02, -8.1095e-02],
          [-3.8523e-02, -3.9269e-02, -1.7947e-01, -1.7269e-01, -3.1296e-02],
          [ 2.5864e-03, -9.0855e-02,  7.5340e-02, -1.8001e-01, -1.3497e-02],
          [ 1.7588e-01, -8.1578e-02,  1.8060e-01,  7.2431e-02, -1.8049e-01]]],


        [[[ 1.2654e-01, -2.3079e-02, -8.9281e-02,  1.5993e-01, -1.6162e-01],
          [ 2.1461e-02, -4.1874e-02,  1.4282e-01,  5.5829e-02,  9.6101e-02],
          [ 7.0632e-02, -4.8095e-02, -4.2061e-02, -1.6482e-01,  1.0837e-01],
          [ 1.5880e-01,  1.3684e-01, -1.4108e-01,  8.9200e-03, -1.4099e-01],
          [-1.1010e-01, -1.1654e-01,  6.8349e-02, -1.1918e-01, -4.3634e-03]]],


        [[[ 8.4136e-03,  1.2892e-01, -1.5118e-01, -1.3730e-01, -1.1613e-01],
          [ 1.3999e-01, -7.1893e-02,  1.6870e-01,  7.2322e-02,  2.5325e-02],
          [-1.4888e-03, -3.9536e-02,  2.5093e-02, -4.5669e-02, -1.4053e-03],
          [ 2.5519e-02, -1.5644e-01, -1.0483e-01,  1.6150e-01, -1.6231e-01],
          [-1.4361e-02,  1.9785e-01,  7.2247e-02,  5.6626e-03, -1.7332e-01]]],


        [[[ 9.9076e-02, -1.4246e-01, -5.6773e-02, -6.7103e-02, -2.9617e-02],
          [ 2.1877e-03,  1.6496e-01,  2.4968e-02,  1.7914e-01,  1.2234e-01],
          [-1.2644e-01,  8.9701e-02, -1.4138e-01, -8.4765e-02,  5.8825e-02],
          [ 6.6038e-02,  1.5005e-01, -6.4383e-02,  3.2017e-04,  1.0296e-01],
          [-1.9342e-01,  1.4460e-01, -1.6538e-01,  2.7565e-03, -3.4003e-02]]],


        [[[-1.0533e-01,  2.6434e-02,  1.6538e-01, -5.8464e-02, -1.1874e-01],
          [-7.3968e-02, -1.9823e-01,  9.0279e-02, -9.6053e-02, -1.3347e-01],
          [-1.1522e-01,  1.1499e-01,  1.0591e-01,  1.5350e-01,  7.2545e-02],
          [-6.6792e-02, -5.5890e-02,  5.9086e-02,  1.6441e-01,  5.4378e-02],
          [-9.4629e-02, -9.4015e-02, -1.8909e-01,  4.3210e-02, -1.1224e-01]]],


        [[[-1.7832e-01,  1.7538e-01, -1.2988e-01, -2.2752e-02,  5.7299e-02],
          [ 6.3717e-03, -1.3458e-01, -1.6166e-01,  1.5942e-01,  3.2568e-02],
          [ 1.6592e-01, -6.7048e-02,  5.8911e-02, -4.5735e-02, -8.8942e-03],
          [-1.2181e-01,  6.7640e-02,  6.3233e-02, -4.1257e-03, -4.4980e-02],
          [-1.2329e-01,  1.3831e-01, -1.4888e-01,  8.1933e-02, -6.7251e-02]]],


        [[[-9.6494e-02,  3.5929e-02, -1.0389e-01,  4.6081e-02,  3.9276e-02],
          [-1.4850e-01,  3.3300e-02,  8.5186e-02,  7.9166e-02, -2.5176e-02],
          [-1.6396e-01, -3.0831e-02,  6.9462e-02, -7.2975e-02,  7.5916e-02],
          [ 1.3319e-01, -1.0443e-01,  1.9724e-03,  8.2700e-02,  1.5676e-02],
          [ 1.6706e-02,  2.4973e-02, -1.5723e-01,  1.5718e-02,  1.3849e-01]]],


        [[[ 1.8022e-01,  1.1755e-01,  2.6801e-02,  9.3403e-02, -9.7296e-02],
          [-1.6574e-01, -1.7198e-01,  1.9952e-01,  1.2696e-01, -1.3825e-01],
          [ 7.8249e-02,  1.5103e-01,  1.9992e-01,  1.7488e-01,  1.5495e-01],
          [-4.5852e-02, -7.0188e-02,  1.6421e-01,  1.1208e-01, -1.2036e-01],
          [ 1.7980e-01,  9.6632e-02,  1.0903e-01, -1.2536e-01,  5.7380e-02]]],


        [[[-7.0112e-02,  1.5627e-01, -3.5988e-02,  7.7863e-02,  3.5524e-02],
          [ 8.5094e-02, -6.7966e-02,  9.7514e-02, -1.3969e-01,  4.5171e-02],
          [-1.3532e-01, -1.9731e-01, -1.6061e-01,  1.5789e-01,  1.0821e-01],
          [ 1.8764e-01,  1.6022e-01, -1.7861e-01, -1.3649e-01, -3.2317e-02],
          [-1.2989e-01,  1.3888e-01, -1.5120e-01, -9.7580e-02, -1.9322e-01]]],


        [[[-1.1355e-01,  1.6450e-01,  1.6375e-01,  1.4317e-01,  1.5443e-01],
          [ 1.7784e-01, -5.1213e-02,  8.8000e-02,  1.7819e-01,  6.6164e-02],
          [ 1.9994e-01,  1.0373e-01,  1.2433e-01, -6.9998e-02,  9.5966e-02],
          [ 2.2981e-02, -4.7763e-02, -1.1274e-01, -1.1222e-01, -1.5390e-01],
          [ 1.3427e-01,  1.4219e-01, -2.2763e-02, -1.1574e-01,  1.5458e-01]]],


        [[[ 1.2790e-01,  1.4867e-02, -9.4427e-02,  1.8381e-01,  8.1788e-02],
          [-1.5183e-01,  1.9142e-01,  1.5187e-01, -7.2898e-02,  1.1243e-01],
          [-1.1364e-01, -3.1343e-02,  1.6982e-01,  8.2638e-03, -1.4144e-01],
          [-6.6847e-02, -5.4289e-02, -3.8586e-02,  1.9140e-02,  1.8497e-01],
          [ 1.0708e-02, -1.2349e-01,  1.0251e-02,  9.5897e-02,  9.9208e-02]]],


        [[[-1.8279e-01, -3.5788e-02, -1.4863e-01, -8.5337e-02,  7.2059e-02],
          [-1.4203e-01,  7.4345e-02,  1.6976e-01,  1.3118e-02, -1.3330e-01],
          [-7.1657e-02,  4.3673e-02, -1.5246e-01,  9.9362e-02, -1.8157e-01],
          [-1.9226e-01, -1.9433e-01, -4.0573e-02,  1.3449e-01, -1.8930e-01],
          [ 1.6624e-01, -8.0005e-02,  5.8577e-02,  9.1206e-03, -1.8034e-01]]],


        [[[ 1.6587e-01,  1.0769e-01,  1.9880e-01,  1.0104e-01, -1.3201e-01],
          [ 1.6692e-01,  1.0749e-02,  9.4843e-02, -1.6037e-01, -5.7525e-02],
          [-1.9638e-01, -7.7899e-02,  4.3146e-02, -1.5703e-01,  6.3753e-02],
          [ 1.0736e-01,  2.7862e-02, -1.3382e-01, -1.5506e-01, -6.1703e-02],
          [ 8.7792e-02,  1.9728e-01,  1.1501e-01, -2.2522e-02,  7.0123e-02]]],


        [[[-1.9621e-01, -1.7082e-01,  9.3322e-02, -1.1328e-01,  9.6219e-02],
          [-1.4119e-01, -9.9062e-02, -1.6474e-01,  1.0437e-01, -2.0379e-02],
          [ 1.5392e-01,  1.2377e-01,  1.1067e-01,  6.4312e-03, -6.1836e-02],
          [-4.3486e-02,  2.6580e-02,  9.9142e-02, -1.4012e-01,  1.6786e-01],
          [-2.1746e-02, -1.6759e-01, -1.0821e-01,  1.7696e-01,  1.8291e-01]]],


        [[[-1.8526e-01,  1.4106e-01,  1.0023e-01,  1.1838e-01,  1.6931e-01],
          [-1.0779e-01,  6.3155e-02,  8.1847e-02, -5.9099e-02,  6.6931e-02],
          [-5.7543e-02,  1.2365e-01, -5.5491e-02, -7.4559e-02,  5.0350e-02],
          [ 7.0939e-02, -9.7714e-02,  1.7680e-02,  1.1591e-01, -1.9899e-02],
          [ 6.0867e-02, -4.8237e-02,  7.0100e-02, -1.4488e-01, -1.1761e-01]]],


        [[[-1.0152e-01,  1.8380e-01, -5.3817e-02, -5.4607e-04, -9.6899e-02],
          [ 1.9966e-01,  1.9533e-01, -1.5084e-01, -1.6213e-01, -1.5160e-01],
          [-9.6455e-04, -5.0981e-02, -1.3091e-01, -7.1734e-02,  3.7786e-02],
          [-1.0450e-01,  4.4317e-02, -4.5863e-02, -9.6913e-02,  2.7477e-02],
          [ 1.6445e-01, -1.3522e-01,  9.2869e-03, -7.3754e-02,  1.9626e-01]]],


        [[[-1.8975e-01, -1.9173e-01,  1.9708e-01, -1.2654e-01,  3.8345e-02],
          [-1.7261e-02, -4.2131e-02, -4.4670e-02,  1.2709e-01,  9.5594e-03],
          [-1.9472e-01, -1.1807e-01, -6.8180e-02,  1.0064e-01, -1.2943e-01],
          [ 1.8858e-01, -4.4546e-02, -3.5912e-02,  1.5671e-01,  1.0052e-01],
          [ 1.6962e-01,  1.1569e-01, -6.0671e-02, -1.3269e-01, -1.4881e-02]]],


        [[[ 1.6553e-01, -6.7124e-02, -1.8547e-01,  8.1986e-02,  1.9469e-01],
          [-5.6937e-02, -1.6560e-01, -1.8141e-01,  5.0120e-02, -1.5144e-02],
          [-1.0100e-01,  4.0426e-02,  7.5952e-02,  1.5906e-01,  1.5528e-01],
          [-2.9937e-02, -1.7635e-01, -1.8072e-01,  1.8674e-01,  8.8411e-02],
          [ 8.7181e-02, -1.7304e-01,  1.8520e-01,  1.8947e-01,  1.8057e-01]]],


        [[[-1.6872e-01, -7.5465e-02, -1.3755e-01,  1.8939e-01, -8.5936e-02],
          [-9.1311e-02,  1.0478e-01, -9.2518e-02, -9.8504e-02, -1.7494e-02],
          [-1.9224e-02, -1.5580e-01,  1.6673e-01, -8.8224e-02,  7.0943e-02],
          [ 1.7396e-01,  1.0086e-01,  2.8316e-02,  1.7017e-01,  2.6885e-02],
          [-9.2531e-02,  1.8920e-01,  4.7336e-02, -1.9514e-01, -5.6938e-02]]],


        [[[-1.3623e-01,  1.7535e-01, -3.3029e-02, -1.8230e-01, -1.2573e-02],
          [ 1.2560e-01,  5.1960e-02,  6.3241e-02,  1.8575e-02,  7.4564e-02],
          [-4.8730e-02, -7.9560e-02, -1.8694e-01, -1.5067e-01,  8.6681e-02],
          [-1.1842e-01,  2.8693e-02,  6.3838e-02,  1.4161e-02, -1.2967e-01],
          [ 1.9127e-01, -1.1630e-01,  1.6450e-01, -1.5910e-01, -4.8110e-02]]],


        [[[ 1.0879e-01, -8.1716e-02,  1.6801e-01, -1.3763e-01, -1.6795e-01],
          [-9.0184e-02,  3.2330e-02,  1.8415e-01, -9.5480e-02,  7.1526e-02],
          [-5.0141e-02, -4.3372e-02,  1.4706e-01, -1.5500e-01,  2.1243e-02],
          [ 1.8808e-01, -2.7482e-02,  1.5529e-01, -6.1591e-02,  1.6099e-01],
          [-1.9347e-01, -2.8825e-02, -3.5129e-02,  6.4817e-02,  7.8465e-02]]],


        [[[ 1.5356e-01, -2.9797e-02, -7.9310e-03,  1.3696e-01, -5.4115e-02],
          [ 1.7533e-01, -1.3316e-01, -2.1642e-02, -1.0742e-02,  8.9242e-02],
          [ 1.3674e-01, -3.1701e-02, -1.6570e-01,  9.9098e-02,  5.9826e-02],
          [ 8.0340e-02, -1.2337e-01,  1.2872e-01,  1.8943e-01,  1.7354e-02],
          [-1.8681e-01,  1.4040e-01, -1.4829e-01,  4.5974e-02,  2.9064e-02]]],


        [[[-9.3607e-02,  6.9639e-02, -1.7889e-01,  4.5552e-02, -1.2679e-01],
          [-2.1627e-02,  2.5726e-02,  1.7039e-01, -9.5424e-02,  1.2813e-01],
          [-2.5407e-02, -9.4986e-02, -1.7416e-01, -1.8350e-01,  1.9532e-01],
          [-4.9880e-02,  9.9898e-03,  5.4222e-02,  1.3595e-01,  1.7069e-01],
          [ 1.6220e-01, -1.4818e-01, -3.2043e-02, -1.1835e-01, -1.1427e-01]]],


        [[[ 4.7441e-02,  1.8773e-01, -1.6022e-01,  1.2104e-01, -1.0369e-01],
          [-3.8955e-02,  1.5875e-01, -4.5234e-02,  1.8211e-02, -1.3981e-01],
          [ 1.7025e-01, -2.5854e-02, -1.4628e-01,  5.8562e-02, -1.4220e-01],
          [-1.5870e-01,  1.2184e-02,  1.5857e-01, -5.6597e-02,  9.4133e-02],
          [ 1.7186e-01,  1.3265e-01, -1.0491e-01, -2.1924e-02, -6.2937e-02]]],


        [[[-1.6082e-01,  8.2642e-05,  1.5049e-01,  1.6850e-01,  1.8660e-02],
          [ 4.5425e-02, -8.6584e-02,  1.5097e-01, -8.3222e-02, -1.3893e-01],
          [ 3.0799e-02,  1.1988e-01, -1.8032e-01,  1.8079e-01,  7.1946e-02],
          [-1.4012e-01, -4.3081e-02,  1.7352e-01, -1.5345e-01, -5.8455e-02],
          [ 6.5596e-02, -1.7522e-01,  1.0964e-01,  1.0411e-01,  1.2404e-01]]],


        [[[-1.2751e-01,  1.9920e-01, -1.1855e-01,  1.9967e-01, -1.9194e-01],
          [-1.7819e-01,  1.2284e-01,  2.0903e-02,  1.1538e-02, -1.1075e-01],
          [-8.3895e-02, -5.8477e-02, -1.9483e-01,  1.0394e-02,  3.5372e-02],
          [-1.6932e-04,  6.4586e-02,  1.8977e-01,  5.3177e-02, -7.3214e-02],
          [-8.2309e-02, -1.2796e-01, -1.3864e-01, -3.2210e-02, -3.5373e-02]]]],
       device='cuda:0', requires_grad=True)
# 模型发送时间 每个客户端单独计算发送时间
(Pdb) p client.send_time_cost
{'num_rounds': 1, 'total_cost': 1461.3775424957275}

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# 每一轮评估一次模型
(Pdb) p self.eval_gap
1
# rounds = 0: 初始状态评估,全局模型下发到client之后
-------------Round number: 0-------------
self.evaluate()
# 测试指标 中心方没有test数据集 中心方得到准确率的方式是从客户端这里拿,然后加权
self.test_metrics()

# client的测试方法client0,acc和auc
c.test_metrics()
# 一个batch一个batch的操作 DataLoader 第一个batchsize
for x, y in testloaderfull
(Pdb) p x.shape
torch.Size([10, 1, 28, 28])
(Pdb) p y.shape
torch.Size([10])
(Pdb) p output.shape
torch.Size([10, 10])
(Pdb) p test_acc
0
(Pdb) p test_num
10
# 真实值y对应的独热编码
(Pdb) p lb
array([[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
       [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
       [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 1, 0, 0]])

# 遍历完所有client0 测试集
# 48个batchsize后,重新按第一个维度进行拼接
(Pdb) p len(y_true)
48
(Pdb) p y_true[0].shape# 每个batchsize,(batchsize,numclasses)
(10, 10)
(Pdb) p y_true[0][0]
array([0, 0, 0, 1, 0, 0, 0, 0, 0, 0])# 每个数据
# 按第一个维度拼接之后,所有数据拼接,去掉了batchsize
(Pdb) p len(y_true)
471
(Pdb) p y_true[0]
array([0, 0, 0, 1, 0, 0, 0, 0, 0, 0])#同上
(Pdb) p auc
0.38275912327597994
(Pdb) p test_acc
16
(Pdb) p test_num
471

# 对于server而言
# client 0 
(Pdb) p tot_correct # 正确的样本数
[16.0]
(Pdb) p tot_auc # auc 乘 总样本数
[180.27954706298655]
(Pdb) p num_samples
[471]
(Pdb) p tot_correct
[16.0, 55.0]
(Pdb) p tot_auc
[180.27954706298655, 331.7627298981608]
(Pdb) p num_samples
[471, 731]
(Pdb) p ids
[0, 1]

# client测试方法,loss,client 0 
c.train_metrics()
(Pdb) p len(trainloader)# 141个batchsize
141
# 第一个batchsize
(Pdb) p x.shape
torch.Size([10, 1, 28, 28])
(Pdb) p y.shape
torch.Size([10])
(Pdb) p y
tensor([2, 1, 4, 4, 5, 3, 3, 1, 4, 9], device='cuda:0')
(Pdb) p output.shape
torch.Size([10, 10]
(Pdb) p loss
tensor(2.3223, device='cuda:0')
(Pdb) p train_num
10
(Pdb) p losses # loss*样本数
23.22331666946411
(Pdb) p losses
3269.2880749702454
(Pdb) p train_num # 少一个数据?
1410
           
# 对于server而言
(Pdb) p num_samples
[1410]
(Pdb) p losses
[3269.2880749702454]
(Pdb) p num_samples
[1410, 2190]
(Pdb) p losses
[3269.2880749702454, 5063.167586326599]
(Pdb) p ids
[0, 1]

# 计算acc,auc,loss
(Pdb) p stats
([0, 1], [471, 731], [16.0, 55.0], [180.27954706298655, 331.7627298981608])
(Pdb) p stats_train
([0, 1], [1410, 2190], [3269.2880749702454, 5063.167586326599])
# acc:每个client样本正确数分别除所有client样本数 想加
# 相当于,每个客户端的正确率*每个客户端数据的比重 相加          
test_acc = sum(stats[2])*1.0 / sum(stats[1])
# <=>stats[2][0]/stats[1][0])*(stats[1][0]/sum[stats[1]]+...
(Pdb) p test_acc
0.05906821963394343
(Pdb) p test_auc
0.4259919109493738
(Pdb) p train_loss
2.314571017026901
# 每个客户端的acc和auc
(Pdb) p accs
[0.03397027600849257, 0.07523939808481532] 
(Pdb) p aucs
[0.38275912327597994, 0.4538477837184142]
(Pdb) p np.std(accs)# 标准差
0.020634561038161376   
           
Averaged Train Loss: 2.3146
Averaged Test Accurancy: 0.0591
Averaged Test AUC: 0.4260
Std Test Accurancy: 0.0206
Std Test AUC: 0.0355
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# client ->server
self.receive_models()
# 选择client的时候是随机选择的,所有会导致选择的客户端的id的顺序会变化
# `random.choice` 函数在随机抽取多个元素时会打乱它们的顺序
(Pdb) p self.clients
[<flcore.clients.clientavg.clientAVG object at 0x000001BC6BCD4F88>, <flcore.clients.clientavg.clientAVG object at 0x000001BC3A056D88>]
(Pdb) p self.selected_clients
[<flcore.clients.clientavg.clientAVG object at 0x000001BC3A056D88>, <flcore.clients.clientavg.clientAVG object at 0x000001BC6BCD4F88>]
for client in active_clients:
# client 1
(Pdb) p client_time_cost# 训练时间+发送时间
1.2847285270690918
(Pdb) p self.time_threthold # 超时就不参与聚合(s)(2.78h)
10000
(Pdb) p tot_samples # 样本数
2191
(Pdb) p self.uploaded_ids # client id
[1]
(Pdb) p self.uploaded_weights # 权重,数据量
[2191]
self.uploaded_models.append(client.model)# client模型
# 所有client
(Pdb) p tot_samples
3602
(Pdb) p self.uploaded_ids
[1, 0]
(Pdb) p self.uploaded_weights
[2191, 1411]
(Pdb) p self.uploaded_weights# 权重归一化
[0.6082731815657968, 0.3917268184342032]
1
2
3
4
5
6
7
8
9
# 聚合模型
self.aggregate_parameters()
# 初始化全局模型,复制client 0的模型结构且参数置为0 
self.global_model = copy.deepcopy(self.uploaded_models[0])
(Pdb) p self.uploaded_weights
[0.6082731815657968, 0.3917268184342032]
# 加权累加
self.add_parameters(w, client_model)
server_param.data += client_param.data.clone() * w
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# python main.py -gr 20
-------------Round number: 20-------------

Evaluate global model
Averaged Train Loss: 0.0773
Averaged Test Accurancy: 0.9584
Averaged Test AUC: 0.9933
Std Test Accurancy: 0.0025
Std Test AUC: 0.0011

(Pdb) p self.rs_test_acc
[0.05906821963394343, 0.3011647254575707, 0.6755407653910149, 0.8236272878535774, 0.8569051580698835, 0.889351081530782, 0.9026622296173045, 0.913477537437604, 0.9234608985024958, 0.93
51081530782029, 0.9334442595673876, 0.937603993344426, 0.9484193011647255, 0.9450915141430949, 0.9492512479201332, 0.9559068219633944, 0.9534109816971714, 0.9492512479201332, 0.9550748
752079867, 0.9534109816971714, 0.9584026622296173]

Best accuracy.
0.9584026622296173
Average time cost per round.# 时间开销
1.283463180065155

# 保存结果 
self.save_results()
(Pdb) p algo
'mnist_FedAvg'
(Pdb) p result_path
'../results/'
(Pdb) p algo
'mnist_FedAvg_test_0'
(Pdb) p file_path
'../results/mnist_FedAvg_test_0.h5'
# 向文件中写入三个数据集,分别是 ‘rs_test_acc’、‘rs_test_auc’ 和 ‘rs_train_loss’
# evaluate中并没有保存rs_test_auc
(Pdb) p self.rs_test_acc
[0.05906821963394343, 0.3011647254575707, 0.6755407653910149, 0.8236272878535774, 0.8569051580698835, 0.889351081530782, 0.9026622296173045, 0.913477537437604, 0.9234608985024958, 0.93
51081530782029, 0.9334442595673876, 0.937603993344426, 0.9484193011647255, 0.9450915141430949, 0.9492512479201332, 0.9559068219633944, 0.9534109816971714, 0.9492512479201332, 0.9550748
752079867, 0.9534109816971714, 0.9584026622296173]
(Pdb) p self.rs_test_auc
[]
(Pdb) p self.rs_train_loss
[2.314571017026901, 2.086947536468506, 1.3195898769630325, 0.6253579562943843, 0.43360667809223136, 0.32859760250689257, 0.28546613085911504, 0.2471489678292225, 0.21599847520701587, 0
.1882220826137604, 0.1742167657499926, 0.1539727338169339, 0.14151522711229822, 0.13180814096962826, 0.126626226349294, 0.10841290605668392, 0.10612137640515963, 0.10253492878497733, 0
.08905219737142842, 0.0842796761914441, 0.07731760282153523]

# 重新修改 self.evaluate() 后
(Pdb) p self.rs_test_auc
[0.4259919109493738, 0.7447302648864543, 0.9229162438448569, 0.9672684372148672, 0.9744737834460371, 0.9766528826588904, 0.9814767508333857, 0.9850432755687425, 0.985603523913785, 0.98
7561845229809, 0.9881369731008514, 0.989969924724848, 0.989925098209637, 0.9911689493678782, 0.9914534505980559, 0.9919368626935388, 0.9926894980663109, 0.9917173501511063, 0.992863217
0223384, 0.9926117283057737, 0.9932817668849713]

# 读取h5文件并按原格式打印
import h5py
with h5py.File('D:/Python/PycharmProjects/PFL-Non-IID-master/PFL-Non-IID-master/results/mnist_FedAvg_test_0.h5', 'r') as file:
    # 遍历文件中的所有数据集并打印
    for key in file.keys():
        print(key + ":")
        print(file[key][:])
        
# 保存全局模型
self.save_global_model()
(Pdb) p model_path
'models\\mnist'
(Pdb) p model_path
'models\\mnist\\FedAvg_server.pt'
torch.save(self.global_model, model_path)

# 全局平均,多次实验中(args.times)
# 从文件中读取多次实验的acc,得到每次实验的最大acc,计算标准差和均值
average_data
(Pdb) p algorithm
'FedAvg'
(Pdb) p dataset
'mnist'
(Pdb) p goal
'test'
(Pdb) p times
1
(Pdb) p max_accurancy
[0.9584026622296173]
# 多次实验结果的标准差和均值
std for best accurancy: 0.0
mean for best accurancy: 0.9584026622296173

# 输出内存使用情况
reporter.report()

Storage on cuda:0
-------------------------------------------------------------------------------
Total Tensors: 4656208  Used Memory: 13.33M
The allocated memory on cuda:0: 13.33M
-------------------------------------------------------------------------------

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# 参考文件 example.sh
# `2>&1` 表示将标准错误(stderr)重定向到标准输出(stdout),这意味着错误信息将与标准输出一起处理。
# `&` 表示在后台运行命令,这样可以使命令在后台持续运行而不会阻塞当前终端。
# 保存输出到mnist_fedavg.out文件中
python -u main.py -gr 20 > mnist_fedavg.out 2>&1 &

# 读取mnist_fedavg.out文件
# 查找并提取出包含最佳准确率的行,然后计算这些最佳准确率的平均值和标准差
from statistics import mean
import numpy as np

file_name = input() + '.out'

acc = []

with open(file_name, 'r') as f:
    is_best = False
    for l in f.readlines():
        if is_best:
            acc.append(float(l))
            is_best = False
        elif 'Best accuracy' in l:
            is_best = True

print(acc)
print(mean(acc)*100, np.std(acc)*100)
"""
D:\Python\Python37\python.exe D:/Python/PycharmProjects/PFL-Non-IID-master/PFL-Non-IID-master/system/get_mean_std.py
D:/Python/PycharmProjects/PFL-Non-IID-master/PFL-Non-IID-master/system/mnist_fedavg
[0.9575707154742097]
95.75707154742096 0.0
"""
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
# python generate_mnist.py iid - -

D:\Python\Python37\python.exe D:/Python/PycharmProjects/PFL-Non-IID-master/PFL-Non-IID-master/system/main.py
==================================================
Algorithm: FedAvg
Local batch size: 10
Local steps: 1 #每次聚合的时候本地训练次数
Local learing rate: 0.005
Local learing rate decay: False
Total number of clients: 2
Clients join in each round: 1.0 # cliet全部参与
Clients randomly join: False
Client drop rate: 0.0
Client select regarding time: False
Running times: 1
Dataset: mnist
Number of classes: 10
Backbone: cnn
Using device: cuda
Using DP: False
Auto break: False
Global rounds: 2000
Cuda device id: 0
DLG attack: False
Total number of new clients: 0
Fine tuning epoches on new clients: 0
==================================================

============= Running time: 0th =============# 运行次数(整个实验的次数)
Creating server and clients ...
FedAvgCNN(
  (conv1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): ReLU(inplace=True)
  )
  (fc): Linear(in_features=512, out_features=10, bias=True)
)

Join ratio / total clients: 1.0 / 2
Finished creating server and clients.

-------------Round number: 0-------------#轮次

Evaluate global model
Averaged Train Loss: 2.3146
Averaged Test Accurancy: 0.0591
Averaged Test AUC: 0.4260
Std Test Accurancy: 0.0206
Std Test AUC: 0.0355
------------------------- time cost ------------------------- 5.320608139038086

-------------Round number: 1-------------

Evaluate global model
Averaged Train Loss: 2.0873
Averaged Test Accurancy: 0.2995
Averaged Test AUC: 0.7448
Std Test Accurancy: 0.0540
Std Test AUC: 0.0632
------------------------- time cost ------------------------- 1.006882905960083

-------------Round number: 2-------------

Evaluate global model
Averaged Train Loss: 1.3040
Averaged Test Accurancy: 0.7155
Averaged Test AUC: 0.9386
Std Test Accurancy: 0.0332
Std Test AUC: 0.0046
------------------------- time cost ------------------------- 0.971451997756958

......

-------------Round number: 20-------------

Evaluate global model
Averaged Train Loss: 0.0775
Averaged Test Accurancy: 0.9576
Averaged Test AUC: 0.9933
Std Test Accurancy: 0.0035
Std Test AUC: 0.0011
------------------------- time cost ------------------------- 0.9970800876617432

Best accuracy.
0.9575707154742097

Average time cost per round.
1.0050838351249696
File path: ../results/mnist_FedAvg_test_0.h5

Average time cost: 25.04s.
Length:  21
std for best accurancy: 0.0
mean for best accurancy: 0.9575707154742097
All done!

Storage on cuda:0
-------------------------------------------------------------------------------
Total Tensors: 4656208  Used Memory: 13.33M
The allocated memory on cuda:0: 13.33M
-------------------------------------------------------------------------------

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
-ls 1
mean for best accurancy: 0.9559068219633944
-ls 3 # 最优不是最后一轮
mean for best accurancy: 0.9700499168053245
-ls 5 # 最优不是最后一轮
mean for best accurancy: 0.9725457570715474

这种情况可能是由于神经网络出现了过拟合过拟合指模型在训练集上表现良好但在测试集上性能下降的现象解决方法可以包括

早停Early Stopping):在训练过程中监测验证集上的性能当性能开始下降时停止训练避免过拟合
使用正则化通过 L1  L2 正则化限制模型的复杂度防止过度拟合
数据增强Data Augmentation):增加训练集的样本多样性有助于提高模型的泛化能力
降低模型复杂度减少神经网络的层数或每层神经元数量以降低模型的复杂度
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
# 模型结构图
Input: (batch_size, 1, height, width)  
  
      -> Conv2d(1, 32)  
      -> ReLU  
      -> MaxPool2d  
      -> Conv2d(32, 64)  
      -> ReLU  
      -> MaxPool2d  
      -> Flatten  
      -> Linear(1024, 512)  
      -> ReLU  
      -> Linear(512, 10)  

Installing collected packages: torch, opacus Attempting uninstall: torch Found existing installation: torch 1.10.1+cu113 Uninstalling torch-1.10.1+cu113: Successfully uninstalled torch-1.10.1+cu113 ERROR: pip’s dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.

重新安装了torch 1.10.1+cu113

凸优化?

“Convex"是一个几何和数学术语,用来描述凸形或凸函数。在数学中,一个集合被称为凸集,如果集合中包含的任意两点之间的线段仍然在该集合内部。凸函数则是一种特殊的函数,其图像上的任意两点之间的线段位于函数图像的上方。

在优化和凸优化中,凸集和凸函数具有重要的性质,例如任意局部最小值都是全局最小值等。因此,凸集和凸函数在数学建模和优化问题中具有重要的作用。

“Non-convex"是指非凸的意思。在数学中,非凸通常用来描述不满足凸性质的集合或函数。对于一个集合来说,如果存在集合内的两点,连接这两点的线段不完全位于集合内部,则该集合是非凸的。对于一个函数来说,如果函数图像上的某一部分不满足凸性质,那么该函数就是非凸的。

在优化问题中,非凸函数通常具有多个局部最小值,使得寻找全局最小值变得更加困难。非凸问题的优化通常比凸问题更具挑战性,因为非凸函数的性质更为复杂。

MLR(多元线性回归)是凸函数,因为其损失函数是平方损失,平方损失函数是凸函数。

CNN(卷积神经网络)通常是非凸函数,因为它们的损失函数通常是非凸的,例如交叉熵损失函数。交叉熵损失函数在一般情况下是非凸的,因此导致了 CNN 通常是非凸函数。

需要注意的是,在特定情况下,CNN 的一些部分可能是凸函数,但整个网络通常是非凸的。

交叉熵损失:对于 Logit模型来说,交叉熵损失函数是凸的,因为这个函数是对数和 exp 函数的一个凸组合,这个函数是凸的。这种凸性允许使用基于梯度的方法进行有效的优化。然而,在神经网络环境中,由于交叉熵损失函数由于多层和激活函数引入的非线性而不能保证是凸的。神经网络中权值和偏差的复杂相互作用导致损失函数通常是非凸的,使得优化更具挑战性。因此,神经网络训练通常涉及使用基于梯度的优化算法,可以处理非凸函数,如随机梯度下降及其变体。

梯度下降的学习率是指在每次迭代中,沿着梯度方向更新参数时所乘以的步长大小。学习率的选择会影响算法的收敛速度和最终的收敛效果。过大的学习率可能导致震荡甚至发散,而过小的学习率则会导致收敛速度过慢。因此,选择合适的学习率对于梯度下降算法的有效性至关重要。

梯度下降可能会遇到以下问题:

  1. 局部最优解:可能陷入局部最优解而无法找到全局最优解。
  2. 学习率选择困难:选择不合适的学习率可能导致收敛缓慢或发散。
  3. 高维度问题:在高维空间中,梯度下降的收敛可能变得困难,需要更复杂的优化算法来解决。
  4. 鞍点问题:在鞍点处梯度接近零,使得优化过程变得困难。

针对这些问题,有许多改进的优化算法,如随机梯度下降、动量法、自适应学习率方法等。

在深度学习中,梯度下降不当可能导致以下问题:

  1. 收敛缓慢或不稳定:选择不合适的学习率或优化算法可能导致训练过程收敛缓慢或不稳定。
  2. 陷入局部最优解:梯度下降可能使模型陷入局部最优解,而无法达到全局最优解。
  3. 梯度消失或爆炸:在深层网络中,梯度下降不当可能导致梯度消失或梯度爆炸,使得网络难以训练。
  4. 过拟合:梯度下降过程中的参数更新可能导致模型过度拟合训练数据,导致泛化性能下降。

为了解决这些问题,深度学习中常用的优化算法包括随机梯度下降(SGD)、动量法、Adam优化器等,它们通常能够更稳定地训练深度神经网络

过拟合是指模型在训练数据上表现很好,但在测试数据上表现不佳的情况。过拟合通常是由以下原因导致的:

  1. 模型复杂度过高:模型过于复杂,能够很好地拟合训练数据的细节和噪声,但泛化能力较差。
  2. 训练数据不足:训练数据量太少,模型无法从有限的数据中学习到数据的真实分布,容易记住训练集的特定样本而无法泛化到新数据。
  3. 特征选择不当:选择的特征过多或过少,或者特征工程不合适,都可能导致模型过拟合。
  4. 噪声干扰:数据中的噪声干扰使得模型学习到了数据的随机变化而非真实的模式。

为了缓解过拟合问题,可以采取的方法包括增加训练数据、正则化、特征选择、交叉验证等。

收敛性分析是指对迭代算法在解决特定问题时的收敛性质进行分析和研究的过程。在数值优化、机器学习等领域中,收敛性分析通常用于评估迭代算法在何种条件下能够收敛到期望的解,或者收敛到问题的某种性质(如局部最优解)。

收敛性分析通常包括以下内容:

  1. 收敛准则:确定算法收敛的准则,例如迭代序列是否收敛到某个极限、误差是否趋于零等。
  2. 收敛速度:分析算法的收敛速度,即迭代序列收敛到目标解的速度有多快。
  3. 收敛性质:研究算法收敛到的解的性质,如局部最优解、全局最优解等。

通过收敛性分析,可以评估算法在实际问题中的表现,并为选择合适的算法、调整参数提供指导。

在CNN中,梯度下降是通过反向传播算法来计算的。

  1. 前向传播:首先进行前向传播,通过输入数据,计算损失函数,并沿着网络逐层计算每一层的输出。
  2. 反向传播:然后进行反向传播,利用链式法则计算损失函数对每一层权重参数的梯度。
  3. 梯度下降更新:最后使用计算得到的梯度,按照梯度下降的更新规则,对每一层的权重参数进行更新。w = w - learning_rate * gradient

这样,通过反向传播算法,CNN可以高效地计算梯度并更新网络参数,从而实现模型的训练和优化。

知乎答主LEON分享了他的并行式FL工作,可以大大增加FL的运算速度 知乎主页如下:https://www.zhihu.com/people/leon-6-17-37 代码如下:https://github.com/LEON-gittech/PFLlib.git

FedALA -AAAI2023

paper-张剑清 上交

code

blog2

公式

本文:添加了一个微调的ALA模块进行元素级别的聚合(ALA只修改局部模型初始化的方法)->模型能够适应客户端的目标

全局模型较低层包含了很多的信息,因此在对局部模型进行初始化的时候设置一个超参数p,在较高层使用上述的ALA方法进行初始化,而在下面的层直接将全局模型的参数复制过来。

W矩阵式每个客户端自己的一个需要学习的超参数,当在第二轮W收敛之后的学习过程中W几乎保持不变,因此FedALA在之后训练过程中”复用“它。p也是超参数,但是没有对p的学习方法进行说明。

超参数分析:

s:客户端用来ala初始化的数据占比。实验表明s更大能够取得更好的测试准确率,然而s过大会增加计算开销,设置s为80

p:实验表明减少p,ala中需要学习的参数减少了,并且ala的准确率下降并不明显,设为1

结果分析:效果好

计算开销:用时与fedavg类似,但fedala只用了额外的0.34min实现了极大的提升

通信开销:相似

ALA 模块通过以下方式自适应地聚合全局模型和本地模型,以应对不同的数据分布和模型结构:

  1. 逐元素聚合:ALA 模块逐元素地聚合全局模型和本地模型,以适应每个客户端的局部目标。这样可以更好地捕捉全局模型中有利于改进局部模型的信息。
  2. 局部初始化:在每次迭代中,ALA 模块在训练之前用自适应地聚合的全局模型和本地模型来初始化局部模型。这有助于提高局部模型的质量,从而提高全局模型的性能。
  3. 分层聚合:ALA 模块允许用户通过设置超参数 p 来控制聚合范围,将 ALA 应用于模型的高层,而在较低层次上保留全局模型的信息。这样可以在降低计算开销的同时,仍然捕捉到有用的一般信息。
  4. 适用性:由于 ALA 模块仅修改 FL 中的局部初始化过程,因此它可以应用于大多数现有的 FL 方法,以提高它们的性能,而无需修改其他学习过程。这使得 ALA 可以广泛应用于不同的数据分布和模型结构。

但由于客户机之间数据不可见,数据的统计异质性(数据非独立同分布(non-IID)和数据量不平衡现象)便成了FL 的巨大挑战之一。数据的统计异质性使得传统联邦学习方法(如FedAvg等)很难通过FL过程训练得到适用于每个客户机的单一全局模型。

与寻求高质量全局模型的传统FL不同,pFL方法的目标是借助联邦学习的协同计算能力为每个客户机训练适用于自身的个性化模型。现有的在服务器上聚合模型的pFL研究可以分为以下三类:

  • (1)学习单个全局模型并对其进行微调的方法,包括Per-FedAvg和FedRep;

    per-fedavg(联邦元学习):客户端根据全局模型初始化本身,然后在本地选择一小批数据计算损失的梯度,然后得到本地的元函数,再选取一小批的数据对元函数进行求导,最后本地模型更新以及上传模型

  • (2)学习额外个性化模型的方法,包括pFedMe和Ditto;

    pfedme:额外的个性化模型即 moreau envelopes

  • (3)通过个性化聚合(或本地聚合)学习本地模型的方法,包括FedAMP、FedPHP、FedFomo、APPLE和PartialFed。

    fedfomo:在上传本地的模型到服务器后,服务器会记录各个模型上传的参数,在本地客户端从服务器下载模型时,首先服务器会决定将哪些模型发送给哪些客户端(使用到关联矩阵),接着本地客户端根据下载的其他模型以及自己当前的模型在验证集上的损失来计算其他模型更新的权重,最后对本地模型聚合;使用差分隐私的方法对存在的风险归并

类别(1)和(2)中的pFL方法将全局模型中的所有信息用于本地初始化(指在每次迭代的局部训练之前初始化局部模型)。然而,在全局模型中,只有提高本地模型质量的信息(符合本地训练目标的客户机所需信息)才对客户机有益。全局模型的泛化能力较差是因为其中同时存在对于单一客户机来说需要和不需要的信息。

类别(3)中的pFL方法,通过个性化聚合捕获全局模型中每个客户机所需的信息。但是,类别(3)中的pFL方法依旧存在(a)没有考虑客户机本地训练目标(如FedAMP和FedPHP)、(b)计算代价和通讯代价较高(如FedFomo和APPLE)、(c)隐私泄露(如FedFomo和APPLE)和(d)个性化聚合与本地训练目标不匹配(如PartialFed)等问题。此外,由于这些方法对FL过程做了大量修改,它们使用的个性化聚合方法并不能被直接用于大多数现有FL方法。

为了从全局模型中精确地捕获客户机所需信息,且相比于FedAvg不增加每一轮迭代中的通讯代价,作者提出了一种用于联邦学习的自适应本地聚合方法(FedALA)。如图1所示,FedALA在每次本地训练之前,通过自适应本地聚合(ALA)模块将全局模型与本地模型进行聚合的方式,捕获全局模型中的所需信息。由于FedALA相比于FedAvg仅使用ALA修改了每一轮迭代中的本地模型初始化过程,而没有改动其他FL过程,因此ALA可被直接应用于大多数现有的其他FL方法,以提升它们的个性化表现。

图1:客户端i第t轮,从服务器下载全局模型,通过ALA模块将全局模型与旧的局部模型局部聚合,进行局部初始化,训练局部模型,最后将训练好的局部模型上传到服务器。

自适应本地聚合(ALA)过程

图2:ALA的学习过程。LA代表“本地聚合”。这里,我们考虑一个五层模型,setp = 3。颜色越浅,值越大。

相比于传统联邦学习中直接将下载的全局模型覆盖本地模型得到本地初始化模型的方式,FedALA通过为每个参数学习本地聚合权重,进行自适应本地聚合

作者通过逐元素权重剪枝方法实现正则化,并将W中的值限制在[0,1]中。

Hadamard乘积指的是对应位置上的元素相乘。在神经网络中,Hadamard乘积通常用于执行元素级别的乘法操作,例如两个具有相同维度的矩阵或向量进行元素级别的乘法运算。“element-wise aggregate” 是一个术语,指的是对每个元素进行独立的聚合操作。在神经网络中,这通常指的是对两个具有相同维度的张量或矩阵进行逐元素的聚合操作,例如逐元素相加、相乘等。

因为深度神经网络(DNN)的较低层网络相比于较高层倾向于学习相对更通用的信息,而通用信息是各个本地模型所需信息,所以全局模型中较低层网络中的大部分信息与本地模型中较低层网络所需信息一致。为了降低学习本地聚合权重所需的计算代价,作者引入一个超参数p来控制ALA的作用范围,使得全局模型中较低层网络参数直接覆盖本地模型中的较低层网络,而只在较高层启用ALA

其中,|Θi|表示Θi中的神经网络层数(或神经网络块数),[;]中的前者与Θi的低层网络形状一致,后者与Θi中剩下的p层高层网络形状一致。

Wip 中的值全部初始化为1,且在每一轮本地初始化过程中基于旧的Wip来更新Wip。为了进一步降低计算代价,采用随机采样s%,本地训练数据的方式,在数据集Dis,t上通过基于梯度的学习方法更新Wip。n是更新wip的学习率。在学习wip的过程中,将除wip之外的其他可训练参数冻结。

ˆΘt i:t轮clienti的模型参数

Θt−1:t-1轮的全局模型参数

分号通常表示条件,表示在给定条件下的损失函数–类似于条件分布函数

这种形式的损失函数通常用于需要考虑额外条件或先前参数状态的情况

在给定 Θt−1 的情况下,通过使用 ˆΘt i 和 Ds,t i 来计算损失函数 L 的值。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
# w 是待更新的权重,learning_rate 是学习率,
# gradient 是目标函数关于权重 w 的梯度。
w = w - learning_rate * gradient  
# ∇w L 表示损失函数 L 对模型参数 w 的梯度。
gradient = w L 

在卷积神经网络CNN梯度下降是针对网络中的权重参数进行的这些权重参数包括卷积层的卷积核参数池化层的池化核参数以及全连接层的权重参数和偏置项

# 前向传播
output = self.model(x) 
loss = self.loss(output, y)
# #梯度缓存清零,以确保每个训练批次的梯度都是从头开始计算的
self.optimizer.zero_grad()
# 对损失值1oss进行反向传播,计算模型参数的梯度
loss.backward()
# 梯度下降,更新模型的参数,以使损失函数达到最小值。
self.optimizer.step()

图3 在MNIST和Cifar10数据集上8号客户机的学习曲线,在每次迭代中训练至少六个 epoch 的权重。

一旦我们训练 W pi 在第二次迭代(初始阶段)中收敛,它在随后的迭代中几乎没有变化。换句话说,可以复用 W pi。我们只为 W pi 训练一个 epoch

通过选择较小的p值,在几乎不影响FedALA表现的情况下,大幅度地降低ALA中训练所需的参数。此外,如图3,一旦在第一次训练Wip 将其训练到收敛,即使在后续迭代中训练Wip,其对本地模型质量也没有很大影响。也就是说,每个客户机可以复用旧的Wip实现对其所需信息的捕获。作者采取在后续迭代中微调Wip的方式,降低计算代价。

ala分析???

把在ALA中更新W看成更新Θt i。将更新Wi视为更新Θt i意味着将权重矩阵Wi的更新视为对整体参数Θt中的子参数Θt i的更新。?

实验:

在实际(practical)数据异质环境下的Tiny-ImageNet数据集上用ResNet-18进行了对超参数s和p的对FedALA影响的研究。

对于s来说,采用越多的随机采样的本地训练数据用于ALA模块学习可以使个性化模型表现更好,但也同时增加了计算代价。在使用ALA的过程中,可以根据每个客户机的计算能力调整s的大小。从表中可以得知,即使使用极小的s(如s=5),FedALA依旧具有杰出的表现。

对于p来说,不同的p值对个性化模型的表现几乎没有影响,在计算代价方面却有着巨大的差别。这一现象也从一个侧面展示了FedRep等方法,将模型分割后保留靠近输出的神经网络层在客户机不上传的做法的有效性。使用ALA时,我们可以采用较小且合适的p值,在保证个性化模型表现能力的情况下,进一步降低计算代价。

作者在病态(pathological)数据异质环境和实际(practical)数据异质环境下,将FedALA与11个SOTA方法进行了对比和详细分析。如表2所示,数据显示FedALA在这些情况下的表现都超越了这11个SOTA方法,其中“TINY”表示在Tiny-ImageNet上使用4-layer CNN。例如,FedALA在TINY情况下比最优基线(baseline)(看表是FedRep)高了3.27%。

也在不同异质性环境和客户机总量情况下评估了FedALA的表现。如表3所示,FedALA在这些情况下依旧保持着优异的表现。

图4 病理异质环境下局部学习轨迹(从迭代140到200)和局部损失面的二维可视化

可视化在MNIST上可视化了ALA模块的加入对原本FL过程中模型训练的影响,如图4所示。不激活ALA时,模型训练轨迹与使用FedAvg一致。一旦ALA被激活,模型便可以通过全局模型中捕获的其训练所需信息径直朝着最优目标优化。

更新方向校正。局部学习轨迹(从迭代 140 到 200)和病理异质环境中 MNIST 局部损失面的 2D 可视化。绿色方圆点和红色圆圈分别表示每次迭代开始和结束时的局部模型。带有箭头的黑色和蓝色轨迹分别代表 FedAvg 和 FedALA。使用 PCA 将局部模型投影到 2D 平面。C1 和 C2 是 PCA 生成的两个主成分。

实验设置:

MNIST Cifar10/100 Tiny-ImageNet 4 层 CNN

Tiny-ImageNet 上使用 ResNet-18->tiny*

局部学习率设置为 0.005

批量大小设置为 10,将局部模型训练 epoch 数设置为 1。我们运行了 2000 次迭代的所有任务,以使所有方法都在经验上收敛。在 pFedMe 之后,我们有 20 个客户端,默认情况下设置 ρ = 1,ρ: client joining ratio。

病理异构设置:

每个客户的10/10/100个类别中抽取了MNIST/Cifar10/Cifar100的2/2/10个类别,数据样本不相交

实际的异质环境:

它由Dirichlet分布控制,表示为Dir(β)。β越小,设置就越异构。我们为默认的异构设置设置β=0.1

我们使用与pFedMe相同的评估指标,它报告了传统FL的最佳单个全局模型的测试精度和pFL的最佳局部模型的平均测试精度。为了模拟实际的pFL设置,我们在客户端评估所学习的模型。25%的局部数据形成测试数据集,其余75%的数据用于训练。我们运行所有任务五次,并报告平均值和标准偏差。

FedALA设置为s=80。

通过减小超参数p,我们可以缩小ALA的范围,而精度下降可以忽略不计,如表1所示。当p从6减少到1时,ALA中可训练参数的数量也会减少,特别是从p=2减少到p=1,因为ResNet-18中的最后一个块包含了大部分参数(He等人,2016)。尽管FedALA在这里p=2时表现最好,但我们为ResNet-18设置p=1以减少计算开销。这也表明,全局模型的较低层大多包含客户端所需的通用信息

类别(1)中的pFL方法。个性化方法表现得更好。Per-FedAvg的准确性在这些方法中是最低的,因为它只找到与所有客户的学习趋势相对应的初始共享模型,这可能无法满足单个客户的需求。在FedAvg-C/FedProx-C中微调全局模型会生成特定于客户端的本地模型,从而提高FedAvg/FedProx的准确性。然而,在像FedALA这样的本地训练中,微调只关注本地数据,而不能意识到通用信息。尽管FedRep在每次迭代时也会对头部进行微调,但它在微调时会冻结下载的表示部分,并将大部分通用信息保留在全局模型中,因此表现出色。然而,在客户端之间不共享头部的情况下,头部的通用信息丢失。

下载的表示部分通常指的是在联邦学习(或分布式学习)中,从全局模型中发送到客户端的模型参数或表示权重。这些表示部分可以包括卷积神经网络(CNN)的卷积层权重、循环神经网络(RNN)的循环权重等。即使在每次迭代时对头部进行微调,但在微调过程中冻结了下载的表示部分,这意味着在客户端进行微调时,下载的表示部分权重不会被修改。这种做法旨在保留大部分通用信息在全局模型中,从而提高模型的整体性能。然而,如果客户端之间不共享头部,这种方法可能会导致头部的通用信息丢失。具体来说,在神经网络中,头部通常指的是网络结构的顶部,包括用于特定任务的层或模块,例如分类器。这些层或模块负责将底层表示转换为最终的任务特定输出,比如对图像进行分类或对文本进行情感分析。头部的通用信息表示则指的是这些顶部层中学到的特征或模式,这些特征对于多个任务都是有用的。在联邦学习中,通过冻结表示部分并保留通用信息,可以确保在客户端的微调过程中,保留了全局模型中对多个任务都有用的通用特征表示。

尽管 pFedMe 和 Dititto 都使用近端项来学习它们额外的个性化模型,但 pFedMe 从局部模型中学习所需的信息,而 Dititto 从全局模型中学习它。因此,Ditto 在本地学习更通用的信息,它表现更好。然而,使用近端项学习个性化模型是提取所需信息的隐式方法。

使用基于规则的方法聚合模型是无目标的,无法捕捉全局模型中所需的信息,因此 FedPHP 的性能比 FedRep 和 Ditto 差。FedAMP、FedFomo 和 APPLE 中的模型级个性化聚合以及PartialFed 中的层级和二进制选择不精确,这可能会在全局模型中引入不希望的信息到局部模型。此外,在每次迭代中为每个客户端下载多个模型对于 FedFomo 和 APPLE 提供了额外的通信成本。

在实际异构设置中,由于每个客户端的数据分布复杂,很难衡量客户端之间的相似性。因此,FedAMP 不能通过注意力引导函数精确地为局部模型分配重要性,以生成具有个性化聚合的聚合模型。在下载全局模型/表示后,Ditto 和 FedRep 可以从中捕获通用信息,而不是测量局部模型之间的相似性。这样,在大多数任务中,它们都取得了优异的性能。可训练权重比近似权重信息量更大,因此 APPLE 的性能优于 FedFomo。

尽管 FedPHP 在 TINY 上表现良好,但标准偏差相对较高。由于 FedALA 可以通过 ALA 适应不断变化的环境,因此在实际设置中仍然优于所有基线。如果我们在不适应的情况下重用初始阶段学习到的聚合权重,则 TINY 的准确度下降到 33.81%。由于 ALA 的细粒度特征,它仍然优于 Per-FedAvg、pFedMe、Ditto、FedAMP 和 FedFomo。

code

以 Dir(0.1)为例,在默认的异构设置中上传 mnist 数据集–practical,论文中无对应

1
2
3
4
5
6
7
{"num_clients":20,"num_classes":10,"non_iid":true,"balance":false,"partition":"dir","Size of samples for labels in clients":[[[0,140],[1,890],[4,1],[5,319],[7,29],[8,1067],[9,184]],[[0,5],[2,27],[5,19],[6,335],[8,6],[9,107]],[[0,3],[3,143],[6,1461],[9,23]],[[0,155],[4,1],[7,2381],[8,4]],[[0,71],[1,13],[3,207],[5,1129],[6,6],[8,40],[9,451]],[[1,38],[3,1],[4,39],[8,25],[9,6086]],[[1,873],[2,176],[3,46],[6,42],[8,13],[9,106]],[[1,21],[2,5],[3,11],[5,787],[7,4],[8,441]],[[0,1],[1,3599]],[[0,633],[1,1997],[2,89],[4,519],[6,768]],[[0,920],[1,2],[2,1450],[3,513],[4,134],[5,97]],[[2,159],[3,3055],[5,558]],[[0,8],[1,180],[2,3277],[5,148]],[[1,237],[2,343],[4,6],[5,453],[7,1095]],[[5,2719],[7,3011]],[[0,31],[3,1785],[5,16],[6,4],[7,756],[8,2856]],[[0,3628]],[[1,26],[2,1463],[3,1379],[4,335],[5,60],[7,17],[8,2373]],[[0,998],[5,8],[6,4260]],[[0,310],[1,1],[2,1],[3,1],[4,5789],[9,1]]],"alpha":0.1,"batch_size":10}

client 0 -- 0 1 4 5 7 8 9

client 1 -- 0 2 5 6 8 9

client 2 -- 0 3 6 9

ALA的使用

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
# 将 ALA 模块导入为 Python 模块
import ALA

# 输入所需的参数以初始化模块。
class Client(object):
    def __init__(self, ...):
        # other code
        self.ALA = ALA(self.id, self.loss, self.train_data, self.batch_size, 
                    self.rand_percent, self.layer_idx, self.eta, self.device)
        # other code
        
# 将恢复的全局模型和旧的本地模型馈送到本地初始化。 
self.ALA.adaptive_local_aggregation()

# 在启动阶段,可能需要为ALA模块设置一个合适的threshold (我们在本文中默认设置为0.01)来控制其收敛水平
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
nohup python -u main.py -t 1 -jr 1 -nc 20 -nb 10 -data mnist-0.1-npz -m cnn -algo FedALA -et 1 -p 2 -s 80 -did 0 > result-mnist-0.1-npz.out 2>&1 

# -t:Running times 1
# -jr:Ratio of clients per round 1
# -nc:Total number of clients 20
# -nb:num_classes 10
# -data:dataset 数据集文件夹名 mnist->mnist-0.1-npz
# -m:model cnn
# -algo:algorithm FedALA
# -et: eta ALA weight学习率 1.0
# -p:layer_idx 2
# -s:rand_percent 80
# -did:device_id 0

# test
# -gr:global_rounds 1000->20
python main.py -data mnist-0.1-npz -gr 20
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
(Pdb) p model_str
'cnn'
(Pdb) p args.model
FedAvgCNN(
  (conv1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): ReLU(inplace=True)
  )
  (fc): Linear(in_features=512, out_features=10, bias=True)
)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# 实例化clientALA
self.set_clients(args, clientALA)

# ALA初始化
self.ALA = ALA(self.id, self.loss, train_data, self.batch_size, 
          self.rand_percent, self.layer_idx, self.eta, self.device)
(Pdb) p self.cid
0
(Pdb) p self.loss
CrossEntropyLoss()
(Pdb) p len(self.train_data)
1972
(Pdb) p self.rand_percent# s
80
#  Control the weight range. By default, all the layers are selected. Default: 0
(Pdb) p self.layer_idx # 选中的层数 后两层
2
(Pdb) p self.eta# weight 学习率
1.0
# 训练权重直到记录的损失的标准差小于给定的 阈值
(Pdb) p self.threshold
0.1
# 在计算标准差时要考虑的记录损失的数量。默认值:10
(Pdb) p self.num_pre_loss
10
self.weights = None # 可学习的局部聚合权值
self.start_phase = True 
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# 下发模型
self.send_models()
# 客户端初始化
client.local_initialization(self.global_model)

# ALA
# self.model = copy.deepcopy(args.model)
# self.global_model = copy.deepcopy(args.model)
self.ALA.adaptive_local_aggregation(received_global_model, self.model)
# 随机抽取局部训练数据 s=80
(Pdb) p rand_ratio
0.8
(Pdb) p rand_num# 0.8*1972
1577
# randint(0,395) 作为数据的数据切片的下标-(50,1627)
(Pdb) p rand_idx
50
(Pdb) p len(rand_loader)
157
Round number: 0
# 在第一次通信迭代时停用ALA 全局模型参数和客户端模型参数一致
# 直接return 
-------------------------------------
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
Round number: 1
-------------------------------------
self.send_models()
client.local_initialization(self.global_model)
self.ALA.adaptive_local_aggregation(received_global_model, self.model)

由于每层神经网络参数包括权重和偏置所以每一层实际上有两个参数列表因此即使有四层神经网络每层都包括权重和偏置所以总共会有 4  * 2 = 8 个参数列表

(Pdb) p len(params) # 网络参数
8
# 保留较低层的所有更新 global model的前3层 赋给client 0 
params[:-self.layer_idx]# [:-2]->[0,5]
param.data = param_g.data.clone()

# 临时副本,不影响原始本地模型的情况下进行一些权重学习的操作
model_t = copy.deepcopy(local_model)
params_t = list(model_t.parameters())

# model高层
params_p = params[-self.layer_idx:]
params_gp = params_g[-self.layer_idx:]
params_tp = params_t[-self.layer_idx:]

# 冻结model_t低层,减少在 PyTorch 中进行计算时的计算成本。
for param in params_t[:-self.layer_idx]:
     param.requires_grad = False# 参数设为不需要梯度计算
        
# SGD 优化器,用于获取高层参数的梯度,但不会执行参数更新步骤。
# 因此,学习率被设为 0,表示不会对参数进行更新。
# no need to use optimizer.step(),
optimizer = torch.optim.SGD(params_tp, lr=0)# params_tp要优化的参数列表

(Pdb) p self.weights
None
# 初始化权重,全为1 params_p的格式,仅高层参数
self.weights = [torch.ones_like(param.data).to(self.device) for param in params_p]
(Pdb) p self.weights
[tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]], device='cuda:0'), tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')]
(Pdb) p len(self.weights)# 高层参数
2
(Pdb) p self.weights[0].shape
torch.Size([10, 512])
(Pdb) p self.weights[1].shape
torch.Size([10])

# 在临时本地模型中初始化高层参数。计算临时本地模型的高层参数值。
# 按权重计算高层参数值 ALA的实现
param_t.data = param + (param_g - param) * weight

# weight学习
while True: # 在当前选取的数据集上一直训练,直至满足收敛条件跳出循环
    for x, y in rand_loader:

optimizer.zero_grad()# 梯度归零
output = model_t(x)
# 根据当地目标进行修改
loss_value = self.loss(output, y) 
loss_value.backward()# 反向传播,得到梯度param_t.grad
(Pdb) p loss_value
tensor(2.1265, device='cuda:0', grad_fn=<NllLossBackward0>)

# update weight in this batch 
# 参数weight更新,类似于self.optimizer.step()
# torch.clamp` 函数将更新后的值限制在 [0, 1] 的范围内
weight.data = torch.clamp(
         weight - self.eta * (param_t.grad * (param_g - param)), 0, 1)
(Pdb) p weight.data
tensor([[0.9998, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 0.9999,  ..., 1.0000, 0.9999, 0.9979],
        [0.9997, 1.0000, 1.0000,  ..., 0.9992, 0.9998, 0.9994],
        ...,
        [0.9998, 1.0000, 0.9999,  ..., 1.0000, 1.0000, 0.9999],
        [0.9959, 1.0000, 0.9997,  ..., 0.9929, 0.9985, 0.9982],
        [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],
       device='cuda:0')

# update temp local model in this batch 更新model_t的参数
param_t.data = param + (param_g - param) * weight

# 所有选中的数据训练完后,得到的loss
# 模型训练按batch进行,loss值会覆盖,表示最后一个batch的loss值
# 由于每次计算损失之后会更新模型参数,所以仅记录最后一个batch是有意义的
losses.append(loss_value.item())
(Pdb) p losses
[1.810765266418457]
(Pdb) p cnt # weight training iteration counter
1

# 在随后的迭代中只训练一个epoch
self.start_phase = True 
if not self.start_phase: #false跳出
     break
        
# train the weight until convergence
# 如果损失值列表的长度大于预先指定的损失值数量,
# 并且最近几个损失值的标准差小于预先设定的阈值,
# 则打印相关信息并跳出训练循环,以表示权重已经收敛。
if len(losses) > self.num_pre_loss and 
	np.std(losses[-self.num_pre_loss:]) < self.threshold:
    print('Client:', self.cid, '\tStd:', 
    	np.std(losses[-self.num_pre_loss:]),'\tALA epochs:', cnt)
 	break
    
# 跳出训练
self.start_phase = False

# obtain initialized local model 复制高层
param.data = param_t.data.clone()

Client: 0       Std: 0.04929309959898841        ALA epochs: 11
Client: 1       Std: 0.09025316159724787        ALA epochs: 12
Client: 2       Std: 0.062314922642027974       ALA epochs: 11

Round number: 2
-------------------------------------
(Pdb) p self.start_phase
False
(Pdb) p self.weights
[tensor([[0.8682, 1.0000, 0.9655,  ..., 1.0000, 0.9862, 1.0000],
        [1.0000, 1.0000, 0.8617,  ..., 0.9901, 0.7574, 0.0000],
        [0.7473, 0.9997, 0.9452,  ..., 0.3092, 0.7954, 0.4487],
        ...,
        [0.7925, 1.0000, 0.9288,  ..., 0.9782, 1.0000, 0.8773],
        [0.0000, 0.9990, 0.6782,  ..., 0.0000, 0.0000, 0.0000],
        [0.9319, 1.0000, 0.9962,  ..., 0.5651, 0.9631, 0.7355]],
       device='cuda:0'), tensor([0.8297, 0.0000, 0.0000, 0.0000, 0.0000, 0.7981, 0.0000, 0.0387, 0.0000,
        0.7162], device='cuda:0')]

继续学习权重

# 当前 权重训练只训练一个batch
if not self.start_phase:
	break# 跳出训练
    
从第一次聚合之后以后轮次只训练1个batch
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# 评估 ---全局模型发送到client之后评估
# 当前轮次先评估再训练,训练后下一轮的评估是这一轮的结果
self.evaluate()
# client训练
client.train()

#梯度缓存清零,以确保每个训练批次的梯度都是从头开始计算的
self.optimizer.zero_grad() 
output = self.model(x) # 前向传播
loss = self.loss(output, y)
# 对损失值1oss进行反向传播,计算模型参数的梯度,loss函数对于参数的梯度
loss.backward()
# 梯度下降,更新模型的参数,以使损失函数达到最小值
# 根据梯度下降算法中的优化器规则,通过调整参数的数值来最小化损失函数,以达到优化模型的目的
self.optimizer.step()

# 接收模型
self.receive_models()
(Pdb) p active_train_samples
52492
(Pdb) p client.id
7
(Pdb) p client.train_samples
951
(Pdb) p self.uploaded_weights# train_samples/active_train_samples
[0.018117046407071555]
(Pdb) p self.uploaded_weights
[0.018117046407071555, 0.05722776804084432, 0.08186009296654728, 0.07784043282785948, 0.027375600091442506, 0.01794559170921283, 0.0884134725291473, 0.030480835174883793, 0.05160786405547512, 0.023279737864817497, 0.051836470319286745, 0.03629124438009602, 0.03756762935304427, 0.08719423912215195, 0.04452106987731464, 0.0514364093576164, 0.007124895222129086, 0.08075516269145774, 0.0538939266935914, 0.07523051131601007]
(Pdb) p self.uploaded_ids
[7, 9, 14, 15, 4, 6, 5, 13, 12, 2, 16, 3, 0, 19, 10, 8, 1, 17, 11, 18]
1
2
3
# 聚合模型
self.aggregate_parameters()
server_param.data += client_param.data.clone() * w
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
-------------Round number: 0-------------
Averaged Train Loss: 2.3127
Averaged Test Accurancy: 0.0587
Averaged Test AUC: 0.4315
Std Test Accurancy: 0.0710
Std Test AUC: 0.1843
-------------Round number: 1-------------
Averaged Train Loss: 0.9177
Averaged Test Accurancy: 0.8254
Averaged Test AUC: 0.9823
Std Test Accurancy: 0.1348
Std Test AUC: 0.0219

病态(pathological)数据异质环境pat

每个客户机上的数据只包含特定数量的标签

注意实验参数的设置

1
2
3
4
5
6
7
8
python -u generate_mnist.py noniid - pat > mnist_dataset.out 2>&1

{"num_clients":20,"num_classes":10,"non_iid":true,"balance":false,"partition":"pat","Size of samples for labels in clients":[[[0,1233],[1,1101]],[[0,407],[1,911]],[[0,1268],[1,1865]],[[0,3995],[1,4000]],[[2,1021],[3,307]],[[2,1134],[3,927]],[[2,318],[3,686]],[[2,4517],[3,5221]],[[4,1584],[5,1457]],[[4,1475],[5,1435]],[[4,1372],[5,514]],[[4,2393],[5,2907]],[[6,1085],[7,434]],[[6,639],[7,1696]],[[6,1078],[7,850]],[[6,4074],[7,4313]],[[8,568],[9,1412]],[[8,732],[9,1174]],[[8,750],[9,926]],[[8,4775],[9,3446]]],"alpha":0.1,"batch_size":10}

client 0 :0 1
client 1 :0 1 
...
clinet 4 :2 3
1
2
3
python -u main.py -t 1 -jr 1 -nc 20 -nb 10 -data mnist -m cnn -algo FedALA -et 1 -p 2 -s 80 -did 0 -gr 2000 > result-mnist-pat-npz.out 2>&1 

0.9989147198263552
1
2
3
python -u generate_tiny_imagenet.py noniid - dir > tiny_dataset.out 2>&1

python -u main.py -t 1 -jr 1 -nc 20 -nb 200 -data Tiny-imagenet -m cnn -algo FedALA -et 1 -p 2 -s 80 -did 3 -gr 2000 > result-tiny-dir-npz.out 2>&1

FedCP -KDD2023

通过条件策略分离个性化联邦学习的特征信息、

paper

code

blog

大多数现有的 pFL 方法将全局模型视为存储全局信息的容器,并使用全局模型中的参数丰富个性化模型。然而,它们只关注客户端级别的模型参数,即全局/个性化模型来利用全局/个性化信息。具体来说,基于元学习的方法(如Per-FedAvg[8])只微调全局模型参数以适应本地数据,而正则化方法(如pFedMe[42]、FedAMP[15]和Ditto[24])只在局部训练过程中正则化模型参数。尽管基于个性化头的方法(例如 FedPer [2]、FedRep [6] 和 FedRoD [4])明确地将主干拆分为全局部分(特征提取器)和个性化部分(头),但它们仍然专注于利用模型参数中的全局和局部信息,而不是信息来源:数据。

由于模型是在数据上训练的,因此模型参数中的全局/个性化信息是从客户端数据中导出的。换句话说,客户端的异构数据包含全局信息和个性化信息。如图1所示,广泛使用的颜色,如蓝色,以及很少使用的颜色(如紫色和粉色)分别包含图像中的全局信息和个性化信息。𝑾ℎ𝑑 : 冻结的全局头,𝑾ℎ𝑑 𝑖/𝑗 : 个性化的头部。

为了分别利用数据中的全局信息和个性化信息,我们提出了一种基于条件计算技术的联邦条件策略(FedCP)方法[11,35]。由于原始输入数据的维数远大于特征提取器提取的特征向量,因此为了提高效率,我们将重点放在特征向量上。由于全局信息和个性化信息在特征中的比例在样本和客户端之间不同,我们提出了一个辅助条件策略网络(CPN)来生成用于特征信息分离的样本特定策略。然后,我们分别通过全局头和个性化头在不同的路径上处理全局特征信息和个性化特征信息,如图1所示。我们将个性化信息存储在个性化头部中,并通过冻结全局头部来保留全局信息,而无需对其进行本地训练。通过端到端学习,CPN自动学习生成样本特定策略。

贡献:

  • 我们是第一个考虑FL中样本特定特征信息的个性化。它比大多数现有FL方法中使用客户端级模型参数更细粒度。
  • 我们提出了一种新的FedCP,它生成一个特定于样本的策略,以在每个客户端的特征中分离全局信息和个性化信息。它分别通过每个客户端上的冻结全局头和个性化头来处理这两种特征信息。
  • 此外,即使某些客户端意外退出,FedCP也能保持其卓越的性能

个性化联邦

  • 在基于元学习的方法中,Per-FedAvg[8]学习初始共享模型作为全局模型,满足每个客户端的学习趋势。
  • 在基于正则化的方法中,pFedMe [42] 使用 Moreau 包络为每个客户端在本地学习一个额外的个性化模型。除了只为所有客户端学习一个全局模型外,FedAMP[15]还通过注意诱导函数为一个客户端生成一个服务器模型,以找到相似的客户端。在Ditto[24]中,每个客户端使用近端项在本地学习其个性化模型,以从全局模型参数中获取全局信息。、
  • 在基于个性化头部的方法中,FedPer[2] 和 FedRep [6] 学习了一个全局特征提取器和一个特定于客户端的头。前者使用特征提取器在本地训练头部,而后者在每次迭代训练特征提取器之前局部微调头部直到收敛。为了弥合传统的 FL 和 pFL,FedRoD [4] 使用全局特征提取器和两个头显式学习两个预测任务。它使用平衡的 softmax (BSM) 损失 [39] 进行全局预测任务,并通过个性化头部处理个性化任务。
  • 在其他 pFL 方法中,FedFomo [56] 使用来自其他客户端的个性化模型计算每个客户端聚合的特定于客户端的权重。FedPHP[27]使用移动平均局部聚合全局模型和旧的个性化模型,以保持历史个性化信息。它还通过广泛使用的最大平均差异 (MMD) 损失 [10, 37] 传递全局特征提取器中的信息。
  • 上述 pFL 方法只专注于利用模型参数的全局和个性化信息,但不深入挖掘数据。

图2(a)CPN模块(红色圆角矩形):样本特征信息𝒉𝑖和客户机信息𝒗𝑖,能输出对应的策略向量(𝒓𝑖和𝒔𝑖)

菱形则表示特征信息分离操作,使用条件策略将信息𝒉𝑖通过红色菱形分离为 𝒓𝑖 ⊙ 𝒉𝑖 和𝒔𝑖 ⊙ 𝒉𝑖 。用策略向量即可提取得到全局特征信息𝒓𝑖 ⊙ 𝒉𝑖 和个性化特征信息𝒔𝑖 ⊙ 𝒉𝑖

然后交由全局头部𝑾ℎ𝑑 和个性化头部𝑾ℎ𝑑𝑖 分别处理。最后将输出合并(即加和),得到最终输出值

除了特征向量和向量𝒗𝑖 , 标准矩形和圆形矩形分别表示层和模块。

带虚线边框的圆角矩形 𝑾^ℎ𝑑𝑖 在等式(6)中

𝑾𝑓𝑒 (灰色边界)不是个性化模型的一部分,数据只在训练过程中向前流动。在训练过程中,数据在所有线中流动,但在推理过程中,数据只在实线中流动。

(b)分别显示了特征提取器、头部和CPN的上传和下载流.我们在实践中上传或下载它们作为服务器和每个客户端之间的联合

feature extractors是指用来提取输入数据特征的部分,通常是指卷积神经网络(CNN)中的卷积层和池化层,用来提取输入数据的特征表示。而 “the heads” 通常指的是在深度学习模型中负责执行最终任务(如分类、回归等)的部分,通常是指全连接层(也称为密集层)或输出层。 “the heads” 会接收从"feature extractors"提取的特征,并根据特定的任务进行最终的预测或输出。

条件计算

条件计算是一种根据任务相关的条件输入将动态特性引入模型的技术[11,30,35]。形式上,给定条件输入𝐶 (例如,图像/文本、模型参数矢量或其他辅助信息)和辅助模块𝐴𝑀 (·;𝜃 ), 一个信号𝑆 可以由𝑆 = 𝐴𝑀 (𝐶; 𝜃 ) 生成,并且用于干扰诸如动态路由和特征自适应之类的模型。


该论文提出了一种用于个性化联邦学习的全局和个性化特征信息分离方法FedCP,首次在数据层面实现了全局和个性化信息的分离,为分别处理这两类信息提供了可能。

第一行为样本图片,第二行为全局特征信息,第三行为个性化特征信息。

在实际场景下,由于各个客户机上数据的异质性,每一轮上传到服务器上的客户机模型参数之间具有较大差异,聚合得到的全局模型无法在单个客户机上具有良好的表现。于是研究者们提出个性化联邦学习方法,将学习全局模型的目标,转变为通过全局模型辅助个性化本地模型训练。

在数据异质的情况下进行协同训练,既要考虑个性化(用于应对异质性)又要考虑全局性(用于协同训练)。如何把握全局和个性化这两者之间的关系,是设计个性化联邦学习方法的关键。

大多数现有的个性化联邦学习方法,仅仅从模型参数层面对全局信息和个性化信息进行分离和分别利用(见图4,Per-FedAvg,Ditto,FedRep),却忽略了模型参数中的信息是从数据中学到的这个事实。

虽然每个客户机上的数据是异质的,但异质数据也是在同一个世界中产生的,或多或少都具有一部分全局信息和另一部分个性化信息.

不同客户机上的异质数据示意图,其中蓝色代表全局信息,紫色和粉色代表个性化信息

从数据中分离和利用全局和个性化信息

由于输入空间的数据维度较高,我们便考虑对转换后的特征向量进行处理。如图6所示,在特征空间中,我们通过设计一个辅助的条件策略网络(CPN)来实现特征信息的分离;随后我们利用全局和个性化头部来分别处理分离出来的两类信息。

圆角矩形上的斜线代表“冻结”(即不参与训练、不作更新)菱形则表示特征信息分离操作。半透明的模块表示只在训练过程中参与。实现特征分离的关键在于CPN模块。

个性化特征分离主要通过以下方法和步骤实现:

  1. 使用联邦条件策略(Federated Conditional Policy,FedCP)方法:FedCP生成一个样本特定的策略,将每个客户端的特征分为全局特征信息和个性化特征信息。这两种特征信息分别由全局头和个性化头处理。
  2. 分离特征信息:在FedCP中,通过将策略{α, β}与特征向量x相乘,可以分别获得全局特征信息(x⊙α)和个性化特征信息(x⊙β)。由于特征之间存在连接,所以输出{α, β},而不是布尔值,即α∈(0, 1)且β∈(0, 1)。
  3. 生成样本特定策略:FedCP使用辅助的条件策略网络(Conditional Policy Network, CPN)生成样本特定的策略,以实现特征信息的分离。通过端到端学习,CPN将自动学习生成样本特定策略。
  4. 保存个性化信息:FedCP通过在每个客户端冻结全局头不进行本地训练,以保留全局信息。同时,它通过个性化头存储个性化信息。
  5. 实验验证:FedCP在各种数据集上的实验表现优于其他个性化联邦学习方法,证明了其有效性。

总之,通过FedCP方法和CPN模块,个性化特征分离通过为每个样本生成特定策略,实现了全局特征信息和个性化特征信息的有效分离。


论文

https://mp.weixin.qq.com/s/-gBQbo5rUD_h9Mv3hHRzeA iccv

https://mp.weixin.qq.com/s/1XWGZZa5WIsgsuL7BW1Wyg

https://mp.weixin.qq.com/s/N9h16GPTNg08VKbOYAgNuw kdd

医学个性化联邦

Model-Heterogeneous Semi-Supervised Federated Learning for Medical Image Segmentation

HPFL: hyper-network guided personalized federated learning for multi-center

Specificity-Aware Federated Graph Learning for Brain Disorder Analysis with Functional MRI

A Federated Deep Learning Method for Chronic Disease Diagnosis

FedCE:基于客户端贡献估计的公平联邦医学图像分割

FedRH: Federated Learning Based Remote Healthcare

Fine-Tuning Network in Federated Learning for Personalized Skin Diagnosis

GRACE: A Generalized and Personalized Federated Learning Method for Medical Imaging

AP2FL: Auditable Privacy-Preserving Federated Learning Framework for Electronics in Healthcare

Personalized federated learning for the detection of COVID-19

Adaptive channel-modulated personalized federated learning for magnetic resonance image reconstruction

Federated hospital: a multilevel federated learning architecture for dealing with heterogeneous data distribution in the context of smart hospitals services

Feddp: Dual personalization in federated medical image segmentation

FedSoup: Improving Generalization and Personalization in Federated Learning via Selective Model Interpolation

FedDK: Improving Cyclic Knowledge Distillation for Personalized Healthcare Federated Learning

Medical Federated Model with Mixture of Personalized and Sharing Components

Performance Analysis of Personalized Federated Learning Algorithms for Image Classification

FedFTN: Personalized federated learning with deep feature transformation network for multi-institutional low-count PET denoising.

Personalized Federated Learning for Medical Segmentation using Hypernetworks. iclr

GRACE: A Generalized and Personalized Federated Learning Method for Medical Imaging.

Personalized Retrogress-Resilient Federated Learning Toward Imbalanced Medical Data.

Personalized Retrogress-Resilient Framework for Real-World Medical Federated Learning.

Personalized Retrogress-Resilient Framework for Real-World Medical Federated Learning**.**

个性化联邦


Data-Free Federated Learning blog

以下是一些近期在医疗数据分析领域的联邦学习相关论文推荐:

  1. FedCP: Separating Feature Information for Personalized Federated Learning via Conditional Policy 该论文提出了一种名为FedCP的方法,通过条件策略分离特征信息,实现个性化联邦学习。FedCP在多个数据集和场景下的实验表现优于现有的11种先进方法。
  2. Practical Challenges in Differentially-Private Federated Survival Analysis of Medical Data 这篇论文探讨了在医疗数据分析中进行联邦生存分析的实际挑战,并提出了一种名为DPFed-post的方法,通过在联邦学习过程中加入后处理阶段,提高模型的收敛速度和性能。
  3. Open problems in medical federated learning 这篇论文讨论了医疗联邦学习中的一些开放性问题,包括隐私保护、数据异质性、模型压缩和鲁棒性等方面。

医学联邦

FedA3 I: Annotation Quality-Aware Aggregation for Federated Medical Image Segmentation Against Heterogeneous Annotation Noise

blog 无代码 客户端标注噪声 每个客户端的噪声估计是通过高斯混合模型完成的,然后以层状方式纳入模型聚合中,以增加高质量客户端的权重

FedSODA 无代码 用于组织病理学细胞核和组织分割

https://conferences.miccai.org/2022/papers/

https://conferences.miccai.org/2023/papers/

FedContrast-GPA: Heterogeneous Federated Optimization via Local Contrastive Learning and Global Process-aware Aggregation

无代码 基于中心内部和跨中心的局部原型特征的对比学习框架来增强本地模型更新过程中的特征表达一种简单但十分有效的进程感知模型融合算法,可以有效缓解系统异质导致的落后

Federated Condition Generalization on Low-dose CT Reconstruction via Cross-domain Learning跨域

Federated Uncertainty-Aware Aggregation for Fundus Diabetic Retinopathy Staging 我们开发了一种新的不确定性感知加权模块(UAW),可以根据每个客户端的不确定性评分分布动态调整模型聚合的权重。无代码

FedGrav: An Adaptive Federated Aggregation Algorithm for Multi-institutional Medical Image Segmentation通过计算局部模型之间的亲和度,探索局部模型之间的内在关联,从而提高聚合权值。模型聚合 无代码

FedIIC: Towards Robust Federated Learning for Class-Imbalanced Medical Image Classification code在特征学习中,设计了两个层次的对比学习,以便在FL中对不平衡数据提取更好的类特定特征。在分类器学习中,根据实时难度和类先验动态设置每类边际,帮助模型平等地学习类

FedSoup: Improving Generalization and Personalization in Federated Learning via Selective Model Interpolation code 我们提出了一种新的联邦模型汤方法(即模型参数的选择性插值)来优化局部和全局性能之间的权衡。具体来说,在联邦训练阶段,每个客户机通过监视局部模型和全局模型之间的内插模型的性能来维护自己的全局模型池。这使我们能够缓解过拟合并寻求平坦最小值,这可以显着提高模型的泛化性能

FeSViBS: Federated Split Learning of Vision Transformer with Block Sampling code联邦分离学习 ,并引入了一个块采样模块,该模块利用了服务器上VisionTransformer (ViT)提取的中间特征

Fine-Tuning Network in Federated Learning for Personalized Skin Diagnosis 无代码

GRACE: A Generalized and Personalized Federated Learning Method for Medical Imaging泛化和个性化联邦学习 codeGRACE在客户端的元学习框架下结合了特征对齐正则化,以纠正过度拟合的个性化梯度。同时,GRACE采用一致性增强的重加权聚合在服务器端校准上传的梯度,以实现更好的泛化。

One-shot Federated Learning on Medical Data using Knowledge Distillation with Image Synthesis and Client Model Adaptation code客户端模型自适应知识提取的医学数据一次联邦学习

Rethinking Semi-Supervised Federated Learning: How to co-train fully-labeled and fully-unlabeled client imaging data 无代码 半监督联邦学习

Scale Federated Learning for Label Set Mismatch in Medical Image Classification code 解决标签集不匹配 不同不确定程度的数据采用不同的训练策略,有效利用未标记或部分标记的数据,并在分类层采用分类自适应聚合,避免客户端缺少标签时的不准确聚合


fedsoup

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
python -u main.py -data tiny_camelyon17 -m resnet -algo FedSoup -gr 1000 -did 0 -eg 100 -go fedsoup_debug -nc 4 -hoid 0 -lr 1e-3 -wa_alpha 0.75 | tee ../tmp/tiny_camelyon17/fedsoup_debug_console.output 

# -eg: Rounds gap for evaluation
# -nc:client
# -hoid:for out-of-federation evaluation Hold-out out-of-federated evaluation set. 1e8 means no hold-out set
# -lr:
# -wa_alpha:FedSoup soup wa alpha Weight averaging ratio of personalized global model pool for FedSoup

python main.py -data tiny_camelyon17 -m resnet -algo FedSoup -gr 20 -did 0 -eg 1 -go fedsoup_debug -nc 4 -hoid 0 -lr 1e-3 -wa_alpha 0.75

python -u main.py -data tiny_camelyon17 -m resnet -algo FedSoup -gr 1000 -did 0 -eg 100 -go fedsoup_debug -nc 4 -hoid 0 -lr 1e-3 -wa_alpha 0.75 > result-tiny-came-4-npz.out 2>&1 
0.8973214285714286

python -u main.py -data tiny_camelyon17 -m resnet -algo FedSoup -gr 1000 -did 0 -eg 50 -go fedsoup_mg5_debug -nc 5 -lr 1e-3 -wa_alpha 0.75 --pruning --sparsity_ratio 0.5 --pruning_warmup_round 500 --masking_grad --dynamic_mask | tee ../tmp/tiny_camelyon17/fedsoup_mg5_debug_console.output

python -u main.py -data tiny_camelyon17 -m resnet -algo FedSoup -gr 1000 -did 2 -eg 50 -go fedsoup_mg5_debug -nc 5 -lr 1e-3 -wa_alpha 0.75 --pruning --sparsity_ratio 0.5 --pruning_warmup_round 500 --masking_grad --dynamic_mask > result-tiny-came-npz.out 2>&1 
0.9071428571428571

python -u main.py -data tiny_camelyon17 -m resnet -algo FedSoup -gr 1000 -did 0 -eg 50 -go fedsoup_mg5_debug_2 -nc 5 -lr 1e-3 -wa_alpha 0.75 --pruning --sparsity_ratio 0.5 --pruning_warmup_round 500 --masking_grad --dynamic_mask > result-tiny-came-2-npz.out 2>&1 
0.9042857142857142


python -u main.py -t 1 -jr 1 -nc 20 -nb 200 -data Tiny-imagenet -m cnn -algo FedALA -et 1 -p 2 -s 80 -did 3 -gr 2000 > result-tiny-dir-npz.out 2>&1 

python -u main.py -t 1 -jr 1 -nc 5 -nb 2 -data tiny_camelyon17 -m resnet -algo FedALA -et 1 -p 2 -s 80 -did 3 -gr 1000 > result-tiny-came-npz.out 2>&1 
0.8985714285714286
Averaged Test Accurancy: 0.8736
Averaged Test AUC: 0.9395

python -u main.py -t 1 -lbs 16 -lr 1e-3 -eg 50 -jr 1 -nc 5 -nb 2 -data tiny_camelyon17 -m resnet -algo FedALA -et 1 -p 2 -s 80 -did 3 -gr 1000 > result-tiny-came-16001npz.out 2>&1 
0.8864285714285715
Averaged Test Accurancy: 0.8607
Averaged Test AUC: 0.9283

python -u main.py -go tiny-cam-fedala-lbs10 -t 1 -lbs 10 -lr 1e-3 -eg 50 -jr 1 -nc 5 -nb 2 -data tiny_camelyon17 -m resnet -algo FedALA -et 1 -p 2 -s 80 -did 0 -gr 1000 > result-tiny-came-10001npz.out 2>&1
Averaged Test Accurancy: 0.8307
Averaged Test AUC: 0.9124

python -u main.py -go tiny-cam-fedala-lbs16 -t 1 -lbs 16  -eg 50 -jr 1 -nc 5 -nb 2 -data tiny_camelyon17 -m resnet -algo FedALA -et 1 -p 2 -s 80 -did 1 -gr 1000 > result-tiny-came-16npz.out 2>&1 
Averaged Test Accurancy: 0.8786
Averaged Test AUC: 0.9423

表1 :FedSoup 85.71(0.37) 92.47(0.31) 72.87(1.35) 81.45(1.40)

对应result-tiny-came-4-npz.out hold0

Local Performance: Local Client-Equally Test Accurancy: 0.8777 Local Client-Equally Test AUC: 0.9438 Global Performance: Glocal Client-Equally Test Accurancy: 0.7325 Glocal Client-Equally Test AUC: 0.8042

表2 tta微调

FedSoup 72.87(1.35) 81.45(1.40) 85.36(0.86) 88.86(1.07)

没有

表3 (DG) 的其他结果–OOF

71.97 79.63

OOF Client Test Accurancy: 0.5234 OOF Client Test AUC: 0.6048

表4 no hoid on

FedSoup 89.20(0.53) 95.08(0.61) 86.44(1.12) 93.46(0.74)

对应result-tiny-came-npz.out和result-tiny-came-2-npz.out

Local Performance: Local Client-Equally Test Accurancy: 0.8964 Local Client-Equally Test AUC: 0.9482 Global Performance: Glocal Client-Equally Test Accurancy: 0.8628 Glocal Client-Equally Test AUC: 0.9311

Performance Summarizing… Local Performance: Local Client-Equally Test Accurancy: 0.9050 Local Client-Equally Test AUC: 0.9613 Global Performance: Glocal Client-Equally Test Accurancy: 0.8681 Glocal Client-Equally Test AUC: 0.9387


研究实验超参数和结果

第一项任务涉及使用Camelyon17数据集[1]对来自五个不同来源的病理学图像进行分类,每个来源都被视为客户端。病理学实验共包括4600张图像张,每张图像的分辨率为96×96。我们从原始的 Camelyon17 数据集中取一个随机子集来匹配 FL[18] 中的小数据设置。

Tiny Camelyon 17(TUPAC-2)数据集是一种用于乳腺组织分类的医学影像数据集。该数据集是从肿瘤组织切片图像中提取的,用于研究和开发计算机辅助诊断(CAD)系统和深度学习模型。这个数据集主要用于乳腺癌研究领域,可用于训练和测试医学图像分析算法,以辅助医生诊断乳腺癌。data

第二项任务涉及来自四个不同机构的视网膜眼底图像[9,26,21],每个机构都被视为客户。视网膜眼底实验共包括1264张图像,每张图像的分辨率为128×128。这两个数据集的目标是从正常图像中识别异常图像。

Retinal Fundus-RFMiD

对于每个客户端,我们将75%的数据作为训练集。为了评估我们模型的泛化能力和个性化,我们构建了局部和全局测试集。

在[28]之后,在我们的实验环境中,我们首先通过每个来源/研究所随机采样相同数量的图像来创建一个保持的全局测试集,因此其分布与任何一个客户端都不同。

每个FL客户端的本地测试数据集是来自与其训练集相同来源的剩余样本。每个客户端的本地测试集数量与保留的全局测试集数量大致相同。对于病理学数据集,由于每个受试者可以有多个样本,我们已经确保来自同一受试者的数据只出现在训练或测试集中。

为了与交叉验证设置保持一致,用于后续的域外评估,我们进行了五折留一客户端数据交叉验证,每个折叠中使用不同的随机种子进行三次重复。每次使用不同的随机种子重复三次。

这句话意思是"我们进行了一次五折留一客户数据交叉验证”。这种交叉验证方法将数据集分成五个部分,在每次验证中,将其中一个客户的数据作为验证集,其余客户的数据作为训练集,以此来评估模型的性能。

附录中提供了没有保留客户数据的重复实验的结果。对于PFL方法,我们通过平均每个个性化模型的结果来报告性能。

模型和训练超参数。我们采用ResNet-18体系结构作为主干模型。我们的方法在75%的训练阶段启动局部全局插值,与SWA的默认超参数设置一致。我们使用Adam优化器,学习率为1e−3,动量系数为0.9和0.99,并将批量大小设置为16。我们将本地训练epoch设置为1,并执行总共1000轮通信。


compute_hessian_eigenthings 是一个包含计算Hessian矩阵特征值和特征向量的函数的库。主要作用是在优化算法中用于评估损失函数的二阶导数信息。通过计算Hessian矩阵的特征值和特征向量,可以更好地理解优化空间的形状,并为优化算法提供更准确的方向信息,从而加快模型收敛速度。blog code

pip install –upgrade git+https://github.com/noahgolmant/pytorch-hessian-eigenthings.git@master

pip install –upgrade git+https://github.com/noahgolmant/pytorch-hessian-eigenthings.git@dce2e54a19963b0dfa41b93f531fb7742d46ea04

gpytorch 是一个基于 PyTorch 的高斯过程库,用于概率编程和高斯过程模型的构建。它提供了在 PyTorch 框架下构建、训练和推断高斯过程模型所需的工具和功能。

AttributeError: module ‘argparse’ has no attribute ‘BooleanOptionalAction’–python3.9以上


windows下

1
2
Pillow_SIMD==9.0.0.post1
netifaces==0.11.0
1
2
3
conda create -n fedsoup python=3.9 --y

conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=11.8 -c pytorch -c nvidia

paper

abstract

跨孤岛fl,当面对分布变化时,当前的FL算法面临着局部和全局性能之间的权衡。具体来说,个性化FL方法有过度拟合局部数据的倾向,导致局部模型出现陡谷,抑制了其推广到分布外数据的能力。我们提出了一种新的联邦模型汤方法(即模型参数的选择性插值)来优化局部和全局性能之间的权衡。具体来说,在联邦训练阶段,每个客户机通过监视局部模型和全局模型之间的内插模型的性能来维护自己的全局模型池。这使我们能够缓解过拟合并寻求平坦最小值,这可以显着提高模型的泛化性能。视网膜和病理图像分类

introduction

最近的研究[28]发现了当前FL算法的一个重要问题,即当遇到分布变化时,局部和全局性能之间的权衡。个性化FL (PFL)技术通常通过对每个客户机的分布内(ID)数据施加更多权重来解决数据异构问题。例如,FedRep[5]在本地更新期间学习整个网络,并使部分网络免于全局同步。然而,它们有过拟合本地数据的风险[23],特别是当客户端本地数据有限时,并且对out- distribution (OOD)数据的泛化性较差。另一项工作是通过规范局部模型的更新来研究异构性问题。例如,FedProx[15]约束本地更新更接近全局模型。评估FL泛化性的一种有效方法是研究其在联合全局分布上的性能[28],这是指在⋃{Di}上测试FL模型,其中Di表示客户端i的分布4。遗憾的是,现有的研究还没有找到个性化(局部)和共识(全局)模型之间的平衡点

Dj is viewed as the OOD data for client i.

为此,我们的目标是解决 FL 中的以下两个问题:可能导致局部和全局权衡的原因是什么。以及如何实现更好的局部和全局权衡。我们发现 FL 中的过度个性化会导致对本地数据的过度拟合,并将模型捕获到损失景观的尖锐山谷中(对参数扰动高度敏感,参见第 2.2 节中的详细定义),从而限制了其泛化性。避免损失景观中尖锐山谷的有效策略是强制模型获得平坦的最小值。在集中式领域,权重插值weight interpolation已被探索为寻找平坦最小值的一种手段,因为它的解决方案更接近高性能模型的质心,这对应于更平坦的最小值 [11,3,6,24]。然而,对这些插值方法的研究在FL中被忽略了。

在此基础上,我们建议在联合训练过程中跟踪局部模型和全局模型,并执行模型插值来寻求最优平衡。我们的见解是从模型汤方法[27]中提取的,这表明具有相同初始参数的多个训练模型的平均权重可以增强模型的泛化能力。然而,原始模型汤方法需要训练具有不同超参数的大量模型,这在 FL 期间的通信方面可能非常耗时且成本高昂。鉴于 FL 中的通信成本高且无法从头开始训练,我们利用单个训练会话中不同时间点的全局模型作为使模型汤方法 [27] 适应 FL 的成分。

在本文中,我们提出了一种新的联邦模型汤方法(FedSoup),以从局部和全局模型中生成集成模型,从而实现更好的局部-全局权衡。我们将“soup”称为不同联邦模型的组合。我们提出的FedSoup包括两个关键模块。第一种是时间模型选择,旨在选择合适的模型组合成一个模型。第二个模块是联邦模型修补[10],它指的是一种微调技术,旨在增强个性化,而不影响已经令人满意的全局性能

“temporal model selection” 意味着在时间序列数据或其他时间相关数据集上进行模型选择。这可能涉及在不同的时间点上选择不同的模型或参数配置,以适应数据随时间变化的特性。这种方法有助于构建更具有适应性和泛化能力的模型,以更好地应对时间序列数据的特点。

对于第一个模块,时间模型选择,我们使用了基于局部验证性能的贪婪模型选择策略。避免将可能位于不同误差景观盆地的模型合并到本地的损失景观中(如图1所示)。因此,每个客户都拥有他们的个性化全局模型汤,由基于其本地验证集选择的历史全局模型的子集组成。

“景观"在这里指的是指代模型参数空间中的拓扑结构或形状。在机器学习领域,误差景观通常用来描述损失函数在参数空间中的形状,以及模型在该空间中移动时损失值的变化情况。

图1:PFL方法通常最小化局部损失,但存在较高的全局损失。而我们的联邦模型汤方法通过寻求平坦极小值来平衡局部和全局损失。图中的黑点表示省略号,表示其间的多轮模型上传和模型训练。与以前的pFL方法相比,我们的方法引入了全局模型选择模块和全局模型与局部模型进行插值(称为模型修补)。

“seeking flat minima” 意味着在机器学习中寻找“平坦的最小值”。在训练神经网络时,有时候希望模型收敛到一个平坦的局部最小值,而不是一个非常陡峭的局部最小值。这是因为平坦的最小值可能对噪声更具鲁棒性,有助于提高模型的泛化能力。

“local model interpolation with the global model” 意味着在机器学习中,将局部模型与全局模型进行插值。这可能涉及将局部训练的模型与全局模型进行融合或插值,以获得更具泛化能力的模型。这种方法有助于平衡全局和局部模型的性能,以获得更好的整体性能。

第二个模块,联邦模型修补,它通过将局部模型和全局模型汤插入到新的局部模型中,在局部客户端训练中引入模型修补,弥合局部域和全局域之间的差距。它促进了ID测试模型的个性化,并为OOD泛化保持了良好的全局性能。

(i)提出了一种新的FL方法,称为联邦模型流(FedSoup),通过提高平滑度和寻求平坦极小值来提高泛化和个性化。(ii)为FL设计了一种新的时间模型选择机制,该机制维护了具有时间历史全局模型的客户特定模型汤,以满足个性化要求,同时不产生额外的训练成本。(iii)在联合客户端训练中引入了一种创新的局部和全局模型之间的联合模型修补方法,以缓解局部有限数据的过拟合。

method

个性化的目的是最小化本地客户端训练集 Di 上的经验loss,

泛化(全局性能)的目标是通过所有训练客户的训练集 D 上的经验loss最小化 (ERM) 来最小化多both population损失 ED (θ) 和 ET (θ)–(定义了一组看不见的目标域 T)

我们评估了 Di 的局部测试样本的局部性能,并从联合全局分布 D := {Di}N i=1 评估测试样本的全局性能

“minimize both population loss” 意味着在机器学习中,不仅要最小化经验损失(在训练数据上的损失),还要尽量减小整体(或总体)损失。这表示模型不仅要在训练数据上表现良好,还要在整体总体上具有良好的泛化能力。

泛化和平坦最小值

在实践中,深度神经网络中的ERM,即arg-minθõED(θ),可以产生多个解决方案,这些解决方案提供可比较的训练损失,但可推广性水平截然不同[3]。然而,如果没有适当的正则化,模型容易对训练数据进行过度拟合,并且训练模型将陷入损失面的陡峭山谷,这是不太可推广的[4]。ERM失败的一个常见原因是数据分布存在变化,这可能会导致损失景观的变化。如图 1 所示,优化的最小值越尖锐,它对损失景观的变化就越敏感。这导致泛化误差增加。在跨设备 FL 中,每个客户端可能会过度拟合其本地训练数据,导致全局性能较差。这是由于分布偏移问题,它在局部模型[23]中造成了相互冲突的目标。因此,当局部模型收敛到一个尖锐的最小值时,模型的个性化程度(局部性能)越高,泛化能力较差(全局性能)的可能性就越大。

fedsoup

Temporal Model Selection

随机加权平均 (SWA) 是一种更简洁和有效的方法,通过平均权重隐式偏爱平坦最小值。SWA 算法的动机是观察到 SGD 通常在权重空间中找到高性能模型,但很少达到最优集的中心点。通过对迭代上的参数值进行平均,SWA 算法将解移动到更接近该点空间的质心。

引入了一种称为模型汤[27]的选择性加权平均方法来增强微调模型的泛化能力。我们通过利用在FL训练的一次传递中不同时间点训练的全局模型,使这个想法适应新的方法。我们提出了一种模型选择策略,其中每个客户端利用其本地验证集的性能作为监控指标

Federated Model Patching

根据之前关于损失景观的分析,由于不同FL客户端的域差异,不同FL客户端之间存在损失景观偏移。因此,简单地集成全局模型会损害模型的个性化。我们在 FL 的客户端局部训练期间引入了模型修补 [10](即局部和全局模型插值),旨在提高模型个性化并保持良好的全局性能。具体来说,模型修补方法迫使本地客户端不会严重扭曲全局模型,并在局部和全局之间寻求低损失插值模型,鼓励局部和全局模型位于同一个盆地,没有较大的线性连通性障碍。[19]。我们称这个模块为联邦模型修补。

我们提出的 FedSoup 算法只需要一个仔细调整的超参数,即插值开始 epoch。为了减轻当起始 epoch 太晚并且当起始 epoch 太早时防止潜在性能下降的风险,我们将默认插值起始 epoch 设置为总训练时期的 75%,与 SWA 的默认设置对齐。此外,值得一提的是,我们提出的 FedSoup 框架中修改后的模型汤和模型修补模块是相互依赖的。模型修补是一种基于我们修改后的模型汤算法的技术,提供了丰富的模型来探索更平坦的最小值并提高性能。

Experiments

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
{
  "num_clients": 5,
  "num_classes": 2,
  "non_iid": null,
  "balance": null,
  "partition": null,
  "Size of samples for labels in clients": [
    [
      [
        0,
        460
      ],
  "alpha": 0.1,
  "batch_size": 10
}
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
# hoid 0: 指定保留客户的索引。在交叉验证中,可能需要保留特定的客户数据以进行验证,这就需要指定要保留的客户在数据集中的索引。
# nc :client数目
# -wa_alpha:FedSoup个性化全局模型池的权重平均比率
python main.py -data tiny_camelyon17 -m resnet -algo FedSoup -gr 20 -did 0 -eg 1 -go fedsoup_debug -nc 4 -hoid 0 -lr 1e-3 -wa_alpha 0.75

Algorithm: FedSoup
Local batch size: 16
Local steps: 1
Local learing rate: 0.001
Total number of clients: 4
Clients join in each round: 1.0
Client drop rate: 0.0
Time select: False
Time threthold: 10000
Global rounds: 20
Running times: 1
Dataset: tiny_camelyon17
Local model: resnet
Using device: cpu
Hold-out Client ID: 0
    
============= Running time: 0th =============
    
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=2, bias=True)
)

1
2
3
4
5
6
7
8
server = FedSoup(args, i)

# 设置客户端
self.set_clients(args, clientSoup)
# client0 没有设置
set client:  1 2 3 4
Join ratio / total clients: 1.0 / 5

1
2
3
4
5
6
server.train()

(Pdb) p self.selected_clients
[<flcore.clients.clientsoup.clientSoup object at 0x0000019197EB87C0>, <flcore.clients.clientsoup.clientSoup object at 0x00000191ABB543A0>, <flcore.clients.clientsoup.clientSoup object
at 0x000001919BE7AF40>, <flcore.clients.clientsoup.clientSoup object at 0x0000019194DBA9D0>]

1
2
3
4
# server->client
self.send_models()
# 接收全局模型
client.set_parameters(self.global_model)
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
-------------Round number: 0-------------
Evaluate global model
# out-of-distribution (OOD) 数据的评估为True
# 模型在额外的、不同分布的数据上进行评估,以测试其泛化能力。
self.evaluate(ood_eval=True)

# 获取测试指标
stats = self.test_metrics() # ood_eval=False
# 测试每个client的数据 clientbase.py(113)test_metrics()
c.test_metrics()
(Pdb) p x.shape
torch.Size([16, 3, 96, 96])
# 遍历完所有client0 测试集, 
(Pdb) p len(y_true)# 18个batchsize
18
(Pdb) p y_true[0].shape# batchsize:16
(16, 2)

# 按第一个维度拼接之后,所有数据拼接,去掉了batchsize
(Pdb) p len(y_true) # client的测试数据数
280
(Pdb) p y_true[0].shape
(2,)
(Pdb) p y_true[0]
array([0, 1])
(Pdb) p test_acc
140
(Pdb) p test_num
280
(Pdb) p auc
0.5641454081632653
# 对所有client而言
(Pdb) p ids
[1, 2, 3, 4]
(Pdb) p num_samples
[280, 280, 280, 280]
(Pdb) p tot_correct
[140.0, 140.0, 140.0, 140.0]
(Pdb) p tot_auc# 即使`y_true`和`y_prob`相同,但AUC的计算结果可能不同
[157.9607142857143, 150.17857142857142, 136.8892857142857, 140.6357142857143]

# # acc:每个client样本正确数分别除所有client样本数 相加
# 相当于,每个客户端的正确率*每个客户端数据的比重 相加  
(Pdb) test_acc
0.5
(Pdb) test_auc
0.5229145408163266

(Pdb) p test_acc_list
[0.5, 0.5, 0.5, 0.5]
(Pdb) p test_auc_list
[0.5641454081632653, 0.5363520408163265, 0.48889030612244894, 0.5022704081632654]
(Pdb) accs
[0.5, 0.5, 0.5, 0.5]
(Pdb) aucs
[0.5641454081632653, 0.5363520408163265, 0.48889030612244894, 0.5022704081632654]
(Pdb) local_test_acc
0.5
(Pdb) local_test_auc
0.5229145408163266



OOD eval details:
(Pdb) p ood_eval
True
stats_all = self.test_metrics(ood_eval=True)

for c in self.clients:# 4 self.num_clients 
(Pdb) p c.id
1
(Pdb) p dataset_ids # 其余客户端的id
[0, 2, 3, 4]
c.test_metrics(ood_eval=ood_eval, dataset_ids=dataset_ids)

testloaderfull = self.load_test_data(ood_eval=ood_eval,dataset_ids=dataset_ids)
(Pdb) p batch_size
16
for dataset_id in dataset_ids:# [0, 2, 3, 4]
	if dataset_id == self.hold_out_id:# 0
		train_data = read_client_data(
                 self.dataset, dataset_id, is_train=True)
        # client 0 的训练集
        test_data_list.append(train_data)
        # client 1 的测试集
     test_data = read_client_data(self.dataset, dataset_id, is_train=False)
                test_data_list.append(test_data)
(Pdb) p len(train_data)
640
(Pdb) p len(test_data_list)
5
(Pdb) p len(test_data_list[0]) # client 0 训练集
640
(Pdb) p len(test_data_list[1]) # client 1 测试集
280
(Pdb) p len(test_data_list[2])
280
# client 0的所有数据集+其余client 2, 3, 4数据集
(Pdb) p len(concat_test_data)# 拼接后的测试集
1760# 640   +   280   ×   4 

# client0 的所有数据集
oof_testloaderfull = self.load_test_data(
                    ood_eval=ood_eval, dataset_ids=[self.hold_out_id]
                )
(Pdb) p len(test_data_list)
2
(Pdb) p len(concat_test_data)
920

for x, y in testloaderfull:
for x, y in oof_testloaderfull:
return test_acc, test_num, auc, oof_test_acc, oof_test_num, oof_auc
(Pdb) p test_acc
880
(Pdb) p test_num
1760
(Pdb) p auc
0.5438849431818182
(Pdb) p oof_test_acc
460
(Pdb) p oof_test_num
920
(Pdb) p oof_auc
0.5537807183364839

# client1上的模型测试client0(hold)数据集
(Pdb) p oof_model_ids
[1]
(Pdb) p oof_tot_correct
[460.0]
(Pdb) p oof_tot_auc
[509.4782608695652]
(Pdb) p oof_num_samples
[920]

# client1上测试 client0数据集 和234的测试集
(Pdb) p model_ids
[1]
(Pdb) p tot_correct
[880.0]
(Pdb) p tot_auc
[957.2375000000001]
(Pdb) p num_samples
[1760]

(Pdb) p model_ids
[1, 2, 3, 4]
(Pdb) p num_samples
[1760, 1760, 1760, 1760]
(Pdb) p tot_correct
[880.0, 880.0, 880.0, 880.0]
(Pdb) p tot_auc
[957.2375000000001, 959.5624999999999, 977.6306818181818, 965.6460227272728]
(Pdb) p oof_model_ids
[1, 2, 3, 4]
(Pdb) p oof_num_samples
[920, 920, 920, 920]
(Pdb) p oof_tot_correct
[460.0, 460.0, 460.0, 460.0]
(Pdb) p oof_tot_auc
[509.4782608695652, 509.4782608695652, 509.4782608695652, 509.4782608695652]


(Pdb) p stats_all
([1, 2, 3, 4], [1760, 1760, 1760, 1760], [880.0, 880.0, 880.0, 880.0], [957.2375000000001, 959.5624999999999, 977.6306818181818, 965.6460227272728], [1, 2, 3, 4], [920, 920, 920, 920],
 [460.0, 460.0, 460.0, 460.0], [509.4782608695652, 509.4782608695652, 509.4782608695652, 509.4782608695652])
(Pdb) p stats # 切片前四个为 client0数据集 和234的测试集
([1, 2, 3, 4], [1760, 1760, 1760, 1760], [880.0, 880.0, 880.0, 880.0], [957.2375000000001, 959.5624999999999, 977.6306818181818, 965.6460227272728])
(Pdb) p test_acc
0.5
(Pdb) p test_auc
0.5483063500774793

(Pdb) p test_acc_list
[0.5, 0.5, 0.5, 0.5]
(Pdb) p test_auc_list
[0.5438849431818182, 0.5452059659090909, 0.5554719783057851, 0.5486625129132232]
(Pdb) p accs
[0.5, 0.5, 0.5, 0.5]
(Pdb) p aucs
[0.5438849431818182, 0.5452059659090909, 0.5554719783057851, 0.5486625129132232]



OOD Performance (client model i on all other dataset j):
# 根据客户端数量加权的测试准确率
OOD Client-Num-Weighted Test Accurancy: 0.5000 
OOD Client-Num-Weighted Test AUC: 0.5483
# 客户端平均测试准确率
OOD Client-Equally Test Accurancy: 0.5000
OOD Client-Equally Test AUC: 0.5483
OOD Std Test Accurancy: 0.0000
OOD Std Test AUC: 0.0045

Performance Summarizing...
Local Performance:# client 在自己测试集上的性能的平均值
Local Client-Equally Test Accurancy: 0.5000
Local Client-Equally Test AUC: 0.5229
# accs:每个client在其余client数据集上的准确率
# append:加上每个client在自身数据集上的平均值
accs.append(local_test_acc)
aucs.append(local_test_auc)
(Pdb) p accs
[0.5, 0.5, 0.5, 0.5, 0.5]
(Pdb) p aucs
[0.5438849431818182, 0.5452059659090909, 0.5554719783057851, 0.5486625129132232, 0.5229145408163266]
# 取平均?得到全局?
# 可以理解为 local_test_acc为每个client在自身数据集的准确率的平均值
# 先取平均 np.mean(accs):每个客户端在其余client数据集上准确率的平均值
# 再取平均为全局模型的准确率
Global Performance:
Glocal Client-Equally Test Accurancy: 0.5000
Glocal Client-Equally Test AUC: 0.5432

# client在holdcliet上的准确率(client 0) 为参与到联邦学习中的client
stats = stats_all[4:]
(Pdb) p stats
([1, 2, 3, 4], [920, 920, 920, 920], [460.0, 460.0, 460.0, 460.0], [509.4782608695652, 509.4782608695652, 509.4782608695652, 509.4782608695652])
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
-------------Round number: 0-------------

Evaluate global model# 每个客户端在自身测试集的准确率
([1, 2, 3, 4], [280, 280, 280, 280], [140.0, 140.0, 140.0, 140.0], [157.9607142857143, 150.17857142857142, 136.8892857142857, 140.6357142857143])
[1, 2, 3, 4]
Client 1 Test Accurancy: 0.5000
Client 1 Test AUC: 0.5641
Client 2 Test Accurancy: 0.5000
Client 2 Test AUC: 0.5364
Client 3 Test Accurancy: 0.5000
Client 3 Test AUC: 0.4889
Client 4 Test Accurancy: 0.5000
Client 4 Test AUC: 0.5023
Averaged Test Accurancy: 0.5000# 平均准确率
Averaged Test AUC: 0.5229
Std Test Accurancy: 0.0000
Std Test AUC: 0.0294

OOD eval details: # ood每个客户端在其他所有客户端数据的准确率
([1, 2, 3, 4], [1760, 1760, 1760, 1760], [880.0, 880.0, 880.0, 880.0], [957.2369318181819, 959.5619318181818, 977.6301136363636, 965.6454545454545])
[1, 2, 3, 4]
Client 1 Test Accurancy: 0.5000
Client 1 Test AUC: 0.5439
Client 2 Test Accurancy: 0.5000
Client 2 Test AUC: 0.5452
Client 3 Test Accurancy: 0.5000
Client 3 Test AUC: 0.5555
Client 4 Test Accurancy: 0.5000
Client 4 Test AUC: 0.5487
OOD Performance (client model i on all other dataset j):
OOD Client-Num-Weighted Test Accurancy: 0.5000# 加权平均
OOD Client-Num-Weighted Test AUC: 0.5483
OOD Client-Equally Test Accurancy: 0.5000# 平均
OOD Client-Equally Test AUC: 0.5483
OOD Std Test Accurancy: 0.0000
OOD Std Test AUC: 0.0045

Performance Summarizing...
Local Performance:# client在本地数据上的准确率
Local Client-Equally Test Accurancy: 0.5000
Local Client-Equally Test AUC: 0.5229
Global Performance:# 全局模型准确率
Glocal Client-Equally Test Accurancy: 0.5000
Glocal Client-Equally Test AUC: 0.5432
# OOF未参与联邦学习的client 0的准确率
([1, 2, 3, 4], [920, 920, 920, 920], [460.0, 460.0, 460.0, 460.0], [509.4771739130435, 509.4771739130435, 509.4771739130435, 509.4771739130435])
[1, 2, 3, 4]
================
OOF Performance:
OOF Client Test Accurancy: 0.5000
OOF Client Test AUC: 0.5538

out-of-distribution (OOD):在不同客户端数据上进行测试,未知分布。

out-of-federated performance(oof):指的是在联邦学习环境之外的性能表现。这个术语通常用于描述模型在非联邦学习设置下的性能,例如在集中式学习或单个数据源的情况下的性能表现。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
client.train()# clientsoup.py(33)train()

self.model.train()

client.train()
#轮次达到0.75时执行权重平均化算法
(Pdb) p self.wa_alpha
0.75


round 15
Begin Weight Averaging......
Original Weight Averaging Accuracy:  0.9392857142857143
Updated Weight Averaging Accuracy:  0.8357142857142857
Remain the same Personalized Global Model.
Begin Weight Averaging......
Original Weight Averaging Accuracy:  0.8607142857142858
Updated Weight Averaging Accuracy:  0.7785714285714286
Remain the same Personalized Global Model.
Begin Weight Averaging......
Original Weight Averaging Accuracy:  0.7642857142857142
Updated Weight Averaging Accuracy:  0.7178571428571429
Remain the same Personalized Global Model.
Begin Weight Averaging......
Original Weight Averaging Accuracy:  0.8857142857142857
Updated Weight Averaging Accuracy:  0.8714285714285714
Remain the same Personalized Global Model.

Original Weight Averaging Accuracy:  0.6214285714285714
Updated Weight Averaging Accuracy:  0.7857142857142857
Update Personalized Global Model......
Client ID:  4
Personalized Global Model Num:  1


1
2
3
Evaluating Post-Fine-Tuning and OOD Performance...

self.evaluate(ood_eval=True)
1
Fine-Tuning with the Last Trained Model......

fedsoup

  1. 选择性模型混合(Selective Model Interpolation):FedSoup引入了一种新的模型混合策略,它允许在全局模型和个性化模型之间进行选择性的插值。这种方法不仅考虑了全局模型的泛化能力,还考虑了个性化模型对本地数据的适应性。
  2. 个性化和泛化的平衡:FedSoup通过调整全局模型和个性化模型的插值比例动态地平衡了模型的泛化和个性化。这种平衡使得模型既能捕捉到跨用户的共同特征,又能适应每个用户的特定数据分布。
  3. 自适应插值权重:FedSoup算法中的选择性模型混合不是静态的,而是基于每个用户的数据和模型性能自适应地调整插值权重。这种自适应机制使得算法能够根据不同用户的数据多样性和模型性能需求灵活调整模型结构。
  4. 降低通信开销:在联邦学习中,模型参数的同步是一个通信密集型的过程。FedSoup通过减少需要传输的参数数量,从而降低了通信开销。这是通过只传输那些在全局模型和个性化模型插值中起到关键作用的参数来实现的。
  5. 实验验证:论文中对FedSoup算法进行了广泛的实验验证,证明了它在多个数据集和不同的联邦学习设置下,相较于现有方法,能够显著提高模型的泛化和个性化性能。

总之,FedSoup算法通过选择性模型混合和自适应插值权重,提供了一种有效的机制来同时增强联邦学习中的泛化和个性化性能,并且降低了通信开销。

FedSoup算法在提升模型泛化能力方面的创新体现在以下几个方面:

  1. 全局与个性化模型的融合:FedSoup算法通过结合全局模型和本地个性化模型来提升泛化能力。全局模型能够捕捉到所有用户数据的共同特征,从而提高模型对于未见数据的泛化能力。个性化模型则针对每个用户的本地数据进行优化,提高模型对特定用户数据的适应性
  2. 自适应插值权重:算法根据每个用户的数据特性和模型性能动态调整全局模型和个性化模型的插值权重。这种自适应机制使得算法能够在保证模型个性化的同时,确保模型在全局范围内的泛化性能。
  3. 模型更新的选择性同步:在联邦学习的过程中,并不是每次都需要将整个模型更新同步给所有用户。FedSoup算法通过选择性地同步模型更新,减少了通信开销,并可能只同步那些对提升泛化能力最为关键的部分,这样可以更加高效地利用带宽和计算资源,从而间接提升模型的泛化能力。
  4. 避免过度拟合:由于联邦学习的数据分散在多个用户手中,局部模型可能会过度拟合到本地数据,从而降低泛化能力。FedSoup通过全局模型的引导,帮助局部模型避免这种过度拟合现象,增强模型对新数据的泛化能力。
  5. 实验验证:论文中提供了详细的实验结果,展示了FedSoup算法在不同的数据集和联邦学习场景下,相较于传统的联邦学习方法,在提升模型泛化能力方面的优势。

综上所述,FedSoup算法通过全局和个性化模型的融合、自适应插值权重、选择性模型同步、避免过度拟合以及实验验证等创新手段,有效地提升了模型的泛化能力。

插值权重,在插值方法的上下文中,是指每个已知数据点对未知位置或值的影响权重。这些权重通常是根据已知数据点之间的距离或其他相关性来计算的。简而言之,插值权重用于确定在插值过程中,各个已知数据点应如何贡献于未知点的估计值。通过适当地调整这些权重,可以更准确地逼近未知点的真实值。

在论文《FedSoup: Improving Generalization and Personalization in Federated Learning via Selective Model Interpolation》中,选择性模型混合(Selective Model Interpolation)是通过以下步骤实现的:—类似于fedala

  1. 模型训练:在联邦学习的框架中,每个客户端都会根据自己的本地数据训练一个个性化模型。同时,所有客户端还会训练一个全局模型,该模型旨在捕捉所有客户端数据的共同特征。
  2. 模型融合策略:在模型训练完成后,FedSoup采用一种融合策略,该策略基于每个客户端的个性化模型和全局模型之间的相似性来选择性地进行模型混合。具体来说,FedSoup计算每个个性化模型与其对应客户端的全局模型的差异度量(例如,KL散度或余弦相似性)。
  3. 插值系数计算:接着,FedSoup为每个客户端计算一个插值系数,该系数表示个性化模型在最终模型中的权重。插值系数的计算基于差异度量和一个预设的温度参数(temperature parameter),后者用于控制插值的平滑程度。
  4. 模型更新:最后,每个客户端根据计算出的插值系数更新自己的模型,即将全局模型与个性化模型按照插值系数融合成一个新的模型。这个新模型既包含了全局模型的泛化能力,又保留了个性化模型对本地数据的适应性。

通过这种选择性模型混合,FedSoup旨在实现更好的模型泛化和个性化。模型泛化指的是模型对新数据的适用性,而个性化则是模型对特定用户数据的适应性。通过这种方式,FedSoup旨在平衡这两个目标,使得模型既能够泛化到新用户,也能够很好地服务于每个参与联邦学习的客户端。

在《FedSoup: Improving Generalization and Personalization in Federated Learning via Selective Model Interpolation》论文中,通过全局模型引导局部模型的方式来避免过度拟合主要体现在模型融合策略上。具体而言,FedSoup算法采用了一种选择性模型混合(selective model interpolation)的策略,该策略结合了全局模型和局部模型的优点,以提高模型在联邦学习环境中的泛化能力和个性化能力。

在这个过程中,每个客户端都会根据自己的本地数据训练一个个性化模型,同时也会从服务器那里接收到一个全局模型。然后,客户端利用全局模型来引导本地模型的训练。这种引导通常是通过在训练过程中将全局模型的参数作为一种先验知识加入到本地模型的训练中,或者在训练结束后,通过模型混合的方式来实现的。

在模型混合阶段,客户端会计算本地模型和全局模型之间的差异,并根据这个差异以及其他因素(比如模型性能、数据分布等)来决定如何混合这两个模型。如果本地模型在本地数据上过度拟合,那么它可能会与全局模型有较大的差异。此时,通过选择性地混合全局模型和本地模型,可以帮助本地模型“解拟合”,即减少模型对本地数据的过拟合现象,从而提高模型对新数据的泛化能力

通过这种方式,FedSoup算法能够在保持模型个性化的同时,确保模型不会仅仅适应本地数据,而是能够泛化到更广泛的数据分布上。这样的策略对于联邦学习中的模型训练尤为重要,因为联邦学习涉及多个客户端,每个客户端的数据分布可能都不尽相同,因此需要模型既要有良好的泛化能力,又要能适应不同客户端的个性化需求。

在论文《FedSoup: Improving Generalization and Personalization in Federated Learning via Selective Model Interpolation》中,模型更新的选择性同步是通过以下步骤实现的:

  1. 每个参与联邦学习的客户端在本地训练自己的个性化模型,同时也训练一个全局模型。
  2. 在每次联邦学习轮次结束时,客户端计算自己的个性化模型与全局模型之间的差异度量,并据此确定是否需要同步更新。
  3. 客户端根据差异度量和一个预设的阈值来决定是否将模型更新发送给服务器。只有当差异超过这个阈值时,客户端才会同步其模型更新。
  4. 服务器收到来自不同客户端的更新后,会根据一定的策略(比如加权平均)来聚合这些更新,生成新的全局模型。
  5. 服务器将新的全局模型发送回所有客户端,客户端再根据自己的插值系数将新的全局模型与本地个性化模型融合,形成新的本地模型。

通过这种方式,只有当客户端的模型更新相对于全局模型有显著差异时,更新才会被同步,这样可以减少不必要的通信开销,并且使得全局模型更快地收敛到一个对所有客户端都有用的状态。同时,选择性同步也有助于保护客户端的隐私,因为它减少了需要传输的信息量。

论文《FedSoup: Improving Generalization and Personalization in Federated Learning via Selective Model Interpolation》中,判断是否需要上传模型更新的机制是基于模型间的差异性。具体来说,客户端会计算本地模型更新与全局模型更新之间的差异度量,例如使用欧氏距离或者其他相似度指标。如果差异超过某个设定的阈值,表明本地模型与全局模型存在较大差异,这时候才需要上传模型更新。

这样做的目的主要有两个:

  1. 减少通信成本:通过只上传那些真正有差异、能够为全局模型带来新信息的模型更新,可以减少网络带宽的使用和通信开销,特别是在大规模分布式系统中,这一点尤为重要。
  2. 提高模型效率:只上传重要的更新有助于加快全局模型收敛速度,因为它避免了冗余的、相似的或者不太有用的更新,从而使得全局模型能够更快地整合所有客户端的有用信息,提高整体模型的性能。

1
2
3
4
last_global_model.data = (1.0 / (self.per_global_model_num + 1.0)) * (
                    self.per_global_model_num * global_param.data.clone()
                    + last_global_model.data.clone()
                )
  1. self.per_global_model_num * global_param.data.clone():这部分计算global_param的当前值与其对应权重的乘积。权重是self.per_global_model_num
  2. last_global_model.data.clone():这部分直接使用last_global_model的当前值,没有乘以任何权重。
  3. 将上述两部分相加,得到global_paramlast_global_model的加权和。
  4. 最后,这个加权和乘以一个归一化因子1.0 / (self.per_global_model_num + 1.0),以确保更新后的last_global_model值在合理的范围内。
  5. global_param的权重随着self.per_global_model_num的增加而增加。last_global_model的实际影响会随着self.per_global_model_num的增加而相对减小。

FedSoupALA

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
python main.py -data tiny_camelast_global_model.data = (1.0 / (self.per_global_model_num + 1.0)) * (
                    self.per_global_model_num * global_param.data.clone()
                    + last_global_model.data.clone()
                )lyon17 -m resnet -algo FedSoup -gr 20 -did 0 -eg 1 -go fedsoup_debug -nc 4 -hoid 0 -lr 1e-3 -wa_alpha 0.75

nohup python -u main.py -t 1 -jr 1 -nc 20 -nb 10 -data mnist-0.1-npz -m cnn -algo FedALA -et 1 -p 2 -s 80 -did 0 > result-mnist-0.1-npz.out 2>&1


python main.py -data tiny_camelyon17 -m resnet -algo FedSoupALA -gr 20 -did 0 -eg 1 -go fedsoup_debug -nc 4 -hoid 0 -lr 1e-3 -wa_alpha 0.75 -et 1 -pala 2 -s 80


python -u main.py -data tiny_camelyon17 -m resnet -algo FedSoupALA -gr 1000 -did 0 -eg 100 -go fedsoup_debug -nc 4 -hoid 0 -lr 1e-3 -wa_alpha 0.75  -et 1 -pala 2 -s 80 > result-tiny-came-npz-ala.out 2>&1 

python -u main.py -data tiny_camelyon17 -m resnet -algo FedSoupALA -gr 1000 -did 0 -eg 100 -go fedsoup_debug -nc 4 -hoid 0 -lr 1e-3 -wa_alpha 0.75  -et 1 -pala 6 -s 80 > result-tiny-came-npz-ala-neweval-p6.out 2>&1

python -u main.py -data tiny_camelyon17 -m resnet -algo FedSoupALA -gr 1000 -did 0 -eg 100 -go fedsoup_debug -nc 4 -hoid 0 -lr 1e-3 -wa_alpha 0.75 --pruning --sparsity_ratio 0.5 --pruning_warmup_round 500 --masking_grad --dynamic_mask > result-tiny-came-npz-ala-prun.out 2>&1 
1
2
3
4
# ResNet
(Pdb) p len(params)
62

cifar -- client-20 gr 50

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
client-20 gr 50
# 修改self.evaluate(ood_eval=True)

self.evaluate(ood_eval=True,round = i)
if ood_eval and round ==self.global_rounds

self.evaluate(ood_eval=True,round=self.global_rounds)
self.evaluate(ood_eval=True,round=self.gloabl_rounds)
 -eg 10即可
    
   
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
local_acc = []
for i in range():
    		if i % self.eval_gap == 0:
                print(f"\n-------------Round number: {i}-------------")
                print("\nEvaluate global model")
                #pdb.set_trace()
                self.evaluate()
           	trian 之后     
          	if i % self.eval_gap == 0:
                print("\nEvaluate local model")
                self.evaluate(acc=local_acc)

            if i > 0 and i % 50 == 0:
                print("\nEvaluate ID, OOD and OOF Performance")
                self.evaluate()
                self.evaluate(ood_eval=True)
                # break
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
fedsoup代码中有命令hold中
python main.py -lbs 16 -nc 10 -jr 1 -nb 10 -data Cifar10 -m resnet -algo FedAvg -gr 50  -did 0 -go resnet  

fedavg
nohup python -u main.py -lbs 16 -nc 20 -jr 1 -nb 10 -data Cifar10 -m resnet -algo FedAvg -gr 50 -eg 10 -did 0 -go resnet > result-fedavg.out 2>&1


---

fedprox
nohup python -u main.py -lbs 16 -nc 20 -jr 1 -nb 10 -data Cifar10 -m resnet -algo FedProx -gr 50 -eg 10 -did 0 -mu 0.001 -go resnet > result_fedprox.out 2>&1

---

moon
nohup python -u main.py -lbs 16 -nc 20 -jr 1 -nb 10 -data Cifar10 -m resnet -algo MOON -gr 50 -eg 10 -did 0  -go resnet > result_moon.out 2>&1

fedbn
nohup python -u main.py -lbs 16 -nc 20 -jr 1 -nb 10 -data Cifar10 -m resnet -algo FedBN -gr 50 -did 0 -eg 10 -go fedbn_debug > result_fedbn.out 2>&1

FedFomo
nohup python -u main.py -lbs 16 -nc 20 -jr 1 -nb 10 -data Cifar10 -m resnet -algo FedFomo -gr 50 -did 0 -eg 10 -go fedfomo > result_fedfomo.out 2>&1

FedRep
nohup python -u main.py -lbs 16 -nc 20 -jr 1 -nb 10 -data Cifar10  -m resnet -algo FedRep -gr 50 -did 0 -eg 10 -go fedrep > result_fedrep.out 2>&1

FedBABU
nohup python -u main.py -lbs 16 -nc 20 -jr 1 -nb 10 -data Cifar10 -m resnet -algo FedBABU -gr 50 -did 0 -eg 10 -go fedbabu > result_fedbabu.out 2>&1

fedala
python -u main.py -t 1 -lbs 16 -lr 1e-3 -eg 50 -jr 1 -nc 5 -nb 2 -data tiny_camelyon17 -m resnet -algo FedALA -et 1 -p 2 -s 80 -did 3 -gr 1000 > result-tiny-came-16001npz.out 2>&1

nohup python -u main.py -lbs 16 -nc 20 -jr 1 -nb 10 -data Cifar10 -m resnet -algo FedALA -gr 50 -did 0 -eg 10  -et 1 -pala 2 -s 80 -go fedala > result_fedala.out 2>&1


fedsoup
python -u main.py -data tiny_camelyon17 -m resnet -algo FedSoup -gr 1000 -did 2 -eg 50 -go fedsoup_mg5_debug -nc 5 -lr 1e-3 -wa_alpha 0.75 --pruning --sparsity_ratio 0.5 --pruning_warmup_round 500 --masking_grad --dynamic_mask > result-tiny-came-npz.out 2>&1 

nohup python -u main.py -lbs 16 -nc 20 -jr 1 -nb 10 -data Cifar10 -m resnet -algo FedSoup -gr 50 -eg 10 -did 0 -go resnet  -lr 1e-3 -wa_alpha 0.75 > result-fedsoup.out 2>&1

fedsoupala
python -u main.py -data tiny_camelyon17 -m resnet -algo FedSoupALA -gr 1000 -did 0 -eg 100 -go fedsoup_debug -nc 4 -hoid 0 -lr 1e-3 -wa_alpha 0.75  -et 1 -pala 2 -s 80 > result-tiny-came-npz-ala.out 2>&1 

nohup python -u main.py -lbs 16 -nc 20 -jr 1 -nb 10 -data Cifar10 -m resnet -algo FedSoupALA -gr 1000 -did 0 -go resnet  -lr 1e-3 -wa_alpha 0.75  -et 1 -pala 2 -s 80 > result-fedsoupala-1000.out 2>&1; shutdown

nohup python -u main.py -lbs 16 -nc 20 -jr 1 -nb 10 -data Cifar10 -m resnet -algo FedSoupALA -gr 50 -eg 10 -did 0 -go resnet  -lr 1e-3 -wa_alpha 0.75  -et 1 -pala 2 -s 80 > result-fedsoupala.out 2>&1; shutdown

加上-eg 10

camelyon17–hold client 0

method local_acc loacl_auc global_acc global_auc
FedAvg 82.41 90.44 70.18 78.74
FedProx 86.34 92.78 67.42 77.18
MOON 85.71 91.98 70.61 79.27
FedBN 82.32 90.07 65.16 71.61
FedFomo 80.99 86.51 61.00 61.69
FedRep 82.50 89.77 66.87 72.34
FedBABU 85.18 92.39 69.56 77.26
FedSoup 85.71 92.47 72.87 81.45
Fedsoup-4 87.68、87.90 94.23、94.34 73.31、73.18 80.37、80.86
Fedsoup-4-tune 87.77 94.38 73.25 80.42
FedSoupALA本地 67.77、76.85 75.69、85.63 58.63、63.95 67.44、71.08
FedSoupALA-tune 90.54 95.01 71.82 78.43
method local_acc loacl_auc global_acc global_auc
fedasoupala p2 88.93/89.20/89.02 94.48/94.90/94.88 73.01/73.21/72.92 80.21/80.38/80.20
同上_1 80.09/88.04/87.95 88.91/94.68/94.55 71.04/73.28/72.61 78.06/81.08/80.16
gr20 83.04/84.02/89.20 90.31/92.56/95.45 66.23/65.28/71.29 75.01/75.88/81.18
p4 86.52/86.16/86.96 93.60/93.73/93.94 72.05/72.32/72.52 78.35/78.64/78.94
p6 88.93/88.75/89.02 94.62/94.55/94.64 74.57/74.30/74.51 81.92/81.54/81.89
p8 87.50/87.54/87.86 94.20/94.09/94.19 72.95/72.88/73.11 79.04/78.66/78.95
p10 88.66/87.86/88.48 94.04/93.34/93.68 73.16/72.39/73.01 79.84/78.61/79.74
p2_prun 86.07/86.52/86.16 93.15/93.23/93.01 72.41/72.31/73.32 78.75/78.75/78.74

新的评价方式

nohup python -u main.py -data tiny_camelyon17 -m resnet -algo FedSoupALA -gr 1000 -did 0 -eg 100 -go fedsoup_debug -nc 4 -hoid 0 -lr 1e-3 -wa_alpha 0.75 -et 1 -pala 6 -s 80 > result-tiny-came-npz-ala-neweval-p6.out 2>&1; shutdown

先单独跑FedALA

nohup python -u main.py -lbs 16 -nc 4 -hoid 0 -lr 1e-3-data tiny_camelyon17 -m resnet -algo FedALA -gr 1000 -did 0 -eg 100 -et 1 -pala 6 -s 80 -go fedala > tiny_fedala.out 2>&1

nohup python -u main.py -data tiny_camelyon17 -m resnet -algo FedALA -gr 1000 -did 0 -eg 100 -go fedala -nc 4 -hoid 0 -lr 1e-3 -wa_alpha 0.75 -et 1 -pala 6 -s 80 > tiny-came-fedala-neweval-p6.out 2>&1

fedsoup

nohup python -u main.py -data tiny_camelyon17 -m resnet -algo FedSoup -gr 1000 -did 0 -eg 100 -go fedsoup_debug -nc 4 -hoid 0 -lr 1e-3 -wa_alpha 0.75 -et 1 -pala 6 -s 80 >tiny-came-npz-soup-neweval-p6.out 2>&1;

method local_acc loacl_auc global_acc global_auc
fedasoupala p6 86.96//87.32 99.27/// 71.32//71.29 95.78//97.58
fedala p6 84.91//84.38 99.17//99.11 71.24//71.32 97.37//97.47
fedsoup 84.46//85.18 98.88//99.08 73.03/73.07/ 97.85//97.88

cifar 10 -nohold

50轮次

method local_acc loacl_auc global_acc global_auc
FedAvg 28.70/90.14/91.10 71.77/99.03/99.19 38.22/20.79/21.04 81.09/72.88/73.11
fedavg_1 89.63///90.90 98.91///99.14 29.74/20.73//20.93 72.42/72.63//73.04
FedProx 27.18/89.78/90.74 69.68/98.96/99.13 37.07/20.88/20.99 80.58/72.66/72.88
FedProx_1 89.98///90.85 99.01///99.17 30.52/20.58//20.94 72.93/72.66//72.86
MOON 91.07/91.07/92.10 70.10/70.10/77.00 20.99/20.99/23.30 52.97/52.97/54.69
FedBN 90.82/89.23/91.92 88.15/86.94/94.39 21.33/20.60/22.61 55.27/54.27/57.01
FedFomo 88.93/89.41/89.74 98.83/98.87/98.97 1/19.45/19.90 1/53.94/54.22
FedRep 86.59/88.51/91.09 98.61/98.44/99.14 1/19.82/22.28 1/54.27/56.17
FedBABU 58.46/90.92//92.46 88.67/98.89//99.03 //22.03/23.51 //62.16/64.04
FedALA 89.56///90.58 93.35///94.93 /20.83//20.89 //64.11//64.47
FEDALA_1 89.04/90.77 92.58/94.80/ 20.65/20.79/ /64.04/64.78
FedSoup 26.31/89.65/91.35 57.78/66.91/64.49 18.58/22.18/25.86 53.02/53.18/53.38
Fedsoupala 88.37/91.37/92.44 85.29/87.80/82.10 21.22/21.44/23.85 58.35/57.33/58.35
Fedsoupala_1 90.57/ 86.59/92.48 21.66/24.09 56.94/57.82