import importlib
import os
from pprint import pprint
import maya.cmds as cmds
import maya.mel as mel

from pxr import Usd, UsdGeom, UsdShade, Sdf

import usd.lib.usd_manager as usd_manager

def get_shaders(root='|asset|render'):
    all_shaders = {}
    for node in cmds.listRelatives(root, ad=True, f=True):
        if cmds.nodeType(node) != 'mesh':
            continue

        connections = cmds.listConnections(node)
        for conn in connections:
            if cmds.nodeType(conn) == 'shadingEngine':
                all_shaders[conn] = all_shaders.get(conn, [])
                all_shaders[conn].append(node)

                arnold_shader = cmds.listConnections('%s.aiSurfaceShader' % conn)
                print(arnold_shader)
                if arnold_shader:
                    cmds.connectAttr('%s.outColor' % arnold_shader[0], '%s.surfaceShader' % conn, f=True)
                    cmds.disconnectAttr('%s.outColor' % arnold_shader[0], '%s.aiSurfaceShader' % conn)

    return all_shaders


def create_shader(mesh_list, texture_path, shader_name, connected_shader, base_color):


    sss_path = texture_path.replace('.albedo.', '.sss_albedo.')
    if cmds.attributeQuery('transmission', n=connected_shader[0], exists=True):
        transparency = cmds.getAttr('%s.transmission' % connected_shader[0])
    else:
        transparency = 0.0

    shading_group = cmds.shadingNode('shadingEngine', asShader=True, name='%s_usd_SG' % shader_name)

    usd_shader_node = cmds.shadingNode('usdPreviewSurface', n='%s_usd_MAT' % shader_name, asShader=True)
    cmds.setAttr('%s.useSpecularWorkflow' % usd_shader_node, 1)

    cmds.connectAttr('%s.outColor' % usd_shader_node, '%s.surfaceShader' % shading_group)

    if not base_color:
        texture_file = cmds.shadingNode('file', asTexture=True, isColorManaged=True)
        cmds.connectAttr('%s.outColor' % texture_file, '%s.diffuseColor' % usd_shader_node, f=True)

        cmds.setAttr('%s.fileTextureName' % texture_file, texture_path, type='string')
        cmds.setAttr('%s.uvTilingMode' % texture_file, 3)
    else:
        cmds.setAttr('%s.diffuseColor' % usd_shader_node, *base_color[0], type='double3')

    cmds.setAttr('%s.opacity' % usd_shader_node, 1.0 - transparency)
    metalness = 0.0
    if cmds.attributeQuery('metalness', n=connected_shader[0], exists=True):
        metalness = cmds.getAttr('%s.metalness' % connected_shader[0])

        metalness_connected = cmds.listConnections('%s. metalness' % connected_shader[0])
        if metalness_connected:
            metalness_connected = metalness_connected[0]
            if cmds.nodeType(metalness_connected) == 'file' or cmds.nodeType(metalness_connected) == 'aiImage':
                cmds.connectAttr('%s.outAlpha' % metalness_connected, '%s.metallic' % usd_shader_node, f=True)
                cmds.setAttr('%s.alphaIsLuminance' % metalness_connected, 1)
        else:
            if metalness > 0.1:
                cmds.setAttr('%s.useSpecularWorkflow' % usd_shader_node, 0)

            cmds.setAttr('%s.metallic' % usd_shader_node, metalness)
    else:
        cmds.setAttr('%s.metallic' % usd_shader_node, metalness)

    specularRoughness = 0.5

    if cmds.attributeQuery('specularRoughness', n=connected_shader[0], exists=True):
        specularRoughness = cmds.getAttr('%s.specularRoughness' % connected_shader[0])

        if specularRoughness < 0.001:
            specularRoughness = 0.001

        rough_connected = cmds.listConnections('%s. specularRoughness' %connected_shader[0])
        if rough_connected:
            rough_connected = rough_connected[0]
            if cmds.nodeType(rough_connected) == 'file' or cmds.nodeType(rough_connected) == 'aiImage':
                cmds.connectAttr('%s.outAlpha' % rough_connected, '%s.roughness' % usd_shader_node, f=True)
                cmds.setAttr('%s.alphaIsLuminance' % rough_connected, 1)

        else:
            cmds.setAttr('%s.roughness' % usd_shader_node, specularRoughness)
    else:
        cmds.setAttr('%s.roughness' % usd_shader_node, specularRoughness)


    cmds.connectAttr('%s.pa' % shading_group, ':renderPartition.st', f=True, na=True)

    #    cmds.connectAttr('%s.instObjGroups' % mesh, '%s.dagSetMembers' % shading_group, f=True, na=True)

    # cmds.disconnectAttr('%s.instObjGroups' % mesh, '%s.dagSetMembers' % shader_name)
    for mesh in mesh_list:
        mesh = mesh.replace('|render|', '|proxy|')
        if not cmds.objExists(mesh):
            continue
        cmds.lockNode(mesh, lock=False)

        cmds.sets(mesh, e=True, forceElement=shading_group)


