import os
from pprint import pprint
import mayaUsd
from maya import cmds
from pxr import Usd, UsdGeom, UsdShade, Sdf



CONNECTION_DICTIONARY = {'inputs:base_color': 'diffuseColor',
                         'inputs:metalness': 'metallic',
                         'inputs:normal': 'normal',
                         'inputs:specular_anisotropy': '',
                         'inputs:specular_color': 'specularColor',
                         'inputs:specular_roughness': 'clearcoatRoughness',
                         'inputs:diffuse_roughness': 'roughness'}

def get_arnold_materials(stage):

    all_prim_materials = []
    for prim in stage.Traverse():
        type_name = prim.GetTypeName()
        if type_name == 'Material':
            all_prim_materials.append(prim)

    return all_prim_materials

def get_values(prim):
    all_connections = {}
    for attribute in prim.GetProperties():
        connections = attribute.GetConnections()
        if connections:
            all_connections[attribute.GetName()] = {'connection': connections, 'type': attribute.GetTypeName()}
        else:
            all_connections[attribute.GetName()] = {'value': attribute.Get(), 'type': attribute.GetTypeName()}

    return all_connections


def get_file_path(texture_path):
    print('update ', texture_path)
    texture_path = texture_path.replace('<udim>', '<UDIM>')
    valid_textures = ['exr', 'png', 'jpg']

    basename, extension = str(texture_path).rsplit('.', 1)
    if extension != 'tx':
        return texture_path

    for texture_file in os.listdir(os.path.dirname(texture_path)):
        print(texture_file)
        check_extension = texture_file.rsplit('.', 1)[-1]
        if check_extension in valid_textures:
            new_texture_path = '%s.%s' % (basename, check_extension)
            print(new_texture_path)
            return new_texture_path

    return texture_path

def create_file_shader(stage, arnod_texture, shader_path):

    name = arnod_texture.GetPrim().GetName()
    prim_type = arnod_texture.GetIdAttr().Get()
    if prim_type != 'arnold:image':
        input_attr = arnod_texture.GetPrim().GetAttribute('inputs:input')
        print(input_attr)
        connection = input_attr.GetConnections()
        if connection and len(connection) == 1:
            print(dir(connection[0]))
            prim_path = connection[0].pathString.split('.')[0]
            connected_prim_path = connection[0].GetPrimPath()
            connected_shader = UsdShade.Shader.Define(stage, connected_prim_path)

            shader, st_reader = create_file_shader(stage, connected_shader, shader_path)
            return shader, st_reader



        return None, None

    input_attribute = arnod_texture.GetInput('filename')
    value = input_attribute.Get()
    print(dir(value))
    value = get_file_path(value.path)
    #value = value.path.replace('<udim>', '1001')
    texture_path = shader_path.AppendPath(name)

    shader = UsdShade.Shader.Define(stage, texture_path)
    shader.CreateIdAttr("UsdUVTexture")
    shader.CreateInput("file", Sdf.ValueTypeNames.Asset).Set(value)
    output_attribute = shader.CreateOutput("rgb", Sdf.ValueTypeNames.Float3)
    st_attribute = shader.CreateInput("st", Sdf.ValueTypeNames.Float2)

    st_path = shader_path.AppendPath('%s_st' % name)

    st_reader = UsdShade.Shader.Define(stage, st_path)
    st_reader.CreateIdAttr("UsdPrimvarReader_float2")
    st_attribute.ConnectToSource(st_reader.ConnectableAPI(), 'result')

    return shader, st_reader

