kdTree.h

Go to the documentation of this file.
00001 
00015 #ifndef DLR_COMPUTERVISION_KDTREE_H
00016 #define DLR_COMPUTERVISION_KDTREE_H
00017 
00018 #include <functional>
00019 #include <vector>
00020 
00021 
00022 namespace dlr {
00023 
00024   namespace computerVision {
00025 
00049     template <unsigned int Dimension, class Type>
00050     class KDComparator
00051       : public std::binary_function<Type, Type, bool>
00052     {
00053       public:
00054       
00064       KDComparator(unsigned int axis = 0)
00065         : m_axis(axis) {};
00066       
00067       
00071       virtual
00072       ~KDComparator() {}
00073       
00074 
00088       double
00089       computeDistance(Type const& arg0, Type const& arg1) const {
00090         double distance = 0.0;
00091         for(size_t ii = 0; ii < Dimension; ++ii) {
00092           double newTerm = arg0[ii] - arg1[ii];
00093           distance += newTerm * newTerm;
00094         }
00095         return distance;
00096       }
00097 
00098 
00116       double
00117       getPrimarySeparation(Type const& arg0, Type const& arg1) const {
00118         double difference = arg0[m_axis] - arg1[m_axis];
00119         return difference * difference;
00120       }
00121 
00122 
00136       bool isEqual(Type const& arg0, Type const& arg1) const {
00137         return arg0 == arg1;
00138       }
00139       
00140         
00154       bool
00155       operator()(Type const& arg0, Type const& arg1) const {
00156         if(arg0[m_axis] < arg1[m_axis]) {
00157           return true;
00158         }
00159         if(arg0[m_axis] > arg1[m_axis]) {
00160           return false;
00161         }
00162         for(size_t ii = m_axis + 1; ii < Dimension; ++ii) {
00163           if(arg0[ii] < arg1[ii]) {
00164             return true;
00165           }
00166           if(arg0[ii] > arg1[ii]) {
00167             return false;
00168           }
00169         }
00170         for(size_t ii = 0; ii < m_axis; ++ii) {
00171           if(arg0[ii] < arg1[ii]) {
00172             return true;
00173           }
00174           if(arg0[ii] > arg1[ii]) {
00175             return false;
00176           }
00177         }
00178         return false;
00179       }
00180       
00181       
00182     private:
00183 
00184       unsigned int m_axis;
00185       
00186     };
00187 
00188 
00189     
00230     template <unsigned int Dimension, class Type>
00231     class KDTree {
00232     public:
00233 
00237       KDTree();
00238 
00239 
00254       template <class Iter>
00255       KDTree(Iter beginIter, Iter endIter);
00256 
00257 
00262       virtual
00263       ~KDTree();
00264 
00265 
00279       bool
00280       find(Type const& point) const;
00281 
00282 
00303       Type const&
00304       findNearest(Type const& point, double& distance) const;
00305 
00306 
00307       // void
00308       // rebalance();
00309       
00310       
00311     protected:
00312 
00313       template <class Iter>
00314       KDTree(Iter beginIter, Iter endIter, size_t vectorSize, size_t level);
00315 
00316 
00317       template <class Iter>
00318       void
00319       construct(Iter beginIter, Iter endIter, size_t vectorSize, size_t level);
00320 
00321       void
00322       findNearestIterative(Type const& point,
00323                            Type const*& bestPointPtr,
00324                            double& bestDistance) const;
00325 
00326       
00327       void
00328       findNearestRecursive(Type const& point,
00329                            Type const*& bestPointPtr,
00330                            double& bestDistance) const;
00331       
00332 
00333       KDComparator<Dimension, Type> m_comparator;
00334       Type m_point;
00335       KDTree* m_leftChild;
00336       KDTree* m_rightChild;
00337     };
00338 
00339   } // namespace computerVision
00340   
00341 } // namespace dlr
00342 
00343 
00344 /* ============ Definitions of inline & template functions ============ */
00345 
00346 
00347 #include <algorithm>
00348 #include <cmath>
00349 #include <limits>
00350 #include <stack>
00351 
00352 namespace dlr {
00353 
00354   namespace computerVision {
00355 
00356 
00357     template <unsigned int Dimension, class Type>
00358     KDTree<Dimension, Type>::
00359     KDTree()
00360       : m_comparator(0),
00361         m_point(),
00362         m_leftChild(0),
00363         m_rightChild(0)
00364     {}
00365 
00366 
00367     template <unsigned int Dimension, class Type>
00368     template <class Iter>
00369     KDTree<Dimension, Type>::
00370     KDTree(Iter beginIter, Iter endIter)
00371       : m_comparator(0),
00372         m_point(),
00373         m_leftChild(0),
00374         m_rightChild(0)
00375     {
00376       if(beginIter == endIter) {
00377         return;
00378       }
00379 
00380       std::vector<Type> pointVector;
00381       std::copy(beginIter, endIter, std::back_inserter(pointVector));
00382 
00383       this->construct(pointVector.begin(), pointVector.end(),
00384                       pointVector.size(), 0);
00385     }
00386       
00387     
00388     // The destructor cleans up any system resources during destruction.
00389     template <unsigned int Dimension, class Type>
00390     KDTree<Dimension, Type>::
00391     ~KDTree() {
00392       if(m_leftChild != 0) {
00393         delete m_leftChild;
00394       }
00395       if(m_rightChild != 0) {
00396         delete m_rightChild;
00397       }
00398     }
00399 
00400 
00401     template <unsigned int Dimension, class Type>
00402     bool
00403     KDTree<Dimension, Type>::
00404     find(Type const& point) const
00405     {
00406       if(m_comparator.isEqual(m_point, point)) {
00407         return true;
00408       }
00409 
00410       if(m_comparator(point, m_point)) {
00411         if(m_leftChild == 0) {
00412           return false;
00413         }
00414         return m_leftChild->find(point);
00415       } else {
00416         if(m_rightChild == 0) {
00417           return false;
00418         }
00419         return m_rightChild->find(point);
00420       }
00421     }
00422     
00423 
00424     template <unsigned int Dimension, class Type>
00425     Type const&
00426     KDTree<Dimension, Type>::
00427     findNearest(Type const& point, double& distance) const
00428     {
00429       distance = std::numeric_limits<double>::max();
00430       Type const* bestPointPtr = &m_point;
00431       // this->findNearestRecursive(point, bestPointPtr, distance);
00432       this->findNearestIterative(point, bestPointPtr, distance);
00433       return *bestPointPtr;
00434     }
00435 
00436 
00437     /* ================ Protected ================= */
00438 
00439     template <unsigned int Dimension, class Type>
00440     template <class Iter>
00441     KDTree<Dimension, Type>::
00442     KDTree(Iter beginIter, Iter endIter, size_t vectorSize, size_t level)
00443       : m_comparator(level % Dimension),
00444         m_point(),
00445         m_leftChild(0),
00446         m_rightChild(0)
00447     {
00448       this->construct(beginIter, endIter, vectorSize, level);
00449     }
00450 
00451 
00452     template <unsigned int Dimension, class Type>
00453     template <class Iter>
00454     void
00455     KDTree<Dimension, Type>::
00456     construct(Iter beginIter, Iter endIter, size_t vectorSize, size_t level)
00457     {
00458       std::sort(beginIter, endIter, m_comparator);
00459       size_t partitionIndex = vectorSize / 2;
00460 
00461       m_point = *(beginIter + partitionIndex);
00462 
00463       if(vectorSize == 1) {
00464         return;
00465       }
00466       m_leftChild = new KDTree(
00467         beginIter, beginIter + partitionIndex,
00468         partitionIndex, level + 1);
00469 
00470       if(vectorSize == 2) {
00471         return;
00472       }
00473       m_rightChild = new KDTree(
00474         beginIter + (partitionIndex + 1), endIter,
00475         vectorSize - (partitionIndex + 1), level + 1);
00476     }
00477 
00478 
00479     template <unsigned int Dimension, class Type>
00480     void
00481     KDTree<Dimension, Type>::
00482     findNearestIterative(Type const& point,
00483                          Type const*& bestPointPtr,
00484                          double& bestDistance) const
00485     {
00486       // Contents of this stack are std::pairs in which the first
00487       // element points to an un-searched tree, and the second element
00488       // is a lower bound on the distance from argument point to the
00489       // points contained in the un-searched tree.
00490       std::stack< std::pair< KDTree<Dimension, Type> const*, double> >
00491         kdTreeStack;
00492       kdTreeStack.push(std::make_pair(this, 0.0));
00493 
00494       while(!(kdTreeStack.empty())) {
00495         KDTree<Dimension, Type> const* currentTree = kdTreeStack.top().first;
00496         double bound = kdTreeStack.top().second;
00497         kdTreeStack.pop();
00498 
00499         if(bestDistance < bound) {
00500           continue;
00501         }
00502         
00503         double myDistance = currentTree->m_comparator.computeDistance(
00504           point, currentTree->m_point);
00505         if(myDistance < bestDistance) {
00506           bestDistance = myDistance;
00507           bestPointPtr = &currentTree->m_point;
00508         }
00509 
00510         if(currentTree->m_leftChild == 0 && currentTree->m_rightChild == 0) {
00511           continue;
00512         }
00513       
00514         KDTree* nearChildPtr;
00515         KDTree* farChildPtr;
00516         bool isLeft = currentTree->m_comparator(point, currentTree->m_point);
00517         if(isLeft) {
00518           nearChildPtr = currentTree->m_leftChild;
00519           farChildPtr = currentTree->m_rightChild;
00520         } else {
00521           nearChildPtr = currentTree->m_rightChild;
00522           farChildPtr = currentTree->m_leftChild;
00523         }
00524 
00525         // Push remote child first, and near child second, so that
00526         // near child will be popped first, increasing the chance that
00527         // remote child will be eliminated.
00528 
00529         // Only push the remote child if it's plausible that it
00530         // contains a closer point than our best so far.  This
00531         // duplicates a similar test above. Not sure if it's worth
00532         // duplicating the test to avoid the occasional extra
00533         // push/pop.
00534         if(farChildPtr) {
00535           double remoteDistanceLowerBound =
00536             currentTree->m_comparator.getPrimarySeparation(
00537               point, currentTree->m_point);
00538           if(remoteDistanceLowerBound < bestDistance) {
00539             kdTreeStack.push(
00540               std::make_pair(farChildPtr, remoteDistanceLowerBound));
00541           }
00542         }
00543 
00544         if(nearChildPtr) {
00545           kdTreeStack.push(std::make_pair(nearChildPtr, 0.0));
00546         }
00547       }
00548     }
00549 
00550       
00551     template <unsigned int Dimension, class Type>
00552     void
00553     KDTree<Dimension, Type>::
00554     findNearestRecursive(Type const& point,
00555                          Type const*& bestPointPtr,
00556                          double& bestDistance) const
00557     {
00558       double myDistance = m_comparator.computeDistance(point, m_point);
00559       if(myDistance < bestDistance) {
00560         bestDistance = myDistance;
00561         bestPointPtr = &m_point;
00562       }
00563 
00564       if(m_leftChild == 0 && m_rightChild == 0) {
00565         return;
00566       }
00567       
00568       bool isLeft = m_comparator(point, m_point);
00569       double remoteDistanceLowerBound = m_comparator.getPrimarySeparation(
00570         point, m_point);
00571 
00572       KDTree* nearChildPtr;
00573       KDTree* farChildPtr;
00574       if(isLeft) {
00575         nearChildPtr = m_leftChild;
00576         farChildPtr = m_rightChild;
00577       } else {
00578         nearChildPtr = m_rightChild;
00579         farChildPtr = m_leftChild;
00580       }        
00581 
00582       if(nearChildPtr) {
00583         nearChildPtr->findNearestRecursive(point, bestPointPtr, bestDistance);
00584       }
00585 
00586       if((remoteDistanceLowerBound < bestDistance) && farChildPtr) {
00587         farChildPtr->findNearestRecursive(point, bestPointPtr, bestDistance);
00588       }
00589     }
00590     
00591   } // namespace computerVision
00592   
00593 } // namespace dlr
00594 
00595 #endif /* #ifndef DLR_COMPUTERVISION_KDTREE_H */

Generated on Wed Nov 25 12:15:05 2009 for dlrComputerVision Utility Library by  doxygen 1.5.8