00001
00015 #ifndef DLR_COMPUTERVISION_KDTREE_H
00016 #define DLR_COMPUTERVISION_KDTREE_H
00017
00018 #include <functional>
00019 #include <vector>
00020
00021
00022 namespace dlr {
00023
00024 namespace computerVision {
00025
00049 template <unsigned int Dimension, class Type>
00050 class KDComparator
00051 : public std::binary_function<Type, Type, bool>
00052 {
00053 public:
00054
00064 KDComparator(unsigned int axis = 0)
00065 : m_axis(axis) {};
00066
00067
00071 virtual
00072 ~KDComparator() {}
00073
00074
00088 double
00089 computeDistance(Type const& arg0, Type const& arg1) const {
00090 double distance = 0.0;
00091 for(size_t ii = 0; ii < Dimension; ++ii) {
00092 double newTerm = arg0[ii] - arg1[ii];
00093 distance += newTerm * newTerm;
00094 }
00095 return distance;
00096 }
00097
00098
00116 double
00117 getPrimarySeparation(Type const& arg0, Type const& arg1) const {
00118 double difference = arg0[m_axis] - arg1[m_axis];
00119 return difference * difference;
00120 }
00121
00122
00136 bool isEqual(Type const& arg0, Type const& arg1) const {
00137 return arg0 == arg1;
00138 }
00139
00140
00154 bool
00155 operator()(Type const& arg0, Type const& arg1) const {
00156 if(arg0[m_axis] < arg1[m_axis]) {
00157 return true;
00158 }
00159 if(arg0[m_axis] > arg1[m_axis]) {
00160 return false;
00161 }
00162 for(size_t ii = m_axis + 1; ii < Dimension; ++ii) {
00163 if(arg0[ii] < arg1[ii]) {
00164 return true;
00165 }
00166 if(arg0[ii] > arg1[ii]) {
00167 return false;
00168 }
00169 }
00170 for(size_t ii = 0; ii < m_axis; ++ii) {
00171 if(arg0[ii] < arg1[ii]) {
00172 return true;
00173 }
00174 if(arg0[ii] > arg1[ii]) {
00175 return false;
00176 }
00177 }
00178 return false;
00179 }
00180
00181
00182 private:
00183
00184 unsigned int m_axis;
00185
00186 };
00187
00188
00189
00230 template <unsigned int Dimension, class Type>
00231 class KDTree {
00232 public:
00233
00237 KDTree();
00238
00239
00254 template <class Iter>
00255 KDTree(Iter beginIter, Iter endIter);
00256
00257
00262 virtual
00263 ~KDTree();
00264
00265
00279 bool
00280 find(Type const& point) const;
00281
00282
00303 Type const&
00304 findNearest(Type const& point, double& distance) const;
00305
00306
00307
00308
00309
00310
00311 protected:
00312
00313 template <class Iter>
00314 KDTree(Iter beginIter, Iter endIter, size_t vectorSize, size_t level);
00315
00316
00317 template <class Iter>
00318 void
00319 construct(Iter beginIter, Iter endIter, size_t vectorSize, size_t level);
00320
00321 void
00322 findNearestIterative(Type const& point,
00323 Type const*& bestPointPtr,
00324 double& bestDistance) const;
00325
00326
00327 void
00328 findNearestRecursive(Type const& point,
00329 Type const*& bestPointPtr,
00330 double& bestDistance) const;
00331
00332
00333 KDComparator<Dimension, Type> m_comparator;
00334 Type m_point;
00335 KDTree* m_leftChild;
00336 KDTree* m_rightChild;
00337 };
00338
00339 }
00340
00341 }
00342
00343
00344
00345
00346
00347 #include <algorithm>
00348 #include <cmath>
00349 #include <limits>
00350 #include <stack>
00351
00352 namespace dlr {
00353
00354 namespace computerVision {
00355
00356
00357 template <unsigned int Dimension, class Type>
00358 KDTree<Dimension, Type>::
00359 KDTree()
00360 : m_comparator(0),
00361 m_point(),
00362 m_leftChild(0),
00363 m_rightChild(0)
00364 {}
00365
00366
00367 template <unsigned int Dimension, class Type>
00368 template <class Iter>
00369 KDTree<Dimension, Type>::
00370 KDTree(Iter beginIter, Iter endIter)
00371 : m_comparator(0),
00372 m_point(),
00373 m_leftChild(0),
00374 m_rightChild(0)
00375 {
00376 if(beginIter == endIter) {
00377 return;
00378 }
00379
00380 std::vector<Type> pointVector;
00381 std::copy(beginIter, endIter, std::back_inserter(pointVector));
00382
00383 this->construct(pointVector.begin(), pointVector.end(),
00384 pointVector.size(), 0);
00385 }
00386
00387
00388
00389 template <unsigned int Dimension, class Type>
00390 KDTree<Dimension, Type>::
00391 ~KDTree() {
00392 if(m_leftChild != 0) {
00393 delete m_leftChild;
00394 }
00395 if(m_rightChild != 0) {
00396 delete m_rightChild;
00397 }
00398 }
00399
00400
00401 template <unsigned int Dimension, class Type>
00402 bool
00403 KDTree<Dimension, Type>::
00404 find(Type const& point) const
00405 {
00406 if(m_comparator.isEqual(m_point, point)) {
00407 return true;
00408 }
00409
00410 if(m_comparator(point, m_point)) {
00411 if(m_leftChild == 0) {
00412 return false;
00413 }
00414 return m_leftChild->find(point);
00415 } else {
00416 if(m_rightChild == 0) {
00417 return false;
00418 }
00419 return m_rightChild->find(point);
00420 }
00421 }
00422
00423
00424 template <unsigned int Dimension, class Type>
00425 Type const&
00426 KDTree<Dimension, Type>::
00427 findNearest(Type const& point, double& distance) const
00428 {
00429 distance = std::numeric_limits<double>::max();
00430 Type const* bestPointPtr = &m_point;
00431
00432 this->findNearestIterative(point, bestPointPtr, distance);
00433 return *bestPointPtr;
00434 }
00435
00436
00437
00438
00439 template <unsigned int Dimension, class Type>
00440 template <class Iter>
00441 KDTree<Dimension, Type>::
00442 KDTree(Iter beginIter, Iter endIter, size_t vectorSize, size_t level)
00443 : m_comparator(level % Dimension),
00444 m_point(),
00445 m_leftChild(0),
00446 m_rightChild(0)
00447 {
00448 this->construct(beginIter, endIter, vectorSize, level);
00449 }
00450
00451
00452 template <unsigned int Dimension, class Type>
00453 template <class Iter>
00454 void
00455 KDTree<Dimension, Type>::
00456 construct(Iter beginIter, Iter endIter, size_t vectorSize, size_t level)
00457 {
00458 std::sort(beginIter, endIter, m_comparator);
00459 size_t partitionIndex = vectorSize / 2;
00460
00461 m_point = *(beginIter + partitionIndex);
00462
00463 if(vectorSize == 1) {
00464 return;
00465 }
00466 m_leftChild = new KDTree(
00467 beginIter, beginIter + partitionIndex,
00468 partitionIndex, level + 1);
00469
00470 if(vectorSize == 2) {
00471 return;
00472 }
00473 m_rightChild = new KDTree(
00474 beginIter + (partitionIndex + 1), endIter,
00475 vectorSize - (partitionIndex + 1), level + 1);
00476 }
00477
00478
00479 template <unsigned int Dimension, class Type>
00480 void
00481 KDTree<Dimension, Type>::
00482 findNearestIterative(Type const& point,
00483 Type const*& bestPointPtr,
00484 double& bestDistance) const
00485 {
00486
00487
00488
00489
00490 std::stack< std::pair< KDTree<Dimension, Type> const*, double> >
00491 kdTreeStack;
00492 kdTreeStack.push(std::make_pair(this, 0.0));
00493
00494 while(!(kdTreeStack.empty())) {
00495 KDTree<Dimension, Type> const* currentTree = kdTreeStack.top().first;
00496 double bound = kdTreeStack.top().second;
00497 kdTreeStack.pop();
00498
00499 if(bestDistance < bound) {
00500 continue;
00501 }
00502
00503 double myDistance = currentTree->m_comparator.computeDistance(
00504 point, currentTree->m_point);
00505 if(myDistance < bestDistance) {
00506 bestDistance = myDistance;
00507 bestPointPtr = ¤tTree->m_point;
00508 }
00509
00510 if(currentTree->m_leftChild == 0 && currentTree->m_rightChild == 0) {
00511 continue;
00512 }
00513
00514 KDTree* nearChildPtr;
00515 KDTree* farChildPtr;
00516 bool isLeft = currentTree->m_comparator(point, currentTree->m_point);
00517 if(isLeft) {
00518 nearChildPtr = currentTree->m_leftChild;
00519 farChildPtr = currentTree->m_rightChild;
00520 } else {
00521 nearChildPtr = currentTree->m_rightChild;
00522 farChildPtr = currentTree->m_leftChild;
00523 }
00524
00525
00526
00527
00528
00529
00530
00531
00532
00533
00534 if(farChildPtr) {
00535 double remoteDistanceLowerBound =
00536 currentTree->m_comparator.getPrimarySeparation(
00537 point, currentTree->m_point);
00538 if(remoteDistanceLowerBound < bestDistance) {
00539 kdTreeStack.push(
00540 std::make_pair(farChildPtr, remoteDistanceLowerBound));
00541 }
00542 }
00543
00544 if(nearChildPtr) {
00545 kdTreeStack.push(std::make_pair(nearChildPtr, 0.0));
00546 }
00547 }
00548 }
00549
00550
00551 template <unsigned int Dimension, class Type>
00552 void
00553 KDTree<Dimension, Type>::
00554 findNearestRecursive(Type const& point,
00555 Type const*& bestPointPtr,
00556 double& bestDistance) const
00557 {
00558 double myDistance = m_comparator.computeDistance(point, m_point);
00559 if(myDistance < bestDistance) {
00560 bestDistance = myDistance;
00561 bestPointPtr = &m_point;
00562 }
00563
00564 if(m_leftChild == 0 && m_rightChild == 0) {
00565 return;
00566 }
00567
00568 bool isLeft = m_comparator(point, m_point);
00569 double remoteDistanceLowerBound = m_comparator.getPrimarySeparation(
00570 point, m_point);
00571
00572 KDTree* nearChildPtr;
00573 KDTree* farChildPtr;
00574 if(isLeft) {
00575 nearChildPtr = m_leftChild;
00576 farChildPtr = m_rightChild;
00577 } else {
00578 nearChildPtr = m_rightChild;
00579 farChildPtr = m_leftChild;
00580 }
00581
00582 if(nearChildPtr) {
00583 nearChildPtr->findNearestRecursive(point, bestPointPtr, bestDistance);
00584 }
00585
00586 if((remoteDistanceLowerBound < bestDistance) && farChildPtr) {
00587 farChildPtr->findNearestRecursive(point, bestPointPtr, bestDistance);
00588 }
00589 }
00590
00591 }
00592
00593 }
00594
00595 #endif