import os

from pprint import pprint

import mayaUsd
from maya import cmds
from pxr import Usd, UsdGeom


def load_usd_asset(asset_name, layers=None):
    cmds.file(new=True, f=True)
    shape_node = cmds.createNode('mayaUsdProxyShape', name='%sShape' % asset_name)
    parent_node = cmds.listRelatives(shape_node, p=True)
    cmds.rename(parent_node, asset_name)

    asset_path = 'V:/SGD/publish/usd/assets/%s/asset_assembly.usda' % asset_name
    shape = cmds.ls(type="mayaUsdProxyShape", long=True)[0]  # make sure to use long full paths
    stage = mayaUsd.ufe.getStage(shape)

    source_stage = Usd.Stage.Open(asset_path)
    asset_prim = source_stage.GetPrimAtPath('/Geometry')
    source_layers = []
    references = asset_prim.GetReferences()

    for source_layer in references:

        if asset_path.lower() == source_layer.identifier.lower():
            continue
        if asset_path.find('/') == -1:
            continue
        source_layers.append(source_layer.identifier)

    index = 0
    root_layer = stage.GetRootLayer()

    for source_layer in source_layers:
        basename = os.path.basename(source_layer).split('_')[0]


        cmds.mayaUsdLayerEditor(root_layer.identifier, edit=True, insertSubPath=(index, source_layer))
        #if basename not in layers:
        #    cmds.mayaUsdLayerEditor(source_layer, edit=True, muteLayer=(1, shape))


        index += 1

    asset_prim = stage.GetPrimAtPath('/%s' % asset_name)
    variant_sets = asset_prim.GetVariantSets()
    for layer, layer_data in layers.items():
        basename = os.path.basename(layer).split('_')[0]
        variant_set_name = '%s_variant' % basename.title()
        version_set_name = '%s_version' % basename.title()

        variant_set = variant_sets.GetVariantSet(variant_set_name)
        variant_set.SetVariantSelection(layer_data.get('variant', 'Master'))

        version_set = variant_sets.GetVariantSet(version_set_name)
        version_set.SetVariantSelection(layer_data.get('version', '1'))



    #for source_layer in source_stage.GetLayerStack():
    #    if source_layer.identifier.endswith('session.usda'):
    #        cmds.mayaUsdEditTarget(shape, e=True, editTarget=source_layer.identifier)


if __name__ == '__main__':
    muted_layers = ['shading', 'uvs', 'textures', 'cfx', 'cloth', 'procedural']
    load_usd_asset('cube', muted_layers=muted_layers)