import cv2
import sys
import numpy as np
from tqdm import tqdm, trange
from scipy.fft import fft2, ifft2
from scipy.signal import correlate2d, fftconvolve
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt


target = 'fft' # 'real' # 'fft'
infile = "normalized_image.png"
outfile = "classified_image.png"

block_size = 50
num_clusters = 10  # クラスタの数を指定
view_w = 800

colors = [(255,   0, 0), (  0, 255, 0), (0,   0, 255), (255, 255,   0), (  0, 255, 255), (255,   0, 255), 
          (255,  80, 0), ( 80, 255, 0), (0,  80, 255), (255, 255,  80), ( 80, 255, 255), (255,  80, 255), 
          (255, 160, 0), (160, 255, 0), (0, 160, 255), (255, 255, 160), (160, 255, 255), (255, 160, 255), 
         ]
ncolors = len(colors)
alpha = 0.8


def usage():
    print()
    print(f"Usage: python {sys.argv[0]} infile block_size nclusters comparison_func")

def update_vars():
    global target, infile
    global block_size, num_clusters, view_w

    nargv = len(sys.argv)
    if nargv >= 2: infile = sys.argv[1]
    if nargv >= 3: block_size = int(sys.argv[2])
    if nargv >= 4: num_clusters = int(sys.argv[3])
    if nargv >= 5: target = sys.argv[4]

def read_image(infile, color = cv2.IMREAD_GRAYSCALE):
    print()
    print(f"Read image [{infile}]")
    return cv2.imread(infile, color)

def save_image(img, outfile):
    print()
    print(f"Save converted image to [{outfile}]")
    cv2.imwrite(outfile, img)

def plot_image(images, titles, view_w = 600, mode = 'cv2'):
    print()
    print("plot")

    img1 = images[0]
    if len(img1.shape) == 2:
        height, width = img1.shape
    else:
        height, width, _ = img1.shape
    print(f"image shape: {width} x {height}")
    mag = view_w / width
    print(f"view width: {view_w} = {width} * {mag:.3f}")

    images_resized = []
    _size = (int(width * mag + 1.0e-5), int(height * mag + 1.0e-5))
    for img in images:
        img_resized = cv2.resize(img, _size)#, interpolation=cv2.INTER_LINEAR)
        images_resized.append(img_resized)

    if mode == 'cv2':
        for img, title in zip(images_resized, titles):
            cv2.imshow(title, img)

        cv2.waitKey(0)
        cv2.destroyAllWindows()
    else:
        nimages = len(images)
        plt.figure(figsize=(15, 5))
        for i in range(nimages):
            img = images_resized[i]
            title = titles[i]
            plt.subplot(1, nimages, i + 1)
            plt.title(title)
            plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))

        plt.pause(1.0e-5)
        input("\nPress ENTER to terminate>>\n")

def split_image_into_blocks(image, block_size_h, block_size_w):
    blocks = []
    positions = []
    h, w = image.shape
    print("split image to blocks")
    for i in tqdm(range(0, h, block_size_h)):
        for j in range(0, w, block_size_w):
            block = image[i:i+block_size_h, j:j+block_size_w]
            if block.shape == (block_size_h, block_size_w):
                blocks.append(block)
                positions.append((i, j))

    return blocks, positions

def autocorrelation_blocks(blocks):
    print("auto correction functions:")
    acf_blocks = []
    for block in tqdm(blocks):
# 画像の平均を引く
#        image_mean = block - np.mean(block)
# 2次元自己相関を計算
#        acf = correlate2d(image_mean, image_mean, mode = 'full')
        fft_image = fft2(block)
        acf = ifft2(np.abs(fft_image)**2).real

        vmax = 0.0
        for data_list in acf:
            for v in data_list:
                if vmax < v: vmax = v

        acf *= 1.0 / vmax * 255.0

        acf_blocks.append(acf)

    return acf_blocks

def fft_blocks(blocks):
    print("convert blocks by fft")
    use_thres = False
    thres = 2.0
    normalize = False
    
    fft_blocks = []
    for block in tqdm(blocks):
        f = np.fft.fft2(block)
        fshift = np.fft.fftshift(f)

#        magnitude_spectrum = np.abs(fshift)
        magnitude_spectrum = np.log(np.abs(fshift))

        vmax = 0.0
        for data_list in magnitude_spectrum:
            for v in data_list:
                if vmax < v: vmax = v
#        print("vmax=", vmax)

        magnitude_spectrum *= 1.0 / vmax * 255.0

        if use_thres:
            for ix, data_list in enumerate(magnitude_spectrum):
                for iy, v in enumerate(data_list):
                    if magnitude_spectrum[ix, iy] > 255.0 / thres:
                        magnitude_spectrum[ix, iy] = 255.0
                    else:
                        magnitude_spectrum[ix, iy] *= thres

        if normalize:
            magnitude_spectrum = cv2.normalize(magnitude_spectrum, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX)

        fft_blocks.append(magnitude_spectrum)

    return fft_blocks

def comparison_function(blocks, target = 'fft'):
    if target == 'real': return blocks
    elif target == 'acf': return autocorrelation_blocks(blocks)
    else: return fft_blocks(blocks)

def main():
    update_vars()

    image = read_image(infile)
    height, width = image.shape

    nh = int(height / block_size)
    block_size_h = int(height / nh)
    nw = int(width / block_size)
    block_size_w = int(width / nw)
    print(f"image size: {width} x {height}")
    print(f"block size: {block_size_w} x {block_size_h}")
    print("ratios:", width / block_size_w, height / block_size_h)

    blocks, positions = split_image_into_blocks(image, block_size_h, block_size_w)

#比較関数
    cmp_blocks = comparison_function(blocks, target)
    cmp_blocks1d = [v.flatten() for v in cmp_blocks]

# k-meansクラスタリング
    kmeans = KMeans(n_clusters = num_clusters, random_state = 42, n_init = 'auto')
    kmeans.fit(cmp_blocks1d)
    labels = kmeans.labels_

# 画像に分類結果を描画
    output_image1 = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
    kfontsize = 2.5
    for (i, (x, y)) in enumerate(positions):
        label = labels[i]
        cv2.putText(output_image1, str(label), (y, x + block_size_w // 2), 
                    cv2.FONT_HERSHEY_SIMPLEX, kfontsize, (0, 255, 0), 2, cv2.LINE_AA)

# blockを色分け
    output_image2 = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
    for (i, (x, y)) in enumerate(positions):
        label = labels[i]
        color = colors[label]
        cv2.rectangle(output_image2, (y, x), (y + block_size_h, x + block_size_w), color, -1)

# 色分けしたblockを元画像に重ねる
    bgr_image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
    output_image2 = cv2.addWeighted(bgr_image, alpha, output_image2, 1.0 - alpha, 0)

# comaprison像を合成
    cmp_image = np.zeros(image.shape, dtype = np.float32)
    for (i, (x, y)) in enumerate(positions):
        cmp_image[x:x+block_size_h, y:y+block_size_w] = cmp_blocks[i]

# comparison像を正規化して表示
    cmp_image = cv2.normalize(cmp_image, None, 0, 255, cv2.NORM_MINMAX)
    cmp_image = cmp_image.astype(np.uint8)

    save_image(output_image2, outfile)

    plot_image([image, cmp_image, output_image2], ['Original', 'Comparison func', 'Classified'],
               view_w =view_w, mode = 'matplotlib')


if __name__ == '__main__':
    main()
    