Added DominantColor algorithm with ColorThief

Added percentage calculator for collage pasting
This commit is contained in:
Victor Westerlund 2021-04-12 07:21:55 +02:00
parent 2cd99417f1
commit bf6daead3f
6 changed files with 515 additions and 36 deletions

View file

@ -1,4 +1,4 @@
from PIL import Image from PIL import Image, ImageOps
from bisect import bisect_left from bisect import bisect_left
# Create instructions for the Collage constructor # Create instructions for the Collage constructor
@ -6,6 +6,7 @@ class Schematic():
def __init__(self,template,samples): def __init__(self,template,samples):
self.template = template self.template = template
self.samples = samples self.samples = samples
self.length = 0
self.schematic = {} self.schematic = {}
self.create_schematic() self.create_schematic()
@ -21,14 +22,16 @@ class Schematic():
return samples[pos] return samples[pos]
# Build a 2-Dimensional list of matches
def create_schematic(self): def create_schematic(self):
for x in range(1,self.template.size[0]): for y in range(1,self.template.size[0]):
self.schematic[x] = {} self.schematic[y] = {}
for y in range(1,self.template.size[1]): for x in range(1,self.template.size[1]):
r,g,b = self.template.getpixel((x,y)) r,g,b = self.template.getpixel((x,y)) # Extract RGB from current pixel
eyedropper = "%02x%02x%02x" % (r,g,b) eyedropper = "%02x%02x%02x" % (r,g,b) # Convert RGB to HEX
self.schematic[x][y] = self.query_sample(eyedropper) self.schematic[y][x] = self.query_sample(eyedropper)
self.length += 1
print(f"Found best match for index [{x},{y}] ",end="\r",flush="True") print(f"Found best match for index [{x},{y}] ",end="\r",flush="True")
print("") print("")
@ -53,16 +56,18 @@ class Collage():
# Assemble the collage # Assemble the collage
def create_collage(self): def create_collage(self):
schematic = Schematic(self.template,self.samples).schematic build = Schematic(self.template,self.samples)
offset_x = 0 offset_x = 0
offset_y = 0 offset_y = 0
i = 0
print("Pasing samples..")
# Apply each sample by raster scanning # Apply each sample by raster scanning
for x in range(1,self.template.size[0]): for y in range(1,self.template.size[0]):
offset_x = 0 offset_x = 0
for y in range(1,self.template.size[1]): for x in range(1,self.template.size[1]):
key = schematic[x][y] # Get sample index for current pixel key = build.schematic[y][x] # Get sample index for current pixel
resolve_posix = self.samples[key] # Convert sample index to sample set index resolve_posix = self.samples[key] # Convert sample index to sample set index
# Load and resize the requested sample from disk # Load and resize the requested sample from disk
@ -71,13 +76,20 @@ class Collage():
# Add the loaded sample to the collage # Add the loaded sample to the collage
self.collage = self.collage.copy() self.collage = self.collage.copy()
self.collage.paste(sample,(offset_x,offset_y)) self.collage.paste(sample,(offset_y,offset_x))
offset_x += self.size[0] offset_x += self.size[0]
print(f"Pasted sample at index [{x},{y}] ",end="\r",flush="True") progress = round(i / build.length * 100,2)
print(f"Progress: (%) {progress} ",end="\r",flush="True")
i += 1
offset_y += self.size[1] offset_y += self.size[1]
print("") print("")
print("Collage created")
# Correct rotation and reflection
self.collage = self.collage.rotate(-90)
self.collage = ImageOps.mirror(self.collage)
# Save collage to disk # Save collage to disk
def put(self,dest): def put(self,dest):

View file

