import os
import importlib
from pprint import pprint
from pxr import Usd, UsdGeom, Sdf, Kind

import maya.cmds as cmds
import shotgrid_lib.database as database
import library.core.config_manager as config_manager

import usd.lib.usd_manager as usd_manager


importlib.reload(usd_manager)

#importlib.reload(database)
#importlib.reload(shotgrid_helpers)


def add_attribute(prim, attribute, value=None, attr_type=Sdf.ValueTypeNames.String):
    new_attr = prim.CreateAttribute(attribute, attr_type)
    if value is not None:
        new_attr.Set(value)
    return new_attr



def get_all_variants(node):
    all_variants = {}
    for attribute in cmds.listAttr(node):
        if attribute.endswith('_variant') or attribute.endswith('_version'):
            step_name, name = attribute.rsplit('_', 1)
            all_variants[step_name] = all_variants.get(step_name, {})
            all_variants[step_name][name] = cmds.getAttr('%s.%s' % (node, attribute), asString=True)

    return all_variants

def export_children(stage, maya_root, usd_root, sg_database, project):
    config_solver = config_manager.ConfigSolver(project=project)
    project_config = config_solver.project_config

    usd_root_path = usd_root.GetPath()
    children = cmds.listRelatives(maya_root, c=True, f=False)

    print('maya root', maya_root)
    print('usd_root', usd_root.GetPath())
    for child in children:
        print(child)
        full_path = '%s/%s' % (usd_root_path, child)
        print('path', full_path)

        if cmds.nodeType(child) == 'transform':
            child_prim = stage.DefinePrim(full_path, 'Xform')
            transformable = UsdGeom.Xformable(child_prim)
            translate = cmds.getAttr('%s.translate' % child)[0]
            rotate = cmds.getAttr('%s.rotate' % child)[0]
            scale = cmds.getAttr('%s.scale' % child)[0]
            transformable.AddTranslateOp().Set(value=(translate[0], translate[1], translate[2]))
            transformable.AddRotateXYZOp().Set(value=(rotate[0], rotate[1], rotate[2]))
            transformable.AddScaleOp().Set(value=(scale[0], scale[1], scale[2]))

            export_children(stage, child, child_prim, sg_database, project)
            model_API = Usd.ModelAPI(child_prim)
            model_API.SetKind(Kind.Tokens.group)

        elif cmds.nodeType(child) == 'SetAssembly':
            usd_path = cmds.getAttr('%s.setPath' % child)
            asset_name = cmds.getAttr('%s.assetName' % child)
            asset_type = 'Sets'
            child_prim = stage.DefinePrim(full_path, 'Xform')
            setDressing_variant = 'Master'
            setDressing_version = 'recommended'

            asset_name_attr = add_attribute(child_prim, 'asset_name', value=asset_name)
            asset_type_attr = add_attribute(child_prim, 'asset_type', value=asset_type)

            instance_number_attr = add_attribute(child_prim,
                                                 'instance_number',
                                                 value=cmds.getAttr('%s.instanceNumber' % child),
                                                 attr_type=Sdf.ValueTypeNames.Int)

            if cmds.attributeQuery('setDressing_variant', node=child, exists=True):
                setDressing_variant = cmds.getAttr('%s.setDressing_variant' % child, asString=True)
            if cmds.attributeQuery('setDressing_version', node=child, exists=True):
                setDressing_version = cmds.getAttr('%s.setDressing_version' % child, asString=True)


            variant_sets = child_prim.GetVariantSets()
            variant_set = variant_sets.AddVariantSet('setDressing_variant')
            variant_set.SetVariantSelection(setDressing_variant)
            version_set = variant_sets.AddVariantSet('setDressing_version')
            version_set.SetVariantSelection(setDressing_version)

            transformable = UsdGeom.Xformable(child_prim)
            translate = cmds.getAttr('%s.translate' % child)[0]
            rotate = cmds.getAttr('%s.rotate' % child)[0]
            scale = cmds.getAttr('%s.scale' % child)[0]
            print(translate)
            print(rotate)
            print(scale)
            transformable.AddTranslateOp().Set(value=(translate[0], translate[1], translate[2]))
            transformable.AddRotateXYZOp().Set(value=(rotate[0], rotate[1], rotate[2]))
            transformable.AddScaleOp().Set(value=(scale[0], scale[1], scale[2]))

            payloads = child_prim.GetPayloads()
            payloads.ClearPayloads()
            payloads.AddPayload(usd_path)
            model_API = Usd.ModelAPI(child_prim)
            model_API.SetKind(Kind.Tokens.assembly)


        elif cmds.nodeType(child) == 'UsdAssembly':
            usd_path = cmds.getAttr('%s.usdPath' % child)
            asset_name = cmds.getAttr('%s.assetName' % child)
            asset_type = cmds.getAttr('%s.assetType' % child)
            asset_type = asset_type.replace(' ', '_')
            child_prim = stage.DefinePrim(full_path, 'Xform')
            geometry_variant = 'Master'
            shading_variant = 'Master'
            geometry_version = 'recommended'
            shading_version = 'recommended'

            all_variants = get_all_variants(child)

            asset_name_attr = add_attribute(child_prim, 'asset_name', value=asset_name)
            asset_type_attr = add_attribute(child_prim, 'asset_type', value=asset_type)

            instance_number_attr = add_attribute(child_prim,
                                                 'instance_number',
                                                 value=cmds.getAttr('%s.instanceNumber' % child),
                                                 attr_type=Sdf.ValueTypeNames.Int)

            pprint(all_variants)
            for step_name, step_data in all_variants.items():
                variant_name = '%s_variant' % step_name
                variant_value = step_data['variant']

                version_name = '%s_version' % step_name
                version_value = step_data['version']

                variant_sets = child_prim.GetVariantSets()
                variant_set = variant_sets.AddVariantSet(variant_name)

                variant_set.SetVariantSelection(variant_value)

                version_set = variant_sets.AddVariantSet(version_name)
                version_set.SetVariantSelection(version_value)

            transformable = UsdGeom.Xformable(child_prim)
            translate = cmds.getAttr('%s.translate' % child)[0]
            rotate = cmds.getAttr('%s.rotate' % child)[0]
            scale = cmds.getAttr('%s.scale' % child)[0]
            print(translate)
            print(rotate)
            print(scale)
            transformable.AddTranslateOp().Set(value=(translate[0], translate[1], translate[2]))
            transformable.AddRotateXYZOp().Set(value=(rotate[0], rotate[1], rotate[2]))
            transformable.AddScaleOp().Set(value=(scale[0], scale[1], scale[2]))

            payloads = child_prim.GetPayloads()
            payloads.ClearPayloads()
            usd_path = project_config['paths']['usd_files']
            payload_path = '%s/assets/%s/%s/asset_assembly.usda' % (usd_path, asset_type, asset_name)

            
            print('payload path', payload_path)
            payloads.AddPayload(payload_path)

            model_API = Usd.ModelAPI(child_prim)
            model_API.SetKind(Kind.Tokens.assembly)


