import numpy as np
from numpy import fft
from skimage.color import rgb2gray
import copy
def cart2pol(x, y):
rho = np.sqrt(x ** 2 + y ** 2)
phi = np.arctan2(y, x)
return (phi, rho)
def pol2cart(phi, rho):
x = rho * np.cos(phi)
y = rho * np.sin(phi)
return (x, y)
[docs]def sfMatch(images, rescaling=0, tarmag=None):
"""
Match the rotational fourier spectrum of a stack of images. For documentation see [1]_
References
----------
.. [1] Willenbockel, V., Sadr, J., Fiset, D. et al. Controlling low-level image properties: The SHINE toolbox. Behavior Research Methods 42, 671–684 (2010). https://doi.org/10.3758/BRM.42.3.671
"""
assert type(images) == type([]), 'The input must be a list.'
numin = len(images)
xs, ys = images[1].shape
angs = np.zeros((xs, ys, numin))
mags = np.zeros((xs, ys, numin))
for x in range(numin):
if len(images[x].shape) == 3:
images[x] = rgb2gray(images[x])
im1 = images[x] / 255
xs1, ys1 = im1.shape
assert (xs == xs1) and (ys == ys1), 'All images must have the same size.'
fftim1 = fft.fftshift(fft.fft2(im1))
angs[:, :, x], mags[:, :, x] = cart2pol(np.real(fftim1), np.imag(fftim1))
if tarmag is None:
tarmag = np.mean(mags, 2)
xt, yt = tarmag.shape
assert (xs == xt) and (ys == yt), 'The target spectrum must have the same size as the images.'
f1 = np.linspace(-ys / 2, ys / 2 - 1, ys)
f2 = np.linspace(-xs / 2, xs / 2 - 1, xs)
XX, YY = np.meshgrid(f1, f2)
t, r = cart2pol(XX, YY)
if xs % 2 == 1 or ys % 2 == 1:
r = np.round(r) - 1
else:
r = np.round(r)
output_images = []
for x in range(numin):
fftim = mags[:, :, x]
a = fftim.T.ravel()
accmap = r.T.ravel() + 1
a2 = tarmag.T.ravel()
en_old = np.array(
[np.sum([a[x] for x in y]) for y in [list(np.where(accmap == z)) for z in np.unique(accmap).tolist()]])
en_new = np.array(
[np.sum([a2[x] for x in y]) for y in [list(np.where(accmap == z)) for z in np.unique(accmap).tolist()]])
coefficient = en_new / en_old
cmat = coefficient[(r).astype(int)] # coefficient[r+1]
cmat[r > np.floor(np.max((xs, ys)) / 2)] = 0
newmag = fftim * cmat
XX, YY = pol2cart(angs[:, :, x], newmag)
new = XX + YY * complex(0, 1)
output = np.real(fft.ifft2(fft.ifftshift(new)))
if rescaling == 0:
output = (output * 255)
output_images.append(output)
if rescaling != 0:
output_images = rescale_shine(output_images, rescaling)
return output_images
[docs]def rescale_shine(images, option=1):
"""
Rescale the intensity of a stack of images. For documentation see [1]_
References
----------
.. [1] Willenbockel, V., Sadr, J., Fiset, D. et al. Controlling low-level image properties: The SHINE toolbox. Behavior Research Methods 42, 671–684 (2010). https://doi.org/10.3758/BRM.42.3.671
"""
assert type(images) == type([]), 'The input must be a list.'
assert option == 1 or option == 2, "Invalid rescaling option"
numin = len(images)
brightests = np.zeros((numin, 1))
darkests = np.zeros((numin, 1))
for n in range(numin):
if len(images[n].shape) == 3:
images[n] = rgb2gray(images[n])
brightests[n] = np.max(images[n])
darkests[n] = np.min(images[n])
the_brightest = np.max(brightests)
the_darkest = np.min(darkests)
avg_brightest = np.mean(brightests)
avg_darkest = np.mean(darkests)
output_images = []
for m in range(numin):
if option == 1:
rescaled = (images[m] - the_darkest) / (the_brightest - the_darkest) * 255
else: # option == 2:
rescaled = (images[m] - avg_darkest) / (avg_brightest - avg_darkest) * 255
output_images.append(rescaled.astype(np.uint8))
return output_images
[docs]def lumMatch(images, mask=None, lum=None):
"""
Match the luminosity of a stack of images. For documentation see [1]_
References
----------
.. [1] Willenbockel, V., Sadr, J., Fiset, D. et al. Controlling low-level image properties: The SHINE toolbox. Behavior Research Methods 42, 671–684 (2010). https://doi.org/10.3758/BRM.42.3.671
"""
assert type(images) == type([]), 'The input must be a list.'
assert (mask is None) or type(mask) == type([]), 'The input mask must be a list.'
numin = len(images)
if (mask is None) and (lum is None):
M = 0;
S = 0
for im in range(numin):
if len(images[im].shape) == 3:
images[im] = rgb2gray(images[im])
M = M + np.mean(images[im])
S = S + np.std(images[im])
M = M / numin
S = S / numin
output_images = []
for im in range(numin):
im1 = copy.deepcopy(images[im])
if np.std(im1) != 0:
im1 = (im1 - np.mean(im1)) / np.std(im1) * S + M
else:
im1[:, :] = M
output_images.append(im1)
elif (mask is None) and (lum is not None):
M = 0
S = 0
for im in range(numin):
if len(images[im].shape) == 3:
images[im] = rgb2gray(images[im])
M = lum[0]
S = lum[1]
M = M / numin
S = S / numin
output_images = []
for im in range(numin):
im1 = copy.deepcopy(images[im])
if np.std(im1) != 0:
im1 = (im1 - np.mean(im1)) / np.std(im1) * S + M
else:
im1[:, :] = M
output_images.append(im1)
elif (mask is not None) and (lum is None):
M = 0
S = 0
for im in range(numin):
if len(images[im].shape) == 3:
images[im] = rgb2gray(images[im])
im1 = images[im]
assert len(images) == len(mask), "The inputs must have the same length"
m = mask[im]
assert m.size == images[im].size, "Image and mask are not the same size"
assert np.sum(m == 1) > 0, 'The mask must contain some ones.'
M = M + np.mean(im1[m == 1])
S = S + np.mean(im1[m == 1])
M = M / numin
S = S / numin
output_images = []
for im in range(numin):
im1 = images[im]
if type(mask) == type([]):
m = mask[im]
if np.std(im1[m == 1]):
im1[m == 1] = (im1[m == 1] - np.mean(im1[m == 1])) / np.std(im1[m == 1]) * S + M
else:
im1[m == 1] = M
output_images.append(im1)
elif (mask is not None) and (lum is not None):
M = lum[0]
S = lum[1]
output_images = []
for im in range(numin):
if len(images[im].shape) == 3:
images[im] = rgb2gray(images[im])
im1 = images[im]
if len(mask) == 0:
if np.std(im1) != 0.0:
im1 = (im1 - np.mean(im1)) / np.std(im1) * S + M
else:
im1[:, :] = M
else:
if type(mask) == type([]):
assert len(images) == len(mask), "The inputs must have the same length"
m = mask[im]
assert m.size == images[im].size, "Image and mask are not the same size"
assert np.sum(m == 1) > 0, 'The mask must contain some ones.'
if np.std(im1[m == 1]) != 0.0:
im1[m == 1] = (im1[m == 1] - np.mean(im1[m == 1])) / np.std(im1[m == 1]) * S + M
else:
im1[m == 1] = M
output_images.append(im1)
return output_images