import os
import importlib

import maya.cmds as cmds
import maya.OpenMaya as OpenMaya

from pxr import Usd, UsdGeom, Sdf


import maya_assemblies.lib.representations.representation as representation
import maya_assemblies.lib.shader_helpers as shader_helpers

importlib.reload(shader_helpers)


class SetUsdRepresentation(representation.Representation):
    typename = "set"

    def get_prim_vars(self):
        self.prim_var = []
        self.stage = Usd.Stage.Open(self.path)
        asset_prim = self.stage.GetPrimAtPath('/asset')
        if not asset_prim.IsValid():
            asset_prim = self.stage.GetPrimAtPath('/World')
        self.stage.SetDefaultPrim(asset_prim)
        variant_sets = asset_prim.GetVariantSets()
        v_sets_names = variant_sets.GetNames()

        for variant_set_name in sorted(v_sets_names):
            variant_set = variant_sets.GetVariantSet(variant_set_name)
            values = variant_set.GetVariantNames()
            if not values:
                continue

            if cmds.attributeQuery(variant_set_name, n=self.get_assembly_name(), exists=True):
                value = cmds.getAttr('%s.%s' % (self.get_assembly_name(), variant_set_name), asString=True)
                cmds.addAttr('%s.%s' % (self.get_assembly_name(), variant_set_name),
                             en=':'.join(values),
                             e=True)

                if value not in values:
                    value = 'recommended'

                index = values.index(value)
                cmds.setAttr('%s.%s' % (self.get_assembly_name(), variant_set_name), index)

                variant_set.SetVariantSelection(value)
            else:
                value = variant_set.GetVariantSelection()
                if value not in values:
                    value = values[-1]
                index = values.index(value)
                cmds.addAttr(self.get_assembly_name(),
                             ln=variant_set_name,
                             at='enum',
                             en=':'.join(values),
                             category='refresh_geo', defaultValue=index)

            self.prim_var.append(['/asset', variant_set_name, value])

    def get_prim(self):
        prim_path = '/asset'

        prim = self.stage.GetPrimAtPath(prim_path)
        if not prim.IsValid():
            prim_path = '/World/Camera'

            prim = self.stage.GetPrimAtPath(prim_path)

        if not prim.IsValid():
            return None
        children = prim.GetChildren()
        if len(children) == 1:
            return children[0]

        return prim

    def activate(self):
        self.assembly.is_activating = True
        parent_assembly = self.get_assembly_name()
        self.stage = Usd.Stage.Open(self.path)

        namespace = self.get_assembly_namespace()
        self.loaded_nodes = cmds.namespaceInfo(namespace, ls=True, dagPath=True)
        if self.loaded_nodes:
            cmds.delete(self.loaded_nodes)

        prim = self.get_prim()
        prim_path = (str(prim.GetPath()))
        self.get_prim_vars()
        print('load usd')
        print(self.path)
        print(prim_path)
        print(self.prim_var)


        try:
            cmds.mayaUSDImport(file=self.path,
                               primVariant=self.prim_var,
                               primPath=prim_path,
                               parent=parent_assembly,
                               readAnimData=False)

        except RuntimeError:
            return False

        self.loaded_nodes = cmds.namespaceInfo(namespace, ls=True, dagPath=True)

        for node in self.loaded_nodes:
            if cmds.nodeType(node) == 'transform':
                cmds.setAttr('%s.translate' % node, lock=True, e=True)
                cmds.setAttr('%s.rotate' % node, lock=True, e=True)
                cmds.setAttr('%s.scale' % node, lock=True, e=True)

        self.assembly.is_activating = False

        return True

    def can_apply_edits(self):
        return True

