import os
import ast

import maya.cmds as cmds
import maya.OpenMaya as OpenMaya

from pprint import pprint
shading_info_attribute = 'shading_io_info'


def string_distance(str1, str2):
    d = dict()
    for i in range(len(str1) + 1):
        d[i] = dict()
        d[i][0] = i
    for i in range(len(str2) + 1):
        d[0][i] = i
    for i in range(1, len(str1) + 1):
        for j in range(1, len(str2) + 1):
            d[i][j] = min(d[i][j - 1] + 1, d[i - 1][j] + 1, d[i - 1][j - 1] + (not str1[i - 1] == str2[j - 1]))
    return d[len(str1)][len(str2)]


def get_translator(wrong_alembic, wrong_rig):
    all_distances = {}
    for abc_node in wrong_alembic:
        for rig_node in wrong_rig:
            this_distance = string_distance(abc_node, rig_node)
            if this_distance not in all_distances:
                all_distances[this_distance] = {abc_node: rig_node}
            else:
                all_distances[this_distance][abc_node] = rig_node

    translator = {}
    for key in sorted(all_distances.keys()):
        for a_node, r_node in all_distances[key].items():
            if a_node not in translator:
                translator[a_node] = r_node

    return translator


def import_shaders(assembly, shader_path):
    fnassembly = OpenMaya.MFnAssembly(assembly.thisMObject())
    shader_namespace = fnassembly.getAbsoluteRepNamespace()

    #cmds.namespace(set=':')

    #if not cmds.namespace(exists=shader_namespace):
    #    cmds.namespace(add=shader_namespace)

    #cmds.namespace(set=shader_namespace)

    """
    try:
        fnassembly.importFile(shader_path, None, False, None, True)
    except RuntimeError:
        return False
    """
    print('shader_namespace', shader_namespace)
    loaded_nodes = cmds.namespaceInfo(shader_namespace, ls=True, dagPath=True)
    clean_nodes = []
    valid_textures = ['exr', 'png', 'jpg']
    texture_node_types = {'file': 'fileTextureName', 'aiImage': 'filename'}

    texture_nodes = []
    for node_type in texture_node_types:
        this_files = cmds.ls(type=node_type)
        for file_node in this_files:
            if file_node.split(':')[0] in loaded_nodes:
                texture_nodes.append(file_node)


    for node in texture_nodes:
        full_node = '%s:%s' % (shader_namespace, node.split(':')[-1])
        if not cmds.objExists(full_node):
            continue
        try:
            node_type = cmds.nodeType(full_node)
        except:
            continue

        if node_type == 'shadingEngine':
            clean_nodes.append(full_node)
        elif node_type in texture_node_types:
            attribute = texture_node_types[node_type]
            texture_path = cmds.getAttr('%s.%s' % (full_node, attribute))
            basename, extension = texture_path.rsplit('.', 1)
            if extension != 'tx':
                continue
            for texture_file in os.listdir(os.path.dirname(texture_path)):
                check_extension = texture_file.rsplit('.', 1)[-1]
                if check_extension in valid_textures:
                    new_texture_path = '%s.%s' % (basename, check_extension)
                    cmds.setAttr('%s.%s' % (full_node, attribute), new_texture_path, type='string')
                    continue

    return clean_nodes


def get_render_info(sg_node):
    if not cmds.attributeQuery(shading_info_attribute, node=sg_node, exists=True):
        return {}
    sg_data_raw = cmds.getAttr('%s.%s' % (sg_node, shading_info_attribute))
    if not sg_data_raw:
        return {}
    sg_data = ast.literal_eval(sg_data_raw)

    return sg_data

def get_assigned_meshes(sg_nodes):
    to_assign_nodes = []

    for sg_name in sg_nodes:
        sg_data = get_render_info(sg_name)
        for key in sg_data.keys():
            if key[0] == '{':
                continue
            to_assign_nodes.append(key)

    return to_assign_nodes


def get_nodes_translator(sg_nodes, root, namespace):
    to_assign_nodes = get_assigned_meshes(sg_nodes)
    comp_nodes = []
    for node in cmds.listRelatives(root, ad=True):
        if not cmds.objExists(node):
            continue
        try:
            if cmds.nodeType(node) != 'mesh':
                continue
        except:
            continue
        node = cmds.listRelatives(node, parent=True)[0]

        comp_nodes.append(node.split(':')[-1])


    wrong_assign = list(set(comp_nodes) - set(to_assign_nodes))
    wrong_name = list(set(to_assign_nodes) - set(comp_nodes))
    translator = get_translator(wrong_name, wrong_assign)

    return translator


def assign_render_settings(sg_node, mesh, level_of_detail, subdivs):

    if cmds.objExists(mesh) and cmds.nodeType(mesh) == 'mesh':
        cmds.sets(mesh, e=True, forceElement=sg_node)




def apply_shaders(assembly, shader_path, loaded_nodes):
    import maya.mel as mel
    shaders_to_apply = import_shaders(assembly, shader_path)
    namespace = assembly.getRepNamespace()
    level_of_detail = assembly.lod

    if len(shaders_to_apply) == 1:
        sg_node = shaders_to_apply[0]
        node_data = get_render_info(sg_node)
        for mesh_node in loaded_nodes:
            node_path_bits = []
            for bit in mesh_node.split('|'):
                if bit != '':
                    clean_bit = '%s:%s' % (namespace, bit.split(':')[-1])
                    node_path_bits.append(clean_bit)

            mesh_node = '|'.join(node_path_bits)

            subdivisions = node_data.get('rsMaxTessellationSubdivs', 2)
            assign_render_settings(sg_node, mesh_node, level_of_detail, subdivisions)
        mel.eval('generateAllUvTilePreviews')
        return
    namespace = assembly.getRepNamespace()
    root = assembly.get_assembly_name()

    translator = get_nodes_translator(shaders_to_apply, root, namespace)

    for key, value in translator.items():
        print('Can\'t find geo %s, assigning it shader to %s' % (key, value))
    for sg_name in shaders_to_apply:
        print('Shader to apply', sg_name)

        sg_data = get_render_info(sg_name)
        for node_name, node_data in sg_data.items():

            if node_name[0] == '{':
                continue
            if node_name in translator:
                node_name = translator[node_name]
            full_node_name = ':%s:%s' % (namespace, node_name)

            if cmds.nodeType(full_node_name) == 'transform':
                children = cmds.listRelatives(full_node_name, f=True, c=True)
                if not children:
                    continue
                for child in children:
                    if child.find('Orig') > -1:
                        continue
                    full_node_name = child
            subdivisions = node_data.get('rsMaxTessellationSubdivs', 2)
            assign_render_settings(sg_name, full_node_name, level_of_detail, subdivisions)

    mel.eval('generateAllUvTilePreviews')