def create_preview_material(stage, arnold_material, arnold_shader, values):
    mtl_path = Sdf.Path('/asset/mtl/%s_usd_mat' % arnold_material.GetName())
    material = UsdShade.Material.Define(stage, mtl_path)
    material_st_input = material.CreateInput('frame:stPrimvarName', Sdf.ValueTypeNames.Token)
    material_st_input.Set('st')

    material_surface_attribute = material.CreateInput("surface", Sdf.ValueTypeNames.Token)

    shader_path = Sdf.Path('/asset/mtl/%s_usd_mat/%s_usd_shd' % (arnold_material.GetName(), arnold_shader.GetName()))
    shader = UsdShade.Shader.Define(stage, shader_path)
    shader.CreateIdAttr("UsdPreviewSurface")
    #shader.CreateInput("roughness", Sdf.ValueTypeNames.Float).Set(0.4)
    #shader.CreateInput("metallic", Sdf.ValueTypeNames.Float).Set(0.0)
    shader_output_attr = shader.CreateOutput("surface", Sdf.ValueTypeNames.Token)

    material_surface = material.CreateSurfaceOutput()
    material_surface.ConnectToSource(shader.ConnectableAPI(), "surface")

    pprint(values)

    for attribute, value_data in values.items():
        preview_attribute = CONNECTION_DICTIONARY.get(attribute)
        if not preview_attribute:
            continue
        if attribute != 'inputs:normal':
            new_attribute = shader.CreateInput(preview_attribute, value_data['type'])
        else:
            new_attribute = shader.CreateInput(preview_attribute, Sdf.ValueTypeNames.Normal3f)


        if 'value' in value_data:
            new_attribute.Set(value_data['value'])
        else:
            for connection in value_data['connection']:
                connected_prim_path = connection.GetPrimPath()
                connected_shader = UsdShade.Shader.Define(stage, connected_prim_path)

                usd_texture, st_reader = create_file_shader(stage, connected_shader, mtl_path)
                if not usd_texture:
                    continue

                if new_attribute.GetTypeName() == 'float':
                    new_attribute.ConnectToSource(usd_texture.ConnectableAPI(), 'r')
                else:
                    new_attribute.ConnectToSource(usd_texture.ConnectableAPI(), 'rgb')

                st_reader.CreateInput('varname', Sdf.ValueTypeNames.Token).ConnectToSource(material_st_input)


    return material, shader


def create_preview_from(stage):

    all_materials = get_arnold_materials(stage)

    material_relations = {}

    for material in all_materials:
        material_name = material.GetName()
        surface_attribute = material.GetAttribute('outputs:arnold:surface')

        connections = surface_attribute.GetConnections()
        if not connections:
            continue

        shader_prim = stage.GetPrimAtPath(connections[0].GetPrimPath())
        shader_name = shader_prim.GetName()

        values = get_values(shader_prim)
        preview_material, preview_shader = create_preview_material(stage, material, shader_prim, values)

        material_relations[material.GetPath()] = preview_material

    proxy_prim = stage.GetPrimAtPath('/asset/proxy')
    render_prim = stage.GetPrimAtPath('/asset/render')
    proxy_exists = proxy_prim.IsValid()
    if not proxy_exists:
        proxy_prim = stage.DefinePrim('/asset/proxy', 'Xform')
    print(proxy_prim, proxy_exists)
    for prim in stage.Traverse():
        path = str(prim.GetPath())
        if not path.startswith('/asset/render'):
            continue
        prim_binding = UsdShade.MaterialBindingAPI.Get(stage, prim.GetPath())

        applied_material = prim_binding.GetDirectBinding().GetMaterialPath()
        original_material = prim_binding.GetDirectBinding().GetMaterial()
        if applied_material:
            preview_material = material_relations.get(applied_material)
            if not preview_material:
                continue
            proxy_path = path.replace('/asset/render/', '/asset/proxy/')
            proxy_prim = stage.OverridePrim(proxy_path)

            if proxy_prim.IsValid() and proxy_prim.GetChildren():
                proxy_prim.ApplyAPI(UsdShade.MaterialBindingAPI)
                UsdShade.MaterialBindingAPI(proxy_prim).Bind(preview_material)
                print('proxy prim', proxy_prim)
            else:
                prim.ApplyAPI(UsdShade.MaterialBindingAPI)
                render_purpose = render_prim.GetAttribute('purpose')
                render_purpose.Set('default')
                binding = UsdShade.MaterialBindingAPI(prim)
                print(original_material)
                print(UsdShade.Tokens.full)
                print(preview_material)
                print(dir(UsdShade.Tokens))

                binding.Bind(preview_material, materialPurpose=UsdShade.Tokens.preview)
                binding.Bind(original_material, materialPurpose=UsdShade.Tokens.allPurpose)