def bake_textures(outFolder, resolution, bake=False):
    all_shaders = get_shaders()
    merge_node = ''
    generated = False
    for shader_name, shaders_geo in all_shaders.items():
        if len(shaders_geo) > 1:
            duplicated_meshes = cmds.duplicate(shaders_geo, name='|duplicated')
            print('duplicated', duplicated_meshes)
            new_mesh = cmds.polyUnite(duplicated_meshes, ch=False, mergeUVSets=2)
            print('new mesh', new_mesh)
            generated = True
        else:
            node = cmds.duplicate(shaders_geo[0])
            new_mesh = cmds.parent(node, w=True)
        parent_mesh = new_mesh[0]
        child = cmds.listRelatives(new_mesh[0], c=True, f=True)
        new_mesh = child

        cmds.sets(new_mesh, e=True, forceElement=shader_name)
        connected_shader = cmds.listConnections('%s.surfaceShader' % shader_name)
        texture_path = '%s/%s_%s_<UDIM>.albedo.exr' % (outFolder, connected_shader[0], new_mesh[0].split('|')[-1])

        for mesh in shaders_geo:
            cmds.sets(mesh, e=True, forceElement='initialShadingGroup')

        connected_texture = cmds.listConnections('%s.baseColor' % connected_shader[0])
        base_color = None
        if not connected_texture:
            base_color = cmds.getAttr('%s.baseColor' % connected_shader[0])

        if bake:
            cmds.arnoldRenderToTexture(folder=outFolder,
                                       shader=shader_name,
                                       resolution=resolution,
                                       all_udims=True,
                                       aa_samples=1,
                                       enable_aovs=True,
                                       extend_edges=True)


        create_shader(shaders_geo, texture_path, shader_name, connected_shader, base_color)

        cmds.delete(parent_mesh)
        for mesh in shaders_geo:
            cmds.sets(mesh, e=True, forceElement=shader_name)

    try:
        mel.eval('generateAllUvTilePreviews')
    except:
        pass



def create_aov(aov_name):
    new_aov = cmds.createNode('aiAOV', name=aov_name)

    cmds.connectAttr('defaultArnoldFilter.message', '%s.outputs[0].filter' % new_aov, f=True)
    cmds.connectAttr('defaultArnoldDriver.message', '%s.outputs[0].driver' % new_aov, f=True)
    cmds.connectAttr('%s.message' % new_aov, 'defaultArnoldRenderOptions.aovList', na=True, f=True)
    cmds.setAttr('%s.name' % new_aov, aov_name, type='string')

    return new_aov
