import matplotlib.pyplot as plt
import matplotlib.image as mpimg
#import tkinter as tk
#from tkinter import filedialog
import numpy as np
from PIL import Image
import pixelateImg as pixelate
import smooshImg as smoosh
import os
rainbowOrder = [9, 3, 2, 6, 4, 20, 5]
greenOrder = [1, 202, 214, 11, 154, 46]

def splice_4d_array(array_4d, start, end):
    spliced_array_4d = []
    for array_3d in array_4d:
        spliced_array_3d = []
        for array_2d in array_3d:
            spliced_array_2d = []
            for array_1d in array_2d:
                spliced_array_2d.append(array_1d[start:end])
            spliced_array_3d.append(spliced_array_2d)
        spliced_array_4d.append(spliced_array_3d)
    return spliced_array_4d

y = input("Enter width of the collage: ")
uni = input("Would you like forced variety in images (y/n)? ")
def getImg():
    file_path = input("Enter your file path for the \"big picture :)\":\n")
    img = mpimg.imread(file_path)
    ex = 4
    if '.' in file_path[-4:]:
        ex = 3
    img = (img * 255).astype(np.uint8)
    return img, file_path, ex

def getImgs():
    pth = input("Tiny images file path (/Users/...): ")
    print("Loading and pixelating tiny images")
    file_pathJPG = [entry.path for entry in os.scandir(pth) if (entry.is_file() and entry.name.lower().endswith((".jpg", ".png")))]
    file_pathJPEG = [entry.path for entry in os.scandir(pth) if (entry.is_file() and entry.name.lower().endswith(".jpeg"))]
    imgListTmp = []
    n = 0
    for i in file_pathJPG:
        n += 1
        print(f"{n} image(s)")
        # print("HERE")
        # print(i, "tmp" + str(i[-3:]))
        smoosh.run(i, "tmp." + str(i[-3:]), 800, 800)
        img = mpimg.imread("tmp." + str(i[-3:]))
        pixelate.pixelate("tmp." + str(i[-3:]), 250/img.shape[0], 'tmptmptmptmp.png')
        imgListTmp.append((mpimg.imread('tmptmptmptmp.png') * 255).astype(np.uint8))
    for i in file_pathJPEG:
        n += 1
        print(f"{n} image(s)")
        # print("HERE")
        # print(i, "tmp" + str(i[-4:]))
        smoosh.run(i, "tmp." + str(i[-4:]), 800, 800)
        img = mpimg.imread("tmp." + str(i[-4:]))
        pixelate.pixelate("tmp." + str(i[-4:]), 250/img.shape[0], 'tmptmptmptmp.png')
        imgListTmp.append((mpimg.imread('tmptmptmptmp.png') * 255).astype(np.uint8))
    return imgListTmp 

def getAvg(img):
    avg = np.mean(img, axis=(0, 1))[0:3]
    return avg

def errB(perc):
    print(f"\033[38;5;{greenOrder[int(perc*len(greenOrder)/100)]}m" + ((int(perc*2/5) + 1) * ">"))
    
imgs = getImgs()
imgAvgs = []
imgPic, path, ex = getImg()
pixelate.pixelate(path, int(y)/imgPic.shape[0], 'tmptmp.png')# + path[-ex:])
imgPic = (mpimg.imread('tmptmp.png') * 255).astype(np.uint8)
# print(imgPic, mpimg.imread('tmptmp.' + 'png'))
plt.imshow(mpimg.imread('tmptmp.' + 'png'))
# print(imgs[0][0], imgPic[0])
plt.show()
imgArr = np.zeros(np.shape(imgPic)[:2], dtype=object)
print("Averaging imgs.")
for i in imgs:
    imgAvgs.append(getAvg(i))
print("Computing.")
shuffled_arrayA = np.random.permutation(imgPic.shape[0])
shuffled_arrayB = np.random.permutation(imgPic.shape[1])
print("Making the collage")
for a in range(imgPic.shape[0]):
    i = shuffled_arrayA[a]
    usedIndx = []
    for b in range(imgPic.shape[1]):
        y = shuffled_arrayB[b]
        max_similarity = 3*255
        most_similar_index = 0
        for z in range(len(imgs)):
            if  not ((usedIndx.count(z) > 3) and (uni == "y")):
                avgC = imgAvgs[z]
                dif = 0
                for m in range(3):
                    dif += abs(imgPic[i][y][m] - avgC[m])
                if (dif < max_similarity):
                    max_similarity = dif
                    most_similar_index = z
                    
        usedIndx.append(most_similar_index)
            
        #mpimg.imsave('output_image.png', imgs[most_similar_index])
        #pixelate.pixelate('output_image.png', 250/imgs[most_similar_index].shape[0], 'tmptmptmptmp.png')
        percent = (a*imgPic.shape[1]+b)*100/(imgPic.shape[0]*imgPic.shape[1])
        #print(f"\033[38;5;{greenOrder[int(percent*len(greenOrder)/100)]}m{percent} %")
        errB(percent)
        #imgArr[i][y] = (mpimg.imread('tmptmptmptmp.png') * 255).astype(np.uint8)
        imgArr[i][y] = imgs[most_similar_index]

