001/*******************************************************************************
002 * This software is provided as a supplement to the authors' textbooks on digital
003 *  image processing published by Springer-Verlag in various languages and editions.
004 * Permission to use and distribute this software is granted under the BSD 2-Clause 
005 * "Simplified" License (see http://opensource.org/licenses/BSD-2-Clause). 
006 * Copyright (c) 2006-2016 Wilhelm Burger, Mark J. Burge. All rights reserved. 
007 * Visit http://imagingbook.com for additional details.
008 *******************************************************************************/
009
010package imagingbook.pub.edgepreservingfilters;
011
012import ij.IJ;
013import ij.plugin.filter.Convolver;
014import ij.process.ColorProcessor;
015import ij.process.FloatProcessor;
016import ij.process.ImageProcessor;
017
018// TODO: convert to subclass of GenericFilter using ImageAccessor (see BilateralFilter)
019
020/**
021 * This class implements the Anisotropic Diffusion filter proposed by David Tschumperle 
022 * in D. Tschumperle and R. Deriche, "Diffusion PDEs on vector-valued images", 
023 * IEEE Signal Processing Magazine, vol. 19, no. 5, pp. 16-25 (Sep. 2002). It is based 
024 * on an earlier C++ (CImg) implementation (pde_TschumperleDeriche2d.cpp) by the original
025 * author, made available under the CeCILL v2.0 license 
026 * (http://www.cecill.info/licences/Licence_CeCILL_V2-en.html).
027 * 
028 * This class is based on the ImageJ API and intended to be used in ImageJ plugins.
029 * How to use: consult the source code of the related ImageJ plugins for examples.
030 * 
031 * @author W. Burger
032 * @version 2013/05/30
033 */
034
035public class TschumperleDericheFilter {
036        
037        public static class Parameters {
038                /** Number of smoothing iterations */
039                public int iterations = 20;     
040                /** Adapting time step */
041                public double dt = 20.0;                
042                /** Gradient smoothing (sigma of Gaussian) */
043                public double sigmaG  = 0.5;    
044                /** Structure tensor smoothing (sigma of Gaussian) */
045                public double sigmaS  = 0.5;    
046                /** Diff. limiter along minimal var. (small value = strong smoothing) */
047                public float a1 = 0.25f;                
048                /** Diff. limiter along maximal var. (small value = strong smoothing) */
049                public float a2 = 0.90f;
050                /** Set true to apply the filter in linear RGB (assumes sRGB input) */
051                public boolean useLinearRgb = false;
052        }
053        
054        private final Parameters params;
055        private final int T;                    // number of iterations
056        
057        private int M;  // image width
058        private int N;  // image height
059        private int K;  // number of color channels, k = 0,...,K-1
060
061        private float[][][] I;          // float image data:            I[k][u][v] for color channel k
062        private float[][][] Dx;         // image x-gradient:            Dx[k][u][v] for color channel k
063        private float[][][] Dy;         // image y-gradient:            Dx[k][u][v] for color channel k
064        private float[][][] G;          // 2x2 structure tensor:        G[i][u][v], i=0,1,2 (only 3 elements because of symmetry)
065        private float[][][] A;          // 2x2 tensor field:            A[i][u][v], i=0,1,2 (only 3 elements because of symmetry)
066        private float[][][] B;          // scalar local velocity:   B[k][u][v]  for channel k (== beta_k)
067//      private float[][][] Hk;         // Hessian matrix for channel k: Hk[i][u][v], i=0,1,2
068        FloatProcessor tmpFp;           // used as temporary storage for blurring
069        
070        private float initial_max;
071        private float initial_min;
072        
073        // constructor - uses only default settings:
074        public TschumperleDericheFilter() {
075                this(new Parameters());
076        }
077        
078        // constructor - use for setting individual parameters:
079        public TschumperleDericheFilter(Parameters params) {
080                this.params = params;
081                T = params.iterations;
082        }
083        
084        /* This method applies the filter to the given image (ip). 
085         * Note that ip is destructively modified.
086         */
087        public void applyTo(ImageProcessor ip) {        
088                initialize(ip);
089                // main iteration loop
090                for (int n = 1; n <= T; n++) {
091                        IJ.showProgress(n, T);
092                        
093                        // Step 1:
094                        calculateGradients(I, Dx, Dy);
095                        
096                        // Step 2:
097                        smoothGradients(Dx, Dy);
098                        
099                        // Step 3: Hessian matrix is only calculated locally as part of Step 8.
100                        
101                        // Step 4:
102                        calculateStructureMatrix(Dx, Dy, G);
103                        // Step 5:
104                        smoothStructureMatrix(G);
105
106                        // Step 6-7:
107                        calculateGeometryMatrix(G, A);
108                        
109                        // Step 8:
110                        float maxVelocity = calculateVelocities(I, A, B);
111                        
112                        double alpha = params.dt / maxVelocity;
113                        updateImage(I, B, alpha);
114                }
115                copyResultToImage(ip);
116                cleanUp();
117        } 
118        
119        // -------------------------------------------------------------------------
120        
121        /*
122         * Create temporary arrays, copy image data and calculate
123         * initial image statistics (all in one pass).
124         */
125        private void initialize(ImageProcessor ip) {
126                M = ip.getWidth(); 
127                N = ip.getHeight(); 
128                K = (ip instanceof ColorProcessor) ? 3 : 1;
129                I  = new float[K][M][N];
130                Dx  = new float[K][M][N];
131                Dy  = new float[K][M][N];
132                G = new float[3][M][N];
133                A = new float[3][M][N];
134                B = new float[K][M][N];
135//              Hk = new float[3][M][N];
136
137                if (ip instanceof ColorProcessor) {
138                        final int[] pixel = new int[K]; 
139                        for (int u = 0; u < M; u++) {
140                                for (int v = 0; v < N; v++) {
141                                        ip.getPixel(u, v, pixel);
142                                        for (int k = 0; k < K; k++) {
143                                                float c = pixel[k];
144                                                I[k][u][v] = params.useLinearRgb ? srgbToRgb(c) : c;
145                                        }
146                                }
147                        }
148                }
149                else {  // 8-bit, 16-bit or 32-bit (float) processor
150                        for (int u = 0; u < M; u++) {
151                                for (int v = 0; v < N; v++) {
152                                        I[0][u][v] = ip.getf(u,v);
153                                }
154                        }
155                }
156                getImageMinMax();
157        }
158        
159        void getImageMinMax() {
160                float max = Float.MIN_VALUE;
161                float min = Float.MAX_VALUE;
162                for (int u = 0; u < M; u++) {
163                        for (int v = 0; v < N; v++) {
164                                for (int k = 0; k < K; k++) {
165                                        float p = I[k][u][v];
166                                        if (p>max) max = p;
167                                        if (p<min) min = p;
168                                }
169                        }
170                }
171                initial_max = max;
172                initial_min = min;
173        }
174        
175        void cleanUp() {
176                I  = null;              Dx = null;              Dy = null;
177                G = null;               A = null;               B = null;
178//              Hk = null;              
179                tmpFp = null;
180        }
181        
182        void calculateGradients(float[][][] I, float[][][] Dx, float[][][] Dy) {
183                // these Gradient kernels produce reduced artifacts
184                final float c1 = (float) (2 - Math.sqrt(2.0)) / 4;
185                final float c2 = (float) (Math.sqrt(2.0) - 1) / 2;
186                
187                final float[][] Hdx = 
188                        {{-c1, 0, c1},
189                         {-c2, 0, c2},
190                         {-c1, 0, c1}};
191                
192                final float[][] Hdy = 
193                        {{-c1, -c2, -c1},
194                         {  0,   0,   0},
195                         { c1,  c2,  c1}};
196
197                for (int k = 0; k < K; k++) {
198                        convolve2dArray(I[k], Dx[k], Hdx);
199                        convolve2dArray(I[k], Dy[k], Hdy);
200                }
201        }
202        
203        void smoothGradients(float[][][] Dx, float[][][] Dy) {
204                for (int k = 0; k < Dx.length; k++) {
205                        gaussianBlur(Dx[k], params.sigmaG);
206                }
207                for (int k = 0; k < Dy.length; k++) {
208                        gaussianBlur(Dy[k], params.sigmaG);
209                }
210        }
211        
212        void calculateStructureMatrix(float[][][] Dx, float[][][] Dy, float[][][] G) {
213                // compute structure tensor field G
214                // G = new float[width][height][3]; // must be clean for each slice
215                for (int u = 0; u < M; u++) {
216                        for (int v = 0; v < N; v++) {
217                                G[0][u][v]= 0.0f;
218                                G[1][u][v]= 0.0f;
219                                G[2][u][v]= 0.0f;
220                                for (int k = 0; k < K; k++) {
221                                        //version 0.2 normalization
222                                        float fx = Dx[k][u][v];
223                                        float fy = Dy[k][u][v];
224                                        G[0][u][v] += fx * fx;
225                                        G[1][u][v] += fx * fy;
226                                        G[2][u][v] += fy * fy;
227                                }
228                        }
229                }
230        }
231        
232        void smoothStructureMatrix(float[][][] G) {
233                for (int i = 0; i < G.length; i++) {
234                        gaussianBlur(G[i], params.sigmaS);
235                }
236        }
237        
238        /*
239         * Compute the local geometry matrix A (used to drive the diffusion process)
240         * from the structure matrix G.
241         */
242        void calculateGeometryMatrix(float[][][] G, float[][][] A) {
243                final double[] lambda12 = new double[2];        // eigenvalues
244                final double[] e1 = new double[2];                      // eigenvectors
245                final double[] e2 = new double[2];
246                final double a1 = params.a1;
247                final double a2 = params.a2;
248                for (int u = 0; u < M; u++) {
249                        for (int v = 0; v < N; v++) {
250                                final double G0 = G[0][u][v];   // elements of local geometry matrix (2x2)
251                                final double G1 = G[1][u][v];
252                                final double G2 = G[2][u][v];
253                                // calculate eigenvalues:
254                                if (!realEigenValues2x2(G0, G1, G1, G2, lambda12, e1, e2)) {
255                                        throw new RuntimeException("eigenvalues undefined in " + 
256                                                                TschumperleDericheFilter.class.getSimpleName());
257                                }
258                                final double val1 = lambda12[0];
259                                final double val2 = lambda12[1];
260                                final double arg = 1.0 + val1 + val2;
261                                final float c1 = (float) Math.pow(arg, -a1);
262                                final float c2 = (float) Math.pow(arg, -a2);
263                                
264                                // calculate eigenvectors:
265                                normalize(e1);
266                                final float ex = (float) e1[0];
267                                final float ey = (float) e1[1];
268                                final float exx = ex * ex;
269                                final float exy = ex * ey;
270                                final float eyy = ey * ey;
271                                A[0][u][v] = c1 * eyy + c2 * exx;
272                                A[1][u][v] = (c2 - c1)* exy;
273                                A[2][u][v] = c1 * exx + c2 * eyy;
274                        }
275                }
276        }
277        
278        // Calculate the Hessian matrix Hk for a single position (u,v) in image Ik.
279        void calculateHessianMatrix(float[][] Ik, int u, int v, float[] Hk) {
280                final int pu = (u > 0) ? u-1 : 0; 
281                final int nu = (u < M-1) ? u+1 : M-1;
282                final int pv = (v > 0) ? v-1 : 0; 
283                final int nv = (v < N-1) ? v+1 : N-1;
284                float icc = Ik[u][v];
285                Hk[0] = Ik[pu][v] + Ik[nu][v] - 2 * icc;                                                                // = H_xx(u,v)
286                Hk[1] = 0.25f * (Ik[pu][pv] + Ik[nu][nv] - Ik[pu][nv] - Ik[nu][pv]);    // = H_xy(u,v)
287                Hk[2] = Ik[u][nv] + Ik[u][pv] - 2 * icc;                                                                // = H_yy(u,v)
288        }
289        
290        /*
291         * Calculate the local image velocity B(k,u,v) from the geometry matrix A(i,u,v)
292         * and the Hessian matrix Hkuv.
293         */
294        float calculateVelocities(float[][][] I, float[][][] A, float[][][] B) {
295                float maxV = Float.MIN_VALUE;
296                float minV = Float.MAX_VALUE;
297                final float[] Hkuv = new float[3];
298                for (int k = 0; k < K; k++) {
299                        for (int u = 0; u < M; u++) {
300                                for (int v = 0; v < N; v++) {
301                                        calculateHessianMatrix(I[k], u, v, Hkuv);
302                                        final float a = A[0][u][v];
303                                        final float b = A[1][u][v];
304                                        final float c = A[2][u][v];                                     
305                                        final float ixx = Hkuv[0]; 
306                                        final float ixy = Hkuv[1]; 
307                                        final float iyy = Hkuv[2];
308                                        final float vel = a * ixx + 2 * b * ixy + c * iyy; 
309                                        // find min/max velocity for time-step adaptation
310                                        if (vel > maxV) maxV = vel;
311                                        if (vel < minV) minV = vel;
312                                        B[k][u][v] = vel;
313                                }
314                        }
315                }
316                return Math.max(Math.abs(maxV), Math.abs(minV));
317        }
318
319        // Calculate the Hessian matrix Hk for the whole (single-channel) image Ik.
320//      void calculateHessianMatrix(float[][] Ik, float[][][] Hk) {
321//              for (int u = 0; u < M; u++) {
322//                      final int pu = (u > 0) ? u-1 : 0; 
323//                      final int nu = (u < M-1) ? u+1 : M-1;
324//                      for (int v = 0; v < N; v++) {
325//                              final int pv = (v > 0) ? v-1 : 0; 
326//                              final int nv = (v < N-1) ? v+1 : N-1;
327//                              float icc = Ik[u][v];
328//                              Hk[0][u][v] = Ik[pu][v] + Ik[nu][v] - 2 * icc;                                                          // = H_xx(u,v)
329//                              Hk[1][u][v] = 0.25f * (Ik[pu][pv] + Ik[nu][nv] - Ik[pu][nv] - Ik[nu][pv]);      // = H_xy(u,v)
330//                              Hk[2][u][v] = Ik[u][nv] + Ik[u][pv] - 2 * icc;                                                          // = H_yy(u,v)
331//                      }
332//              }
333//      }
334        
335//      float calculateVelocity(float[][][] I, float[][][] A, float[][][] B) {
336//              float maxV = Float.MIN_VALUE;
337//              float minV = Float.MAX_VALUE;
338//              for (int k = 0; k < K; k++) {
339//                      // calculate the Hessian matrix for channel k:
340//                      calculateHessianMatrix(I[k], Hk);
341//                      for (int u = 0; u < M; u++) {
342//                              for (int v = 0; v < N; v++) {
343//                                      float a = A[0][u][v];
344//                                      float b = A[1][u][v];
345//                                      float c = A[2][u][v];                                   
346//                                      float ixx = Hk[0][u][v]; 
347//                                      float ixy = Hk[1][u][v]; 
348//                                      float iyy = Hk[2][u][v];
349//                                      float vel = a * ixx + 2 * b * ixy + c * iyy; 
350//                                      // find min/max velocity for time-step adaptation
351//                                      if (vel > maxV) maxV = vel;
352//                                      if (vel < minV) minV = vel;
353//                                      B[k][u][v] = vel;
354//                              }
355//                      }
356//              }
357//              return Math.max(Math.abs(maxV), Math.abs(minV));
358//      }
359        
360        void updateImage(float[][][] I, float[][][] B, double alpha) {
361                final float alphaF = (float) alpha;
362                for (int k = 0; k < K; k++) {
363                        for (int u = 0; u < M; u++) {
364                                for (int v = 0; v < N; v++) {
365                                        float inew = I[k][u][v] + alphaF * B[k][u][v];
366                                        // clamp image to the original range (brute!)
367                                        if (inew < initial_min) inew = initial_min;
368                                        if (inew > initial_max) inew = initial_max;
369                                        I[k][u][v] = inew;
370                                }
371                        }
372                }
373        }
374        
375        void copyResultToImage(ImageProcessor ip) {
376                final int[] pixel = new int[K];
377                if (ip instanceof ColorProcessor) {
378                        for (int u = 0; u < M; u++) {
379                                for (int v = 0; v < N; v++) {
380                                        for (int k = 0; k < K; k++) {
381                                                int c = params.useLinearRgb ? 
382                                                                Math.round(rgbToSrgb(I[k][u][v])) : 
383                                                                Math.round(I[k][u][v]);
384                                                if (c < 0) c = 0;
385                                                if (c > 255) c = 255;
386                                                pixel[k] = c;
387                                        }
388                                        ip.putPixel(u,v,pixel);
389                                }
390                        }
391                }
392                else {  // 8-bit, 16-bit or 32-bit (float) processor
393                        for (int u = 0; u < M; u++) {
394                                for (int v = 0; v < N; v++) {
395                                        ip.setf(u, v, I[0][u][v]);
396                                }
397                        }
398                }
399        }
400        
401        // Utility methods -------------------------------------------------
402        
403        /*
404         * Blur the 2D array source with a Gaussian kernel of width sigma
405         * and store the result in target.
406         */
407        void gaussianBlur(float[][] source, float[][] target, double sigma) {
408                if (sigma < 0.1) return;
409                if (source.length != target.length || source[0].length != target[0].length) {
410                        throw new Error("source/target arrays have different dimensions");
411                }
412                float[][] Hgx = makeGaussKernel1D(sigma, true);         // horizontal 1D kernel
413                float[][] Hgy = makeGaussKernel1D(sigma, false);        // vertical 1D kernel
414                convolve2dArray(source, target, Hgx, Hgy);
415        }
416        
417        void gaussianBlur(float[][] source, double sigma) {     // source = target
418                gaussianBlur(source, source, sigma);
419        }
420        
421        // ----------------------------------------------------------------
422        
423        /*
424         * Convolve the 2D array source successively with a sequence of kernels
425         * and store the result in target.
426         * This should eventually be implemented without an ImageJ FloatProcessor!
427         */
428        void convolve2dArray(float[][] source, float[][] target, float[][]... kernels) {
429                if (source.length != target.length || source[0].length != target[0].length) {
430                        throw new Error("source/target arrays have different dimensions");
431                }
432                int w = source.length;
433                int h = source[0].length;
434                if (tmpFp == null || tmpFp.getWidth() != w || tmpFp.getHeight() != h) {
435                        tmpFp = new FloatProcessor(w, h);
436                }
437                // copy data to FloatProcessors         
438                for (int u = 0; u < w; u++) {
439                        for (int v = 0; v < h; v++) {
440                                tmpFp.setf(u, v, source[u][v]);
441                        }
442                }
443                
444                Convolver conv = new Convolver();
445                conv.setNormalize(false);
446                // convolve with all specified kernels
447                for (float[][] H : kernels) {
448                        if (H == null) break;
449                        int wH = H.length;
450                        int hH = H[0].length;
451                        float[] H1 = flatten(H);
452                        conv.convolveFloat(tmpFp, H1, wH, hH);
453                }
454                // copy data back to array      
455                for (int u = 0; u < w; u++) {
456                        for (int v = 0; v < h; v++) {
457                                target[u][v] = tmpFp.getf(u, v);
458                        }
459                }
460        }
461        
462        /*
463         * Copy a 2D float  array into a 1D float array
464         */
465        float[] flatten (float[][] arr2d) {
466                int w = arr2d.length;
467                int h = arr2d[0].length;
468                float[] arr1d = new float[w*h];
469                int k = 0;
470                for (int i = 0; i < w; i++) {
471                        for (int j = 0; j < h; j++) {
472                                arr1d[k] = arr2d[i][j];
473                                k++;
474                        }
475                }
476                return arr1d;
477        }
478        
479        /*
480         * Construct a 2D Gaussian filter kernel large enough to avoid truncation effects.
481         * Returns a 1D kernel as a 2D array, so it can be used flexibly in horizontal
482         * or vertical direction.
483         */
484        private float[][] makeGaussKernel1D(double sigma, boolean horizontal){
485                // Construct a 2D Gaussian filter kernel large enough
486                // to avoid truncation effects.
487                final double sigma2 = sigma * sigma;
488                final double scale = 1.0 / (Math.sqrt(2 * Math.PI) * sigma);    
489                final int rad = Math.max((int) (3.5 * sigma), 1); 
490                int size = rad + 1 +rad;        //center cell = kernel[rad]
491                float[][] kernel = (horizontal) ?  
492                                new float[size][1] : 
493                                new float[1][size]; 
494                double sum = 0;
495                for (int i = 0; i < size; i++) {
496                        double x = rad - i;
497                        float val = (float) (scale * Math.exp(-0.5 * (x*x) / sigma2));
498                        if (horizontal) 
499                                kernel[i][0] =  val;
500                        else 
501                                kernel[0][i] =  val;
502                        sum = sum + val;
503                }
504                
505                // normalize (just to be safe)
506                for (int i = 0; i < kernel.length; i++) {
507                        for (int j = 0; j < kernel[i].length; j++) {
508                                kernel[i][j] = (float) (kernel[i][j] / sum);
509                        }
510                }
511                return kernel;
512        }
513        
514        // ----------------------------------------------------------------
515        
516        void normalize(double[] vec) {
517                double sum = 0;
518                for (double v : vec) {
519                        sum = sum + v * v;
520                }
521                if (sum > 0.000001) {
522                        double s = 1 / Math.sqrt(sum);
523                        for (int i = 0; i < vec.length; i++) {
524                                vec[i] = vec[i] * s;
525                        }
526                }
527        }
528        
529        boolean realEigenValues2x2 (
530                        double A, double B, double C, double D, 
531                        double[] lam12, double[] x1, double[] x2) {
532                final double R = (A + D) / 2;
533                final double S = (A - D) / 2;
534                final double V = S * S + B * C;
535                if (V < 0) 
536                        return false; // matrix has no real eigenvalues
537                else {
538                        double T = Math.sqrt(V);
539                        lam12[0] = R + T;       // lambda_1
540                        lam12[1] = R - T;       // lambda_2
541                        if ((A - D) >= 0) {
542                                x1[0] = S + T;  //e_1x
543                                x1[1] = C;              //e_1y                  
544                                x2[0] = B;              //e_2x
545                                x2[1] = -S - T; //e_2y          
546                        } 
547                        else {
548                                x1[0] = B;              //e_1x
549                                x1[1] = -S + T; //e_1y  
550                                x2[0] = S - T;  //e_2x
551                                x2[1] = C;              //e_2y  
552                        }
553                        return true;
554                }
555        }
556        
557        //  RGB/sRGB conversion -----------------------------------
558        // TODO: move this to lib.colorimage.sRgbUtil class
559        
560        float srgbToRgb(float nc) {
561                float nc01 = nc/255;
562                return (float)(gammaInv(nc01) * 255);
563        }
564
565        float rgbToSrgb(float lc) {
566                float lc01 = lc/255;
567                return (float) (gammaFwd(lc01) * 255);
568        }
569        
570        
571        double gammaFwd(double lc) {    // input: linear component value
572                return (lc > 0.0031308) ?
573                        (1.055 * Math.pow(lc, 1/2.4) - 0.055) :
574                        (lc * 12.92);
575    }
576    
577    double gammaInv(double nc) {        // input: nonlinear component value
578        return (nc > 0.03928) ?
579                        Math.pow((nc + 0.055)/1.055, 2.4) :
580                        (nc / 12.92);
581    }
582
583}
584