|
@@ -1,18 +0,0 @@
|
|
|
-#!/usr/bin/env python
|
|
|
-# -*- coding: utf-8 -*-
|
|
|
-# Python version: 3.6
|
|
|
-
|
|
|
-import copy
|
|
|
-import torch
|
|
|
-
|
|
|
-
|
|
|
-def average_weights(w):
|
|
|
- """
|
|
|
- Returns the average of the weights.
|
|
|
- """
|
|
|
- w_avg = copy.deepcopy(w[0])
|
|
|
- for key in w_avg.keys():
|
|
|
- for i in range(1, len(w)):
|
|
|
- w_avg[key] += w[i][key]
|
|
|
- w_avg[key] = torch.div(w_avg[key], len(w))
|
|
|
- return w_avg
|