Heatmap.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import UnityEngine as ue
  2. import numpy as np
  3. import seaborn as sns
  4. import matplotlib.pyplot as plt
  5. import matplotlib.patches as patches
  6. from matplotlib.colors import LinearSegmentedColormap
  7. import pandas as pd
  8. WIDTH = int(70)
  9. HEIGHT = int(35)
  10. OBSTACLE_PATH = "Assets/Data_image/obstacle.pkl"
  11. POSITION_PATH = ue.Application.dataPath + '/Data_position/Walk1.csv'
  12. HEATMAP_PATH = "Assets/Data_image/heatmap1.png"
  13. # Generate only if obstacles change
  14. def set_obstacles():
  15. global list_obstacle_length
  16. positions = pd.DataFrame(np.zeros((HEIGHT, WIDTH)))
  17. obstacles = ue.Object.FindObjectsOfType(ue.GameObject)
  18. for obstacle in obstacles:
  19. if(obstacle.layer == 15 or obstacle.layer == 12):
  20. startWidth = int(obstacle.transform.position.x - obstacle.transform.localScale.x / 2)
  21. endWidth = int(obstacle.transform.position.x + obstacle.transform.localScale.x / 2)
  22. startHeight = int(obstacle.transform.position.z - obstacle.transform.localScale.z / 2)
  23. endHeight = int(obstacle.transform.position.z + obstacle.transform.localScale.z / 2)
  24. for currentW in range(startWidth, endWidth, 1):
  25. for currentH in range(startHeight, endHeight, 1):
  26. positions[currentW][currentH] = -500
  27. positions.to_pickle(OBSTACLE_PATH)
  28. def set_patches(plt):
  29. obstacles = ue.Object.FindObjectsOfType(ue.GameObject)
  30. for obstacle in obstacles:
  31. if(obstacle.layer == 15):
  32. startWidth = int(obstacle.transform.position.x - obstacle.transform.localScale.x / 2)
  33. endWidth = int(obstacle.transform.position.x + obstacle.transform.localScale.x / 2)
  34. startHeight = int(obstacle.transform.position.z - obstacle.transform.localScale.z / 2)
  35. endHeight = int(obstacle.transform.position.z + obstacle.transform.localScale.z / 2)
  36. # plt.gca().add_patch(
  37. plt.add_patch(
  38. patches.Rectangle(
  39. (startWidth, startHeight),
  40. endWidth - startWidth,
  41. endHeight - startHeight,
  42. fill=False,
  43. color='black'
  44. )
  45. )
  46. # 1. Get position data from csv file
  47. data = pd.read_csv(POSITION_PATH, sep=';', usecols=["Position x", "Position z"], decimal=',', dtype={'Position x': float, 'Position z': float})
  48. data = data.round(0)
  49. # 2. Group by positions and count appearance
  50. data_count = data.groupby(['Position x', 'Position z']).size().reset_index(name='counts')
  51. # 3. Create wide-form DataFrame for generating heatmap
  52. positions = data_count.loc[:,:].reset_index().pivot(index='Position z', columns='Position x', values='counts')
  53. # 4. Fill missing values
  54. positions.fillna(0, inplace=True)
  55. # 5. reindex DataFrame (70,35) size of Surface; (70, 35) first x width then z height
  56. positions = positions.reindex_axis(range(0, HEIGHT), axis=0, fill_value=0)
  57. positions = positions.reindex_axis(range(0, WIDTH), axis=1, fill_value=0)
  58. # 6. Get obstacles (obstacles, market stalls) and paste specific value in positions
  59. # Therefore save pkl file in folder and read from it afterwards
  60. set_obstacles()
  61. positions_heatmap = pd.read_pickle(OBSTACLE_PATH)
  62. # 6.1 Merge positions data with obstacles data
  63. positions_heatmap.where(positions_heatmap != 0, positions, inplace=True)
  64. # 6.2 Debug Output
  65. # positions_heatmap = pd.DataFrame(positions_heatmap)
  66. # positions_heatmap.to_html('Assets/Data_image/positions_heatmap.html')
  67. # 7. Plot the heatmap
  68. fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2)
  69. ax4.remove()
  70. # cmap = LinearSegmentedColormap.from_list(name='greenToRed', colors=['grey', 'limegreen', 'chartreuse', 'yellow', 'darkorange', 'red'])
  71. cmap19 = LinearSegmentedColormap.from_list(name='2019', colors=['grey', (0.40,0.76,0.65), (0.11,0.62,0.47)])
  72. cmap20 = LinearSegmentedColormap.from_list(name='2020', colors=['grey', (0.99,0.55,0.38), (0.85,0.37,0.01)])
  73. cmap21 = LinearSegmentedColormap.from_list(name='2021', colors=['grey', (0.55,0.63,0.80), (0.46,0.44,0.70)])
  74. heatmap = sns.heatmap(positions_heatmap, cmap=cmap19, cbar=False, square=True, yticklabels=False, xticklabels=False, ax=ax1)
  75. heatmap.invert_yaxis()
  76. heatmap.set_title(cmap19.name)
  77. heatmap2 = sns.heatmap(positions_heatmap, cmap=cmap20, cbar=False, square=True, yticklabels=False, xticklabels=False, ax=ax2)
  78. heatmap2.invert_yaxis()
  79. heatmap2.set_title(cmap20.name)
  80. heatmap3 = sns.heatmap(positions_heatmap, cmap=cmap21, cbar=False, square=True, yticklabels=False, xticklabels=False, ax=ax3)
  81. heatmap3.invert_yaxis()
  82. heatmap3.set_title(cmap21.name)
  83. plt.xlabel('')
  84. plt.ylabel('')
  85. # 9. Mark the Market stalls
  86. set_patches(ax1)
  87. set_patches(ax2)
  88. set_patches(ax3)
  89. # 9.1
  90. plt.show()
  91. # 10. Save Heatmap
  92. # heatmap.get_figure().savefig(HEATMAP_PATH, transparent=True)
  93. fig.savefig(HEATMAP_PATH, transparent=True)