import os


import usd.lib.usd_manager as usd_manager
from pxr import Sdf



def copy_children(*args):
    return False

def copy_value(*args):
    return True

def copy_prim_to_output(master_prim, output_usd, copy_child=False, variant=None):
    source_stage = master_prim.GetStage()
    source_layer = source_stage.GetRootLayer()

    target_layer = output_usd.root_layer

    if variant:
        target_path = '%s:%s' % (master_prim.GetPath(), variant)
    else:
        target_path = master_prim.GetPath()

    if copy_child:
        Sdf.CopySpec(source_layer, master_prim.GetPath(), target_layer, target_path)
    else:
        Sdf.CopySpec(source_layer, master_prim.GetPath(), target_layer, target_path, shouldCopyValueFn=copy_value, shouldCopyChildrenFn=copy_children)

    new_output_prim = output_usd.stage.GetPrimAtPath(master_prim.GetPath())
    output_usd.save_stage()
    return new_output_prim


def create_variant(parent, master_prim, master_variant, output_usd):


    new_prim = copy_prim_to_output(parent, output_usd)
    all_variant_sets = new_prim.GetVariantSets()
    variant_set = all_variant_sets.AddVariantSet('model_variant')
    variant_set.AddVariant(master_variant)
    variant_set.SetVariantSelection(master_variant)

    if not master_prim:
        return variant_set.GetVariantNames()

    with variant_set.GetVariantEditContext():
        for child in master_prim.GetChildren():
            source_stage = child.GetStage()
            source_layer = source_stage.GetRootLayer()
            target_layer = output_usd.root_layer

            parent_path = new_prim.GetPath().AppendVariantSelection('model_variant', master_variant)
            target_path = parent_path.AppendChild(child.GetName())

            Sdf.CopySpec(source_layer,
                         child.GetPath(),
                         target_layer,
                         target_path)

    return variant_set.GetVariantNames()

def duplicate_root_prims(master_prim, output_usd, master_variant):

    if not master_prim.IsValid():
        return
    master_prim_children = master_prim.GetChildren()
    master_prim_children_names = [p.GetName() for p in master_prim_children]

    output_prim = output_usd.stage.GetPrimAtPath(master_prim.GetPath())

    all_variants = []
    if output_prim.IsValid():
        output_prim_children = output_prim.GetChildren()
        for child_prim in output_prim_children:

            all_variant_sets = child_prim.GetVariantSets()
            variant_set = all_variant_sets.AddVariantSet('model_variant')
            all_variants += variant_set.GetVariantNames()
            if child_prim.GetName() not in master_prim_children_names:
                create_variant(child_prim, None, 'empty', output_usd)

    all_variants = set(all_variants)
    for prim in master_prim_children:
        create_variant(prim, prim, master_variant, output_usd)
        prim_path = prim.GetPath()
        base_prim = output_usd.stage.GetPrimAtPath(prim_path)

        all_variant_sets = base_prim.GetVariantSets()
        variant_set = all_variant_sets.AddVariantSet('model_variant')
        all_variants = variant_set.GetVariantNames()
        if len(all_variants) <= 1:
            create_variant(prim, None, 'empty', output_usd)



def merge_layers(master_layer, variant_master, output_path, project):

    master_usd = usd_manager.UsdManager(project=project)
    master_usd.open(master_layer)
    output_usd = usd_manager.UsdManager(project=project)
    output_usd.open(output_path)

    asset_prim = master_usd.stage.GetPrimAtPath('/asset')
    render_default_prim = master_usd.stage.GetPrimAtPath('/asset/render')

    master_default_prim = master_usd.stage.GetPrimAtPath('/asset/render/x_geo_001_grp')

    asset_output_prim = output_usd.stage.GetPrimAtPath('/asset')

    if not asset_output_prim.IsValid():
        copy_prim_to_output(asset_prim, output_usd)
    render_output_prim = output_usd.stage.GetPrimAtPath('/asset/render')

    if not render_output_prim.IsValid():
        copy_prim_to_output(render_default_prim, output_usd)

    geo_output_prim = output_usd.stage.GetPrimAtPath('/asset/render/x_geo_001_grp')
    if not geo_output_prim.IsValid():
        copy_prim_to_output(master_default_prim, output_usd)

    duplicate_root_prims(master_default_prim, output_usd, variant_master)

    output_usd.save_stage()





def test_merge_layers():
    master_layer = r'\\Project\SGD\TPT\publish\model\Main_Characters\CharacterTara\GirlA\CharacterTara_GirlA_c90403cf2bbb35fd\v004\usd\CharacterTara_geometry.usd'
    added_layer = r'\\Project\SGD\TPT\publish\model\Main_Characters\CharacterTara\tara05\CharacterTara_tara05_6e865012637e08a8\v007\usd\CharacterTara_geometry.usd'
    other_layer = r'\\Project\SGD\TPT\publish\model\Main_Characters\CharacterTara\Tara07\CharacterTara_Tara07_0602387e62b96f42\v003\usd\CharacterTara_geometry.usd'

    angelica_layer = r'\\Project\SGD\TPT\publish\model\Main_Characters\Angelica\VarA\Angelica_VarA_ebd99b82113b22a3\v003\usd\Angelica_geometry.usd'
    angelica_05_layer = r'\\Project\SGD\TPT\publish\model\Main_Characters\Angelica\VarE\Angelica_VarE_2d9363374d060859\v003\usd/Angelica_geometry.usd'

    emp_06_layer = r'\\Project\SGD\TPT\publish\model\Main_Characters\Angelica\VarG\Angelica_VarG_ecc64dd4cac40484\v001\usd/Angelica_geometry.usd'


    output_path = r'C:\projects\TPT\assets\Main_Characters\CharacterTara\fernando.vizoso\maya\data\CharacterTara_geometry.usda'

    if os.path.exists(output_path):
        os.remove(output_path)
    merge_layers(master_layer, 'tara', output_path, project='TPT')
    merge_layers(added_layer, 'ball_dress', output_path, project='TPT')

    merge_layers(other_layer, 'flex', output_path, project='TPT')
    merge_layers(angelica_layer, 'angelica', output_path, project='TPT')
    merge_layers(angelica_05_layer, 'game', output_path, project='TPT')
    merge_layers(emp_06_layer, 'lisbeth', output_path, project='TPT')


if __name__ ==  '__main__':
    test_merge_layers()