import importlib
from pprint import pprint

import maya.cmds as cmds


usd_export_attr = 'USD_UserExportedAttributesJson'

attributes_prefix = ['aa_', 'mtoa_constant_', 'USD_']
def get_child_prim(prim, prim_name):
    for child in prim.GetChildren():
        if child.GetName() == prim_name:

            return child
        else:
            named_prim = get_child_prim(child, prim_name)
            if named_prim:
                return named_prim


    return None

def move_asset(manager, prim, destination_path, set_paths):
    print('=== move asset ===')
    print(prim)

    from pxr import UsdGeom

    prim_path = prim.GetPath()
    prim_name = prim.GetName()
    prim_name = prim_name.rsplit('_', 1)[0]

    destination_path = set_paths.get(prim_name, destination_path)

    geo_prim = get_child_prim(prim, 'geo')
    if not geo_prim:
        print('check in render')
        geo_prim = get_child_prim(prim, 'render')

    print('geo_prim', geo_prim)

    print('%s/%s/render' % (destination_path, prim_name))
    new_prim = manager.stage.DefinePrim('%s/%s' % (destination_path, prim_name), 'Xform')

    render_path = '%s/%s/render' % (destination_path, prim_name)

    manager.repath_prim(geo_prim.GetPath(), render_path)
    manager.set_prim_purpose(render_path, 'default')
    asset_prim = manager.stage.DefinePrim('%s/%s' % (destination_path, prim_name), 'Xform')
    asset_prim.SetActive(True)

    proxy_path = '%s/%s/proxy' % (destination_path, prim_name)
    proxy_prim = manager.stage.DefinePrim(proxy_path, 'Xform')
    proxy_prim.SetActive(False)
    if prim_name in set_paths:
        print('reset transforms', prim_name)
        transformable = UsdGeom.Xformable(new_prim)
        transformable. SetResetXformStack(True)

        #transformable.AddTranslateOp().Set(value=(0.0, 0.0, 0.0))
        #transformable.AddRotateXYZOp().Set(value=(0.0, 0.0, 0.0))
        #transformable.AddScaleOp().Set(value=(1.0, 1.0, 1.0))


def move_prims(manager, root_prim, destination_path, set_paths):
    root_name = root_prim.GetName()
    manager.stage.DefinePrim('/World/%s' % root_name, 'Xform')

    for child in root_prim.GetChildren():

        child_name = child.GetName()
        print('Child', child_name)
        print(destination_path)
        move_asset(manager, child, destination_path, set_paths)



def add_node_attributes(node_name):
    attributes = cmds.listAttr(node_name)
    attributes_to_export = []

    for attr in attributes:
        for prefix in attributes_prefix:
            if attr.startswith(prefix):
                attributes_to_export.append(attr)

    if attributes_to_export:
        if not cmds.attributeQuery(usd_export_attr, n=node_name, exists=True):
            cmds.addAttr(node_name, ln=usd_export_attr, dt='string')

        json_dict = {}
        for attr in attributes_to_export:
            json_dict[attr] = {}

        json_str = str(json_dict).replace("'", '"')
        cmds.setAttr('%s.%s' % (node_name, usd_export_attr), json_str, type='string')


def add_custom_attributes():
    selected = cmds.ls(sl=True, l=True)
    for node in selected:
        add_node_attributes(node)
        children = cmds.listRelatives(node, ad=True)
        if children:
            for child_node in children:
                add_node_attributes(child_node)
                
    cmds.select(selected, r=True)

def export_usd_cache(project, shot_name, start_frame, end_frame,  output_path, set_paths):
    import usd.lib.usd_manager as usd_manager
    importlib.reload(usd_manager)
    #start_frame = start_frame + 20
    #end_frame = start_frame + 1

    cmds.select('*:Rig_Geo', r=True)
    cmds.select('*:Rig_Locators', add=True)

    add_custom_attributes()

    options = ';exportUVs=0;exportSkels=none;exportSkin=none;exportBlendShapes=0;exportDisplayColor=0;;exportColorSets=0;exportComponentTags=0;defaultMeshScheme=catmullClark;animation=1;eulerFilter=0;staticSingleSample=0;startTime=%s;endTime=%s;frameStride=1;frameSample=0.0;defaultUSDFormat=usdc;parentScope=;shadingMode=useRegistry;convertMaterialsTo=[];exportRelativeTextures=automatic;exportInstances=1;exportVisibility=1;mergeTransformAndShape=1;stripNamespaces=1;worldspace=0'
    options = options % (start_frame, end_frame)
    print('Frame rate: %s' % cmds.currentUnit(t=True,q=True))
    print('start frame: %s' % start_frame)
    print('end frame: %s' % end_frame)
    cmds.playbackOptions(ast=start_frame, e=True)
    cmds.playbackOptions(aet=end_frame, e=True)
    cmds.playbackOptions(minTime=start_frame, e=True)
    cmds.playbackOptions(maxTime=end_frame, e=True)

    cmds.file(output_path, options=options, typ='USD Export', es=True, pr=True, f=True)


    manager = usd_manager.UsdManager(project)
    manager.set_entity(shot_name, 'Shot')
    manager.open(output_path)
    manager.remove_attributes('animation')
    root_prim = manager.stage.DefinePrim('/World', 'Xform')
    manager.default_prim = '/World'
    manager.stage.SetMetadata('timeCodesPerSecond', 24)
    character_prim = manager.stage.GetPrimAtPath('/Characters')
    if character_prim.IsValid():
        move_prims(manager, character_prim, '/World/Characters', {})
        manager.stage.RemovePrim(character_prim.GetPath())

    props_prim = manager.stage.GetPrimAtPath('/Props')
    if props_prim.IsValid():
        move_prims(manager, props_prim, '/World/Props', set_paths)
        manager.stage.RemovePrim(props_prim.GetPath())

    elements_prim = manager.stage.GetPrimAtPath('/Shot_elements')
    if elements_prim.IsValid():
        manager.stage.RemovePrim(elements_prim.GetPath())




    manager.save_stage()
    print('Saved: %s ' % output_path)

if __name__ == '__main__':
    import argparse

    print('Running export usd')
    parser = argparse.ArgumentParser(description='Shotgun event launcher')
    parser.add_argument('-s', '-source', '--source-file', dest='source', help='Source file path')
    parser.add_argument('-of', '-output', '--output-file', dest='output_path', help='Output file path')

    parser.add_argument('-sf', '-start', '--start-frame', dest='start_frame', help='First frame to cache')
    parser.add_argument('-ef', '-end', '--end-frame', dest='end_frame', help='Last frame to cache')

#    parser.add_argument('-i', '-id', '--shotgrid-id', dest='shotgrid_id', help='Current publish shotgrid ID', type=int)
#    parser.add_argument('-m', '-version', '--animation-version', dest='animation_version', help='Asset model version', type=int)
#    parser.add_argument('-ha', '-hash', '--publish-hash', dest='publish_hash', help='Publish hash')
    parser.add_argument('-p', '-project', '--project', dest='project', help='Project code')

    args = parser.parse_args()

    import maya.standalone as standalone
    standalone.initialize(name='python')
    print('loading: %s' % args.source)

    cmds.file(args.source, o=True)
    print('loaded..')

    export_usd_cache(args.project, args.shot_name, args.start_frame, args.end_frame, args.output_path)

