averaging.py 331 B

12345678910111213141516
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Python version: 3.6
  4. import copy
  5. import torch
  6. from torch import nn
  7. def average_weights(w):
  8. w_avg = copy.deepcopy(w[0])
  9. for k in w_avg.keys():
  10. for i in range(1, len(w)):
  11. w_avg[k] += w[i][k]
  12. w_avg[k] = torch.div(w_avg[k], len(w))
  13. return w_avg