@ -1,11 +1,26 @@
from PIL import Image from PIL import Image, ImageFilter
from .lib.colorthief import ColorThief
# Calculate the average color of a sample # Image loader
class AverageColor(): class PrepImage():
def __init__(self,image): def __init__(self,image):
self.image = Image.open(image) self.image = Image.open(image)
self.image = self.image.resize((50,50)) # Downscale image to improve performance self.image = self.image.resize((50,50)) # Downscale image to improve performance
# Format RGB output as HEX (without #)
def hex(self):
return "%02x%02x%02x" % self.rgb()
# Calculate the average color of a sample
class AverageColor(PrepImage):
def __init__(self,image):
super(AverageColor,self).__init__(image)
# Normalize colors with a blur
def blur(self):
blur = ImageFilter.GaussianBlur(10)
self.image = self.image.filter(blur)
def rgb(self): def rgb(self):
width,height = self.image.size width,height = self.image.size
rgb = [0,0,0] rgb = [0,0,0]
@ -26,6 +41,28 @@ class AverageColor():
return tuple(map(average,rgb)) return tuple(map(average,rgb))
# Format RGB output as HEX (without #) class DominantColor(ColorThief,PrepImage):
def hex(self): def __init__(self,image):
return "%02x%02x%02x" % self.rgb() super(DominantColor,self).__init__(image)
def rgb(self):
return self.get_color(quality=1)
# class DominantColor(PrepImage):
# def __init__(self,image):
# super(DominantColor,self).__init__(image)
# self.palette_size = 16
# def get_palette(self):
# palettised = self.image.convert("P",palette=Image.ADAPTIVE,colors=self.palette_size)
# palette = palettised.getpalette()
# colors = sorted(palettised.getcolors(),reverse=True)
# index = colors[0][1]
# dominant_color = palette[index * 3:index * + 3]
# return dominant_color
# def rgb(self):
# colors = self.get_palette()
# return tuple(colors)

View file

@ -1,7 +1,7 @@
import zlib import zlib
import json import json
from collections import OrderedDict from collections import OrderedDict
from .Color import AverageColor from .Color import AverageColor, DominantColor
from pathlib import Path from pathlib import Path
# Generate a unique identifier for the current sample set # Generate a unique identifier for the current sample set
@ -42,14 +42,9 @@ class Samples(SamplesFingerprint):
self.samples = {} # HEX from color calc algorithm self.samples = {} # HEX from color calc algorithm
super(Samples,self).__init__() super(Samples,self).__init__()
self.run(force)
try: # Get the pixel value for each sample using a desired color extraction algorithm
self.load_sample_set()
except:
self.map_color()
self.save_sample_set()
# Get the pixel value for each sample using a desired algorithm
def map_color(self): def map_color(self):
for i,sample in enumerate(self.samples_posix): for i,sample in enumerate(self.samples_posix):
color = AverageColor(sample).hex() # Get the average color of a sample as HEX color = AverageColor(sample).hex() # Get the average color of a sample as HEX
@ -70,4 +65,16 @@ class Samples(SamplesFingerprint):
def load_sample_set(self): def load_sample_set(self):
with open(self.memory) as f: with open(self.memory) as f:
self.samples = json.load(f) self.samples = json.load(f)
print(f"Loaded {len(self.samples)} samples from set {self.hash}") print(f"Loaded {len(self.samples)} samples from set {self.hash}")
def run(self,force):
if(force):
self.map_color()
self.save_sample_set()
# Attempt to load sample set from memory
try:
self.load_sample_set()
except:
self.map_color()
self.save_sample_set()

0
classes/lib/__init__.py Normal file
View file

422
classes/lib/colorthief.py Normal file
View file

