import math
import os
import sys
import Imath

from datetime import datetime


from PIL import Image, ImageDraw, ImageFont, ImageSequence

import OpenEXR
import photoshop_tools.psd_tools as psd_tools

from pprint import pprint
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = "1"
os.environ['NPY_PROMOTION_STATE'] = "legacy"

this_folder = os.path.dirname(os.path.abspath(__file__))
resources_folder = os.path.abspath(os.path.join(this_folder, '../resources'))
print(resources_folder)

import numpy as np
import cv2

images_layout = {1: [1],
                 2: [2],
                 3: [2, 1],
                 4: [2, 2],
                 5: [3, 2],
                 6: [3, 3],
                 7: [3, 2, 2],
                 8: [3, 3, 2],
                 9: [3, 3, 3],
                 }

BG_COLOR = (254, 252, 240, 255)
BG_COLOR = (25, 25, 24, 0)

def EncodeToSRGB(v):
    return v ** (1 / 2.2)

def psd_to_png(psd_path):
    psd = psd_tools.PSDImage.open(psd_path)

    width, height = psd.size
    final_image = Image.new("RGBA", (width, height), BG_COLOR)

    for index, layer in enumerate(psd):
        if index == 0 and len(psd) > 1:
            continue

        if layer.is_visible():
            if type(layer).__name__ == 'Group':
                layer_image = layer.compose()
            else:
                layer_image = layer.topil()
            if not layer_image:
                continue

            left_side = -1 * layer.bbox[0]
            top_side = -1 * layer.bbox[1]

            layer_image = layer_image.crop((left_side, top_side, width + left_side, height+top_side))

            if layer_image:
                final_image = Image.alpha_composite(final_image, layer_image)

    return final_image

def exr_to_png(exr_file_path):

    exr_file = OpenEXR.InputFile(exr_file_path)
    header = exr_file.header()
    dw = header['dataWindow']
    width = dw.max.x - dw.min.x + 1
    height = dw.max.y - dw.min.y + 1

    # Definir el tipo de píxel (float32)
    pt = Imath.PixelType(Imath.PixelType.FLOAT)

    # Leer los canales de color (R, G, B)
    red_str = exr_file.channel('R', pt)
    green_str = exr_file.channel('G', pt)
    blue_str = exr_file.channel('B', pt)
    alpha_str = exr_file.channel('A', pt)

    # Convertir los canales a arrays de numpy
    red = np.frombuffer(red_str, dtype=np.float32).reshape(height, width)
    green = np.frombuffer(green_str, dtype=np.float32).reshape(height, width)
    blue = np.frombuffer(blue_str, dtype=np.float32).reshape(height, width)
    alpha = np.frombuffer(alpha_str, dtype=np.float32).reshape(height, width)

    # Combinar los canales en una sola imagen y normalizar
    img_array = np.stack([red, green, blue, alpha], axis=-1)
    img_array = np.clip(img_array * 255, 0, 255).astype(np.uint8)  # Normalizar a [0, 255]

    # Convertir el array de numpy a una imagen de PIL
    pil_image = Image.fromarray(img_array, 'RGBA')

    return pil_image


def open_image(image_path):
    extension = image_path.split('.')[-1]
    if extension == 'exr':
        im = exr_to_png(image_path)
    elif extension == 'psd':
        print('->')
        im = psd_to_png(image_path)
    else:
        im = Image.open(image_path)

    return im


def has_transparency(img):
    if img.info.get("transparency", None) is not None:
        return True
    if img.mode == "P":
        transparent = img.info.get("transparency", -1)
        for _, index in img.getcolors():
            if index == transparent:
                return True
    elif img.mode == "RGBA":
        extrema = img.getextrema()
        if extrema[3][0] < 255:
            return True

    return False

