From bf6daead3f24f41c66c13e2bace034069a2b1304 Mon Sep 17 00:00:00 2001 From: Victor Westerlund Date: Mon, 12 Apr 2021 07:21:55 +0200 Subject: [PATCH] Added DominantColor algorithm with ColorThief Added percentage calculator for collage pasting --- classes/Collage.py | 38 ++-- classes/Color.py | 49 ++++- classes/Samples.py | 25 ++- classes/lib/__init__.py | 0 classes/lib/colorthief.py | 422 ++++++++++++++++++++++++++++++++++++++ create.py | 17 +- 6 files changed, 515 insertions(+), 36 deletions(-) create mode 100644 classes/lib/__init__.py create mode 100644 classes/lib/colorthief.py diff --git a/classes/Collage.py b/classes/Collage.py index a208a44..d117ef3 100644 --- a/classes/Collage.py +++ b/classes/Collage.py @@ -1,4 +1,4 @@ -from PIL import Image +from PIL import Image, ImageOps from bisect import bisect_left # Create instructions for the Collage constructor @@ -6,6 +6,7 @@ class Schematic(): def __init__(self,template,samples): self.template = template self.samples = samples + self.length = 0 self.schematic = {} self.create_schematic() @@ -21,14 +22,16 @@ class Schematic(): return samples[pos] + # Build a 2-Dimensional list of matches def create_schematic(self): - for x in range(1,self.template.size[0]): - self.schematic[x] = {} - for y in range(1,self.template.size[1]): - r,g,b = self.template.getpixel((x,y)) - eyedropper = "%02x%02x%02x" % (r,g,b) + for y in range(1,self.template.size[0]): + self.schematic[y] = {} + for x in range(1,self.template.size[1]): + r,g,b = self.template.getpixel((x,y)) # Extract RGB from current pixel + eyedropper = "%02x%02x%02x" % (r,g,b) # Convert RGB to HEX - self.schematic[x][y] = self.query_sample(eyedropper) + self.schematic[y][x] = self.query_sample(eyedropper) + self.length += 1 print(f"Found best match for index [{x},{y}] ",end="\r",flush="True") print("") @@ -53,16 +56,18 @@ class Collage(): # Assemble the collage def create_collage(self): - schematic = Schematic(self.template,self.samples).schematic + build = Schematic(self.template,self.samples) offset_x = 0 offset_y = 0 + i = 0 + print("Pasing samples..") # Apply each sample by raster scanning - for x in range(1,self.template.size[0]): + for y in range(1,self.template.size[0]): offset_x = 0 - for y in range(1,self.template.size[1]): - key = schematic[x][y] # Get sample index for current pixel + for x in range(1,self.template.size[1]): + key = build.schematic[y][x] # Get sample index for current pixel resolve_posix = self.samples[key] # Convert sample index to sample set index # Load and resize the requested sample from disk @@ -71,13 +76,20 @@ class Collage(): # Add the loaded sample to the collage self.collage = self.collage.copy() - self.collage.paste(sample,(offset_x,offset_y)) + self.collage.paste(sample,(offset_y,offset_x)) offset_x += self.size[0] - print(f"Pasted sample at index [{x},{y}] ",end="\r",flush="True") + progress = round(i / build.length * 100,2) + print(f"Progress: (%) {progress} ",end="\r",flush="True") + i += 1 offset_y += self.size[1] print("") + print("Collage created") + + # Correct rotation and reflection + self.collage = self.collage.rotate(-90) + self.collage = ImageOps.mirror(self.collage) # Save collage to disk def put(self,dest): diff --git a/classes/Color.py b/classes/Color.py index 8c6736f..f1fa5a9 100644 --- a/classes/Color.py +++ b/classes/Color.py @@ -1,11 +1,26 @@ -from PIL import Image +from PIL import Image, ImageFilter +from .lib.colorthief import ColorThief -# Calculate the average color of a sample -class AverageColor(): +# Image loader +class PrepImage(): def __init__(self,image): self.image = Image.open(image) self.image = self.image.resize((50,50)) # Downscale image to improve performance + # Format RGB output as HEX (without #) + def hex(self): + return "%02x%02x%02x" % self.rgb() + +# Calculate the average color of a sample +class AverageColor(PrepImage): + def __init__(self,image): + super(AverageColor,self).__init__(image) + + # Normalize colors with a blur + def blur(self): + blur = ImageFilter.GaussianBlur(10) + self.image = self.image.filter(blur) + def rgb(self): width,height = self.image.size rgb = [0,0,0] @@ -26,6 +41,28 @@ class AverageColor(): return tuple(map(average,rgb)) - # Format RGB output as HEX (without #) - def hex(self): - return "%02x%02x%02x" % self.rgb() \ No newline at end of file +class DominantColor(ColorThief,PrepImage): + def __init__(self,image): + super(DominantColor,self).__init__(image) + + def rgb(self): + return self.get_color(quality=1) + +# class DominantColor(PrepImage): +# def __init__(self,image): +# super(DominantColor,self).__init__(image) +# self.palette_size = 16 + +# def get_palette(self): +# palettised = self.image.convert("P",palette=Image.ADAPTIVE,colors=self.palette_size) + +# palette = palettised.getpalette() +# colors = sorted(palettised.getcolors(),reverse=True) +# index = colors[0][1] +# dominant_color = palette[index * 3:index * + 3] + +# return dominant_color + +# def rgb(self): +# colors = self.get_palette() +# return tuple(colors) \ No newline at end of file diff --git a/classes/Samples.py b/classes/Samples.py index 0f3715a..15a0fd2 100644 --- a/classes/Samples.py +++ b/classes/Samples.py @@ -1,7 +1,7 @@ import zlib import json from collections import OrderedDict -from .Color import AverageColor +from .Color import AverageColor, DominantColor from pathlib import Path # Generate a unique identifier for the current sample set @@ -42,14 +42,9 @@ class Samples(SamplesFingerprint): self.samples = {} # HEX from color calc algorithm super(Samples,self).__init__() + self.run(force) - try: - self.load_sample_set() - except: - self.map_color() - self.save_sample_set() - - # Get the pixel value for each sample using a desired algorithm + # Get the pixel value for each sample using a desired color extraction algorithm def map_color(self): for i,sample in enumerate(self.samples_posix): color = AverageColor(sample).hex() # Get the average color of a sample as HEX @@ -70,4 +65,16 @@ class Samples(SamplesFingerprint): def load_sample_set(self): with open(self.memory) as f: self.samples = json.load(f) - print(f"Loaded {len(self.samples)} samples from set {self.hash}") \ No newline at end of file + print(f"Loaded {len(self.samples)} samples from set {self.hash}") + + def run(self,force): + if(force): + self.map_color() + self.save_sample_set() + + # Attempt to load sample set from memory + try: + self.load_sample_set() + except: + self.map_color() + self.save_sample_set() \ No newline at end of file diff --git a/classes/lib/__init__.py b/classes/lib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/classes/lib/colorthief.py b/classes/lib/colorthief.py new file mode 100644 index 0000000..84eece7 --- /dev/null +++ b/classes/lib/colorthief.py @@ -0,0 +1,422 @@ +# -*- coding: utf-8 -*- +""" + colorthief + ~~~~~~~~~~ + + Grabbing the color palette from an image. + + :copyright: (c) 2015 by Shipeng Feng. + :license: BSD, see LICENSE for more details. +""" +__version__ = '0.2.1' + +import math + +from PIL import Image + + +class cached_property(object): + """Decorator that creates converts a method with a single + self argument into a property cached on the instance. + """ + def __init__(self, func): + self.func = func + + def __get__(self, instance, type): + res = instance.__dict__[self.func.__name__] = self.func(instance) + return res + + +class ColorThief(object): + """Color thief main class.""" + def __init__(self, file): + """Create one color thief for one image. + + :param file: A filename (string) or a file object. The file object + must implement `read()`, `seek()`, and `tell()` methods, + and be opened in binary mode. + """ + self.image = Image.open(file) + + def get_color(self, quality=10): + """Get the dominant color. + + :param quality: quality settings, 1 is the highest quality, the bigger + the number, the faster a color will be returned but + the greater the likelihood that it will not be the + visually most dominant color + :return tuple: (r, g, b) + """ + palette = self.get_palette(5, quality) + return palette[0] + + def get_palette(self, color_count=10, quality=10): + """Build a color palette. We are using the median cut algorithm to + cluster similar colors. + + :param color_count: the size of the palette, max number of colors + :param quality: quality settings, 1 is the highest quality, the bigger + the number, the faster the palette generation, but the + greater the likelihood that colors will be missed. + :return list: a list of tuple in the form (r, g, b) + """ + image = self.image.convert('RGBA') + width, height = image.size + pixels = image.getdata() + pixel_count = width * height + valid_pixels = [] + for i in range(0, pixel_count, quality): + r, g, b, a = pixels[i] + # If pixel is mostly opaque and not white + if a >= 125: + if not (r > 250 and g > 250 and b > 250): + valid_pixels.append((r, g, b)) + + # Send array to quantize function which clusters values + # using median cut algorithm + cmap = MMCQ.quantize(valid_pixels, color_count) + return cmap.palette + + +class MMCQ(object): + """Basic Python port of the MMCQ (modified median cut quantization) + algorithm from the Leptonica library (http://www.leptonica.com/). + """ + + SIGBITS = 5 + RSHIFT = 8 - SIGBITS + MAX_ITERATION = 1000 + FRACT_BY_POPULATIONS = 0.75 + + @staticmethod + def get_color_index(r, g, b): + return (r << (2 * MMCQ.SIGBITS)) + (g << MMCQ.SIGBITS) + b + + @staticmethod + def get_histo(pixels): + """histo (1-d array, giving the number of pixels in each quantized + region of color space) + """ + histo = dict() + for pixel in pixels: + rval = pixel[0] >> MMCQ.RSHIFT + gval = pixel[1] >> MMCQ.RSHIFT + bval = pixel[2] >> MMCQ.RSHIFT + index = MMCQ.get_color_index(rval, gval, bval) + histo[index] = histo.setdefault(index, 0) + 1 + return histo + + @staticmethod + def vbox_from_pixels(pixels, histo): + rmin = 1000000 + rmax = 0 + gmin = 1000000 + gmax = 0 + bmin = 1000000 + bmax = 0 + for pixel in pixels: + rval = pixel[0] >> MMCQ.RSHIFT + gval = pixel[1] >> MMCQ.RSHIFT + bval = pixel[2] >> MMCQ.RSHIFT + rmin = min(rval, rmin) + rmax = max(rval, rmax) + gmin = min(gval, gmin) + gmax = max(gval, gmax) + bmin = min(bval, bmin) + bmax = max(bval, bmax) + return VBox(rmin, rmax, gmin, gmax, bmin, bmax, histo) + + @staticmethod + def median_cut_apply(histo, vbox): + if not vbox.count: + return (None, None) + + rw = vbox.r2 - vbox.r1 + 1 + gw = vbox.g2 - vbox.g1 + 1 + bw = vbox.b2 - vbox.b1 + 1 + maxw = max([rw, gw, bw]) + # only one pixel, no split + if vbox.count == 1: + return (vbox.copy, None) + # Find the partial sum arrays along the selected axis. + total = 0 + sum_ = 0 + partialsum = {} + lookaheadsum = {} + do_cut_color = None + if maxw == rw: + do_cut_color = 'r' + for i in range(vbox.r1, vbox.r2+1): + sum_ = 0 + for j in range(vbox.g1, vbox.g2+1): + for k in range(vbox.b1, vbox.b2+1): + index = MMCQ.get_color_index(i, j, k) + sum_ += histo.get(index, 0) + total += sum_ + partialsum[i] = total + elif maxw == gw: + do_cut_color = 'g' + for i in range(vbox.g1, vbox.g2+1): + sum_ = 0 + for j in range(vbox.r1, vbox.r2+1): + for k in range(vbox.b1, vbox.b2+1): + index = MMCQ.get_color_index(j, i, k) + sum_ += histo.get(index, 0) + total += sum_ + partialsum[i] = total + else: # maxw == bw + do_cut_color = 'b' + for i in range(vbox.b1, vbox.b2+1): + sum_ = 0 + for j in range(vbox.r1, vbox.r2+1): + for k in range(vbox.g1, vbox.g2+1): + index = MMCQ.get_color_index(j, k, i) + sum_ += histo.get(index, 0) + total += sum_ + partialsum[i] = total + for i, d in partialsum.items(): + lookaheadsum[i] = total - d + + # determine the cut planes + dim1 = do_cut_color + '1' + dim2 = do_cut_color + '2' + dim1_val = getattr(vbox, dim1) + dim2_val = getattr(vbox, dim2) + for i in range(dim1_val, dim2_val+1): + if partialsum[i] > (total / 2): + vbox1 = vbox.copy + vbox2 = vbox.copy + left = i - dim1_val + right = dim2_val - i + if left <= right: + d2 = min([dim2_val - 1, int(i + right / 2)]) + else: + d2 = max([dim1_val, int(i - 1 - left / 2)]) + # avoid 0-count boxes + while not partialsum.get(d2, False): + d2 += 1 + count2 = lookaheadsum.get(d2) + while not count2 and partialsum.get(d2-1, False): + d2 -= 1 + count2 = lookaheadsum.get(d2) + # set dimensions + setattr(vbox1, dim2, d2) + setattr(vbox2, dim1, getattr(vbox1, dim2) + 1) + return (vbox1, vbox2) + return (None, None) + + @staticmethod + def quantize(pixels, max_color): + """Quantize. + + :param pixels: a list of pixel in the form (r, g, b) + :param max_color: max number of colors + """ + if not pixels: + raise Exception('Empty pixels when quantize.') + if max_color < 2 or max_color > 256: + raise Exception('Wrong number of max colors when quantize.') + + histo = MMCQ.get_histo(pixels) + + # check that we aren't below maxcolors already + if len(histo) <= max_color: + # generate the new colors from the histo and return + pass + + # get the beginning vbox from the colors + vbox = MMCQ.vbox_from_pixels(pixels, histo) + pq = PQueue(lambda x: x.count) + pq.push(vbox) + + # inner function to do the iteration + def iter_(lh, target): + n_color = 1 + n_iter = 0 + while n_iter < MMCQ.MAX_ITERATION: + vbox = lh.pop() + if not vbox.count: # just put it back + lh.push(vbox) + n_iter += 1 + continue + # do the cut + vbox1, vbox2 = MMCQ.median_cut_apply(histo, vbox) + if not vbox1: + raise Exception("vbox1 not defined; shouldn't happen!") + lh.push(vbox1) + if vbox2: # vbox2 can be null + lh.push(vbox2) + n_color += 1 + if n_color >= target: + return + if n_iter > MMCQ.MAX_ITERATION: + return + n_iter += 1 + + # first set of colors, sorted by population + iter_(pq, MMCQ.FRACT_BY_POPULATIONS * max_color) + + # Re-sort by the product of pixel occupancy times the size in + # color space. + pq2 = PQueue(lambda x: x.count * x.volume) + while pq.size(): + pq2.push(pq.pop()) + + # next set - generate the median cuts using the (npix * vol) sorting. + iter_(pq2, max_color - pq2.size()) + + # calculate the actual colors + cmap = CMap() + while pq2.size(): + cmap.push(pq2.pop()) + return cmap + + +class VBox(object): + """3d color space box""" + def __init__(self, r1, r2, g1, g2, b1, b2, histo): + self.r1 = r1 + self.r2 = r2 + self.g1 = g1 + self.g2 = g2 + self.b1 = b1 + self.b2 = b2 + self.histo = histo + + @cached_property + def volume(self): + sub_r = self.r2 - self.r1 + sub_g = self.g2 - self.g1 + sub_b = self.b2 - self.b1 + return (sub_r + 1) * (sub_g + 1) * (sub_b + 1) + + @property + def copy(self): + return VBox(self.r1, self.r2, self.g1, self.g2, + self.b1, self.b2, self.histo) + + @cached_property + def avg(self): + ntot = 0 + mult = 1 << (8 - MMCQ.SIGBITS) + r_sum = 0 + g_sum = 0 + b_sum = 0 + for i in range(self.r1, self.r2 + 1): + for j in range(self.g1, self.g2 + 1): + for k in range(self.b1, self.b2 + 1): + histoindex = MMCQ.get_color_index(i, j, k) + hval = self.histo.get(histoindex, 0) + ntot += hval + r_sum += hval * (i + 0.5) * mult + g_sum += hval * (j + 0.5) * mult + b_sum += hval * (k + 0.5) * mult + + if ntot: + r_avg = int(r_sum / ntot) + g_avg = int(g_sum / ntot) + b_avg = int(b_sum / ntot) + else: + r_avg = int(mult * (self.r1 + self.r2 + 1) / 2) + g_avg = int(mult * (self.g1 + self.g2 + 1) / 2) + b_avg = int(mult * (self.b1 + self.b2 + 1) / 2) + + return r_avg, g_avg, b_avg + + def contains(self, pixel): + rval = pixel[0] >> MMCQ.RSHIFT + gval = pixel[1] >> MMCQ.RSHIFT + bval = pixel[2] >> MMCQ.RSHIFT + return all([ + rval >= self.r1, + rval <= self.r2, + gval >= self.g1, + gval <= self.g2, + bval >= self.b1, + bval <= self.b2, + ]) + + @cached_property + def count(self): + npix = 0 + for i in range(self.r1, self.r2 + 1): + for j in range(self.g1, self.g2 + 1): + for k in range(self.b1, self.b2 + 1): + index = MMCQ.get_color_index(i, j, k) + npix += self.histo.get(index, 0) + return npix + + +class CMap(object): + """Color map""" + def __init__(self): + self.vboxes = PQueue(lambda x: x['vbox'].count * x['vbox'].volume) + + @property + def palette(self): + return self.vboxes.map(lambda x: x['color']) + + def push(self, vbox): + self.vboxes.push({ + 'vbox': vbox, + 'color': vbox.avg, + }) + + def size(self): + return self.vboxes.size() + + def nearest(self, color): + d1 = None + p_color = None + for i in range(self.vboxes.size()): + vbox = self.vboxes.peek(i) + d2 = math.sqrt( + math.pow(color[0] - vbox['color'][0], 2) + + math.pow(color[1] - vbox['color'][1], 2) + + math.pow(color[2] - vbox['color'][2], 2) + ) + if d1 is None or d2 < d1: + d1 = d2 + p_color = vbox['color'] + return p_color + + def map(self, color): + for i in range(self.vboxes.size()): + vbox = self.vboxes.peek(i) + if vbox['vbox'].contains(color): + return vbox['color'] + return self.nearest(color) + + +class PQueue(object): + """Simple priority queue.""" + def __init__(self, sort_key): + self.sort_key = sort_key + self.contents = [] + self._sorted = False + + def sort(self): + self.contents.sort(key=self.sort_key) + self._sorted = True + + def push(self, o): + self.contents.append(o) + self._sorted = False + + def peek(self, index=None): + if not self._sorted: + self.sort() + if index is None: + index = len(self.contents) - 1 + return self.contents[index] + + def pop(self): + if not self._sorted: + self.sort() + return self.contents.pop() + + def size(self): + return len(self.contents) + + def map(self, f): + return list(map(f, self.contents)) \ No newline at end of file diff --git a/create.py b/create.py index 78a9d11..12b40a0 100644 --- a/create.py +++ b/create.py @@ -1,21 +1,22 @@ import sys from classes.Samples import Samples from classes.Collage import Collage -from pathlib import Path +# sys.argv[1] = Input image +# sys.argv[2] = Output image + +sample_scale = (20,20) force = False # Will generate a new sample set every time when true # Prompt IO declaration if no CLI arguments provided if(len(sys.argv) < 2): - input_file = input("Input image (.jpg):\n") - output_file = input("Output image (.jpg):\n") -else: - input_file = sys.argv[1] - output_file = sys.argv[2] + sys.argv.insert(1,input("Input image (.jpg):\n")) + sys.argv.insert(1,input("Output image (.jpg):\n")) # Load all images from the "samples/" folder samples = Samples("samples",force) # Create a collage from the loaded samples -collage = Collage(input_file,samples) -collage.put(output_file) \ No newline at end of file +collage = Collage(sys.argv[1],samples) +collage.size = sample_scale +collage.put(sys.argv[2]) \ No newline at end of file