-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdt.cc
89 lines (77 loc) · 2.8 KB
/
dt.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include <math.h>
#include <sys/types.h>
#include "mex.h"
/*
* Generalized distance transforms.
* We use a simple nlog(n) divide and conquer algorithm instead of the
* theoretically faster linear method, for no particular reason except
* that this is a bit simpler and I wanted to test it out.
*
* The code is a bit convoluted because dt1d can operate either along
* a row or column of an array.
*/
static inline int square(int x) { return x*x; }
// dt helper function
void dt_helper(double *src, double *dst, int *ptr, int step,
int s1, int s2, int d1, int d2, double a, double b) {
if (d2 >= d1) {
int d = (d1+d2) >> 1;
int s = s1;
for (int p = s1+1; p <= s2; p++)
if (src[s*step] + a*square(d-s) + b*(d-s) >
src[p*step] + a*square(d-p) + b*(d-p))
s = p;
dst[d*step] = src[s*step] + a*square(d-s) + b*(d-s);
ptr[d*step] = s;
dt_helper(src, dst, ptr, step, s1, s, d1, d-1, a, b);
dt_helper(src, dst, ptr, step, s, s2, d+1, d2, a, b);
}
}
// dt of 1d array
void dt1d(double *src, double *dst, int *ptr, int step, int n,
double a, double b) {
dt_helper(src, dst, ptr, step, 0, n-1, 0, n-1, a, b);
}
// matlab entry point
// [M, Ix, Iy] = dt(vals, ax, bx, ay, by)
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
if (nrhs != 5)
mexErrMsgTxt("Wrong number of inputs");
if (nlhs != 3)
mexErrMsgTxt("Wrong number of outputs");
if (mxGetClassID(prhs[0]) != mxDOUBLE_CLASS)
mexErrMsgTxt("Invalid input");
const int *dims = mxGetDimensions(prhs[0]);
double *vals = (double *)mxGetPr(prhs[0]);
double ax = mxGetScalar(prhs[1]);
double bx = mxGetScalar(prhs[2]);
double ay = mxGetScalar(prhs[3]);
double by = mxGetScalar(prhs[4]);
mxArray *mxM = mxCreateNumericArray(2, dims, mxDOUBLE_CLASS, mxREAL);
mxArray *mxIx = mxCreateNumericArray(2, dims, mxINT32_CLASS, mxREAL);
mxArray *mxIy = mxCreateNumericArray(2, dims, mxINT32_CLASS, mxREAL);
double *M = (double *)mxGetPr(mxM);
int32_t *Ix = (int32_t *)mxGetPr(mxIx);
int32_t *Iy = (int32_t *)mxGetPr(mxIy);
double *tmpM = (double *)mxCalloc(dims[0]*dims[1], sizeof(double));
int32_t *tmpIx = (int32_t *)mxCalloc(dims[0]*dims[1], sizeof(int32_t));
int32_t *tmpIy = (int32_t *)mxCalloc(dims[0]*dims[1], sizeof(int32_t));
for (int x = 0; x < dims[1]; x++)
dt1d(vals+x*dims[0], tmpM+x*dims[0], tmpIy+x*dims[0], 1, dims[0], ay, by);
for (int y = 0; y < dims[0]; y++)
dt1d(tmpM+y, M+y, tmpIx+y, dims[0], dims[1], ax, bx);
// get argmins and adjust for matlab indexing from 1
for (int x = 0; x < dims[1]; x++) {
for (int y = 0; y < dims[0]; y++) {
int p = x*dims[0]+y;
Ix[p] = tmpIx[p]+1;
Iy[p] = tmpIy[tmpIx[p]*dims[0]+y]+1;
}
}
mxFree(tmpM);
mxFree(tmpIx);
mxFree(tmpIy);
plhs[0] = mxM;
plhs[1] = mxIx;
plhs[2] = mxIy;
}