@ -0,0 +1,422 @@
# -*- coding: utf-8 -*-
"""
colorthief
~~~~~~~~~~
Grabbing the color palette from an image.
:copyright: (c) 2015 by Shipeng Feng.
:license: BSD, see LICENSE for more details.
"""
__version__ = '0.2.1'
import math
from PIL import Image
class cached_property(object):
"""Decorator that creates converts a method with a single
self argument into a property cached on the instance.
"""
def __init__(self, func):
self.func = func
def __get__(self, instance, type):
res = instance.__dict__[self.func.__name__] = self.func(instance)
return res
class ColorThief(object):
"""Color thief main class."""
def __init__(self, file):
"""Create one color thief for one image.
:param file: A filename (string) or a file object. The file object
must implement `read()`, `seek()`, and `tell()` methods,
and be opened in binary mode.
"""
self.image = Image.open(file)
def get_color(self, quality=10):
"""Get the dominant color.
:param quality: quality settings, 1 is the highest quality, the bigger
the number, the faster a color will be returned but
the greater the likelihood that it will not be the
visually most dominant color
:return tuple: (r, g, b)
"""
palette = self.get_palette(5, quality)
return palette[0]
def get_palette(self, color_count=10, quality=10):
"""Build a color palette. We are using the median cut algorithm to
cluster similar colors.
:param color_count: the size of the palette, max number of colors
:param quality: quality settings, 1 is the highest quality, the bigger
the number, the faster the palette generation, but the
greater the likelihood that colors will be missed.
:return list: a list of tuple in the form (r, g, b)
"""
image = self.image.convert('RGBA')
width, height = image.size
pixels = image.getdata()
pixel_count = width * height
valid_pixels = []
for i in range(0, pixel_count, quality):
r, g, b, a = pixels[i]
# If pixel is mostly opaque and not white
if a >= 125:
if not (r > 250 and g > 250 and b > 250):
valid_pixels.append((r, g, b))
# Send array to quantize function which clusters values
# using median cut algorithm
cmap = MMCQ.quantize(valid_pixels, color_count)
return cmap.palette
class MMCQ(object):
"""Basic Python port of the MMCQ (modified median cut quantization)
algorithm from the Leptonica library (http://www.leptonica.com/).
"""
SIGBITS = 5
RSHIFT = 8 - SIGBITS
MAX_ITERATION = 1000
FRACT_BY_POPULATIONS = 0.75
@staticmethod
def get_color_index(r, g, b):
return (r << (2 * MMCQ.SIGBITS)) + (g << MMCQ.SIGBITS) + b
@staticmethod
def get_histo(pixels):
"""histo (1-d array, giving the number of pixels in each quantized
region of color space)
"""
histo = dict()
for pixel in pixels:
rval = pixel[0] >> MMCQ.RSHIFT
gval = pixel[1] >> MMCQ.RSHIFT
bval = pixel[2] >> MMCQ.RSHIFT
index = MMCQ.get_color_index(rval, gval, bval)
histo[index] = histo.setdefault(index, 0) + 1
return histo
@staticmethod
def vbox_from_pixels(pixels, histo):
rmin = 1000000
rmax = 0
gmin = 1000000
gmax = 0
bmin = 1000000
bmax = 0
for pixel in pixels:
rval = pixel[0] >> MMCQ.RSHIFT
gval = pixel[1] >> MMCQ.RSHIFT
bval = pixel[2] >> MMCQ.RSHIFT
rmin = min(rval, rmin)
rmax = max(rval, rmax)
gmin = min(gval, gmin)
gmax = max(gval, gmax)
bmin = min(bval, bmin)
bmax = max(bval, bmax)
return VBox(rmin, rmax, gmin, gmax, bmin, bmax, histo)
@staticmethod
def median_cut_apply(histo, vbox):
if not vbox.count:
return (None, None)
rw = vbox.r2 - vbox.r1 + 1
gw = vbox.g2 - vbox.g1 + 1
bw = vbox.b2 - vbox.b1 + 1
maxw = max([rw, gw, bw])
# only one pixel, no split
if vbox.count == 1:
return (vbox.copy, None)
# Find the partial sum arrays along the selected axis.
total = 0
sum_ = 0
partialsum = {}
lookaheadsum = {}
do_cut_color = None
if maxw == rw:
do_cut_color = 'r'
for i in range(vbox.r1, vbox.r2+1):
sum_ = 0
for j in range(vbox.g1, vbox.g2+1):
for k in range(vbox.b1, vbox.b2+1):
index = MMCQ.get_color_index(i, j, k)
sum_ += histo.get(index, 0)
total += sum_
partialsum[i] = total
elif maxw == gw:
do_cut_color = 'g'
for i in range(vbox.g1, vbox.g2+1):
sum_ = 0
for j in range(vbox.r1, vbox.r2+1):
for k in range(vbox.b1, vbox.b2+1):
index = MMCQ.get_color_index(j, i, k)
sum_ += histo.get(index, 0)
total += sum_
partialsum[i] = total
else: # maxw == bw
do_cut_color = 'b'
for i in range(vbox.b1, vbox.b2+1):
sum_ = 0
for j in range(vbox.r1, vbox.r2+1):
for k in range(vbox.g1, vbox.g2+1):
index = MMCQ.get_color_index(j, k, i)
sum_ += histo.get(index, 0)
total += sum_
partialsum[i] = total
for i, d in partialsum.items():
lookaheadsum[i] = total - d
# determine the cut planes
dim1 = do_cut_color + '1'
dim2 = do_cut_color + '2'
dim1_val = getattr(vbox, dim1)
dim2_val = getattr(vbox, dim2)
for i in range(dim1_val, dim2_val+1):
if partialsum[i] > (total / 2):
vbox1 = vbox.copy
vbox2 = vbox.copy
left = i - dim1_val
right = dim2_val - i
if left <= right:
d2 = min([dim2_val - 1, int(i + right / 2)])
else:
d2 = max([dim1_val, int(i - 1 - left / 2)])
# avoid 0-count boxes
while not partialsum.get(d2, False):
d2 += 1
count2 = lookaheadsum.get(d2)
while not count2 and partialsum.get(d2-1, False):
d2 -= 1
count2 = lookaheadsum.get(d2)
# set dimensions
setattr(vbox1, dim2, d2)
setattr(vbox2, dim1, getattr(vbox1, dim2) + 1)
return (vbox1, vbox2)
return (None, None)
@staticmethod
def quantize(pixels, max_color):
"""Quantize.
:param pixels: a list of pixel in the form (r, g, b)
:param max_color: max number of colors
"""
if not pixels:
raise Exception('Empty pixels when quantize.')
if max_color < 2 or max_color > 256:
raise Exception('Wrong number of max colors when quantize.')
histo = MMCQ.get_histo(pixels)
# check that we aren't below maxcolors already
if len(histo) <= max_color:
# generate the new colors from the histo and return
pass
# get the beginning vbox from the colors
vbox = MMCQ.vbox_from_pixels(pixels, histo)
pq = PQueue(lambda x: x.count)
pq.push(vbox)
# inner function to do the iteration
def iter_(lh, target):
n_color = 1
n_iter = 0
while n_iter < MMCQ.MAX_ITERATION:
vbox = lh.pop()
if not vbox.count: # just put it back
lh.push(vbox)
n_iter += 1
continue
# do the cut
vbox1, vbox2 = MMCQ.median_cut_apply(histo, vbox)
if not vbox1:
raise Exception("vbox1 not defined; shouldn't happen!")
lh.push(vbox1)
if vbox2: # vbox2 can be null
lh.push(vbox2)
n_color += 1
if n_color >= target:
return
if n_iter > MMCQ.MAX_ITERATION:
return
n_iter += 1
# first set of colors, sorted by population
iter_(pq, MMCQ.FRACT_BY_POPULATIONS * max_color)
# Re-sort by the product of pixel occupancy times the size in
# color space.
pq2 = PQueue(lambda x: x.count * x.volume)
while pq.size():
pq2.push(pq.pop())
# next set - generate the median cuts using the (npix * vol) sorting.
iter_(pq2, max_color - pq2.size())
# calculate the actual colors
cmap = CMap()
while pq2.size():
cmap.push(pq2.pop())
return cmap
class VBox(object):
"""3d color space box"""
def __init__(self, r1, r2, g1, g2, b1, b2, histo):
self.r1 = r1
self.r2 = r2
self.g1 = g1
self.g2 = g2
self.b1 = b1
self.b2 = b2
self.histo = histo
@cached_property
def volume(self):
sub_r = self.r2 - self.r1
sub_g = self.g2 - self.g1
sub_b = self.b2 - self.b1
return (sub_r + 1) * (sub_g + 1) * (sub_b + 1)
@property
def copy(self):
return VBox(self.r1, self.r2, self.g1, self.g2,
self.b1, self.b2, self.histo)
@cached_property
def avg(self):
ntot = 0
mult = 1 << (8 - MMCQ.SIGBITS)
r_sum = 0
g_sum = 0
b_sum = 0
for i in range(self.r1, self.r2 + 1):
for j in range(self.g1, self.g2 + 1):
for k in range(self.b1, self.b2 + 1):
histoindex = MMCQ.get_color_index(i, j, k)
hval = self.histo.get(histoindex, 0)
ntot += hval
r_sum += hval * (i + 0.5) * mult
g_sum += hval * (j + 0.5) * mult
b_sum += hval * (k + 0.5) * mult
if ntot:
r_avg = int(r_sum / ntot)
g_avg = int(g_sum / ntot)
b_avg = int(b_sum / ntot)
else:
r_avg = int(mult * (self.r1 + self.r2 + 1) / 2)
g_avg = int(mult * (self.g1 + self.g2 + 1) / 2)
b_avg = int(mult * (self.b1 + self.b2 + 1) / 2)
return r_avg, g_avg, b_avg
def contains(self, pixel):
rval = pixel[0] >> MMCQ.RSHIFT
gval = pixel[1] >> MMCQ.RSHIFT
bval = pixel[2] >> MMCQ.RSHIFT
return all([
rval >= self.r1,
rval <= self.r2,
gval >= self.g1,
gval <= self.g2,
bval >= self.b1,
bval <= self.b2,
])
@cached_property
def count(self):
npix = 0
for i in range(self.r1, self.r2 + 1):
for j in range(self.g1, self.g2 + 1):
for k in range(self.b1, self.b2 + 1):
index = MMCQ.get_color_index(i, j, k)
npix += self.histo.get(index, 0)
return npix
class CMap(object):
"""Color map"""
def __init__(self):
self.vboxes = PQueue(lambda x: x['vbox'].count * x['vbox'].volume)
@property
def palette(self):
return self.vboxes.map(lambda x: x['color'])
def push(self, vbox):
self.vboxes.push({
'vbox': vbox,
'color': vbox.avg,
})
def size(self):
return self.vboxes.size()
def nearest(self, color):
d1 = None
p_color = None
for i in range(self.vboxes.size()):
vbox = self.vboxes.peek(i)
d2 = math.sqrt(
math.pow(color[0] - vbox['color'][0], 2) +
math.pow(color[1] - vbox['color'][1], 2) +
math.pow(color[2] - vbox['color'][2], 2)
)
if d1 is None or d2 < d1:
d1 = d2
p_color = vbox['color']
return p_color
def map(self, color):
for i in range(self.vboxes.size()):
vbox = self.vboxes.peek(i)
if vbox['vbox'].contains(color):
return vbox['color']
return self.nearest(color)
class PQueue(object):
"""Simple priority queue."""
def __init__(self, sort_key):
self.sort_key = sort_key
self.contents = []
self._sorted = False
def sort(self):
self.contents.sort(key=self.sort_key)
self._sorted = True
def push(self, o):
self.contents.append(o)
self._sorted = False
def peek(self, index=None):
if not self._sorted:
self.sort()
if index is None:
index = len(self.contents) - 1
return self.contents[index]
def pop(self):
if not self._sorted:
self.sort()
return self.contents.pop()
def size(self):
return len(self.contents)
def map(self, f):
return list(map(f, self.contents))

View file

@ -1,21 +1,22 @@
import sys import sys
from classes.Samples import Samples from classes.Samples import Samples
from classes.Collage import Collage from classes.Collage import Collage
from pathlib import Path
# sys.argv[1] = Input image
# sys.argv[2] = Output image
sample_scale = (20,20)
force = False # Will generate a new sample set every time when true force = False # Will generate a new sample set every time when true
# Prompt IO declaration if no CLI arguments provided # Prompt IO declaration if no CLI arguments provided
if(len(sys.argv) < 2): if(len(sys.argv) < 2):
input_file = input("Input image (.jpg):\n") sys.argv.insert(1,input("Input image (.jpg):\n"))
output_file = input("Output image (.jpg):\n") sys.argv.insert(1,input("Output image (.jpg):\n"))
else:
input_file = sys.argv[1]
output_file = sys.argv[2]
# Load all images from the "samples/" folder # Load all images from the "samples/" folder
samples = Samples("samples",force) samples = Samples("samples",force)
# Create a collage from the loaded samples # Create a collage from the loaded samples
collage = Collage(input_file,samples) collage = Collage(sys.argv[1],samples)
collage.put(output_file) collage.size = sample_scale
collage.put(sys.argv[2])