statistic.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import glob
  2. import os
  3. import pandas as pd
  4. import matplotlib.pyplot as plt
  5. import time
  6. import numpy as np
  7. import seaborn as sns
  8. import math
  9. path = os.getcwd()
  10. def get_average(records):
  11. """
  12. 平均值
  13. """
  14. return sum(records) / len(records)
  15. def get_variance(records):
  16. """
  17. 方差 反映一个数据集的离散程度
  18. """
  19. average = get_average(records)
  20. return sum([(x - average) ** 2 for x in records]) / len(records)
  21. def get_standard_deviation(records):
  22. """
  23. 标准差 == 均方差 反映一个数据集的离散程度
  24. """
  25. variance = get_variance(records)
  26. return math.sqrt(variance)
  27. def get_rms(records):
  28. """
  29. 均方根值 反映的是有效值而不是平均值
  30. """
  31. return math.sqrt(sum([x ** 2 for x in records]) / len(records))
  32. def get_mse(records_real, records_predict):
  33. """
  34. 均方误差 估计值与真值 偏差
  35. """
  36. if len(records_real) == len(records_predict):
  37. return sum([(x - y) ** 2 for x, y in zip(records_real, records_predict)]) / len(records_real)
  38. else:
  39. return None
  40. def get_rmse(records_real, records_predict):
  41. """
  42. 均方根误差:是均方误差的算术平方根
  43. """
  44. mse = get_mse(records_real, records_predict)
  45. if mse:
  46. return math.sqrt(mse)
  47. else:
  48. return None
  49. def get_mae(records_real, records_predict):
  50. """
  51. 平均绝对误差
  52. """
  53. if len(records_real) == len(records_predict):
  54. return sum([abs(x - y) for x, y in zip(records_real, records_predict)]) / len(records_real)
  55. else:
  56. return None
  57. def writeSDCSV(filename):
  58. file = pd.read_csv("Mean.csv")
  59. conditions = file['condition']
  60. dict = {}
  61. dict['conditon'] = conditions
  62. for scale in scales:
  63. temp = []
  64. for condition in conditions:
  65. col = df_merged.groupby('condition').get_group(condition)
  66. col = col[scale]
  67. temp.append(get_standard_deviation(col))
  68. dict[scale] = temp
  69. df = pd.DataFrame(dict)
  70. df.to_csv(filename)
  71. def draw(scale):
  72. conditions = file['condition']
  73. result = file[scale]
  74. plt.figure(figsize=(9, 6), dpi=100)
  75. sd = pd.read_csv(SD)
  76. std_err = sd[scale]
  77. error_params=dict(elinewidth=1,ecolor='black',capsize=5)
  78. plt.bar(conditions, result, width=0.35, color=colors,alpha=a,yerr=std_err,error_kw=error_params)
  79. plt.title(scale,fontsize=15)
  80. plt.ylabel('score')
  81. plt.grid(alpha=0, linestyle=':')
  82. plt.savefig(scale, dpi=300)
  83. #plt.show()
  84. def drawTogether():
  85. scales = ["mental-demand","physical-demand","temporal-demand","performance", "effort","frustration"]
  86. plt.figure(figsize=(15,7))
  87. x = np.arange(len(scales))
  88. total_width, n = 0.8, 4
  89. width = total_width / n
  90. for i in range(0,4):
  91. result = []
  92. for scale in scales:
  93. result.append(file.iloc[i][scale])
  94. plt.bar(x+width*(i-1),result,width=width,color=colors[i],label=file.iloc[i]["condition"],alpha=a)
  95. plt.legend()
  96. plt.title("TLX Average",fontsize=15)
  97. plt.xticks(x+width/2,scales)
  98. #plt.show()
  99. plt.savefig("summary.jpg",dpi=300)
  100. # Merge all the .csv file start with "HectorVR", and
  101. all_files = glob.glob(os.path.join(path, "HectorVR*.csv"))
  102. df_from_each_file = (pd.read_csv(f, sep=',') for f in all_files)
  103. df_merged = pd.concat(df_from_each_file, ignore_index=True)
  104. # Save the file to Merged.csv in the same folder
  105. df_merged.to_csv( "Merged.csv")
  106. # save the results in csv
  107. file = df_merged.groupby(["condition"]).mean()
  108. file.to_csv("Mean.csv")
  109. scales = ["mental-demand","physical-demand","temporal-demand","performance", "effort","frustration","total"]
  110. SD = "standard_deviation.csv"
  111. writeSDCSV(SD)
  112. file = pd.read_csv("Mean.csv")
  113. colors = sns.color_palette()
  114. a = 0.6
  115. for scale in scales:
  116. draw(scale)
  117. drawTogether()