import os
import shutil
from pprint import pprint
import glob
import re

import maya.cmds as cmds

import library.core.config_manager as config_manager


input_connections = []
output_connections = ['outColor', 'outValue', 'outAlpha', 'outColorR', 'outColorG', 'outColorB']
end_nodes = {'file': 'fileTextureName', 'aiImage': 'filename'}
shading_nodes = ['aiStandardSurface', 'standardSurface', 'aiLayerShader']

channel_codes = ['MSK', 'DFC', 'MTL', 'NRM', 'RLR', 'HGT', 'IBL', 'EMC']



"""
'baseColor': 'LEO_ch_leo01Body_DFC_1001_png',
                                  'coatRoughness': 'LEO_ch_leo01Body_RLR_1001_png',
                                  'diffuseRoughness': 'LEO_ch_bodyVelvet_MSK_1004_1',
                                  'emissionColor': 'LEO_ch_leo01Body_DFC_1001_png',
                                  'metalness': 'LEO_ch_leo01Body_MTL_1001_png',
                                  'normalCamera': 'LEO_ch_leo01Body_NRM_1001_png',
                                  'specularRoughness': 'LEO_ch_leo01Body_RLR_1001_png',
                                  'subsurface': 'LEO_ch_leo01BodySSS_MSK_1001',
                                  'subsurfaceColor': 'LEO_ch_leo01Body_DFC_1001_png'}},
"""
def surf_nodes(node_name, layer='NON'):
    print('-- node: ', node_name)
    connected_nodes = {}
    node_connections = cmds.listConnections(node_name, d=False, s=True, p=True, c=True)
    print('connections', node_connections)
    if not node_connections:
        return {}
    print('node type', cmds.nodeType(node_name))
    for index, source_connection in enumerate(node_connections[::2]):
        connection = node_connections[index*2 + 1]

        node, attribute = connection.split('.', 1)

        if attribute not in output_connections:
            continue

        if cmds.nodeType(node_name) in shading_nodes:
            layer = source_connection.split('.')[-1]

        node_type = cmds.nodeType(node)
        if node_type in end_nodes:
            connected_nodes[layer] = node

        else:
            new_connection = surf_nodes(node, layer)
            connected_nodes.update(new_connection)

    return connected_nodes



def replace_path_items(path):
    add = True
    new_path = ''
    for char in path:
        if char == '<':
            new_path += '*'
            add = False
        elif char == '>':
            add = True
        elif add:
            new_path += char

    return new_path

def replace_attr_pattern(match_pattern):

    pattern = '<attr:([a-zA-Z\_]*)'
    matchs = re.findall(pattern, match_pattern)
    attribute = ''
    if matchs:
        attribute = matchs[0]
        pattern = match_pattern.replace('<attr:%s>' % attribute, '(?P<attribute>[a-zA-Z]*)')
        match_pattern = match_pattern.replace('<attr:%s>' % attribute, '[a-zA-Z]*')
        return match_pattern, pattern, attribute

    return match_pattern, match_pattern, ''


def copy_files(pattern, node, layer, output_folder):

    pattern = pattern.replace('\\', '/')

    node_short_name = node.split('_')[1]
    #full_pattern = replace_path_items(pattern)

    basename = os.path.basename(pattern)
    dirname = os.path.dirname(pattern)
    extension = basename.split('.')[-1]
    basename, match_basename, attribute_name = replace_attr_pattern(basename)

    match_basename = match_basename.replace('<UDIM>', '(?P<udim>[0-9][0-9][0-9][0-9])')
    match_basename = match_basename.replace('<f>', '(?P<frame>[0-9][0-9][0-9][0-9])')

    basename = basename.replace('<UDIM>', '[0-9][0-9][0-9][0-9]')
    basename = basename.replace('<f>', '[0-9][0-9][0-9][0-9]')
    output_pattern = ''
    found = glob.glob(basename, root_dir=dirname)
    for file_path in found:
        match = re.match(match_basename, file_path)
        if match:
            full_source_path = '%s/%s' % (dirname, file_path)
            match_dict = match.groupdict()

            new_name = '%s_%s' % (node_short_name, layer)
            new_pattern = new_name
            if 'attribute' in match_dict:
                new_name = '%s_%s' % (new_name, match_dict['attribute'])
                new_pattern = '%s_<attr:%s>' % (new_pattern, attribute_name)
            if 'frame' in match_dict:
                new_name = '%s_%s' % (new_name, match_dict['frame'])
                new_pattern = '%s_<f>' % new_pattern

            if 'udim' in match_dict:
                new_name = '%s.%s' % (new_name, match_dict['udim'])
                new_pattern = '%s.<UDIM>' % new_pattern

            new_name = '%s.%s' % (new_name, extension)
            new_pattern = '%s.%s' % (new_pattern, extension)
            output_pattern = '%s/%s/%s' % (output_folder, layer, new_pattern)
            output_path = '%s/%s/%s' % (output_folder, layer, new_name)

            output_dirname = os.path.dirname(output_path)
            if not os.path.exists(output_dirname):
                os.makedirs(output_dirname)

            shutil.copy2(full_source_path, output_path)
            print('file', output_path)

    print('out pattern', output_pattern)

def sort_layers(shader_data, project='TPT'):
    config_solver = config_manager.ConfigSolver(project=project)
    texture_config = config_solver.get_config('texture_layers', module='texture_uploader')
    sorted_shaders = {}
    layers_config = texture_config['layers']
    code_list = list(layers_config.values())
    for channel, node in shader_data.items():
        layer_code = layers_config.get(channel, 'None')
        if node in sorted_shaders:
            old_layer = sorted_shaders[node]
            if old_layer in code_list:
                old_index = code_list.index(old_layer)
            else:
                sorted_shaders[node] = layer_code
                continue
            if layer_code not in code_list:
                continue
            layer_index = code_list.index(layer_code)
            if layer_index < old_index:
                sorted_shaders[node] = layer_code
        else:
            sorted_shaders[node] = layer_code

    return sorted_shaders

def copy_shader_files(shader):
    output_folder = 'C:/projects/TPT/assets/Main_Characters/Goldy/fernando.vizoso/maya/sourceimages/test'
    shader_files = surf_nodes(shader)
    shader_data = sort_layers(shader_files)
    for node, channel in shader_data.items():
        node_type = cmds.nodeType(node)
        attribute = end_nodes[node_type]
        file_pattern = cmds.getAttr('%s.%s' % (node, attribute))
        copy_files(file_pattern, shader, channel, output_folder)
    #

def shader_inspect(root_node):
    valid_nodes = 'mesh'
    all_shaders = {}

    for node in cmds.listRelatives(root_node, ad=True, f=True):
        if cmds.nodeType(node) not in valid_nodes:
            continue

        connections = cmds.listConnections(node, type='shadingEngine')
        if connections:
            sg_node = connections[0]
            all_shaders[sg_node] = all_shaders.get(sg_node, {'meshes': [], 'nodes': {}})
            all_shaders[sg_node]['meshes'].append(node)

    for sg_group in all_shaders:
        print('-' * 100)
        print(sg_group)
        copy_shader_files(sg_group)