def generate_proxy_shaders(maya_scene,
                           images_path,
                           usd_path,
                           asset_name,
                           asset_type,
                           project='TPT',
                           resolution=512,
                           bake=False):

    if not os.path.exists(images_path):
        os.makedirs(images_path)
    usd_folder = os.path.dirname(usd_path)
    if not os.path.exists(usd_folder):
        os.makedirs(usd_folder)


    project = 'TPT'
    asset_name = 'LeoHero'
    asset_type = 'Main_Characters'

    cmds.file(maya_scene, o=True, f=True)
    all_aov = cmds.ls(type='aiAOV')
    cmds.delete(all_aov)
    albedo_aov = create_aov('albedo')
    sss_aov = create_aov('specular_albedo')


    all_usd_nodes = cmds.ls(type='mayaUsdProxyShape')
    cmds.delete(all_usd_nodes)

    cmds.setAttr('asset.visibility', 0)

    cmds.setAttr('defaultArnoldRenderOptions.GIDiffuseSamples', 1)
    cmds.setAttr('defaultArnoldRenderOptions.GISpecularSamples', 1)
    cmds.setAttr('defaultArnoldRenderOptions.GITransmissionSamples', 1)
    cmds.setAttr('defaultArnoldRenderOptions.GISssSamples', 1)

    bake_textures(images_path, resolution, bake=bake)
    cmds.setAttr('|asset|proxy.visibility', 1)
    cmds.setAttr('|asset.visibility', 1)
    cmds.select('|asset', r=True)

    options = 'exportUVs=1;exportSkels=none;exportSkin=none;exportBlendShapes=0;exportDisplayColor=0;;exportColorSets=1;exportComponentTags=1;defaultMeshScheme=catmullClark;animation=0;eulerFilter=0;staticSingleSample=0;startTime=1;endTime=500;frameStride=1;frameSample=0.0;defaultUSDFormat=usda;parentScope=;shadingMode=useRegistry;convertMaterialsTo=[UsdPreviewSurface];exportRelativeTextures=automatic;exportInstances=1;exportVisibility=1;mergeTransformAndShape=1;stripNamespaces=1;worldspace=0;jobContext=[Arnold]'
    cmds.file(usd_path, options=options, type='USD Export', es=True, force=True, pr=True)

    #cmds.arnoldExportAss(f=usd_path, mask=6392, s=True, lightLinks=False, shadowLinks=False,  exportAllShadingGroups=True)
    manager = usd_manager.UsdManager(project)
    manager.set_entity(asset_name, 'Asset', asset_type=asset_type)
    manager.open(usd_path)

    for prim in manager.stage.Traverse():
        if prim.GetTypeName() != 'Mesh':
            continue
        value = manager.get_attribute(prim, 'primvars:arnold:subdiv_type')
        if value is None:
            manager.create_attribute(prim, 'primvars:arnold:subdiv_type', attribute_type='token', value='none')
            manager.set_attribute(prim, 'primvars:arnold:subdiv_iterations', 0)

    manager.remove_attributes('shading', flatten_meshes=False)

    for prim in manager.stage.Traverse():
        path = str(prim.GetPath())
        if not path.startswith('/asset/render'):
            continue
        prim_binding = UsdShade.MaterialBindingAPI.Get(manager.stage, prim.GetPath())

        applied_material = prim_binding.GetDirectBinding().GetMaterialPath()
        original_material = prim_binding.GetDirectBinding().GetMaterial()
        if applied_material:
            preview_material = '%s_usd_SG' % applied_material
            preview_prim = manager.stage.GetPrimAtPath(preview_material)
            preview_material = UsdShade.Material.Define(manager.stage, preview_material)
            binding = UsdShade.MaterialBindingAPI(prim)
            binding.Bind(preview_material, materialPurpose=UsdShade.Tokens.preview)
            binding.Bind(original_material, materialPurpose=UsdShade.Tokens.allPurpose)


    asset_prim = manager.stage.GetPrimAtPath('/asset')
    """
    manager.create_attribute(asset_prim, 'atlantis:asset_name', value=asset_name)
    manager.create_attribute(asset_prim, 'atlantis:asset_type', value=asset_type)
    manager.create_attribute(asset_prim, 'atlantis:shading_variant', value=shading_variant)
    manager.create_attribute(asset_prim, 'atlantis:shading_version', value=shading_version, attribute_type=int)
    manager.create_attribute(asset_prim, 'atlantis:shading_shotgrid_id', value=shading_shotgrid_id, attribute_type=int)
    manager.create_attribute(asset_prim, 'atlantis:shading_maya_path', value=maya_path, attribute_type='path')

    manager.create_attribute(asset_prim, 'atlantis:path', value=output_path, attribute_type='path')
    """
    # new_path = output_path.replace('.usda', '_changed.usda')
    # manager.save_as(new_path)
    manager.save_stage()

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='Export shaders to usd')
    parser.add_argument('-s', '-source', '--source-file', dest='source', help='Source file path')
    parser.add_argument('-r', '-root', '--root-geo', dest='root', help='Root geometry')
    parser.add_argument('-u', '-usd', '--usd-file', dest='usd_file', help='Output file path')
    parser.add_argument('-i', '-images', '--images-folder', dest='images_folder', help='Output file path')
    parser.add_argument('-an', '-asset', '--asset-name', dest='asset_name', help='Root geometry')
    parser.add_argument('-at', '-type', '--asset-type', dest='asset_type', help='Root geometry')
    parser.add_argument('-p', '-project', '--project', dest='project', help='Project code')
    args = parser.parse_args()
    print('ENVIRONMENT VARIABLES')
    for env, value in os.environ.items():
        print('env var %s = %s' % (env, value))

    import maya.standalone as standalone
    standalone.initialize(name='python')

    generate_proxy_shaders(args.source,
                           args.images_folder,
                           args.usd_file,
                           args.asset_name,
                           args.asset_type,
                           project=args.project,
                           resolution=512,
                           bake=True)


