vptree.h 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. /*
  2. *
  3. * Copyright (c) 2014, Laurens van der Maaten (Delft University of Technology)
  4. * All rights reserved.
  5. *
  6. * Redistribution and use in source and binary forms, with or without
  7. * modification, are permitted provided that the following conditions are met:
  8. * 1. Redistributions of source code must retain the above copyright
  9. * notice, this list of conditions and the following disclaimer.
  10. * 2. Redistributions in binary form must reproduce the above copyright
  11. * notice, this list of conditions and the following disclaimer in the
  12. * documentation and/or other materials provided with the distribution.
  13. * 3. All advertising materials mentioning features or use of this software
  14. * must display the following acknowledgement:
  15. * This product includes software developed by the Delft University of Technology.
  16. * 4. Neither the name of the Delft University of Technology nor the names of
  17. * its contributors may be used to endorse or promote products derived from
  18. * this software without specific prior written permission.
  19. *
  20. * THIS SOFTWARE IS PROVIDED BY LAURENS VAN DER MAATEN ''AS IS'' AND ANY EXPRESS
  21. * OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
  22. * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
  23. * EVENT SHALL LAURENS VAN DER MAATEN BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  24. * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  25. * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
  26. * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  27. * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
  28. * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
  29. * OF SUCH DAMAGE.
  30. *
  31. */
  32. /* This code was adopted with minor modifications from Steve Hanov's great tutorial at http://stevehanov.ca/blog/index.php?id=130 */
  33. #include <stdlib.h>
  34. #include <algorithm>
  35. #include <vector>
  36. #include <stdio.h>
  37. #include <queue>
  38. #include <limits>
  39. #include <cmath>
  40. #ifndef VPTREE_H
  41. #define VPTREE_H
  42. class DataPoint
  43. {
  44. int _ind;
  45. public:
  46. double* _x;
  47. int _D;
  48. DataPoint() {
  49. _D = 1;
  50. _ind = -1;
  51. _x = NULL;
  52. }
  53. DataPoint(int D, int ind, double* x) {
  54. _D = D;
  55. _ind = ind;
  56. _x = (double*) malloc(_D * sizeof(double));
  57. for(int d = 0; d < _D; d++) _x[d] = x[d];
  58. }
  59. DataPoint(const DataPoint& other) { // this makes a deep copy -- should not free anything
  60. if(this != &other) {
  61. _D = other.dimensionality();
  62. _ind = other.index();
  63. _x = (double*) malloc(_D * sizeof(double));
  64. for(int d = 0; d < _D; d++) _x[d] = other.x(d);
  65. }
  66. }
  67. ~DataPoint() { if(_x != NULL) free(_x); }
  68. DataPoint& operator= (const DataPoint& other) { // asignment should free old object
  69. if(this != &other) {
  70. if(_x != NULL) free(_x);
  71. _D = other.dimensionality();
  72. _ind = other.index();
  73. _x = (double*) malloc(_D * sizeof(double));
  74. for(int d = 0; d < _D; d++) _x[d] = other.x(d);
  75. }
  76. return *this;
  77. }
  78. int index() const { return _ind; }
  79. int dimensionality() const { return _D; }
  80. double x(int d) const { return _x[d]; }
  81. };
  82. inline double euclidean_distance(const DataPoint &t1, const DataPoint &t2) {
  83. double dd = .0;
  84. double* x1 = t1._x;
  85. double* x2 = t2._x;
  86. double diff;
  87. for(int d = 0; d < t1._D; d++) {
  88. diff = (x1[d] - x2[d]);
  89. dd += diff * diff;
  90. }
  91. return sqrt(dd);
  92. }
  93. template<typename T, double (*distance)( const T&, const T& )>
  94. class VpTree
  95. {
  96. public:
  97. // Default constructor
  98. VpTree() : _root(0) {}
  99. // Destructor
  100. ~VpTree() {
  101. delete _root;
  102. }
  103. // Function to create a new VpTree from data
  104. void create(const std::vector<T>& items) {
  105. delete _root;
  106. _items = items;
  107. _root = buildFromPoints(0, items.size());
  108. }
  109. // Function that uses the tree to find the k nearest neighbors of target
  110. void search(const T& target, int k, std::vector<T>* results, std::vector<double>* distances)
  111. {
  112. // Use a priority queue to store intermediate results on
  113. std::priority_queue<HeapItem> heap;
  114. // Variable that tracks the distance to the farthest point in our results
  115. _tau = DBL_MAX;
  116. // Perform the search
  117. search(_root, target, k, heap);
  118. // Gather final results
  119. results->clear(); distances->clear();
  120. while(!heap.empty()) {
  121. results->push_back(_items[heap.top().index]);
  122. distances->push_back(heap.top().dist);
  123. heap.pop();
  124. }
  125. // Results are in reverse order
  126. std::reverse(results->begin(), results->end());
  127. std::reverse(distances->begin(), distances->end());
  128. }
  129. private:
  130. std::vector<T> _items;
  131. double _tau;
  132. // Single node of a VP tree (has a point and radius; left children are closer to point than the radius)
  133. struct Node
  134. {
  135. int index; // index of point in node
  136. double threshold; // radius(?)
  137. Node* left; // points closer by than threshold
  138. Node* right; // points farther away than threshold
  139. Node() :
  140. index(0), threshold(0.), left(0), right(0) {}
  141. ~Node() { // destructor
  142. delete left;
  143. delete right;
  144. }
  145. }* _root;
  146. // An item on the intermediate result queue
  147. struct HeapItem {
  148. HeapItem( int index, double dist) :
  149. index(index), dist(dist) {}
  150. int index;
  151. double dist;
  152. bool operator<(const HeapItem& o) const {
  153. return dist < o.dist;
  154. }
  155. };
  156. // Distance comparator for use in std::nth_element
  157. struct DistanceComparator
  158. {
  159. const T& item;
  160. DistanceComparator(const T& item) : item(item) {}
  161. bool operator()(const T& a, const T& b) {
  162. return distance(item, a) < distance(item, b);
  163. }
  164. };
  165. // Function that (recursively) fills the tree
  166. Node* buildFromPoints( int lower, int upper )
  167. {
  168. if (upper == lower) { // indicates that we're done here!
  169. return NULL;
  170. }
  171. // Lower index is center of current node
  172. Node* node = new Node();
  173. node->index = lower;
  174. if (upper - lower > 1) { // if we did not arrive at leaf yet
  175. // Choose an arbitrary point and move it to the start
  176. int i = (int) ((double)rand() / RAND_MAX * (upper - lower - 1)) + lower;
  177. std::swap(_items[lower], _items[i]);
  178. // Partition around the median distance
  179. int median = (upper + lower) / 2;
  180. std::nth_element(_items.begin() + lower + 1,
  181. _items.begin() + median,
  182. _items.begin() + upper,
  183. DistanceComparator(_items[lower]));
  184. // Threshold of the new node will be the distance to the median
  185. node->threshold = distance(_items[lower], _items[median]);
  186. // Recursively build tree
  187. node->index = lower;
  188. node->left = buildFromPoints(lower + 1, median);
  189. node->right = buildFromPoints(median, upper);
  190. }
  191. // Return result
  192. return node;
  193. }
  194. // Helper function that searches the tree
  195. void search(Node* node, const T& target, int k, std::priority_queue<HeapItem>& heap)
  196. {
  197. if(node == NULL) return; // indicates that we're done here
  198. // Compute distance between target and current node
  199. double dist = distance(_items[node->index], target);
  200. // If current node within radius tau
  201. if(dist < _tau) {
  202. if(heap.size() == k) heap.pop(); // remove furthest node from result list (if we already have k results)
  203. heap.push(HeapItem(node->index, dist)); // add current node to result list
  204. if(heap.size() == k) _tau = heap.top().dist; // update value of tau (farthest point in result list)
  205. }
  206. // Return if we arrived at a leaf
  207. if(node->left == NULL && node->right == NULL) {
  208. return;
  209. }
  210. // If the target lies within the radius of ball
  211. if(dist < node->threshold) {
  212. if(dist - _tau <= node->threshold) { // if there can still be neighbors inside the ball, recursively search left child first
  213. search(node->left, target, k, heap);
  214. }
  215. if(dist + _tau >= node->threshold) { // if there can still be neighbors outside the ball, recursively search right child
  216. search(node->right, target, k, heap);
  217. }
  218. // If the target lies outsize the radius of the ball
  219. } else {
  220. if(dist + _tau >= node->threshold) { // if there can still be neighbors outside the ball, recursively search right child first
  221. search(node->right, target, k, heap);
  222. }
  223. if (dist - _tau <= node->threshold) { // if there can still be neighbors inside the ball, recursively search left child
  224. search(node->left, target, k, heap);
  225. }
  226. }
  227. }
  228. };
  229. #endif