def export_usd(output_path='', asset_name='', root='', project='', variant='', version=1, shotgrid_id=0):
    sg_database = database.DataBase()
    sg_database.fill(project, precatch=False)

    filters = [['code', 'is', asset_name]]
    sg_database.query_sg_database('Asset', filters=filters)

    asset_view = sg_database['Asset'].find_with_filters(code=asset_name, single_item=True)
    asset_type = asset_view.sg_asset_type
    manager = usd_manager.UsdManager(project)
    manager.set_entity(asset_name, 'Asset', asset_type=asset_type)

    manager.open(output_path)

    stage = manager.stage
    asset_prim = stage.DefinePrim('/asset', 'Xform')
    print('asset_prim', asset_prim.GetPath())
    stage.SetDefaultPrim(asset_prim)

    export_children(stage, root, asset_prim, sg_database, project)

    manager.create_attribute(asset_prim, 'atlantis:asset_name', value=asset_name)
    manager.create_attribute(asset_prim, 'atlantis:asset_type', value=asset_type)
    manager.create_attribute(asset_prim, 'atlantis:set_variant', value=variant)
    manager.create_attribute(asset_prim, 'atlantis:set_version', value=version, attribute_type=int)
    manager.create_attribute(asset_prim, 'atlantis:set_shotgrid_id', value=shotgrid_id, attribute_type=int)
    manager.create_attribute(asset_prim, 'atlantis:path', value=output_path, attribute_type='path')


    manager.save_stage()


    return {'usd': output_path}



if __name__ == '__main__':
    import argparse

    print('Running export alembic')
    parser = argparse.ArgumentParser(description='Shotgun event launcher')
    parser.add_argument('-s', '-source', '--source-file', dest='source', help='Source file path')
    parser.add_argument('-o', '-output', '--output-file', dest='output', help='Output file path')
    parser.add_argument('-r', '-root', '--root-geo', dest='root', help='Root geometry')
    parser.add_argument('-a', '-asset', '--asset_name', dest='asset_name', help='Root geometry')

    args = parser.parse_args()

    import maya.standalone as standalone

    standalone.initialize(name='python')
    import maya.cmds as cmds

    cmds.loadPlugin('AbcImport')


    cmds.file(args.source, o=True)

    export_usd(args.output, args.asset_name, args.root)