import os
import importlib
from pprint import pprint
import library.core.config_manager as config_manager

import maya.cmds as cmds
import usd.lib.usd_manager as usd_manager
import maya_assemblies.lib.representations.set_usd_representation as set_usd_representation
from pxr import Usd, Sdf

importlib.reload(usd_manager)

class SetUsdLoader():
    def __init__(self, usd_path):
        self.usd_path = usd_path

        stage = Usd.Stage.Open(self.usd_path)
        prim = stage.GetPrimAtPath('/asset')
        root_name = 'set'
        self.get_children(prim, maya_root=root_name)

    def get_value(self, prim, attribute):

        attr = prim.GetAttribute(attribute)
        value = attr.Get()
        return value

    def set_enum_attribute(self, node_name, attribute, value):
        if not cmds.attributeQuery(attribute, n=node_name, exists=True):
            return
        if value is None or value == 'None':
            return
        values = cmds.attributeQuery(attribute, n=node_name, listEnum=True)
        values = values[0].split(':')
        if value not in values:
            value = 'recommended'
        index = values.index(value)
        cmds.setAttr('%s.%s' % (node_name,attribute), index)

    def get_variant(self, variant_sets, variant_name):
        variant_set = variant_sets.AddVariantSet(variant_name)
        variant_value = variant_set.GetVariantSelection()
        return variant_value

    def get_children(self, prim, maya_root=''):
        parent_path = prim.GetPath()
        relative_path_bits = str(parent_path).split('/')

        relative_path = '|'.join(relative_path_bits[2:])
        maya_parent = '%s|%s' % (maya_root, relative_path)

        for child in prim.GetChildren():
            if not child.IsActive():
                continue
            child_name = child.GetName()
            properties = child.GetPropertyNames()
            payloads = child.GetMetadata("payload")
            path_value = ''
            if payloads:
                path_value = payloads.ApplyOperations([])[0].assetPath

            if 'asset_name' not in properties:
                if maya_parent != '|':
                    new_node = cmds.createNode('transform', n=child_name, p=maya_parent)
                else:
                    new_node = cmds.createNode('transform', n=child_name)
                self.get_children(child, maya_root=maya_root)

            else:

                asset_name_value = self.get_value(child, 'asset_name')
                asset_type_value = self.get_value(child, 'asset_type')
                if asset_type_value == 'Sets':
                    new_node = cmds.createNode('SetAssembly', name=child_name, p=maya_parent)
                    cmds.setAttr('%s.setPath' % new_node, path_value, type='string')

                else:
                    new_node = cmds.createNode('UsdAssembly', name=child_name, p=maya_parent)
                    cmds.setAttr('%s.usdPath' % new_node, path_value, type='string')
                    cmds.setAttr('%s.assetType' % new_node, asset_type_value, type='string')

                #assembly_type_value = self.get_value(child, 'assembly_type')

                    if not asset_type_value:
                        asset_type_value = 'SetProps'

                    variant_sets = child.GetVariantSets()

                    model_version = self.get_variant(variant_sets, 'model_version')
                    model_variant = self.get_variant(variant_sets, 'model_variant')

                    shader_variant = self.get_variant(variant_sets, 'shader_variant')
                    shader_version = self.get_variant(variant_sets, 'shader_version')


                    self.set_enum_attribute(new_node, 'model_variant', model_variant)
                    self.set_enum_attribute(new_node, 'model_version', model_version)

                    self.set_enum_attribute(new_node, 'shader_variant', shader_variant)
                    self.set_enum_attribute(new_node, 'shader_version', shader_version)

                cmds.setAttr('%s.assetName' % new_node, asset_name_value, type='string')

            if 'xformOp:translate' in properties:
                translate_attr = child.GetAttribute('xformOp:translate')
                translate_value = translate_attr.Get()
                cmds.setAttr('%s.translate' % child_name, *translate_value, type='double3')

            if 'xformOp:rotateXYZ' in properties:
                rotate_attr = child.GetAttribute('xformOp:rotateXYZ')
                rotate_value = rotate_attr.Get()

                cmds.setAttr('%s.rotate' % child_name, *rotate_value, type='double3')

            if 'xformOp:scale' in properties:
                scale_attr = child.GetAttribute('xformOp:scale')
                scale_value = scale_attr.Get()
                cmds.setAttr('%s.scale' % child_name, *scale_value, type='double3')

    def add_sub_layer(self, sub_layer_path: str, root_layer) -> Sdf.Layer:
        sub_layer = Sdf.Layer.FindOrOpen(sub_layer_path)
        if sub_layer is None:
            return
        root_layer.subLayerPaths.append(sub_layer.identifier)
        return sub_layer

def load_set(step_data=None, project='', current_step=''):


    set_usd_path = '%s/%s' % (step_data['published_folder'], step_data['files'].get('usd'))
    print(set_usd_path)
    if not os.path.exists(set_usd_path):
        return

    print('exists')
    representation = SetUsdLoader(set_usd_path)
