import os
import sys
import ast
import shutil
import weakref
import importlib
import yaml

import maya.OpenMaya as OpenMaya
import maya.OpenMayaMPx as OpenMayaMPx
import maya.cmds as cmds

import maya_assemblies.lib.master_assembly as master_assembly

import maya_assemblies.lib.helpers as helpers


importlib.reload(master_assembly)
importlib.reload(helpers)


class Assembly(master_assembly.MasterAssembly):
    typename = "Assembly"
    id = OpenMaya.MTypeId(70000)
    icon_name = "out_assemblyReference.png"

    # Maya Attributes
    agrid_uid = OpenMaya.MObject()
    active_representation = OpenMaya.MObject()
    arep_namespace = OpenMaya.MObject()
    ainitial_rep = OpenMaya.MObject()
    tracking_dict = {}



    def __init__(self):
        self.representations = {}
        self.active_rep = None
        self.is_updating_namespace = False
        self.is_activating = False
        self.added_attributes = []
        self.allow_shaders = False
        self.allow_transforms = False

        self.localized = False
        self.is_loaded = False
        self.loader_attributes_callback = None
        self.set_attributes_callback = None
        self.connect_attributes_callback = None
        self.attribute_set_edits = {}
        self.attribute_connect_edits = {}
        self.applying_edit = False
        self.force_refresh = False

        super(Assembly, self).__init__()

    @property
    def geometry_path(self):

        plug = OpenMaya.MPlug(self.thisMObject(), self.geometry_file)
        value = plug.asString()
        return value

    @property
    def shader_path(self):

        shading_plug = OpenMaya.MPlug(self.thisMObject(), self.shading_file)
        shading = shading_plug.asString()
        return shading


    @classmethod
    def _get_mpx_derived_object(cls, assembly):
        fndep = OpenMaya.MFnDependencyNode(assembly)
        return cls.tracking_dict[OpenMayaMPx.asHashable(fndep.userNode())]()


    def initialize_assembly(self):
        import maya_assemblies.lib.representations
        path_plug = OpenMaya.MPlug(self.thisMObject(), self.geometry_file)
        path = path_plug.asString()

        local_root = 'C:/projects'
        if self.localized:
            local_path = '%s/%s' % (local_root, path.split('/', 1)[-1])
            if not os.path.exists(local_path):
                folder = os.path.dirname(local_path)
                if not os.path.exists(folder):
                    os.makedirs(folder)
                shutil.copy2(path, local_path)
                path = local_path

        assembly_type_plug = OpenMaya.MPlug(self.thisMObject(), self.assembly_type)
        ass_type = assembly_type_plug.asString()

        if not path or not ass_type:
            return -1

        import maya_assemblies as assemblies
        import importlib

        root_folder = os.path.dirname(os.path.abspath(assemblies.__file__))
        config_file = os.path.join(root_folder, 'config', 'assembly_representations.yaml')

        root_module = 'maya_assemblies.lib.representations'
        with open(config_file, 'r') as config_file:
            representations_config = yaml.safe_load(config_file)

        if not ass_type in representations_config:
            self._default_representation = ''
            return

        low_path = None
        if path.find('_high.abc') > -1:
            low_path = path.replace('_high.abc', '_low.abc')
            if not os.path.exists(low_path):
                low_path = None

        assembly_config = representations_config[ass_type]
        self.allow_shaders = assembly_config.get('allow_shaders', False)
        self.allow_transforms = assembly_config.get('allow_transforms', False)
        self.representations = {}

        for rep_data  in assembly_config['representations']:
            for rep_name, rep_mod_class in rep_data.items():
                module_name, class_name = rep_mod_class.split('.')
                full_module_name = '%s.%s' %(root_module, module_name)
                module = importlib.import_module(full_module_name)
                importlib.reload(module)
                representation = getattr(module, class_name)
                self.representations[rep_name] = representation(self, rep_name, path)

                if low_path and rep_name != 'Locator':
                    self.representations['%s_low' % rep_name] = representation(self, '%s_low' % rep_name, low_path)

        self._default_representation = assembly_config['default']

        fndep = OpenMaya.MFnDependencyNode(self.thisMObject())

        icon_name = assembly_config.get('icon')

        icon_name = '%s/icons/%s' % (root_folder, icon_name)
        if icon_name:
            icon_plug = fndep.findPlug("iconName", True)

            icon_plug.setString(icon_name)

        assemby_name = self.get_assembly_name()

        attributes_config = assembly_config.get('attributes', {})
        new_added_attributes = list(attributes_config.keys())

        for attribute_name, attribute_data in attributes_config.items():
            if attribute_name in self.added_attributes:
                continue
            kwargs = attribute_data.copy()
            kwargs['longName'] = attribute_name
            cmds.addAttr(assemby_name,
                         longName=attribute_name,
                         shortName=attribute_data['shortName'],
                         dataType=attribute_data['dataType'],
                         niceName=attribute_data['niceName'],
                         usedAsFilename=attribute_data.get('asFilename', False),
                         category = attribute_data.get('category', None),
                         )

        for attribute in self.added_attributes:
            if attribute in new_added_attributes:
                continue
            cmds.deleteAttr(attribute, n=assemby_name)
        self.added_attributes = new_added_attributes



    @classmethod
    def initialize(cls):
        cls.sg_id = cls.add_string_attribute('shotgrid_id', 'sg')
        cls.active_representation = cls.add_string_attribute("activeRepresentation", "arep")
        cls.assembly_type = cls.add_string_attribute("assemblyType", "at")
        cls.representation_namespace = cls.add_string_attribute('representationNamespace', 'rns')
        cls.initial_representation = cls.add_string_attribute("initialRepresentation", "irp")

        cls.geometry_file = cls.add_string_attribute('path', 'gep', as_path=True, category='refresh_geo')


        cls.asset_name = cls.add_string_attribute('assetName', 'an', category='asset_data')
        cls.asset_type = cls.add_string_attribute('assetType', 'atp', category='asset_data')


        cls.variant = cls.add_string_attribute('variant', 'var', category='asset_data')
        cls.version = cls.add_string_attribute('version', 'ver', category='asset_data')
        cls.file_tag = cls.add_string_attribute('fileTag', 'ft', category='asset_data')

        cls.pipeline_step = cls.add_string_attribute('pipelineStep', 'ps', category='asset_data')
        cls.publish_hash = cls.add_string_attribute('hash', 'hs', category='asset_data')

        cls.instance_number = cls.add_integer_attribute('instanceNumber', 'in', min=0, default=0, category='asset_data')
        cls.level_of_detail = cls.add_integer_attribute('LevelOfDetail', 'ld', min=0, default=4, category='shader')

        cls.shading_file = cls.add_string_attribute('shadingPath', 'shpa', as_path=True, category='refresh_shading')


    def postLoad(self):
        afn = OpenMaya.MFnAssembly(self.thisMObject())

        representation_plug = OpenMaya.MPlug(self.thisMObject(), self.initial_representation)
        representation_type = representation_plug.asString()
        if not representation_type:
            afn.activate(self._default_representation)
            return

        representation_data = ast.literal_eval(representation_type)

        if not afn.isTopLevel():
            namespace_plug = OpenMaya.MPlug(self.thisMObject(),
                                            self.arep_namespace)
            namespace_plug.setLocked(True)
        current_representation = representation_data.get('Active', self._default_representation)

        afn.activate(current_representation)
        afn = OpenMaya.MFnAssembly(self.thisMObject())

        sub_assemblies = afn.getSubAssemblies()
        for index in range(sub_assemblies.length()):
            child = sub_assemblies[index]
            child_assembly = OpenMaya.MFnAssembly(child)
            assembly_name = child_assembly.name().split(':')[-1]
            if assembly_name in representation_data['Set']:
                child_assembly.activate(representation_data['Set'][assembly_name]['active'])

    def postConstructor(self):
        self.loader_attributes_callback = OpenMaya.MNodeMessage.addAttributeChangedCallback(self.thisMObject(),
                                                                                            self.add_assembly_node_callbacks)
        super(Assembly, self).postConstructor()


    def supportsEdits(self):
        return True

    def beforeSave(self):
        self.save_edits()
        assembly_type_plug = OpenMaya.MPlug(self.thisMObject(), self.assembly_type)
        assembly_type = assembly_type_plug.asString()

        if assembly_type == 'Set':
            self.save_children_representation()

    def save_children_representation(self):
        afn = OpenMaya.MFnAssembly(self.thisMObject())
        sub_assemblies = afn.getSubAssemblies()

        representations = {}
        for index in range(sub_assemblies.length()):
            child = sub_assemblies[index]
            child_assembly = OpenMaya.MFnAssembly(child)

    def save_edits(self):
        assemblynode = self.thisMObject()
        helpers.save_initial_representation(assemblynode)
        active_rep = self.getActive()
        if not active_rep:
            return

        if not self.canRepApplyEdits(active_rep):
            return

        node_name = self.get_assembly_name()
        namespace = self.getRepNamespace()
        root = helpers.get_root_assembly(assemblynode)
        iterator = OpenMaya.MItEdits(root, assemblynode, OpenMaya.MItEdits.ALL_EDITS)

        while not iterator.isDone():
            iterator.removeCurrentEdit()
            iterator.next()

        self.handle_edit = True

        for source_connection in self.attribute_connect_edits:
            bits = source_connection.split(':')
            if bits[0] != namespace:
                source_connection = '%s:%s' % (namespace, source_connection)

            if not cmds.objExists(source_connection):
                continue

            input_connections = cmds.listConnections(source_connection, p=True, c=True, s=True, d=False)
            output_connections = cmds.listConnections(source_connection, p=True, c=True, d=True, s=False)

            if input_connections:
                target_attr, source_attr = input_connections
            elif output_connections:
                source_attr, target_attr = output_connections
            else:
                continue
            source_attr = source_attr.split(':', 1)[-1]

            source_attr = source_attr.split('[')[0]
            target_attr = target_attr.split('[')[0]
            self.addConnectAttrEdit('..:%s' % node_name, '..:%s' % source_attr, '..:%s' % target_attr)

        for source_attr, value in self.attribute_set_edits.items():
            source_attr = source_attr.split(':', 1)[-1]
            source_attr = source_attr.split('[')[0]

            self.addSetAttrEdit('..:%s' % node_name, '..:%s:%s' % (self.getRepNamespace(), source_attr), ' %s' % value)

        self.handle_edit = False

    def clean_shading_data(self):
        active_rep = self.getActive()
        if not active_rep:
            return

        if not self.canRepApplyEdits(active_rep):
            return
        assemby_name = self.get_assembly_name()

        all_nodes = cmds.listRelatives(assemby_name, ad=True)
        if not all_nodes:
            return


        for node in all_nodes:
            for long_node in cmds.ls(node, l=True):
                cmds.sets(long_node, e=True, forceElement='initialShadingGroup')

        cmds.setAttr('%s.shaderAssignment' % assemby_name, '', type='string')

    def add_callbacks(self):

        if not self.allow_transforms and not self.allow_shaders:
            return

        assembly_node = self.get_assembly_name()
        nodes = cmds.listRelatives(assembly_node, ad=True, f=True)
        if not nodes:
            return
        for node in nodes:

            if cmds.nodeType(node) == 'transform' and self.allow_transforms:
                mobject_node = self.get_mobject_from_name(node)
                self.set_attributes_callback = OpenMaya.MNodeMessage.addAttributeChangedCallback(mobject_node,
                                                                                                 self.add_set_edit)
            elif cmds.nodeType(node) == 'mesh' and self.allow_shaders:
                mobject_node = self.get_mobject_from_name(node)
                self.set_attributes_callback = OpenMaya.MNodeMessage.addAttributeChangedCallback(mobject_node, self.add_set_edit)

    def get_assembly_name(self):
        fnassembly = OpenMaya.MFnDagNode(self.thisMObject())
        return fnassembly.partialPathName()

    def get_assembly_type(self):
        asset_type_plug = OpenMaya.MPlug(self.thisMObject(), self.assembly_type)
        assembly_type = asset_type_plug.asString()
        return assembly_type

    def get_asset_name(self):
        asset_name_plug = OpenMaya.MPlug(self.thisMObject(), self.asset_name)
        asset_name = asset_name_plug.asString()
        return asset_name

    def get_variant(self):
        variant_plug = OpenMaya.MPlug(self.thisMObject(), self.ainitial_rep)
        variant = variant_plug.asString()
        return variant

    def get_lod(self):
        lod_plug = OpenMaya.MPlug(self.thisMObject(), self.level_of_detail)
        lod = lod_plug.asInt()
        return lod

    def get_shader_path(self):
        assembly_name = self.get_assembly_name()
        if cmds.attributeQuery('shaderFilePath', n=assembly_name, exists=True):
            shader_path = cmds.getAttr('%s.shaderFilePath' % self.get_assembly_name())
            return shader_path
        else:
            return None

    def load_shading_data(self):


        active_rep = self.getActive()
        if not active_rep:
            return

        activated_representation = self.representations[active_rep]

        if self.get_assembly_type() == 'Geometry' or self.get_assembly_type() == 'Rig' :

            shader_path = cmds.getAttr('%s.shaderFilePath' % self.get_assembly_name())
            if not shader_path or not os.path.exists(shader_path):
                return

            activated_representation.apply_shaders(shader_path)

    def add_set_edit(self, msg, plug, other_plug, data):
        if not msg:
            return

        if msg & OpenMaya.MNodeMessage.kAttributeSet:

            attribute = plug.name()
            connected = cmds.listConnections(attribute, p=True, c=True, s=True)
            if connected:
                return

            data = []
            plug.getSetAttrCmds(data, OpenMaya.MPlug.kAll, True)
            setAttr_cmd = data[0].split(' ')[2:]

            raw_value = cmds.getAttr(attribute)
            setAttr_cmd[-1] = str(raw_value)
            value = ' '.join(setAttr_cmd)

            if not self.applying_edit and attribute not in self.attribute_connect_edits and self.attribute_set_edits.get(
                    attribute) != value:
                self.attribute_set_edits[attribute] = value

        if msg & OpenMaya.MNodeMessage.kConnectionMade:  # or  msg & OpenMaya.MNodeMessage.kOtherPlugSet :
            attribute = plug.name()
            self.attribute_connect_edits[attribute] = other_plug.name()

    def add_assembly_node_callbacks(self, msg, plug, other_plug, data):
        attribute_name = plug.name().split('.',1)[-1]

        if msg & OpenMaya.MNodeMessage.kAttributeSet:
            fnattr = OpenMaya.MFnAttribute(plug.attribute())
            categories = []
            fnattr.getCategories(categories)

            if 'refresh_geo' in categories:

                out = self.initialize_assembly()
                if out == -1:
                    return

                if helpers.in_io() and not self.is_loaded:
                    return

                #if Assembly._is_parent_activating(mobject):
                #    return

                self.force_refresh = True
                self.postLoad()
                self.force_refresh = False

            if 'shader' in categories:

                self.load_shading_data()