del imgPic
def makeKey(img):
    print("\033[0mMaking the key")
    img = splice_4d_array(img, 0, 3)
    finicKy = None
    if (len(img) > 0):
        if (8 > len(img)):
            finicKy = np.concatenate(tuple(img[0:]), axis=1)
            # print(np.shape(finicKy))
            for l in range(8 - len(img)):
                # print(np.shape(np.zeros(np.shape(img[0]))))
                finicKy = np.hstack((finicKy, np.zeros(np.shape(img[0])) + 255))
        else:
            finicKy = np.concatenate(tuple(img[0:8]), axis=1)
    else:
        finicKy = np.zeros((np.shape(img[0])[0], np.shape(img[0])[1] * 8, np.shape(img[0])[2])) + 255

    for k in range(9):
        r = 0
        i = k + 1
        if (len(img) > (i*8)):
            if ((i*8 + 8) > len(img)):
                r = np.concatenate(tuple(img[i*8:]), axis=1)
                for l in range((i*8 + 8) - len(img)):
                    r = np.hstack((r, np.zeros(np.shape(img[0])) + 255))
                # print(np.shape(r), "HERE!!")
            else:
                r = np.concatenate(tuple(img[i*8:i*8+8]), axis=1)
                # print(np.shape(r), "HERE!!")
        else:
            r = np.zeros((np.shape(img[0])[0], np.shape(img[0])[1] * 8, np.shape(img[0])[2])) + 255
            # print(np.shape(r), "HERE!!")
        # print(np.shape(finicKy), np.shape(r))
        finicKy = np.vstack((finicKy, r))
    # print(finicKy)
    return finicKy/255
        

            
def append_to_3d_array(array_3d, value_to_append):
  shape = array_3d.shape
  reshaped_array = array_3d.reshape(-1, shape[-1])
  appended_array = np.concatenate((reshaped_array, np.full((reshaped_array.shape[0], 1), value_to_append)), axis=1)
  return appended_array.reshape(shape[0], shape[1], shape[2] + 1)



print("\033[0mStacking the collage")
imgsToVstack = []
for i in range(len(imgArr)):
    r = imgArr[i][0][:, :, 0:3]
    #print("H")
    for n in range(len(imgArr[i]) - 1):
        y = n + 1
        #if (np.shape(imgArr[i][y])[2] == 3):
         #   xn = append_to_3d_array(imgArr[i][y], 255)
        #else:
        xn = imgArr[i][y][:, :, 0:3]
       # print("N")
        r = np.hstack((r, xn))
    # if (i == 0):
    #     finalImg = r
    # else:
    #     print("M")
    #     finalImg = np.vstack((finalImg, r))
    #     print("N")
    imgsToVstack.append(r)
    #print("D")
    perc = (i * (len(imgArr[i]) - 1) + n)/(len(imgArr)*(len(imgArr[i])-1)) * 100
    errB(perc)

#print("VStack")
finalImg = np.vstack(imgsToVstack)
#print("done")
del imgsToVstack
finalKey = makeKey(imgs)
del imgs
del imgArr

print("\033[38;5;49mDone!\033[0m")
print("Displaying images")
plt.imshow(finalImg)
plt.show()
plt.imshow(finalKey)
plt.show()
def noName():
    n = input("Would you like to save these images (y/n)? ")
    if n == "y":
        print("Sounds good.")
        return
    elif n == "n":
        print("Ok then.")
        text = "Hooray!!"
        for i in range(len(rainbowOrder)):
            if not (i == (len(rainbowOrder) - 1)):
                print(f"\033[38;5;{rainbowOrder[i]}m{text[i]}\033[0m", end='')
            else:
                print(f"\033[38;5;{rainbowOrder[i]}m{text[i]}\033[0m")
        raise Exception("Bye!")
    else:
        print("Not \"n\" or \"y\". Trying again")
        noName()
    
noName()

name = input("What name would you like to save them with:\n")
print("Saving files")

mpimg.imsave(name + '.png', finalImg)
mpimg.imsave(name + 'Key.png', finalKey)

text = "Hooray!!"
for i in range(len(rainbowOrder)):
    if not (i == (len(rainbowOrder) - 1)):
        print(f"\033[38;5;{rainbowOrder[i]}m{text[i]}\033[0m", end='')
    else:
        print(f"\033[38;5;{rainbowOrder[i]}m{text[i]}\033[0m")
