import os
import importlib

from pprint import pprint

import maya.cmds as cmds
import maya.mel as mel
import maya.OpenMaya as OpenMaya

import maya_assemblies.lib.representations.representation as representation
import maya_assemblies.lib.shader_helpers as shader_helpers

importlib.reload(shader_helpers)


class SceneRepresentation(representation.Representation):
    typename = "scene"

    def activate(self):
        if not self.path:
            return True

        namespace = self.get_assembly_namespace()
        shader_path = self.assembly.shader_path

        assembly = self.assembly.thisMObject()
        fnassembly = OpenMaya.MFnAssembly(assembly)

        ignore_version = cmds.optionVar(query='fileIgnoreVersion') == 1
        fnassembly.importFile(self.path, None, True, None, ignore_version)

        if shader_path:
            fnassembly.importFile(shader_path, None, False, None, True)

        self.loaded_nodes = cmds.namespaceInfo(namespace, ls=True, dagPath=True )
        root = self.get_assembly_name()
        for node in self.loaded_nodes:
            bits = node.split('|')
            if len(bits) == 2:
                cmds.parent(node, root)

        if not shader_path:
            return True

        loaded_mesh = cmds.ls('%s:*' % namespace , type='mesh')

        loaded_mesh = [mesh for mesh in loaded_mesh if not cmds.getAttr('%s.intermediateObject' % mesh)]
        loaded_shaders = cmds.ls('%s:*' % namespace, type='shadingEngine')
        print(loaded_shaders)

        self.apply_shaders(loaded_shaders, loaded_mesh)

        return True

    def set_textures(self, namespace):

        valid_textures = ['exr', 'png', 'jpg']
        texture_node_types = {'file': 'fileTextureName', 'aiImage': 'filename'}

        for node_type in texture_node_types:
            loaded_textures = cmds.ls('%s:*' % namespace, type=node_type)
            for texture_node in loaded_textures:
                if node_type == 'file':
                    cmds.setAttr('%s.uvTilingMode' % texture_node, 3)

                attribute = texture_node_types[node_type]
                texture_path = cmds.getAttr('%s.%s' % (texture_node, attribute))
                basename, extension = texture_path.rsplit('.', 1)
                if extension != 'tx':
                    continue
                for texture_file in os.listdir(os.path.dirname(texture_path)):
                    check_extension = texture_file.rsplit('.', 1)[-1]
                    if check_extension in valid_textures:
                        new_texture_path = '%s.%s' % (basename, check_extension)
                        cmds.setAttr('%s.%s' % (texture_node, attribute), new_texture_path, type='string')

                        continue

    def apply_shaders(self, shaders_to_apply, meshes):
        namespace = self.get_assembly_namespace()
        self.set_textures(namespace)
        level_of_detail = 1

        if len(shaders_to_apply) == 1:
            sg_node = shaders_to_apply[0]
            node_data = shader_helpers.get_render_info(sg_node)
            for mesh_node in meshes:
                node_path_bits = []
                for bit in mesh_node.split('|'):
                    if bit != '':
                        clean_bit = '%s:%s' % (namespace, bit.split(':')[-1])
                        node_path_bits.append(clean_bit)

                mesh_node = '|'.join(node_path_bits)

                subdivisions = node_data.get('rsMaxTessellationSubdivs', 2)
                shader_helpers.assign_render_settings(sg_node, mesh_node, level_of_detail, subdivisions)
            mel.eval('generateAllUvTilePreviews')
            return
        root = self.assembly.get_assembly_name()

        translator = shader_helpers.get_nodes_translator(shaders_to_apply, root, namespace)

        for key, value in translator.items():
            print('Can\'t find geo %s, assigning it shader to %s' % (key, value))
        for sg_name in shaders_to_apply:
            if sg_name.find('standardSurface') > -1:
                continue
            print('Shader to apply', sg_name)

            sg_data = shader_helpers.get_render_info(sg_name)
            for node_name, node_data in sg_data.items():
                print('\tNode: %s' % node_name)
                if node_name[0] == '{':
                    continue
                if node_name in translator:
                    node_name = translator[node_name]
                full_node_name = '%s:%s' % (namespace, node_name)

                if cmds.objExists(full_node_name) and cmds.nodeType(full_node_name) == 'transform':
                    children = cmds.listRelatives(full_node_name, f=True, c=True)
                    if not children:
                        continue
                    for child in children:
                        if child.find('Orig') > -1:
                            continue
                        full_node_name = child
                subdivisions = node_data.get('rsMaxTessellationSubdivs', 2)
                shader_helpers.assign_render_settings(sg_name, full_node_name, level_of_detail, subdivisions)
        print('applied')
        mel.eval('generateAllUvTilePreviews')


    def can_apply_edits(self):
        return True