def initializePlugin(mobject):
    ''' Initialize the plug-in '''
    mplugin = OpenMayaMPx.MFnPlugin(mobject)
    try:
        mplugin.registerNode(Assembly.typename, Assembly.id, Assembly.creator, Assembly.initialize,
                             OpenMayaMPx.MPxNode.kAssembly, "drawdb/geometry/transform")  # , kPluginNodeClassify)

        cmds.assembly(edit=True, type=Assembly.typename, label=Assembly.typename)

        cmds.assembly(edit=True, type=Assembly.typename, repTypeLabelProc=Assembly.representation_label)

        cmds.assembly(edit=True, type=Assembly.typename, listRepTypesProc=Assembly.list_representation_types)

        # Register the assembly nodes to the filePathEditor
        cmds.filePathEditor(registerType=Assembly.typename, typeLabel=Assembly.typename, temporary=True)
        if not cmds.about(batch=True):
            import maya_assemblies.lib.templates.AEAssemblyTemplate as AEAssemblyTemplate

    except:
        sys.stderr.write("Failed to register node: " + Assembly.typename)
        raise

def uninitializePlugin(mobject):
    ''' Uninitializes the plug-in '''
    mplugin = OpenMayaMPx.MFnPlugin(mobject)
    try:
        mplugin.deregisterNode(Assembly.id)
        cmds.assembly(edit=True, deregister=Assembly.typename)


    except:
        sys.stderr.write("Failed to deregister node: " + Assembly.typename)
        raise

def get_assembly_node(node):
    if cmds.nodeType(node) == 'Assembly':
        return node
    parent = cmds.listRelatives(node, p=True)

    if not parent:
        return None
    parent = parent[0]
    return get_assembly_node(parent)


def get_mobject_from_name(name):
    """ get mObject from a given dag-path
    :param name : the name or dag-path to a shapenode to return a mObject to """

    sl = OpenMaya.MSelectionList()
    if not cmds.objExists(name):
        raise RuntimeError('Object does not exist: {}'.format(name))
    
    OpenMaya.MGlobal.getSelectionListByName(name, sl)
    node = OpenMaya.MObject()
    sl.getDependNode(0, node)
    return node


