001/*******************************************************************************
002 * This software is provided as a supplement to the authors' textbooks on digital
003 *  image processing published by Springer-Verlag in various languages and editions.
004 * Permission to use and distribute this software is granted under the BSD 2-Clause 
005 * "Simplified" License (see http://opensource.org/licenses/BSD-2-Clause). 
006 * Copyright (c) 2006-2016 Wilhelm Burger, Mark J. Burge. All rights reserved. 
007 * Visit http://imagingbook.com for additional details.
008 *******************************************************************************/
009
010package imagingbook.pub.color.quantize;
011
012import java.util.LinkedHashSet;
013import java.util.LinkedList;
014import java.util.List;
015import java.util.Locale;
016import java.util.Random;
017import java.util.Set;
018
019import imagingbook.pub.color.statistics.ColorHistogram;
020
021/**
022 * This class implements color quantization using k-means clustering
023 * of image pixels in RGB color space. Two modes of selecting
024 * the colors for the initial clusters are provided:
025 * (a) random sampling of the input colors,
026 * (b) using the most frequent colors.
027 * 
028 * @author WB
029 * @version 2017/01/04
030 */
031public class KMeansClusteringQuantizer extends ColorQuantizer {
032        
033        private final Parameters params;
034        private final int[][] colormap;
035        private final Cluster[] clusters;
036        private final double totalError;
037        
038        public enum SamplingMethod {
039                Random, Most_Frequent
040        };
041        
042        public static class Parameters {
043                /** Maximum number of quantized colors. */
044                public int maxColors = 16;
045                /** Maximum number of clustering iterations */
046                public int maxIterations = 500;
047                /** The method used for selecting the initial color samples. */
048                public SamplingMethod samplMethod = SamplingMethod.Random;
049                
050                void check() {
051                        if (maxColors < 2 || maxColors > 256 || maxIterations < 1) {
052                                throw new IllegalArgumentException();
053                        }
054                }
055        }
056        
057        // --------------------------------------------------------------
058
059        /**
060         * Creates a new quantizer instance from the supplied sequence
061         * of color values (assumed to be ARGB-encoded integers).
062         * 
063         * @param pixels Sequence of input color values.
064         * @param params Parameter object.
065         */
066        public KMeansClusteringQuantizer(int[] pixels, Parameters params) {
067                params.check();
068                this.params = params;
069                clusters = makeClusters(pixels);
070                totalError = cluster(pixels);
071                colormap = makeColorMap();
072        }
073        
074        public KMeansClusteringQuantizer(int[] pixels) {
075                this(pixels, new Parameters());
076        }
077        
078        // --------------------------------------------------------------
079
080        private Cluster[] makeClusters(int[] pixels) {
081                int Kmax = Math.min(pixels.length, params.maxColors);
082                int[] samples = getColorSamples(pixels, Kmax);
083                int k = Math.min(samples.length, Kmax);
084                Cluster[] cls = new Cluster[k]; // create an array of K clusters
085                for (int i = 0; i < k; i++) {
086                        cls[i] = new Cluster(samples[i]); // initialize cluster center
087                }
088                return cls; 
089        }
090        
091        /**
092         * We randomly pick k distinct colors from the original image
093         * pixels.
094         * @param pixels
095         * @param k
096         * @return
097         */
098        private int[] getColorSamples(int[] pixels, int k) {
099                switch (params.samplMethod) {
100                case Random:
101                        return getRandomColors(pixels, k);
102                case Most_Frequent:
103                        return getMostFrequentColors(pixels, k);
104                default:
105                        return null;
106                }
107        }
108
109        private int[] getRandomColors(int[] pixels, int k) {
110                Random rng = new Random();
111                Set<Integer> pixelSet = new LinkedHashSet<Integer>();
112                while (pixelSet.size() < k) {
113                        Integer next = rng.nextInt(pixels.length);
114                        int p = pixels[next];
115                        // adding to a set automatically does a containment check
116                        pixelSet.add(p);
117                }
118                int[] s = new int[k];
119                int i = 0;
120                for (Integer p : pixelSet) {
121                        s[i] = p;
122                        i++;
123                }
124                return s;
125        }
126        
127        private int[] getMostFrequentColors(int[] pixels, int k) {
128                ColorHistogram colorHist = new ColorHistogram(pixels, true);
129                k = Math.min(k, colorHist.getNumberOfColors());
130                int[] s = new int[k];
131                for (int i = 0; i < k; i++) {
132                        s[i] = colorHist.getColor(i);
133                }
134                return s;
135        }
136        
137        private double cluster(int[] pixels) {
138                int changed = Integer.MAX_VALUE;
139                double distSum = Double.POSITIVE_INFINITY;
140                int j = 0;
141                while (changed > 0 && j < params.maxIterations) {
142                        distSum = assignSamples(pixels);
143                        changed = updateClusters();
144                        j++;
145                }
146                return distSum;
147        }
148
149        
150        private double assignSamples(int[] pixels) {
151                double distSum = 0;
152                for (int p : pixels) {
153                        double dist = addToClosestCluster(p);
154                        distSum = distSum + dist;
155                }
156                return distSum;
157        }
158        
159        private int updateClusters() {
160                int changed = 0;
161                for (Cluster c : clusters) {
162                        changed = changed + c.upDate();
163                }
164                return changed;
165        }
166        
167        private double addToClosestCluster(int p) {
168                double minDist = Double.POSITIVE_INFINITY;
169                Cluster closest = null;
170                for (Cluster c : clusters) {
171                        double d = c.getDistance(p);
172                        if (d < minDist) {
173                                minDist = d;
174                                closest = c;
175                        }
176                }
177                closest.addPixel(p);
178                return minDist;
179        }
180
181        private int[][] makeColorMap() {
182                List<int[]> colList = new LinkedList<>();
183                for (Cluster c : clusters) {
184                        if (!c.isEmpty()) {
185                                colList.add(c.getCenterColor());
186                        }
187                }               
188                return colList.toArray(new int[0][]);
189        }
190        
191        /**
192         * Lists the color clusters to System.out (intended for debugging only).
193         */
194        public void listClusters() {
195                for (Cluster c : clusters) {
196                        System.out.println(c.toString());
197                }
198        }
199        
200        /**
201         * Returns the total error of this clustering, calculated as the sum of
202         * the squared distances of the color samples to the associated cluster
203         * center. This calculation is performed during the final iteration.
204         * 
205         * @return The sum of the squared distances between samples and cluster centers.
206         */
207        public double getTotalError() {
208                return totalError;
209        }
210
211        // ------- methods required by abstract super class -----------------------
212        
213        @Override
214        public int[][] getColorMap() {
215                return colormap;
216        }
217        
218        // ------------------------------------------------------------------------
219        
220        private class Cluster {
221                int sRed, sGrn, sBlu;           // RGB sum of contained pixels
222                int pCounter;                           // pixel counter, used during pixel assignment
223                int population = 0;                     // number of contained pixels
224                double cRed, cGrn, cBlu;        // center of this cluster
225
226                Cluster(int p) {
227                        int[] rgb = intToRgb(p);
228                        cRed = rgb[0];
229                        cGrn = rgb[1];
230                        cBlu = rgb[2];
231                        reset();
232                }
233
234                public int[] getCenterColor() {
235                        int[] rgb = new int[] {
236                                        (int) Math.round(cRed),
237                                        (int) Math.round(cGrn),
238                                        (int) Math.round(cBlu)
239                                        };
240
241                        return rgb;
242                }
243
244                public boolean isEmpty() {
245                        return (population == 0);
246                }
247
248                void reset() {  // Used at the start of the pixel assignment.
249                        sRed = 0;
250                        sGrn = 0;
251                        sBlu = 0;
252                        pCounter = 0;
253                }
254                
255                void addPixel(int p) {
256                        int[] rgb = intToRgb(p);
257                        sRed += rgb[0];
258                        sGrn += rgb[1];
259                        sBlu += rgb[2];
260                        pCounter = pCounter + 1;
261                }
262                
263                /**
264                 * This method is invoked after all samples have been assigned.
265                 * It updates the cluster's center and returns true if its
266                 * population changed from the previous clustering.
267                 * @return true if the population of this cluster has changed.
268                 */
269                int upDate() {
270                        if (pCounter > 0) {
271                                double scale = 1.0 / pCounter;
272                                cRed = sRed * scale;
273                                cGrn = sGrn * scale;
274                                cBlu = sBlu * scale;
275                        }
276                        int changed = Math.abs(pCounter - population);  // change in cluster population
277                        population = pCounter;
278                        reset();
279                        return changed; 
280                }
281                
282                /**
283                 * Calculates and returns the squared Euclidean distance between the color p
284                 * and this cluster's center in RGB space.
285                 * @param p Color sample
286                 * @return Squared distance to the cluster center
287                 */
288                double getDistance(int p) {
289                        int[] rgb = intToRgb(p);
290                        final double dR = rgb[0] - cRed;
291                        final double dG = rgb[1] - cGrn;
292                        final double dB = rgb[2] - cBlu;
293                        return dR * dR + dG * dG + dB * dB;
294                }
295                
296                @Override
297                public String toString() {
298                        return String.format(Locale.US, Cluster.class.getSimpleName() +
299                                        ": center=(%.1f,%.1f,%.1f), population=%d", cRed, cGrn, cBlu, population);
300                }
301        }
302        
303        
304} 
305