ransac.h
Go to the documentation of this file.00001
00015 #ifndef DLR_COMPUTERVISION_RANSAC_H
00016 #define DLR_COMPUTERVISION_RANSAC_H
00017
00018 #include <vector>
00019 #include <dlrComputerVision/randomSampleSelector.h>
00020
00021 namespace dlr {
00022
00023 namespace computerVision {
00024
00031 enum RansacInlierStrategy {
00032 DLR_CV_NAIVE_ERROR_THRESHOLD
00033 };
00034
00035
00051 template <class Problem>
00052 class Ransac {
00053 public:
00054
00058 typedef Problem ProblemType;
00059
00060
00066 typedef typename Problem::ModelType ResultType;
00067
00068
00100 Ransac(ProblemType const& problem,
00101 size_t minimumConsensusSize = 0,
00102 double requiredConfidence = 0.99,
00103 double inlierProbability = 0.5);
00104
00105
00109 virtual
00110 ~Ransac() {}
00111
00112
00120 virtual ResultType
00121 getResult() {
00122
00123 ResultType result;
00124 this->estimate(result);
00125 return result;
00126 }
00127
00128 protected:
00129
00130 void
00131 computeConsensusSet(ResultType& model, std::vector<bool>& consensusFlags);
00132
00133 bool
00134 estimate(ResultType& model);
00135
00136 bool
00137 isConverged(const std::vector<bool>& consensusFlags,
00138 const std::vector<bool>& previousConsensusFlags);
00139
00140
00141 size_t m_minimumConsensusSize;
00142 size_t m_numberOfRandomSampleSets;
00143 ProblemType m_problem;
00144 };
00145
00146
00167 template <class Sample, class Model>
00168 class RansacProblem
00169 : public RandomSampleSelector<Sample>
00170 {
00171 public:
00172
00173
00174
00175
00176
00181 typedef Model ModelType;
00182
00183
00188 typedef Sample SampleType;
00189
00190
00217 typedef typename RandomSampleSelector<SampleType>::SampleSequenceType
00218 SampleSequenceType;
00219
00220
00221
00222
00223
00238 virtual ModelType
00239 estimateModel(SampleSequenceType const& sampleSequence) = 0;
00240
00241
00260 template <class IterType>
00261 void
00262 computeError(ModelType const& model,
00263 SampleSequenceType const& sampleSequence,
00264 IterType ouputIter) {};
00265
00266
00282 virtual double
00283 getNaiveErrorThreshold() = 0;
00284
00285
00286
00287
00288
00289
00308 template <class IterType>
00309 RansacProblem(size_t sampleSize, IterType beginIter, IterType endIter)
00310 : RandomSampleSelector<SampleType>(beginIter, endIter),
00311 m_sampleSize(sampleSize) {}
00312
00313
00317 virtual
00318 ~RansacProblem() {}
00319
00320
00330 virtual size_t
00331 getSampleSize() {return m_sampleSize;}
00332
00333
00341 RansacInlierStrategy
00342 getInlierStrategy() {
00343 return DLR_CV_NAIVE_ERROR_THRESHOLD;
00344 }
00345
00346 protected:
00347
00348 size_t m_sampleSize;
00349
00350 };
00351
00352 }
00353
00354 }
00355
00356
00357
00358
00359
00360
00361 #include <cmath>
00362 #include <algorithm>
00363 #include <functional>
00364 #include <dlrCommon/exception.h>
00365 #include <dlrNumeric/maxRecorder.h>
00366
00367 namespace dlr {
00368
00369 namespace computerVision {
00370
00371
00372 template <class Problem>
00373 Ransac<Problem>::
00374 Ransac(Problem const& problem,
00375 size_t minimumConsensusSize,
00376 double requiredConfidence,
00377 double inlierProbability)
00378 : m_minimumConsensusSize(minimumConsensusSize),
00379 m_numberOfRandomSampleSets(),
00380 m_problem(problem)
00381 {
00382 size_t sampleSize = m_problem.getSampleSize();
00383
00384
00385 double singlePickConfidence =
00386 std::pow(inlierProbability, static_cast<double>(sampleSize));
00387
00388
00389
00390 double singlePickDisconfidence = 1.0 - singlePickConfidence;
00391
00392
00393
00394
00395 m_numberOfRandomSampleSets =
00396 std::log(1.0 - requiredConfidence) / std::log(singlePickDisconfidence);
00397
00398 if(m_minimumConsensusSize == 0) {
00399
00400
00401
00402
00403
00404
00405 int extraSamples = static_cast<int>(
00406 std::log(1.0 - requiredConfidence) / std::log(0.5) + 0.5);
00407 if(extraSamples < 0) {
00408 extraSamples = 0;
00409 }
00410 m_minimumConsensusSize = sampleSize + extraSamples;
00411 }
00412
00413 }
00414
00415
00416 template <class Problem>
00417 void
00418 Ransac<Problem>::
00419 computeConsensusSet(typename Ransac<Problem>::ResultType& model,
00420 std::vector<bool>& consensusFlags)
00421 {
00422 if(m_problem.getInlierStrategy() != DLR_CV_NAIVE_ERROR_THRESHOLD) {
00423 DLR_THROW(NotImplementedException, "Ransac::computeConsensusSet()",
00424 "Currently only naive error thresholding is supported.");
00425 }
00426
00427
00428 typename Problem::SampleSequenceType testSet = m_problem.getPool();
00429 std::vector<double> errorMetrics(m_problem.getPoolSize());
00430 m_problem.computeError(model, testSet, errorMetrics.begin());
00431
00432
00433 double threshold = m_problem.getNaiveErrorThreshold();
00434 std::transform(
00435 errorMetrics.begin(), errorMetrics.end(), consensusFlags.begin(),
00436 std::bind2nd(std::less<double>(), threshold));
00437 }
00438
00439
00440 template <class Problem>
00441 bool
00442 Ransac<Problem>::
00443 estimate(typename Ransac<Problem>::ResultType& model)
00444 {
00445 dlr::numeric::MaxRecorder<size_t, ResultType> maxRecorder;
00446 for(size_t iteration = 0; iteration < m_numberOfRandomSampleSets;
00447 ++iteration) {
00448
00449
00450 typename ProblemType::SampleSequenceType trialSet =
00451 m_problem.getRandomSample(m_problem.getSampleSize());
00452
00453 std::vector<bool> consensusFlags(m_problem.getPoolSize());
00454 std::vector<bool> previousConsensusFlags(m_problem.getPoolSize(),
00455 false);
00456 while(1) {
00457
00458 model = m_problem.estimateModel(trialSet);
00459
00460
00461
00462 this->computeConsensusSet(model, consensusFlags);
00463
00464
00465 if(this->isConverged(consensusFlags, previousConsensusFlags)) {
00466 break;
00467 }
00468
00469
00470
00471 trialSet = m_problem.getSubset(
00472 consensusFlags.begin(), consensusFlags.end());
00473 previousConsensusFlags = consensusFlags;
00474 }
00475
00476
00477
00478 size_t consensusSetSize = std::count(
00479 consensusFlags.begin(), consensusFlags.end(), true);
00480 if(consensusSetSize > m_minimumConsensusSize) {
00481
00482 return true;
00483 }
00484
00485
00486
00487 maxRecorder.test(consensusSetSize, model);
00488 }
00489
00490
00491
00492
00493 model = maxRecorder.getPayload();
00494 return false;
00495 }
00496
00497
00498 template <class Problem>
00499 bool
00500 Ransac<Problem>::
00501 isConverged(const std::vector<bool>& consensusFlags,
00502 const std::vector<bool>& previousConsensusFlags)
00503 {
00504 if(consensusFlags.size() != previousConsensusFlags.size()) {
00505 return false;
00506 }
00507 return std::equal(consensusFlags.begin(), consensusFlags.end(),
00508 previousConsensusFlags.begin());
00509 }
00510
00511 }
00512
00513 }
00514
00515 #endif