import maya.cmds as cmds
import mayaUsd

from pxr import Usd

valid_attributes = ['xformOp:rotateXYZ', 'xformOp:translate', 'xformOp:scale', 'variants']

def get_set_edits(parent_prim, new_root_prim, edit_layer, new_stage, original_prim):

    for child in parent_prim.GetAllChildren():
        path = child.GetPath()
        edit_prim = edit_layer.GetPrimAtPath(path)
        if not edit_prim:
            continue

        new_prim_path = new_root_prim.GetPath().AppendPath(child.GetName())
        new_original_prim = original_prim.GetChild(child.GetName())
        new_prim = new_stage.OverridePrim(new_prim_path)

        property_names = child.GetProperties()

        variants = edit_prim.variantSelections
        if variants:
            prim_sets = new_prim.GetVariantSets()

            for variant, value in variants.items():
                version_set = prim_sets.AddVariantSet(variant)
                version_set.SetVariantSelection(value)

        if edit_prim.active != original_prim.IsActive():
            new_prim.SetActive(edit_prim.active)

        for prim_property in property_names:

            property_name = prim_property.GetName()

            if property_name not in valid_attributes:
                continue

            if not edit_prim.GetPropertyAtPath(prim_property.GetPath()):
                continue
            old_property =  original_prim.GetProperty(property_name)
            old_value = old_property.Get()

            property_type = prim_property.GetTypeName()
            property_value = prim_property.Get()

            new_property = new_prim.CreateAttribute(property_name, property_type)
            if property_value != old_value:
                new_property.Set(property_value)


        if 'asset_name' not in property_names :
            get_set_edits(child, new_prim, edit_layer, new_stage, new_original_prim)



def clean_empty_prims(stage):

    to_check_prims = []
    for prim in stage.Traverse():
        if prim.HasChild():
            to_check_prims.append(prim)



def basic_check_prim(parent_prim, edit_root_prim, edit_layer, new_stage):

    for child in parent_prim.GetChildren():
        path = child.GetPath()
        edit_prim = edit_layer.GetPrimAtPath(path)
        print(child.GetPayloads())
        set_layer = ''
        for pay in child.GetPrimStack():
            print(pay.layer)
            if pay.layer.identifier.find('.usda') > -1:
                set_layer = pay.layer.identifier
                break
        if not set_layer:
            continue
        set_stage = Usd.Stage.Open(set_layer)
        original_prim = set_stage.GetDefaultPrim()
        if edit_prim:
            get_set_edits(child, edit_root_prim, edit_layer, new_stage, original_prim)

        clean_empty_prims(new_stage)


def export_usd_edits(output_path):
    shape = cmds.ls(type="mayaUsdProxyShape", long=True)[0]  # make sure to use long full paths
    stage = mayaUsd.ufe.getStage(shape)

    for layer in stage.GetLayerStack():
        print(layer.identifier)
        print(layer.ExportToString())

    edit_layer = stage.GetLayerStack()[-1]

    new_stage = Usd.Stage.CreateNew(output_path)
    edit_root_prim = new_stage.OverridePrim('/asset')
    new_stage.SetDefaultPrim(edit_root_prim)

    default_prim = stage.GetPrimAtPath('/sets')

    basic_check_prim(default_prim, edit_root_prim, edit_layer, new_stage)

    #clean_usd_escene.clean_usd_scene(output_path, config='setedit')

    set_name = 'streetMarinette'

    print(new_stage.GetRootLayer().ExportToString())
    new_stage.GetRootLayer().Save()