def stitch_images(main_image,
                  file_list,
                  output_path,
                  width=4096,
                  height=4096,
                  left_margin=.005,
                  top_margin=.033,
                  right_margin=.005,
                  bottom_margin=.17,
                  asset_name='',
                  department='',
                  pipeline_step='',
                  version=0,

                  ):

    if main_image:
        file_list['main'] = main_image
    background_image_path = os.path.join(resources_folder, 'MATG_Cartouche_Template.png')

    bg_image = Image.open(background_image_path)

    final_image_width = width
    final_image_height = int(height / (16.0/9.0))

    left_margin *= final_image_width
    right_margin *= final_image_width
    top_margin *= final_image_height
    bottom_margin *= final_image_height

    total_images = len(file_list.values())
    newsize = (final_image_width, final_image_height)
    inside_image_width = final_image_width - (left_margin + right_margin)
    inside_image_height = final_image_height - (top_margin + bottom_margin)

    bg_image = bg_image.resize(newsize)
    square, positions = get_positions(total_images)
    rows = columns = int(square)

    thumbnail_width = int(inside_image_width / float(columns))
    thumbnail_height = int(inside_image_height / float(rows))

    for position, image_name in zip(positions, file_list.values()):
        image = open_image(image_name)
        image_width, image_height = image.size

        scale = thumbnail_width / image_width

        scaled_height = int(image_height * scale)
        if scaled_height > thumbnail_height:
            scale = thumbnail_height / image_height
            scaled_height = thumbnail_height
            scaled_width = int( image_width*scale)

        else:
            scaled_width = thumbnail_width

        image_ofset_x = int((thumbnail_width - scaled_width)/2.0)
        image_ofset_y = int((thumbnail_height - scaled_height) / 2.0)
        image = image.resize((scaled_width, scaled_height))
        top_left = position[0]
        left = int(left_margin + image_ofset_x + top_left[0] * inside_image_width)
        top = int(top_margin + image_ofset_y + top_left[1] * inside_image_height)


        if has_transparency(image):
            alpha = image.split()[-1]
            bg_image.paste(image, (left, top), mask=alpha)
        else:
            bg_image.paste(image, (left, top))

    draw = ImageDraw.Draw(bg_image)

    dirname = os.path.dirname(output_path)
    if not os.path.exists(dirname):
        os.makedirs(dirname)

    fontsize = 48
    font = ImageFont.truetype("arial.ttf", fontsize)
    text_color = (20, 20, 25)
    text_position_y = .917 * final_image_height
    text_position_x = .16 * final_image_width
    draw.text((text_position_x, text_position_y), 'Asset: %s ' % asset_name, font=font,fill=text_color)

    text_position_x = .33 * final_image_width
    draw.text((text_position_x, text_position_y), 'Step: %s ' % pipeline_step, font=font, fill=text_color)

    text_position_x = .43 * final_image_width
    draw.text((text_position_x, text_position_y), 'Version: %03d ' % version, font=font, fill=text_color)

    text_position_x = .51 * final_image_width
    draw.text((text_position_x, text_position_y), 'Department: %s ' % department, font=font, fill=text_color)

    today_date = datetime.today().strftime('%d-%b-%Y')

    text_position_x = .69 * final_image_width
    draw.text((text_position_x, text_position_y), 'Date: %s ' % today_date, font=font, fill=text_color)

    print(output_path)
    bg_image.save(output_path)

    output_files = {'stiched': output_path}
    if main_image.endswith == '.psd':
        output_files['png'] = main_image.replace('.psd', 'psd')
    return {'output_paths': {'stiched': output_path}}


def get_layout(num_items):
    square = math.sqrt(num_items)
    part, whole = math.modf(square)
    if part > 0.0001:
        square = int(square) + 1
    else:
        square = int(square)

    items = []
    x = 0
    y = 0

    for count in range(num_items):
        if len(items) == x:
            items.append(0)

        items[x] +=1
        x += 1
        if x == square:
            y += 1
            x = 0
    return items, square

def get_positions(num_items):
    layout, row_width = get_layout(num_items)

    width = 1.0 / row_width
    data = []
    for y, row_number in enumerate(layout):

        start_pos_x = ((row_width - row_number) / 2.0) / row_width
        start_pos_y = (y / row_width)
        end_pos_y = (y / row_width) + width
        for column in range(row_number):
            end_pos_x = start_pos_x + width
            data.append(((start_pos_x, start_pos_y), (end_pos_x, end_pos_y)))
            start_pos_x = end_pos_x

    return row_width, data


if __name__ == '__main__':




    files = {'a': 'C:/dev_mock/test_data/leo/LeoHero.0101.exr',
             'b': 'C:/dev_mock/test_data/leo/LeoHero.0104.exr',
             'c': 'C:/dev_mock/test_data/leo/LeoHero.0107.exr',
             'd': 'C:/dev_mock/test_data/leo/LeoHero.0110.exr',
             'e': 'C:/dev_mock/test_data/leo/LeoHero.0113.exr',
             'f': 'C:/dev_mock/test_data/leo/LeoHero.0101.exr',
             'g': 'C:/dev_mock/test_data/leo/LeoHero.0104.exr',
             'h': 'C:/dev_mock/test_data/leo/LeoHero.0107.exr',
             'i': 'C:/dev_mock/test_data/leo/LeoHero.0110.exr',
             'j': 'C:/dev_mock/test_data/leo/LeoHero.0113.exr',
             }
    files = {'a': 'C:/dev_mock/test_data/leo/LeoHero.0101.exr',
             'b': 'C:/dev_mock/test_data/leo/LeoHero.0104.exr',
             'c': 'C:/dev_mock/test_data/leo/LeoHero.0107.exr',
             }

    files = {}
    main_file = r'\\columbus\SGD\company\interchage/city_square.psd'
    main_file = 'C:/projects/nona_image.psd'
    main_file = r'\\columbus\SGD\company\interchage/big city JPG.jpg'
    main_file = r'\\columbus\SGD\company\interchage/big_city_large.png'
    main_file = 'C:/projects/mom.psd'
    main_file = 'C:/projects/CitizenEko.psd'

    output_path = 'C:/projects/big_city_jpg.png'
    asset_name = 'TestAssetRigged'
    pipeline_step = 'Brief'
    version = 4
    department = 'Preprod'
    stitch_images(main_file,
                  files,
                  output_path,
                  asset_name=asset_name,
                  pipeline_step=pipeline_step,
                  version=version,
                  department=department)
