ms.cpp 79 KB


  1. /*******************************************************
  2. Mean Shift Analysis Library
  3. =============================================
  4. The mean shift library is a collection of routines
  5. that use the mean shift algorithm. Using this algorithm,
  6. the necessary output will be generated needed
  7. to analyze a given input set of data.
  8. MeanShift Base Class:
  9. ====================
  10. The mean shift library of routines is realized
  11. via the creation of a MeanShift base class. This class
  12. provides a mechanism for calculating the mean shift vector
  13. at a specified data point, using an arbitrary N-dimensional
  14. data set, and a user-defined kernel.
  15. For image processing the mean shift base class also allows
  16. for the definition of a data set that is on a two-dimensional
  17. lattice. The amount of time needed to compute the mean shift
  18. vector using such a data set is much less than that of an
  19. arbitrary one. Because images usually contain many data points,
  20. defining the image input data points as being on a lattice
  21. greatly improves computation time and makes algorithms such
  22. as image filtering practical.
  23. The definition of the MeanShift class is provided below. Its
  24. prototype is provided in 'ms.h'.
  25. The theory is described in the papers:
  26. D. Comaniciu, P. Meer: Mean Shift: A robust approach toward feature
  27. space analysis.
  28. C. Christoudias, B. Georgescu, P. Meer: Synergism in low level vision.
  29. and they are is available at:
  30. http://www.caip.rutgers.edu/riul/research/papers/
  31. Implemented by Chris M. Christoudias, Bogdan Georgescu
  32. ********************************************************/
  33. //Include Needed Libraries
  34. #include "ms.h"
  35. #include <string.h>
  36. #include <stdlib.h>
  37. #include <stdio.h>
  38. #include <math.h>
  39. /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
  40. /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
  41. /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ PUBLIC METHODS @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
  42. /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
  43. /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
  44. /*/\/\/\/\/\/\/\/\/\/\/\/\/\/\*/
  45. /*** Constructor/Destructor ***/
  46. /*\/\/\/\/\/\/\/\/\/\/\/\/\/\/*/
  47. /*******************************************************/
  48. /*Class Constructor */
  49. /*******************************************************/
  50. /*Post: */
  51. /* The MeanShift class has been properly */
  52. /* initialized. */
  53. /*******************************************************/
  54. MeanShift::MeanShift ( void )
  55. {
  56. //intialize input data set parameters...
  57. P = NULL;
  58. L = 0;
  59. N = 0;
  60. kp = 0;
  61. //initialize input data set storage structures...
  62. data = NULL;
  63. //initialize input data set kd-tree
  64. root = NULL;
  65. forest = NULL;
  66. range = NULL;
  67. //intialize lattice structure...
  68. height = 0;
  69. width = 0;
  70. //intialize kernel strucuture...
  71. h = NULL;
  72. kernel = NULL;
  73. w = NULL;
  74. offset = NULL;
  75. increment = NULL;
  76. uniformKernel = false;
  77. //initialize weight function linked list...
  78. head = cur = NULL;
  79. //intialize mean shift processing data structures...
  80. uv = NULL;
  81. //set lattice weight map to null
  82. weightMap = NULL;
  83. //indicate that the lattice weight map is undefined
  84. weightMapDefined = false;
  85. //allocate memory for error message buffer...
  86. ErrorMessage = new char [256];
  87. //initialize error status to OKAY
  88. ErrorStatus = EL_OKAY;
  89. //Initialize class state...
  90. class_state.INPUT_DEFINED = false;
  91. class_state.KERNEL_DEFINED = false;
  92. class_state.LATTICE_DEFINED = false;
  93. class_state.OUTPUT_DEFINED = false;
  94. }
  95. /*******************************************************/
  96. /*Class Destructor */
  97. /*******************************************************/
  98. /*Post: */
  99. /* The MeanShift class has been properly */
  100. /* destroyed. */
  101. /*******************************************************/
  102. MeanShift::~MeanShift ( void )
  103. {
  104. delete [] ErrorMessage;
  105. if ( weightMap )
  106. {
  107. delete [] weightMap;
  108. }
  109. //de-allocate memory used to store
  110. //user defined weight functions
  111. ClearWeightFunctions();
  112. //de-allocate memory used for kernel
  113. DestroyKernel();
  114. //de-allocate memory used for input
  115. ResetInput();
  116. }
  117. /*/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\*/
  118. /*** Creation/Initialization of Mean Shift Kernel ***/
  119. /*\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/*/
  120. /*******************************************************/
  121. /*Define Kernel */
  122. /*******************************************************/
  123. /*Creats custom user defined Kernel to be used by the */
  124. /*mean shift procedure. */
  125. /*******************************************************/
  126. /*Pre: */
  127. /* - kernel is an array of kernelTypes specifying */
  128. /* the type of kernel to be used on each sub- */
  129. /* space of the input data set x */
  130. /* - h is the set of bandwidths used to define the*/
  131. /* the search window */
  132. /* - P is a one dimensional array of integers of */
  133. /* size kp, that specifies the dimension of each*/
  134. /* subspace of the input data set x */
  135. /* - kp is the total number of subspaces used to */
  136. /* the input data set x */
  137. /*Post: */
  138. /* - the custom kernel has been created for use */
  139. /* by the mean shift procedure. */
  140. /*******************************************************/
  141. void MeanShift::DefineKernel ( kernelType *kernel_, float *h_, int *P_, int kp_ )
  142. {
  143. // Declare variables
  144. int i, kN;
  145. //if a kernel has already been created then destroy it
  146. if ( kp )
  147. DestroyKernel();
  148. //Obtain kp...
  149. if ( ( kp = kp_ ) <= 0 )
  150. {
  151. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "CreateKernel", ( char * ) "Subspace count (kp) is zero or negative." );
  152. return;
  153. }
  154. //Allocate memory for h, P, kernel, offset, and increment
  155. if ( ( ! ( P = new int [kp] ) ) || ( ! ( h = new float [kp] ) ) || ( ! ( kernel = new kernelType [kp] ) ) ||
  156. ( ! ( offset = new float [kp] ) ) || ( ! ( increment = new double [kp] ) ) )
  157. {
  158. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "CreateKernel", ( char * ) "Not enough memory available to create kernel." );
  159. return;
  160. }
  161. //Populate h, P and kernel, also use P to calculate
  162. //the dimension (N_) of the potential input data set x
  163. kN = 0;
  164. for ( i = 0; i < kp; i++ )
  165. {
  166. if ( ( h[i] = h_[i] ) <= 0 )
  167. {
  168. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "CreateKernel", ( char * ) "Negative or zero valued bandwidths are prohibited." );
  169. return;
  170. }
  171. if ( ( P[i] = P_[i] ) <= 0 )
  172. {
  173. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "CreateKernel", ( char * ) "Negative or zero valued subspace dimensions are prohibited." );
  174. return;
  175. }
  176. kernel[i] = kernel_[i];
  177. kN += P[i];
  178. }
  179. //Allocate memory for range vector and uv using N_
  180. if ( ( ! ( range = new float [2*kN] ) ) || ( ! ( uv = new double [kN] ) ) )
  181. {
  182. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "CreateKernel", ( char * ) "Not enough memory available to create kernel." );
  183. return;
  184. }
  185. // Generate weight function lookup table
  186. // using above information and user
  187. // defined weight function list
  188. generateLookupTable();
  189. //check for errors
  190. if ( ErrorStatus == EL_ERROR )
  191. return;
  192. //indicate that the kernel has been defined
  193. class_state.KERNEL_DEFINED = true;
  194. //done.
  195. return;
  196. }
  197. /*******************************************************/
  198. /*Add Weight Function */
  199. /*******************************************************/
  200. /*Adds a weight function to the Mean Shift class to be */
  201. /*used by the mean shift procedure */
  202. /*******************************************************/
  203. /*Pre: */
  204. /* - g(u) is the normalized weight function with */
  205. /* respect to u = (norm(x-xi))^2/h^2 */
  206. /* - sampleNumber is the number of samples to be */
  207. /* taken of g(u) over halfWindow interval */
  208. /* - halfWindow is the radius of g(u) such that */
  209. /* g(u) is defined for 0 <= u <= halfWindow */
  210. /* - subspace is the subspace number for which */
  211. /* g(u) is to be applied during the mean shift */
  212. /* procedure. */
  213. /*Post: */
  214. /* - g(u) has been added to the Mean Shift class */
  215. /* private data structure to be used by the */
  216. /* mean shift procedure. */
  217. /* - if a weight function has already been spec- */
  218. /* ified for the specified subspace, the weight */
  219. /* function for this subspace has been replaced.*/
  220. /*******************************************************/
  221. void MeanShift::AddWeightFunction ( double g ( double ), float halfWindow, int sampleNumber, int subspace )
  222. {
  223. // Declare Variables
  224. int i;
  225. double increment;
  226. // Search to see if a weight function has already been
  227. // defined for specified subspace, if not then insert
  228. // into the head of the weight function list, otherwise
  229. // replace entry
  230. // Perform Search
  231. cur = head;
  232. while ( ( cur ) && ( cur->subspace != subspace ) )
  233. cur = cur->next;
  234. // Entry Exists - Replace It!
  235. // Otherwise insert at the head of the the weight functon list
  236. if ( cur )
  237. delete cur->w;
  238. else
  239. {
  240. cur = new userWeightFunct;
  241. cur->next = head;
  242. head = cur;
  243. }
  244. // Generate lookup table
  245. increment = halfWindow / ( double ) ( sampleNumber );
  246. cur->w = new double [sampleNumber+1];
  247. for ( i = 0; i <= sampleNumber; i++ )
  248. cur->w[i] = g ( ( double ) ( i * increment ) );
  249. // Set weight function parameters
  250. cur->halfWindow = halfWindow;
  251. cur->sampleNumber = sampleNumber;
  252. cur->subspace = subspace;
  253. //done.
  254. return;
  255. }
  256. /*******************************************************/
  257. /*Clear Weight Functions */
  258. /*******************************************************/
  259. /*Clears user defined weight from the Mean Shift class */
  260. /*private data structure. */
  261. /*******************************************************/
  262. /*Post: */
  263. /* - all user defined weight functions ahve been */
  264. /* cleared from the private data structure of */
  265. /* the mean shift class. */
  266. /*******************************************************/
  267. void MeanShift::ClearWeightFunctions ( void )
  268. {
  269. while ( head )
  270. {
  271. delete head->w;
  272. cur = head;
  273. head = head->next;
  274. delete cur;
  275. }
  276. }
  277. /*/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\*/
  278. /*** Input Data Set Declaration ***/
  279. /*\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/*/
  280. /*******************************************************/
  281. /*Define Input */
  282. /*******************************************************/
  283. /*Uploads input data set x into the mean shift class. */
  284. /*******************************************************/
  285. /*Pre: */
  286. /* - x is a one dimensional array of L N-dimen- */
  287. /* ional data points. */
  288. /*Post: */
  289. /* - x has been uploaded into the mean shift */
  290. /* class. */
  291. /* - the height and width of a previous data set */
  292. /* has been undefined. */
  293. /*******************************************************/
  294. void MeanShift::DefineInput ( float *x, int L_, int N_ )
  295. {
  296. //if input data is defined de-allocate memory, and
  297. //re-initialize the input data structure
  298. if ( ( class_state.INPUT_DEFINED ) || ( class_state.LATTICE_DEFINED ) )
  299. ResetInput();
  300. //make sure x is not NULL...
  301. if ( !x )
  302. {
  303. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "UploadInput", ( char * ) "Input data set is NULL." );
  304. return;
  305. }
  306. //Obtain L and N
  307. if ( ( ( L = L_ ) <= 0 ) || ( ( N = N_ ) <= 0 ) )
  308. {
  309. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "UploadInput", ( char * ) "Input data set has negative or zero length or dimension." );
  310. return;
  311. }
  312. //Allocate memory for data
  313. if ( ! ( data = new float [L*N] ) )
  314. {
  315. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "UploadInput", ( char * ) "Not enough memory." );
  316. return;
  317. }
  318. //Allocate memory for input data set, and copy
  319. //x into the private data members of the mean
  320. //shift class
  321. InitializeInput ( x );
  322. //check for errors
  323. if ( ErrorStatus == EL_ERROR )
  324. return;
  325. // Load x into the MeanShift object using
  326. // using a kd-tree, resulting in better
  327. // range searching of the input data points
  328. // x - also upload window centers into
  329. // msRawData
  330. CreateBST();
  331. //indicate that the input has been recently defined
  332. class_state.INPUT_DEFINED = true;
  333. class_state.LATTICE_DEFINED = false;
  334. class_state.OUTPUT_DEFINED = false;
  335. //done.
  336. return;
  337. }
  338. /*******************************************************/
  339. /*Define Lattice */
  340. /*******************************************************/
  341. /*Defines the height and width of the input lattice. */
  342. /*******************************************************/
  343. /*Pre: */
  344. /* - ht is the height of the lattice */
  345. /* - wt is the width of the lattice */
  346. /*Post: */
  347. /* - the height and width of the lattice has been */
  348. /* specified. */
  349. /* - if a data set is presently loaded into the */
  350. /* mean shift class, an error is flagged if the */
  351. /* number of elements in that data set does not */
  352. /* equal the product ht*wt. */
  353. /*******************************************************/
  354. void MeanShift::DefineLInput ( float *x, int ht, int wt, int N_ )
  355. {
  356. //if input data is defined de-allocate memory, and
  357. //re-initialize the input data structure
  358. if ( ( class_state.INPUT_DEFINED ) || ( class_state.LATTICE_DEFINED ) )
  359. ResetInput();
  360. //Obtain lattice height and width
  361. if ( ( ( height = ht ) <= 0 ) || ( ( width = wt ) <= 0 ) )
  362. {
  363. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "DefineLInput", ( char * ) "Lattice defined using zero or negative height and/or width." );
  364. return;
  365. }
  366. //Obtain input data dimension
  367. if ( ( N = N_ ) <= 0 )
  368. {
  369. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "DefineInput", ( char * ) "Input defined using zero or negative dimension." );
  370. return;
  371. }
  372. //compute the data length, L, of input data set
  373. //using height and width
  374. L = height * width;
  375. //Allocate memory for input data set, and copy
  376. //x into the private data members of the mean
  377. //shift class
  378. InitializeInput ( x );
  379. //check for errors
  380. if ( ErrorStatus == EL_ERROR )
  381. return;
  382. //allocate memory for weight map
  383. if ( ! ( weightMap = new float [L] ) )
  384. {
  385. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "InitializeInput", ( char * ) "Not enough memory." );
  386. return;
  387. }
  388. //initialize weightMap to an array of zeros
  389. memset ( weightMap, 0, L* ( sizeof ( float ) ) );
  390. //Indicate that a lattice input has recently been
  391. //defined
  392. class_state.LATTICE_DEFINED = true;
  393. class_state.INPUT_DEFINED = false;
  394. class_state.OUTPUT_DEFINED = false;
  395. //done.
  396. return;
  397. }
  398. /*******************************************************/
  399. /*Set Lattice Weight Map */
  400. /*******************************************************/
  401. /*Populates the lattice weight map with specified */
  402. /*weight values. */
  403. /*******************************************************/
  404. /*Pre: */
  405. /* - wm is a floating point array of size L */
  406. /* specifying for each data point a weight */
  407. /* value */
  408. /*Post: */
  409. /* - wm has been used to populate the lattice */
  410. /* weight map. */
  411. /*******************************************************/
  412. void MeanShift::SetLatticeWeightMap ( float *wm )
  413. {
  414. //make sure wm is not NULL
  415. if ( !wm )
  416. {
  417. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "SetWeightMap", ( char * ) "Specified weight map is NULL." );
  418. return;
  419. }
  420. //populate weightMap using wm
  421. int i;
  422. for ( i = 0; i < L; i++ )
  423. weightMap[i] = wm[i];
  424. //indicate that a lattice weight map has been specified
  425. weightMapDefined = true;
  426. //done.
  427. return;
  428. }
  429. /*******************************************************/
  430. /*Remove Lattice Weight Map */
  431. /*******************************************************/
  432. /*Removes the lattice weight map. */
  433. /*******************************************************/
  434. /*Post: */
  435. /* - the lattice weight map has been removed. */
  436. /* - if a weight map did not exist NO error is */
  437. /* flagged. */
  438. /*******************************************************/
  439. void MeanShift::RemoveLatticeWeightMap ( void )
  440. {
  441. //only remove weight map if it exists, otherwise
  442. //do nothing...
  443. if ( weightMapDefined )
  444. {
  445. //set values of lattice weight map to zero
  446. memset ( weightMap, 0, L*sizeof ( float ) );
  447. //indicate that a lattice weight map is no longer
  448. //defined
  449. weightMapDefined = false;
  450. }
  451. //done.
  452. return;
  453. }
  454. /*/\/\/\/\/\/\/\/\/\/\/\/\/\/\*/
  455. /*** Mean Shift Operations ***/
  456. /*\/\/\/\/\/\/\/\/\/\/\/\/\/\/*/
  457. /*******************************************************/
  458. /*Mean Shift Vector */
  459. /*******************************************************/
  460. /*Calculates the mean shift vector at a specified data */
  461. /*point yk. */
  462. /*******************************************************/
  463. /*Pre: */
  464. /* - a kernel has been created */
  465. /* - a data set has been uploaded */
  466. /* - Mh is an N dimensional mean shift vector */
  467. /* - yk is an N dimensional data point */
  468. /*Post: */
  469. /* - the mean shift vector at yk has been */
  470. /* calculated and stored in and returned by Mh. */
  471. /*******************************************************/
  472. void MeanShift::msVector ( double *Mh, double *yk )
  473. {
  474. //make sure that Mh and/or yk are not NULL...
  475. if ( ( !Mh ) || ( !yk ) )
  476. {
  477. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "msVector", ( char * ) "Invalid argument(s) passed to this method." );
  478. return;
  479. }
  480. //make sure that a kernel has been created, data has
  481. //been uploaded, and that they are consistent with one
  482. //another...
  483. classConsistencyCheck ( N, false );
  484. //calculate mean shift vector at yk using created kernel
  485. //and uploaded data set
  486. MSVector ( Mh, yk );
  487. //done.
  488. return;
  489. }
  490. /*******************************************************/
  491. /*Lattice Mean Shift Vector */
  492. /*******************************************************/
  493. /*Calculates the mean shift vector at a specified data */
  494. /*point yk, assuming that the data set exhists on a */
  495. /*height x width two dimensional lattice. */
  496. /*******************************************************/
  497. /*Pre: */
  498. /* - a kernel has been created */
  499. /* - a data set has been uploaded */
  500. /* - the height and width of the lattice has been */
  501. /* specified using method DefineLattice() */
  502. /* - Mh is an N dimensional mean shift vector */
  503. /* - yk is an N dimensional data point */
  504. /*Post: */
  505. /* - the mean shift vector at yk has been */
  506. /* calculated and stored in and returned by Mh. */
  507. /* - Mh was calculated using the defined input */
  508. /* lattice. */
  509. /*******************************************************/
  510. void MeanShift::latticeMSVector ( double *Mh, double *yk )
  511. {
  512. //make sure that Mh and/or yk are not NULL...
  513. if ( ( !Mh ) || ( !yk ) )
  514. {
  515. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "lmsVector", ( char * ) "Invalid argument(s) passed to this method." );
  516. return;
  517. }
  518. //make sure that a kernel has been created, data has
  519. //been uploaded, and that they are consistent with one
  520. //another...
  521. classConsistencyCheck ( N + 2, true );
  522. //calculate mean shift vector at yk using created kernel
  523. //and uploaded data set
  524. LatticeMSVector ( Mh, yk );
  525. //done.
  526. return;
  527. }
  528. /*******************************************************/
  529. /*Find Mode */
  530. /*******************************************************/
  531. /*Calculates the mode of a specified data point yk. */
  532. /*******************************************************/
  533. /*Pre: */
  534. /* - a kernel has been created */
  535. /* - a data set has been uploaded */
  536. /* - mode is the N dimensional mode of the N-dim- */
  537. /* ensional data point yk */
  538. /*Post: */
  539. /* - the mode of yk has been calculated and */
  540. /* stored in mode. */
  541. /*******************************************************/
  542. void MeanShift::FindMode ( double *mode, double *yk )
  543. {
  544. //make sure that mode and/or yk are not NULL...
  545. if ( ( !mode ) || ( !yk ) )
  546. {
  547. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "FindMode", ( char * ) "Invalid argument(s) passed to this method." );
  548. return;
  549. }
  550. //make sure that a kernel has been created, data has
  551. //been uploaded, and that they are consistent with one
  552. //another...
  553. classConsistencyCheck ( N, false );
  554. //allocate memory for Mh
  555. double *Mh = new double [N];
  556. //copy yk into mode
  557. int i;
  558. for ( i = 0; i < N; i++ )
  559. mode[i] = yk[i];
  560. //calculate mean shift vector at yk
  561. MSVector ( Mh, yk );
  562. //calculate mvAbs = |Mh|^2
  563. double mvAbs = 0;
  564. for ( i = 0; i < N; i++ )
  565. mvAbs += Mh[i] * Mh[i];
  566. //shift mode until convergence (mvAbs = 0)...
  567. int iterationCount = 1;
  568. while ( ( mvAbs >= EPSILON2 ) && ( iterationCount < LIMIT ) )
  569. {
  570. //shift mode...
  571. for ( i = 0; i < N; i++ )
  572. mode[i] += Mh[i];
  573. //re-calculate mean shift vector at new
  574. //window location have center defined by
  575. //mode
  576. MSVector ( Mh, mode );
  577. //calculate mvAbs = |Mh|^2
  578. mvAbs = 0;
  579. for ( i = 0; i < N; i++ )
  580. mvAbs += Mh[i] * Mh[i];
  581. //increment interation count...
  582. iterationCount++;
  583. }
  584. //shift mode...
  585. for ( i = 0; i < N; i++ )
  586. mode[i] += Mh[i];
  587. //de-allocate memory
  588. delete [] Mh;
  589. //done.
  590. return;
  591. }
  592. /*******************************************************/
  593. /*Find Lattice Mode */
  594. /*******************************************************/
  595. /*Calculates the mode of a specified data point yk, */
  596. /*assuming that the data set exhists on a height x */
  597. /*width two dimensional lattice. */
  598. /*******************************************************/
  599. /*Pre: */
  600. /* - a kernel has been created */
  601. /* - a data set has been uploaded */
  602. /* - the height and width of the lattice has been */
  603. /* specified using method DefineLattice() */
  604. /* - mode is the N dimensional mode of the N-dim- */
  605. /* ensional data point yk */
  606. /*Post: */
  607. /* - the mode of yk has been calculated and */
  608. /* stored in mode. */
  609. /* - mode was calculated using the defined input */
  610. /* lattice. */
  611. /*******************************************************/
  612. void MeanShift::FindLMode ( double *mode, double *yk )
  613. {
  614. //make sure that mode and/or yk are not NULL...
  615. if ( ( !mode ) || ( !yk ) )
  616. {
  617. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "FindLMode", ( char * ) "Invalid argument(s) passed to this method." );
  618. return;
  619. }
  620. //make sure the lattice height and width have been defined...
  621. if ( !height )
  622. {
  623. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "FindLMode", ( char * ) "Lattice height and width is undefined." );
  624. return;
  625. }
  626. //make sure that a kernel has been created, data has
  627. //been uploaded, and that they are consistent with one
  628. //another...
  629. classConsistencyCheck ( N + 2, true );
  630. //define gridN
  631. int gridN = N + 2;
  632. //allocate memory for Mh
  633. double *Mh = new double [gridN];
  634. //copy yk into mode
  635. int i;
  636. for ( i = 0; i < gridN; i++ )
  637. mode[i] = yk[i];
  638. //calculate mean shift vector at yk
  639. LatticeMSVector ( Mh, mode );
  640. //calculate mvAbs = |Mh|^2
  641. double mvAbs = 0;
  642. for ( i = 0; i < gridN; i++ )
  643. mvAbs += Mh[i] * Mh[i];
  644. //shift mode until convergence (mvAbs = 0)...
  645. int iterationCount = 1;
  646. while ( ( mvAbs >= EPSILON2 ) && ( iterationCount < LIMIT ) )
  647. {
  648. //shift mode...
  649. for ( i = 0; i < gridN; i++ )
  650. mode[i] += Mh[i];
  651. //re-calculate mean shift vector at new
  652. //window location have center defined by
  653. //mode
  654. LatticeMSVector ( Mh, mode );
  655. //calculate mvAbs = |Mh|^2
  656. mvAbs = 0;
  657. for ( i = 0; i < gridN; i++ )
  658. mvAbs += Mh[i] * Mh[i];
  659. //increment interation count...
  660. iterationCount++;
  661. }
  662. //shift mode...
  663. for ( i = 0; i < gridN; i++ )
  664. mode[i] += Mh[i];
  665. //de-allocate memory
  666. delete [] Mh;
  667. //done.
  668. return;
  669. }
  670. /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
  671. /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
  672. /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ PROTECTED METHODS @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
  673. /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
  674. /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
  675. /*/\/\/\/\/\/\/\/\/\/\/\/\/\/\*/
  676. /* Mean Shift: Using kd-Tree */
  677. /*\/\/\/\/\/\/\/\/\/\/\/\/\/\/*/
  678. /*******************************************************/
  679. /*Mean Shift Vector */
  680. /*******************************************************/
  681. /*Computes the mean shift vector at a window location */
  682. /*yk using input data set x using a custom, user defin-*/
  683. /*ed kernel. */
  684. /*******************************************************/
  685. /*Pre: */
  686. /* - input data has been uploaded into the private*/
  687. /* data members of the MeanShift class */
  688. /* - a window center yk has been defined */
  689. /* - uniformKernel indicates the which type of */
  690. /* kernel to be used by this procedure: uniform */
  691. /* or general */
  692. /*Post: */
  693. /* - the mean shift vector calculated at yk */
  694. /* using a either a custom, user defined kernel */
  695. /* or a uniform kernel is returned */
  696. /*******************************************************/
  697. void MeanShift::MSVector ( double *Mh_ptr, double *yk_ptr )
  698. {
  699. // Declare Variables
  700. int i, j;
  701. // Initialize mean shift vector
  702. for ( i = 0; i < N; i++ )
  703. Mh_ptr[i] = 0;
  704. // Initialize wsum to zero, the sum of the weights of each
  705. // data point found to lie within the search window (sphere)
  706. wsum = 0;
  707. // Build Range Vector using h[i] and yk
  708. int s = 0;
  709. // The flag uniformKernel is used to determine which
  710. // kernel function is to be used in the calculation
  711. // of the mean shift vector
  712. if ( uniformKernel )
  713. {
  714. for ( i = 0; i < kp; i++ )
  715. {
  716. for ( j = 0; j < P[i]; j++ )
  717. {
  718. range[2* ( s+j ) ] = ( float ) ( yk_ptr[s+j] - h[i] );
  719. range[2* ( s+j ) +1] = ( float ) ( yk_ptr[s+j] + h[i] );
  720. }
  721. s += P[i];
  722. }
  723. }
  724. else
  725. {
  726. for ( i = 0; i < kp; i++ )
  727. {
  728. for ( j = 0; j < P[i]; j++ )
  729. {
  730. range[2* ( s+j ) ] = ( float ) ( yk_ptr[s+j] - h[i] * float ( sqrt ( offset[i] ) ) );
  731. range[2* ( s+j ) +1] = ( float ) ( yk_ptr[s+j] + h[i] * float ( sqrt ( offset[i] ) ) );
  732. }
  733. s += P[i];
  734. }
  735. }
  736. // Traverse through the data set x, performing the
  737. // weighted sum of each point xi that lies within
  738. // the search window (sphere) using a general,
  739. // user defined kernel or uniform kernel depending
  740. // on the uniformKernel flag
  741. if ( uniformKernel )
  742. uniformSearch ( root, 0, Mh_ptr, yk_ptr );
  743. else
  744. generalSearch ( root, 0, Mh_ptr, yk_ptr );
  745. // Calculate the mean shift vector using Mh and wsum
  746. for ( i = 0; i < N; i++ )
  747. {
  748. // Divide Sum by wsum
  749. Mh_ptr[i] /= wsum;
  750. // Calculate mean shift vector: Mh(yk) = y(k+1) - y(k)
  751. Mh_ptr[i] -= yk_ptr[i];
  752. }
  753. //done.
  754. return;
  755. }
  756. /*/\/\/\/\/\/\/\/\/\/\/\/\/\/\*/
  757. /* Mean Shift: Using Lattice */
  758. /*\/\/\/\/\/\/\/\/\/\/\/\/\/\/*/
  759. /*******************************************************/
  760. /*Lattice Mean Shift Vector */
  761. /*******************************************************/
  762. /*Computes the mean shift vector at a specfied window */
  763. /*yk using the lattice data structure. */
  764. /*******************************************************/
  765. /*Pre: */
  766. /* - Mh_ptr and yh_ptr are arrays of doubles con- */
  767. /* aining N+2 elements */
  768. /* - Mh_ptr is the mean shift vector calculated */
  769. /* at window center yk_ptr */
  770. /*Post: */
  771. /* - the mean shift vector at the window center */
  772. /* pointed to by yk_ptr has been calculated and */
  773. /* stored in the memory location pointed to by */
  774. /* Mh_ptr */
  775. /*******************************************************/
  776. void MeanShift::LatticeMSVector ( double *Mh_ptr, double *yk_ptr )
  777. {
  778. // Initialize mean shift vector
  779. register int i;
  780. for ( i = 0; i < N + 2; i++ )
  781. Mh_ptr[i] = 0;
  782. // Initialize wsum
  783. wsum = 0;
  784. // Perform lattice search summing
  785. // all the points that lie within the search
  786. // window defined using the kernel specified
  787. //by uniformKernel
  788. if ( uniformKernel )
  789. uniformLSearch ( Mh_ptr, yk_ptr );
  790. else
  791. generalLSearch ( Mh_ptr, yk_ptr );
  792. // Compute mean shift vector using sum computed
  793. // by lattice search, wsum, and yk_ptr:
  794. // Mh = Mh/wsum - yk_ptr
  795. if ( wsum > 0 )
  796. {
  797. for ( i = 0; i < N + 2; i++ )
  798. Mh_ptr[i] = Mh_ptr[i] / wsum - yk_ptr[i];
  799. }
  800. else
  801. {
  802. for ( i = 0; i < N + 2; i++ )
  803. Mh_ptr[i] = 0;
  804. }
  805. // done.
  806. return;
  807. }
  808. /*******************************************************/
  809. /*Optimized Lattice Mean Shift Vector */
  810. /*******************************************************/
  811. /*Computes the mean shift vector at a specfied window */
  812. /*yk using the lattice data structure. Also the points */
  813. /*that lie within the window are stored into the basin */
  814. /*of attraction structure used by the optimized mean */
  815. /*shift algorithms. */
  816. /*******************************************************/
  817. /*Pre: */
  818. /* - Mh_ptr and yh_ptr are arrays of doubles con- */
  819. /* aining N+2 elements */
  820. /* - Mh_ptr is the mean shift vector calculated */
  821. /* at window center yk_ptr */
  822. /*Post: */
  823. /* - the mean shift vector at the window center */
  824. /* pointed to by yk_ptr has been calculated and */
  825. /* stored in the memory location pointed to by */
  826. /* Mh_ptr */
  827. /* - the data points lying within h of of yk_ptr */
  828. /* have been stored into the basin of attract- */
  829. /* ion data structure. */
  830. /*******************************************************/
  831. void MeanShift::OptLatticeMSVector ( double *Mh_ptr, double *yk_ptr )
  832. {
  833. // Initialize mean shift vector
  834. register int i;
  835. for ( i = 0; i < N + 2; i++ )
  836. Mh_ptr[i] = 0;
  837. // Initialize wsum
  838. wsum = 0;
  839. // Perform lattice search summing
  840. // all the points that lie within the search
  841. // window defined using the kernel specified
  842. //by uniformKernel
  843. if ( uniformKernel )
  844. optUniformLSearch ( Mh_ptr, yk_ptr );
  845. else
  846. optGeneralLSearch ( Mh_ptr, yk_ptr );
  847. // Compute mean shift vector using sum computed
  848. // by lattice search, wsum, and yk_ptr:
  849. // Mh = Mh/wsum - yk_ptr
  850. if ( wsum > 0 )
  851. {
  852. for ( i = 0; i < N + 2; i++ )
  853. Mh_ptr[i] = Mh_ptr[i] / wsum - yk_ptr[i];
  854. } else
  855. {
  856. for ( i = 0; i < N + 2; i++ )
  857. Mh_ptr[i] = 0;
  858. }
  859. // done.
  860. return;
  861. }
  862. /*/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\*/
  863. /*** Kernel-Input Data Consistency ***/
  864. /*\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/*/
  865. /*******************************************************/
  866. /*Class Consistency Check */
  867. /*******************************************************/
  868. /*Checks the state of the class prior to the applicat- */
  869. /*ion of mean shift. */
  870. /*******************************************************/
  871. /*Pre: */
  872. /* - iN is the specified dimension of the input, */
  873. /* iN = N for a general input data set, iN = N */
  874. /* + 2 for a input set defined using a lattice */
  875. /*Post: */
  876. /* - if the kernel has not been created, an input */
  877. /* has not been defined and/or the specified */
  878. /* input dimension (iN) does not match that of */
  879. /* the kernel a fatal error is flagged. */
  880. /*******************************************************/
  881. void MeanShift::classConsistencyCheck ( int iN, bool usingLattice )
  882. {
  883. //make sure that kernel has been created...
  884. if ( class_state.KERNEL_DEFINED == false )
  885. {
  886. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "classConsistencyCheck", ( char * ) "Kernel not created." );
  887. return;
  888. }
  889. //make sure input data set has been loaded into mean shift object...
  890. if ( ( class_state.INPUT_DEFINED == false ) && ( !usingLattice ) )
  891. {
  892. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "classConsistencyCheck", ( char * ) "No input data specified." );
  893. return;
  894. }
  895. //make sure that the lattice is defined if it is being used
  896. if ( ( class_state.LATTICE_DEFINED == false ) && ( usingLattice ) )
  897. {
  898. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "classConsistencyCheck", ( char * ) "Latice not created." );
  899. return;
  900. }
  901. //make sure that dimension of the kernel and the input data set
  902. //agree
  903. //calculate dimension of kernel (kN)
  904. int i, kN = 0;
  905. for ( i = 0; i < kp; i++ )
  906. kN += P[i];
  907. //perform comparison...
  908. if ( iN != kN )
  909. {
  910. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "classConsitencyCheck", ( char * ) "Kernel dimension does not match defined input data dimension." );
  911. return;
  912. }
  913. //done.
  914. return;
  915. }
  916. /*/\/\/\/\/\/\/\/\/\/\/\/\/\*/
  917. /*** Class Error Handler ***/
  918. /*\/\/\/\/\/\/\/\/\/\/\/\/\/*/
  919. /*******************************************************/
  920. /*Error Handler */
  921. /*******************************************************/
  922. /*Class error handler. */
  923. /*******************************************************/
  924. /*Pre: */
  925. /* - className is the name of the class that fl- */
  926. /* agged an error */
  927. /* - methodName is the name of the method that */
  928. /* flagged an error */
  929. /* - errmsg is the error message given by the */
  930. /* calling function */
  931. /*Post: */
  932. /* - the error message errmsg is flagged on beh- */
  933. /* ave of method methodName belonging to class */
  934. /* className: */
  935. /* */
  936. /* (1) ErrorMessage has been updated with the */
  937. /* appropriate error message using the arg- */
  938. /* ments passed to this method. */
  939. /* (2) ErrorStatus is set to ERROR */
  940. /* (ErrorStatus = 1) */
  941. /*******************************************************/
  942. void MeanShift::ErrorHandler ( char *className, char *methodName, char* errmsg )
  943. {
  944. //store trace into error message
  945. strcpy ( ErrorMessage, className );
  946. strcat ( ErrorMessage, ( char * ) "::" );
  947. strcat ( ErrorMessage, methodName );
  948. strcat ( ErrorMessage, ( char * ) " Error: (char *)" );
  949. //store message into error message
  950. strcat ( ErrorMessage, errmsg );
  951. //set error status to ERROR
  952. ErrorStatus = EL_ERROR;
  953. }
  954. /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
  955. /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
  956. /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ PRIVATE METHODS @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
  957. /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
  958. /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
  959. /*/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\*/
  960. /*** Kernel Creation/Manipulation ***/
  961. /*\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/*/
  962. /*******************************************************/
  963. /*Generate Lookup Table */
  964. /*******************************************************/
  965. /*A weight function look up table is generated. */
  966. /*******************************************************/
  967. /*Pre: */
  968. /* - kernel is an array of kernelTypes specifying */
  969. /* the type of kernel to be used on each sub- */
  970. /* space of the input data set x */
  971. /* - kp is the total number of subspaces used to */
  972. /* the input data set x */
  973. /* - the above information has been pre-loaded */
  974. /* into the MeanShift class private members */
  975. /*Post: */
  976. /* - a lookup table is generated for the weight */
  977. /* function of the resulting kernel */
  978. /* - uniformKernel is set to true if the kernel */
  979. /* to be used is uniform, false is returned */
  980. /* otherwise */
  981. /* - if a user defined weight function is requred */
  982. /* for a given subspace but not defined in the */
  983. /* user defined weight function list, an error */
  984. /* is flagged and the program is halted */
  985. /*******************************************************/
  986. void MeanShift::generateLookupTable ( void )
  987. {
  988. // Declare Variables
  989. int i, j;
  990. // Allocate memory for lookup table w
  991. w = new double*[kp];
  992. // Traverse through kernel generating weight function
  993. // lookup table w
  994. // Assume kernel is uniform
  995. uniformKernel = true;
  996. for ( i = 0; i < kp; i++ )
  997. {
  998. switch ( kernel[i] )
  999. {
  1000. // *Uniform Kernel* has weight funciton w(u) = 1
  1001. // therefore, a weight funciton lookup table is
  1002. // not needed for this kernel --> w[i] = NULL indicates
  1003. // this
  1004. case Uniform:
  1005. w [i] = NULL; //weight function not needed for this kernel
  1006. offset [i] = 1; //uniform kernel has u < 1.0
  1007. increment[i] = 1; //has no meaning
  1008. break;
  1009. // *Gaussian Kernel* has weight function w(u) = constant*exp(-u^2/[2h[i]^2])
  1010. case Gaussian:
  1011. // Set uniformKernel to false
  1012. uniformKernel = false;
  1013. // generate weight function using expression,
  1014. // exp(-u/2), where u = norm(xi - x)^2/h^2
  1015. // Allocate memory for weight table
  1016. w[i] = new double [GAUSS_NUM_ELS+1];
  1017. for ( j = 0; j <= GAUSS_NUM_ELS; j++ )
  1018. w[i][j] = exp ( -j * GAUSS_INCREMENT / 2 );
  1019. // Set offset = offset^2, and set increment
  1020. offset [i] = ( float ) ( GAUSS_LIMIT * GAUSS_LIMIT );
  1021. increment[i] = GAUSS_INCREMENT;
  1022. // done
  1023. break;
  1024. // *User Define Kernel* uses the weight function wf(u)
  1025. case UserDefined:
  1026. // Set uniformKernel to false
  1027. uniformKernel = false;
  1028. // Search for user defined weight function
  1029. // defined for subspace (i+1)
  1030. cur = head;
  1031. while ( ( cur ) && ( cur->subspace != ( i + 1 ) ) )
  1032. cur = cur->next;
  1033. // If a user defined subspace has not been found
  1034. // for this subspace, flag an error
  1035. if ( cur == NULL )
  1036. {
  1037. fprintf ( stderr, ( char * ) "\ngenerateLookupTable Fatal Error: User defined kernel for subspace %d undefined.\n\nAborting Program.\n\n", i + 1 );
  1038. exit ( 1 );
  1039. }
  1040. // Otherwise, copy weight function lookup table to w[i]
  1041. w[i] = new double [cur->sampleNumber+1];
  1042. for ( j = 0; j <= cur->sampleNumber; j++ )
  1043. w[i][j] = cur->w[j];
  1044. // Set offset and increment accordingly
  1045. offset [i] = ( float ) ( cur->halfWindow );
  1046. increment[i] = cur->halfWindow / ( float ) ( cur->sampleNumber );
  1047. // done
  1048. break;
  1049. default:
  1050. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "generateLookupTable", ( char * ) "Unknown kernel type." );
  1051. }
  1052. }
  1053. }
  1054. /*******************************************************/
  1055. /*Destroy Kernel */
  1056. /*******************************************************/
  1057. /*Destroys and initializes kernel. */
  1058. /*******************************************************/
  1059. /*Post: */
  1060. /* - memory for the kernel private data members */
  1061. /* have been destroyed and the kernel has been */
  1062. /* initialized for re-use. */
  1063. /*******************************************************/
  1064. void MeanShift::DestroyKernel ( void )
  1065. {
  1066. //de-allocate memory...
  1067. if ( kernel ) delete [] kernel;
  1068. if ( h ) delete [] h;
  1069. if ( P ) delete [] P;
  1070. if ( range ) delete [] range;
  1071. if ( uv ) delete [] uv;
  1072. if ( increment ) delete [] increment;
  1073. if ( offset ) delete [] offset;
  1074. if ( kp > 0 )
  1075. {
  1076. if ( w )
  1077. {
  1078. int i;
  1079. for ( i = 0; i < kp; i++ )
  1080. delete [] w[i];
  1081. delete [] w;
  1082. }
  1083. w = NULL;
  1084. }
  1085. //intialize kernel for re-use...
  1086. kp = 0;
  1087. kernel = NULL;
  1088. h = NULL;
  1089. P = NULL;
  1090. range = NULL;
  1091. increment = NULL;
  1092. uv = NULL;
  1093. offset = NULL;
  1094. //done.
  1095. return;
  1096. }
  1097. /*/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\*/
  1098. /*** Input Data Initialization/Destruction ***/
  1099. /*\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/*/
  1100. /*******************************************************/
  1101. /*Create Binary Search Tree */
  1102. /*******************************************************/
  1103. /*Uploads input data set x into a kd-BST. */
  1104. /*******************************************************/
  1105. /*Pre: */
  1106. /* - x is a one dimensional array of L N-dimensi- */
  1107. /* onal data points */
  1108. /*Post: */
  1109. /* - x has been uploaded into a balanced kd-BST */
  1110. /* data structure for use by the mean shift */
  1111. /* procedure */
  1112. /*******************************************************/
  1113. void MeanShift::CreateBST ( void )
  1114. {
  1115. // Create BST using data....
  1116. // Allocate memory for tree
  1117. forest = new tree[L];
  1118. // Populate 'forest' of tree's with
  1119. // the values stored in x
  1120. int i;
  1121. for ( i = 0; i < L; i++ )
  1122. {
  1123. forest[i].x = &data[i*N];
  1124. forest[i].right = NULL;
  1125. forest[i].left = NULL;
  1126. forest[i].parent = NULL;
  1127. }
  1128. // Build balanced Nd-tree from the
  1129. // forest of trees generated above
  1130. // retaining the root of this tree
  1131. root = BuildKDTree ( forest, L, 0, NULL );
  1132. //done.
  1133. return;
  1134. }
  1135. /*******************************************************/
  1136. /*Initialize Input */
  1137. /*******************************************************/
  1138. /*Allocates memory for and initializes the input data */
  1139. /*structure. */
  1140. /*******************************************************/
  1141. /*Pre: */
  1142. /* - x is a floating point array of L, N dimens- */
  1143. /* ional input data points */
  1144. /*Post: */
  1145. /* - memory has been allocated for the input data */
  1146. /* structure and x has been stored using into */
  1147. /* the mean shift class using the resulting */
  1148. /* structure. */
  1149. /*******************************************************/
  1150. void MeanShift::InitializeInput ( float *x )
  1151. {
  1152. //allocate memory for input data set
  1153. if ( ! ( data = new float [L*N] ) )
  1154. {
  1155. ErrorHandler ( ( char * ) "MeanShift", ( char * ) "InitializeInput", ( char * ) "Not enough memory." );
  1156. return;
  1157. }
  1158. //copy x into data
  1159. int i;
  1160. for ( i = 0; i < L*N; i++ )
  1161. data[i] = x[i];
  1162. //done.
  1163. return;
  1164. }
  1165. /*******************************************************/
  1166. /*Reset Input */
  1167. /*******************************************************/
  1168. /*De-allocates memory for and re-intializes input data */
  1169. /*structure. */
  1170. /*******************************************************/
  1171. /*Post: */
  1172. /* - the memory of the input data structure has */
  1173. /* been de-allocated and this strucuture has */
  1174. /* been initialized for re-use. */
  1175. /*******************************************************/
  1176. void MeanShift::ResetInput ( void )
  1177. {
  1178. //de-allocate memory of input data structure (BST)
  1179. if ( data ) delete [] data;
  1180. if ( forest ) delete [] forest;
  1181. //initialize input data structure for re-use
  1182. data = NULL;
  1183. forest = NULL;
  1184. root = NULL;
  1185. L = 0;
  1186. N = 0;
  1187. width = 0;
  1188. height = 0;
  1189. //re-set class input to indicate that
  1190. //an input is not longer stored by
  1191. //the private data members of this class
  1192. class_state.INPUT_DEFINED = class_state.LATTICE_DEFINED = false;
  1193. }
  1194. /*/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\*/
  1195. /*** k-dimensional Binary Search Tree ***/
  1196. /*\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/*/
  1197. /*******************************************************/
  1198. /*Build KD Tree (for Tree Structure) */
  1199. /*******************************************************/
  1200. /*Builds a KD Tree given a forest of tree's. */
  1201. /*******************************************************/
  1202. /*Pre: */
  1203. /* - subset is a subset of L un-ordered tree nodes*/
  1204. /* each containing an N-dimensional data point */
  1205. /* - d is the depth of the subset, used to specify*/
  1206. /* the dimension used to construct the tree at */
  1207. /* the given depth */
  1208. /* - parent is the parent tree of subset */
  1209. /*Post: */
  1210. /* - a balanced KD tree has been constructed using*/
  1211. /* the forest subset, the root of this tree has */
  1212. /* been returned */
  1213. /*******************************************************/
  1214. tree *MeanShift::BuildKDTree ( tree *subset, int length, int d, tree* parent )
  1215. {
  1216. // If the subset is a single tree
  1217. // then return this tree otherwise
  1218. // partition the subset and place
  1219. // these subsets recursively into
  1220. // the left and right sub-trees having
  1221. // their root specified by the median
  1222. // of this subset in dimension d
  1223. if ( length == 1 )
  1224. {
  1225. subset->parent = parent;
  1226. return subset;
  1227. }
  1228. else if ( length > 1 )
  1229. {
  1230. // Sort Subset
  1231. QuickMedian ( subset, 0, length - 1, d );
  1232. // Get Median of Subset and Partition
  1233. // it into two sub-trees - create
  1234. // a tree with its root being the median
  1235. // of the subset and its left and right
  1236. // children being the medians of the subsets
  1237. int median = length / 2;
  1238. subset[median].parent = parent;
  1239. subset[median].left = BuildKDTree ( subset , median , ( d + 1 ) % N, &subset[median] );
  1240. subset[median].right = BuildKDTree ( &subset[median+1], length - median - 1, ( d + 1 ) % N, &subset[median] );
  1241. // Output tree structure
  1242. return &subset[median];
  1243. }
  1244. else
  1245. return NULL;
  1246. //done.
  1247. }
  1248. /*******************************************************/
  1249. /*Quick Median (for Tree Structure) */
  1250. /*******************************************************/
  1251. /*Finds the median element in an un-ordered set, re- */
  1252. /*structuring the set such that points less than the */
  1253. /*median point are located to the left of the median */
  1254. /*and points greater than the median point are located */
  1255. /*to the right. */
  1256. /*******************************************************/
  1257. /*Pre: */
  1258. /* - arr is a subset of tree nodes whose leftmost */
  1259. /* element is specified by left and rightmost */
  1260. /* element is specified by left */
  1261. /* - d is the dimension of the data set stored by */
  1262. /* the tree structure that is used to find */
  1263. /* the median */
  1264. /*Post: */
  1265. /* - the median point is found and the subset */
  1266. /* of trees is re-ordered such that all trees */
  1267. /* whose data points with d dimensional value */
  1268. /* less than that of the median tree node are */
  1269. /* located to the left of the median tree node, */
  1270. /* otherwise they are located to the right */
  1271. /*******************************************************/
  1272. void MeanShift::QuickMedian ( tree *arr, int left, int right, int d )
  1273. {
  1274. unsigned long k;
  1275. unsigned long n;
  1276. float* a;
  1277. float* temp;
  1278. n = right - left + 1;
  1279. k = n / 2 + 1;
  1280. unsigned long i, ir, j, l, mid;
  1281. l = 1;
  1282. ir = n;
  1283. for ( ;; )
  1284. {
  1285. if ( ir <= l + 1 )
  1286. {
  1287. if ( ir == l + 1 && arr[ir-1].x[d] < arr[l-1].x[d] )
  1288. {
  1289. SWAP ( arr[l-1].x, arr[ir-1].x )
  1290. }
  1291. return;
  1292. } else
  1293. {
  1294. mid = ( l + ir ) >> 1;
  1295. SWAP ( arr[mid-1].x, arr[l+1-1].x )
  1296. if ( arr[l-1].x[d] > arr[ir-1].x[d] )
  1297. {
  1298. SWAP ( arr[l-1].x, arr[ir-1].x )
  1299. }
  1300. if ( arr[l+1-1].x[d] > arr[ir-1].x[d] )
  1301. {
  1302. SWAP ( arr[l+1-1].x, arr[ir-1].x )
  1303. }
  1304. if ( arr[l-1].x[d] > arr[l+1-1].x[d] )
  1305. {
  1306. SWAP ( arr[l-1].x, arr[l+1-1].x )
  1307. }
  1308. i = l + 1;
  1309. j = ir;
  1310. a = arr[l+1-1].x;
  1311. for ( ;; ) {
  1312. do i++;
  1313. while ( arr[i-1].x[d] < a[d] );
  1314. do j--;
  1315. while ( arr[j-1].x[d] > a[d] );
  1316. if ( j < i ) break;
  1317. SWAP ( arr[i-1].x, arr[j-1].x )
  1318. }
  1319. arr[l+1-1].x = arr[j-1].x;
  1320. arr[j-1].x = a;
  1321. if ( j >= k ) ir = j - 1;
  1322. if ( j <= k ) l = i;
  1323. }
  1324. }
  1325. }
  1326. /*/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\*/
  1327. /*** Mean Shift: Using kd-Tree ***/
  1328. /*\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/*/
  1329. /*******************************************************/
  1330. /*Uniform Search */
  1331. /*******************************************************/
  1332. /*Searches the input data using a kd-tree, performs the*/
  1333. /*sum on the data within the Hypercube defined by the */
  1334. /*tree using a uniform kernel. */
  1335. /*******************************************************/
  1336. /*Pre: */
  1337. /* - gt is a possibly NULL pointer to a kd tree */
  1338. /* - Mh_ptr is a pointer to the mean shift vector */
  1339. /* being calculated */
  1340. /* - yk_ptr is a pointer to the current window */
  1341. /* center location */
  1342. /* - gd is the depth of the current subtree */
  1343. /*Post: */
  1344. /* - the mean of the points within the Hypercube */
  1345. /* of the kd tree is computed using a uniform */
  1346. /* kernel */
  1347. /*******************************************************/
  1348. void MeanShift::uniformSearch ( tree *gt, int gd, double *Mh_ptr, double *yk_ptr )
  1349. {
  1350. tree* c_t;
  1351. int c_d;
  1352. int i;
  1353. int actionType;
  1354. c_t = gt;
  1355. c_d = gd;
  1356. actionType = 0;
  1357. double el, diff;
  1358. int k, j, s;
  1359. while ( c_t != NULL )
  1360. {
  1361. switch ( actionType ) {
  1362. case 0: // forward
  1363. if ( ( c_t->x[c_d] > range[2*c_d] ) && ( ( c_t->left ) != NULL ) )
  1364. {
  1365. c_t = c_t->left;
  1366. c_d = ( c_d + 1 ) % N;
  1367. } else
  1368. {
  1369. actionType = 1;
  1370. }
  1371. break;
  1372. case 1: // backleft
  1373. for ( i = 0; i < N; i++ )
  1374. {
  1375. if ( ( c_t->x[i] < range[2*i] ) || ( c_t->x[i] > range[2*i+1] ) )
  1376. break;
  1377. }
  1378. if ( i == N )
  1379. {
  1380. // *** Visit Tree ***
  1381. // Check if xi is in the window centered about yk_ptr
  1382. // If so - use it to compute y(k+1)
  1383. diff = 0;
  1384. j = 0;
  1385. s = 0;
  1386. while ( ( diff < 1.0 ) && ( j < kp ) ) // Partial Distortion Search (PDS)
  1387. {
  1388. // test each sub-dimension independently
  1389. diff = 0;
  1390. for ( k = 0; k < P[j]; k++ )
  1391. {
  1392. el = ( c_t->x[s+k] - yk_ptr[s+k] ) / h[j];
  1393. diff += el * el;
  1394. }
  1395. s += P[j]; // next subspace
  1396. j++;
  1397. }
  1398. if ( diff < 1.0 )
  1399. {
  1400. wsum += 1;
  1401. for ( j = 0; j < N; j++ )
  1402. Mh_ptr[j] += c_t->x[j];
  1403. }
  1404. }
  1405. if ( ( c_t->x[c_d] < range[2*c_d+1] ) && ( ( c_t->right ) != NULL ) )
  1406. {
  1407. c_t = c_t->right;
  1408. c_d = ( c_d + 1 ) % N;
  1409. actionType = 0;
  1410. } else
  1411. {
  1412. actionType = 2;
  1413. }
  1414. break;
  1415. case 2: // backright
  1416. c_d = ( c_d + N - 1 ) % N;
  1417. if ( c_t->parent == NULL )
  1418. {
  1419. c_t = NULL;
  1420. break;
  1421. }
  1422. if ( c_t->parent->left == c_t )
  1423. actionType = 1;
  1424. else
  1425. actionType = 2;
  1426. c_t = c_t->parent;
  1427. break;
  1428. }
  1429. }
  1430. }
  1431. /*******************************************************/
  1432. /*General Search */
  1433. /*******************************************************/
  1434. /*Searches the input data using a kd tree, performs the*/
  1435. /*sum on the data within the Hypercube defined by the */
  1436. /*tree using a general kernel. */
  1437. /*******************************************************/
  1438. /*Pre: */
  1439. /* - gt is a possibly NULL pointer to a kd tree */
  1440. /* - Mh_ptr is a pointer to the mean shift vector */
  1441. /* being calculated */
  1442. /* - yk_ptr is a pointer to the current window */
  1443. /* center location */
  1444. /* - gd is the depth of the current subtree */
  1445. /*Post: */
  1446. /* - the mean of the points within the Hypercube */
  1447. /* of the kd tree is computed using a general */
  1448. /* kernel */
  1449. /*******************************************************/
  1450. void MeanShift::generalSearch ( tree *gt, int gd, double *Mh_ptr, double *yk_ptr )
  1451. {
  1452. tree* c_t;
  1453. int c_d;
  1454. int i;
  1455. int actionType;
  1456. c_t = gt;
  1457. c_d = gd;
  1458. actionType = 0;
  1459. double el, diff, u, tw, y0, y1;
  1460. int k, j, s, x0, x1;
  1461. while ( c_t != NULL )
  1462. {
  1463. switch ( actionType ) {
  1464. case 0: // forward
  1465. if ( ( c_t->x[c_d] > range[2*c_d] ) && ( ( c_t->left ) != NULL ) )
  1466. {
  1467. c_t = c_t->left;
  1468. c_d = ( c_d + 1 ) % N;
  1469. } else
  1470. {
  1471. actionType = 1;
  1472. }
  1473. break;
  1474. case 1: // backleft
  1475. for ( i = 0; i < N; i++ )
  1476. {
  1477. if ( ( c_t->x[i] < range[2*i] ) || ( c_t->x[i] > range[2*i+1] ) )
  1478. break;
  1479. }
  1480. if ( i == N )
  1481. {
  1482. // *** Visit Tree ***
  1483. // Check if xi is in the window centered about yk_ptr
  1484. // If so - use it to compute y(k+1)
  1485. s = 0;
  1486. for ( j = 0; j < kp; j++ )
  1487. {
  1488. // test each sub-dimension independently
  1489. diff = 0;
  1490. for ( k = 0; k < P[j]; k++ )
  1491. {
  1492. el = ( c_t->x[s+k] - yk_ptr[s+k] ) / h[j];
  1493. diff += uv[s+k] = el * el; // Update uv and diff
  1494. if ( diff >= offset[j] ) // Partial Distortion Search (PDS)
  1495. break;
  1496. }
  1497. if ( diff >= offset[j] ) // PDS
  1498. break;
  1499. s += P[j]; // next subspace
  1500. }
  1501. // j == kp indicates that all subspaces passed the test:
  1502. // the data point is within the search window
  1503. if ( j == kp ) j--;
  1504. if ( diff < offset[j] )
  1505. {
  1506. // Initialize total weight to 1
  1507. tw = 1;
  1508. // Calculate weight factor using weight function
  1509. // lookup tables and uv
  1510. s = 0;
  1511. for ( j = 0; j < kp; j++ )
  1512. {
  1513. if ( kernel[j] ) // not uniform kernel
  1514. {
  1515. // Compute u[i]
  1516. u = 0;
  1517. for ( k = 0; k < P[j]; k++ )
  1518. u += uv[s+k];
  1519. // Accumulate tw using calculated u
  1520. // and weight function lookup table
  1521. // Linear interpolate values given by
  1522. // lookup table
  1523. // Calculate x0 and x1, the points surounding
  1524. // u
  1525. x0 = ( int ) ( u / increment[j] );
  1526. x1 = x0 + 1;
  1527. // Get y0 and y1 from the lookup table
  1528. y0 = w[j][x0];
  1529. y1 = w[j][x1];
  1530. // Accumulate tw using linear interpolation
  1531. tw *= ( ( ( double ) ( x1 ) * increment[j] - u ) * y0 + ( u - ( double ) ( x0 ) * increment[j] ) * y1 ) / ( double ) ( x1 * increment[j] - x0 * increment[j] );
  1532. }
  1533. s += P[j]; // next subspace
  1534. }
  1535. // Perform weighted sum using xi
  1536. for ( j = 0; j < N; j++ )
  1537. Mh_ptr[j] += tw * c_t->x[j];
  1538. // Increment wsum by tw
  1539. wsum += tw;
  1540. }
  1541. }
  1542. if ( ( c_t->x[c_d] < range[2*c_d+1] ) && ( ( c_t->right ) != NULL ) )
  1543. {
  1544. c_t = c_t->right;
  1545. c_d = ( c_d + 1 ) % N;
  1546. actionType = 0;
  1547. } else
  1548. {
  1549. actionType = 2;
  1550. }
  1551. break;
  1552. case 2: // backright
  1553. c_d = ( c_d + N - 1 ) % N;
  1554. if ( c_t->parent == NULL )
  1555. {
  1556. c_t = NULL;
  1557. break;
  1558. }
  1559. if ( c_t->parent->left == c_t )
  1560. actionType = 1;
  1561. else
  1562. actionType = 2;
  1563. c_t = c_t->parent;
  1564. break;
  1565. }
  1566. }
  1567. }
  1568. /*/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\*/
  1569. /*** Mean Shift: Using Lattice ***/
  1570. /*\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/\/*/
  1571. /*******************************************************/
  1572. /*Uniform Lattice Search */
  1573. /*******************************************************/
  1574. /*Performs search on data set for all points lying */
  1575. /*within the search window defined using a uniform */
  1576. /*kernel. Their point-wise sum and count is computed */
  1577. /*and returned. */
  1578. /* */
  1579. /*NOTE: This method is the only method in the */
  1580. /* MeanShift class that uses the weight */
  1581. /* map asside from optUniformLSearch. */
  1582. /*******************************************************/
  1583. /*Pre: */
  1584. /* - Mh_ptr is a length N array of doubles */
  1585. /* - yk_ptr is a length N array of doubles */
  1586. /* - Mh_ptr is the sum of the data points found */
  1587. /* within search window having center yk_ptr */
  1588. /*Post: */
  1589. /* - a search on the data set using the lattice */
  1590. /* has been performed, and all points found to */
  1591. /* lie within the search window defined using */
  1592. /* a uniform kernel are summed and counted. */
  1593. /* - their point wise sum is pointed to by Mh_ptr */
  1594. /* and their count is stored by wsum. */
  1595. /*******************************************************/
  1596. void MeanShift::uniformLSearch ( double *Mh_ptr, double *yk_ptr )
  1597. {
  1598. //Declare variables
  1599. register int i, j, k;
  1600. int s, p, dataPoint, lN;
  1601. double diff, el, dx, dy, tx, weight;
  1602. //Define lattice data dimension...
  1603. lN = N + 2;
  1604. //Define bounds of lattice...
  1605. //the lattice is a 2dimensional subspace whose
  1606. //search window bandwidth is specified by
  1607. //h[0]:
  1608. tx = yk_ptr[0] - h[0] + DELTA + 0.99;
  1609. if ( tx < 0 )
  1610. LowerBoundX = 0;
  1611. else
  1612. LowerBoundX = ( int ) tx;
  1613. tx = yk_ptr[1] - h[0] + DELTA + 0.99;
  1614. if ( tx < 0 )
  1615. LowerBoundY = 0;
  1616. else
  1617. LowerBoundY = ( int ) tx;
  1618. tx = yk_ptr[0] + h[0] - DELTA;
  1619. if ( tx >= width )
  1620. UpperBoundX = width - 1;
  1621. else
  1622. UpperBoundX = ( int ) tx;
  1623. tx = yk_ptr[1] + h[0] - DELTA;
  1624. if ( tx >= height )
  1625. UpperBoundY = height - 1;
  1626. else
  1627. UpperBoundY = ( int ) tx;
  1628. //Perform search using lattice
  1629. for ( i = LowerBoundY; i <= UpperBoundY; i++ )
  1630. for ( j = LowerBoundX; j <= UpperBoundX; j++ )
  1631. {
  1632. //get index into data array
  1633. dataPoint = N * ( i * width + j );
  1634. //Determine if inside search window
  1635. k = 1;
  1636. s = 0;
  1637. dx = j - yk_ptr[0];
  1638. dy = i - yk_ptr[1];
  1639. diff = ( dx * dx + dy * dy ) / ( h[0] * h[0] );
  1640. while ( ( diff < 1.0 ) && ( k != kp ) ) // Partial Distortion Search
  1641. {
  1642. //Calculate distance squared of sub-space s
  1643. diff = 0;
  1644. for ( p = 0; p < P[k]; p++ )
  1645. {
  1646. el = ( data[dataPoint+p+s] - yk_ptr[p+s+2] ) / h[k];
  1647. if ( ( !p ) && ( yk_ptr[2] > 80 ) )
  1648. diff += 4 * el * el;
  1649. else
  1650. diff += el * el;
  1651. }
  1652. //next subspace
  1653. s += P[k];
  1654. k++;
  1655. }
  1656. //if its inside search window perform sum and count
  1657. if ( diff < 1.0 )
  1658. {
  1659. weight = 1 - weightMap[i*width+j];
  1660. Mh_ptr[0] += weight * j;
  1661. Mh_ptr[1] += weight * i;
  1662. for ( k = 2; k < lN; k++ )
  1663. Mh_ptr[k] += weight * data[dataPoint+k-2];
  1664. wsum += weight;
  1665. }
  1666. //done.
  1667. }
  1668. //done.
  1669. return;
  1670. }
  1671. /*******************************************************/
  1672. /*Optimized Uniform Latice Search */
  1673. /*******************************************************/
  1674. /*Performs search on data set for all points lying */
  1675. /*within the search window defined using a uniform */
  1676. /*kernel. Their point-wise sum and count is computed */
  1677. /*and returned. Also the points that lie within the */
  1678. /*window are stored into the basin of attraction stru- */
  1679. /*cture used by the optimized mean shift algorithms. */
  1680. /* */
  1681. /*NOTE: This method is the only method in the */
  1682. /* MeanShift class that uses the weight */
  1683. /* map asside from uniformLSearch. */
  1684. /*******************************************************/
  1685. /*Pre: */
  1686. /* - Mh_ptr is a length N array of doubles */
  1687. /* - yk_ptr is a length N array of doubles */
  1688. /* - Mh_ptr is the sum of the data points found */
  1689. /* within search window having center yk_ptr */
  1690. /*Post: */
  1691. /* - a search on the data set using the latice */
  1692. /* has been performed, and all points found to */
  1693. /* lie within the search window defined using */
  1694. /* a uniform kernel are summed and counted. */
  1695. /* - their point wise sum is pointed to by Mh_ptr */
  1696. /* and their count is stored by wsum. */
  1697. /* - the data points lying within h of of yk_ptr */
  1698. /* have been stored into the basin of attract- */
  1699. /* ion data structure. */
  1700. /*******************************************************/
  1701. void MeanShift::optUniformLSearch ( double *Mh_ptr, double *yk_ptr )
  1702. {
  1703. //Declare variables
  1704. register int i, j, k;
  1705. int s, p, dataPoint, pointIndx, lN;
  1706. double diff, el, dx, dy, tx, weight;
  1707. //Define latice data dimension...
  1708. lN = N + 2;
  1709. //Define bounds of latice...
  1710. //the latice is a 2dimensional subspace whose
  1711. //search window bandwidth is specified by
  1712. //h[0]:
  1713. tx = yk_ptr[0] - h[0] + DELTA + 0.99;
  1714. if ( tx < 0 )
  1715. LowerBoundX = 0;
  1716. else
  1717. LowerBoundX = ( int ) tx;
  1718. tx = yk_ptr[1] - h[0] + DELTA + 0.99;
  1719. if ( tx < 0 )
  1720. LowerBoundY = 0;
  1721. else
  1722. LowerBoundY = ( int ) tx;
  1723. tx = yk_ptr[0] + h[0] - DELTA;
  1724. if ( tx >= width )
  1725. UpperBoundX = width - 1;
  1726. else
  1727. UpperBoundX = ( int ) tx;
  1728. tx = yk_ptr[1] + h[0] - DELTA;
  1729. if ( tx >= height )
  1730. UpperBoundY = height - 1;
  1731. else
  1732. UpperBoundY = ( int ) tx;
  1733. //Perform search using latice
  1734. for ( i = LowerBoundY; i <= UpperBoundY; i++ )
  1735. for ( j = LowerBoundX; j <= UpperBoundX; j++ )
  1736. {
  1737. //get index into data array
  1738. pointIndx = i * width + j;
  1739. dataPoint = N * ( pointIndx );
  1740. //Determine if inside search window
  1741. k = 1;
  1742. s = 0;
  1743. dx = j - yk_ptr[0];
  1744. dy = i - yk_ptr[1];
  1745. diff = ( dx * dx + dy * dy ) / ( h[0] * h[0] );
  1746. while ( ( diff < 1.0 ) && ( k != kp ) ) // Partial Distortion Search
  1747. {
  1748. //Calculate distance squared of sub-space s
  1749. diff = 0;
  1750. for ( p = 0; p < P[k]; p++ )
  1751. {
  1752. el = ( data[dataPoint+p+s] - yk_ptr[p+s+2] ) / h[k];
  1753. if ( ( !p ) && ( yk_ptr[2] > 80 ) )
  1754. diff += 4 * el * el;
  1755. else
  1756. diff += el * el;
  1757. }
  1758. //next subspace
  1759. s += P[k];
  1760. k++;
  1761. }
  1762. //if its inside search window perform sum and count
  1763. if ( diff < 1.0 )
  1764. {
  1765. weight = 1 - weightMap[i*width+j];
  1766. Mh_ptr[0] += weight * j;
  1767. Mh_ptr[1] += weight * i;
  1768. for ( k = 2; k < lN; k++ )
  1769. Mh_ptr[k] += weight * data[dataPoint+k-2];
  1770. wsum += weight;
  1771. //set basin of attraction mode table
  1772. if ( diff < 0.5 )
  1773. {
  1774. if ( modeTable[pointIndx] == 0 )
  1775. {
  1776. pointList[pointCount++] = pointIndx;
  1777. modeTable[pointIndx] = 2;
  1778. }
  1779. }
  1780. }
  1781. //done.
  1782. }
  1783. //done.
  1784. return;
  1785. }
  1786. /*******************************************************/
  1787. /*General Lattice Search */
  1788. /*******************************************************/
  1789. /*Performs search on data set for all points lying */
  1790. /*within the search window defined using a general */
  1791. /*kernel. Their point-wise sum and count is computed */
  1792. /*and returned. */
  1793. /*******************************************************/
  1794. /*Pre: */
  1795. /* - Mh_ptr is a length N array of doubles */
  1796. /* - yk_ptr is a length N array of doubles */
  1797. /* - Mh_ptr is the sum of the data points found */
  1798. /* within search window having center yk_ptr */
  1799. /*Post: */
  1800. /* - a search on the data set using the lattice */
  1801. /* has been performed, and all points found to */
  1802. /* lie within the search window defined using */
  1803. /* a general kernel are summed and counted */
  1804. /* - their point wise sum is pointed to by Mh_ptr */
  1805. /* and their count is stored by wsum */
  1806. /*******************************************************/
  1807. void MeanShift::generalLSearch ( double *Mh_ptr, double *yk_ptr )
  1808. {
  1809. //Declare variables
  1810. register int i, j, k;
  1811. int s, p, dataPoint, lN, x0, x1;
  1812. double diff, el, dx, dy, tw, u, y0, y1, tx;
  1813. //Define lattice data dimension...
  1814. lN = N + 2;
  1815. //Define bounds of lattice...
  1816. //the lattice is a 2dimensional subspace whose
  1817. //search window bandwidth is specified by
  1818. //h[0]:
  1819. tx = yk_ptr[0] - h[0] + DELTA + 0.99;
  1820. if ( tx < 0 )
  1821. LowerBoundX = 0;
  1822. else
  1823. LowerBoundX = ( int ) tx;
  1824. tx = yk_ptr[1] - h[0] + DELTA + 0.99;
  1825. if ( tx < 0 )
  1826. LowerBoundY = 0;
  1827. else
  1828. LowerBoundY = ( int ) tx;
  1829. tx = yk_ptr[0] + h[0] - DELTA;
  1830. if ( tx >= width )
  1831. UpperBoundX = width - 1;
  1832. else
  1833. UpperBoundX = ( int ) tx;
  1834. tx = yk_ptr[1] + h[0] - DELTA;
  1835. if ( tx >= height )
  1836. UpperBoundY = height - 1;
  1837. else
  1838. UpperBoundY = ( int ) tx;
  1839. //Perform search using lattice
  1840. for ( i = LowerBoundY; i <= UpperBoundY; i++ )
  1841. for ( j = LowerBoundX; j <= UpperBoundX; j++ )
  1842. {
  1843. //get index into data array
  1844. dataPoint = N * ( i * width + j );
  1845. //Determine if inside search window
  1846. k = 1;
  1847. s = 0;
  1848. dx = j - yk_ptr[0];
  1849. dy = i - yk_ptr[1];
  1850. uv[0] = ( dx * dx ) / ( h[0] * h[0] );
  1851. uv[1] = ( dy * dy ) / ( h[0] * h[0] );
  1852. diff = uv[0] + uv[1];
  1853. while ( ( diff < offset[k-1] ) && ( k != kp ) ) // Partial Distortion Search
  1854. {
  1855. //Calculate distance squared of sub-space s
  1856. diff = 0;
  1857. for ( p = 0; p < P[k]; p++ )
  1858. {
  1859. el = ( data[dataPoint+p+s] - yk_ptr[p+s+2] ) / h[k];
  1860. diff += uv[p+s+2] = el * el;
  1861. }
  1862. //next subspace
  1863. s += P[k];
  1864. k++;
  1865. }
  1866. //if its inside search window perform weighted sum and count
  1867. if ( diff < offset[k-1] )
  1868. {
  1869. // Initialize total weight to 1
  1870. tw = 1;
  1871. // Calculate weight factor using weight function
  1872. // lookup tables and uv
  1873. s = 0;
  1874. for ( k = 0; k < kp; k++ )
  1875. {
  1876. if ( kernel[k] ) // not uniform kernel
  1877. {
  1878. // Compute u[i]
  1879. u = 0;
  1880. for ( p = 0; p < P[k]; p++ )
  1881. u += uv[s+p];
  1882. // Accumulate tw using calculated u
  1883. // and weight function lookup table
  1884. // Linear interpolate values given by
  1885. // lookup table
  1886. // Calculate x0 and x1, the points surounding
  1887. // u
  1888. x0 = ( int ) ( u / increment[k] );
  1889. x1 = x0 + 1;
  1890. // Get y0 and y1 from the lookup table
  1891. y0 = w[k][x0];
  1892. y1 = w[k][x1];
  1893. // Accumulate tw using linear interpolation
  1894. tw *= ( ( ( double ) ( x1 ) * increment[k] - u ) * y0 + ( u - ( double ) ( x0 ) * increment[k] ) * y1 ) / ( double ) ( x1 * increment[k] - x0 * increment[k] );
  1895. }
  1896. s += P[k]; // next subspace
  1897. }
  1898. // Perform weighted sum using xi
  1899. Mh_ptr[0] += tw * j;
  1900. Mh_ptr[1] += tw * i;
  1901. for ( k = 0; k < N; k++ )
  1902. Mh_ptr[k+2] += tw * data[dataPoint+k];
  1903. // Increment wsum by tw
  1904. wsum += tw;
  1905. }
  1906. //done.
  1907. }
  1908. //done.
  1909. return;
  1910. }
  1911. /*******************************************************/
  1912. /*Optimized General Lattice Search */
  1913. /*******************************************************/
  1914. /*Performs search on data set for all points lying */
  1915. /*within the search window defined using a general */
  1916. /*kernel. Their point-wise sum and count is computed */
  1917. /*and returned. Also the points that lie within the */
  1918. /*window are stored into the basin of attraction stru- */
  1919. /*cture used by the optimized mean shift algorithms. */
  1920. /*******************************************************/
  1921. /*Pre: */
  1922. /* - Mh_ptr is a length N array of doubles */
  1923. /* - yk_ptr is a length N array of doubles */
  1924. /* - Mh_ptr is the sum of the data points found */
  1925. /* within search window having center yk_ptr */
  1926. /*Post: */
  1927. /* - a search on the data set using the lattice */
  1928. /* has been performed, and all points found to */
  1929. /* lie within the search window defined using */
  1930. /* a general kernel are summed and counted */
  1931. /* - their point wise sum is pointed to by Mh_ptr */
  1932. /* and their count is stored by wsum */
  1933. /* - the data points lying within h*offset of */
  1934. /* yk_ptr have been stored into the basin of */
  1935. /* attraction data structure. */
  1936. /*******************************************************/
  1937. void MeanShift::optGeneralLSearch ( double *Mh_ptr, double *yk_ptr )
  1938. {
  1939. //Declare variables
  1940. register int i, j, k;
  1941. int s, p, dataPoint, pointIndx, lN, x0, x1;
  1942. double diff, el, dx, dy, tw, u, y0, y1, tx;
  1943. //Define lattice data dimension...
  1944. lN = N + 2;
  1945. //Define bounds of lattice...
  1946. //the lattice is a 2dimensional subspace whose
  1947. //search window bandwidth is specified by
  1948. //h[0]:
  1949. tx = yk_ptr[0] - h[0] + DELTA + 0.99;
  1950. if ( tx < 0 )
  1951. LowerBoundX = 0;
  1952. else
  1953. LowerBoundX = ( int ) tx;
  1954. tx = yk_ptr[1] - h[0] + DELTA + 0.99;
  1955. if ( tx < 0 )
  1956. LowerBoundY = 0;
  1957. else
  1958. LowerBoundY = ( int ) tx;
  1959. tx = yk_ptr[0] + h[0] - DELTA;
  1960. if ( tx >= width )
  1961. UpperBoundX = width - 1;
  1962. else
  1963. UpperBoundX = ( int ) tx;
  1964. tx = yk_ptr[1] + h[0] - DELTA;
  1965. if ( tx >= height )
  1966. UpperBoundY = height - 1;
  1967. else
  1968. UpperBoundY = ( int ) tx;
  1969. //Perform search using lattice
  1970. for ( i = LowerBoundY; i <= UpperBoundY; i++ )
  1971. for ( j = LowerBoundX; j <= UpperBoundX; j++ )
  1972. {
  1973. //get index into data array
  1974. pointIndx = i * width + j;
  1975. dataPoint = N * ( i * width + j );
  1976. //Determine if inside search window
  1977. k = 1;
  1978. s = 0;
  1979. dx = j - yk_ptr[0];
  1980. dy = i - yk_ptr[1];
  1981. uv[0] = ( dx * dx ) / ( h[0] * h[0] );
  1982. uv[1] = ( dy * dy ) / ( h[0] * h[0] );
  1983. diff = uv[0] + uv[1];
  1984. while ( ( diff < offset[k-1] ) && ( k != kp ) ) // Partial Distortion Search
  1985. {
  1986. //Calculate distance squared of sub-space s
  1987. diff = 0;
  1988. for ( p = 0; p < P[k]; p++ )
  1989. {
  1990. el = ( data[dataPoint+p+s] - yk_ptr[p+s+2] ) / h[k];
  1991. diff += uv[p+s+2] = el * el;
  1992. }
  1993. //next subspace
  1994. s += P[k];
  1995. k++;
  1996. }
  1997. //if its inside search window perform weighted sum and count
  1998. if ( diff < offset[k-1] )
  1999. {
  2000. // Initialize total weight to 1
  2001. tw = 1;
  2002. // Calculate weight factor using weight function
  2003. // lookup tables and uv
  2004. s = 0;
  2005. for ( k = 0; k < kp; k++ )
  2006. {
  2007. if ( kernel[k] ) // not uniform kernel
  2008. {
  2009. // Compute u[i]
  2010. u = 0;
  2011. for ( p = 0; p < P[k]; p++ )
  2012. u += uv[s+p];
  2013. // Accumulate tw using calculated u
  2014. // and weight function lookup table
  2015. // Linear interpolate values given by
  2016. // lookup table
  2017. // Calculate x0 and x1, the points surounding
  2018. // u
  2019. x0 = ( int ) ( u / increment[k] );
  2020. x1 = x0 + 1;
  2021. // Get y0 and y1 from the lookup table
  2022. y0 = w[k][x0];
  2023. y1 = w[k][x1];
  2024. // Accumulate tw using linear interpolation
  2025. tw *= ( ( ( double ) ( x1 ) * increment[k] - u ) * y0 + ( u - ( double ) ( x0 ) * increment[k] ) * y1 ) / ( double ) ( x1 * increment[k] - x0 * increment[k] );
  2026. }
  2027. s += P[k]; // next subspace
  2028. }
  2029. // Perform weighted sum using xi
  2030. Mh_ptr[0] += tw * j;
  2031. Mh_ptr[1] += tw * i;
  2032. for ( k = 0; k < N; k++ )
  2033. Mh_ptr[k+2] += tw * data[dataPoint+k];
  2034. // Increment wsum by tw
  2035. wsum += tw;
  2036. //set basin of attraction mode table
  2037. if ( modeTable[pointIndx] == 0 )
  2038. {
  2039. pointList[pointCount++] = pointIndx;
  2040. modeTable[pointIndx] = 2;
  2041. }
  2042. }
  2043. //done.
  2044. }
  2045. //done.
  2046. return;
  2047. }
  2048. /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
  2049. /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
  2050. /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ END OF CLASS DEFINITION @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
  2051. /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/
  2052. /*@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@*/