
import importlib
from pprint import pprint

from pxr import Usd, Sdf, Plug, Tf, UsdRender, Gf, UsdShade

import library.core.config_manager as config_manager



class RenderLayerBuilder():

    def __init__(self,
                 project,
                 source_scene,
                 output_render_layer,
                 images_folder,
                 render_type,
                 start_frame=1001,
                 end_frame=1010
                 ):
        self.source_scene = source_scene
        self.output_render_layer = output_render_layer
        self.images_folder = images_folder
        self.project = project
        self.render_type = render_type
        self.start_frame = int(start_frame)
        self.end_frame = int(end_frame)

        self.config_solver = config_manager.ConfigSolver(project=self.project)
        self.project_config = self.config_solver.get_config('project')
        config_file_name = '%s_config' % self.render_type
        print(config_file_name)
        self.layers_config = self.config_solver.get_config(config_file_name, module='usd')
        pprint(self.layers_config)
        print('--')

    def create_collection(self, collection_name, render_prim, include_paths=None, exclude_path=None):

        collection_api = Usd.CollectionAPI.Apply(render_prim, collection_name)
        if include_paths:
            for inc_path in include_paths:
                collection_api.GetIncludesRel().AddTarget(inc_path)
        if exclude_path:
            for exc_path in exclude_path:
                collection_api.GetExcludesRel().AddTarget(exc_path)

        collection_api.GetExpansionRuleAttr().Set(Usd.Tokens.expandPrims)

        return collection_api

    def create_mate_shader(self, collection):
        asset_prim = self.target_stage.GetPrimAtPath('/World')
        materials_root_prim = self.target_stage.DefinePrim('/mtl', 'Scope')

        # material_prim = stage.DefinePrim('/mtl/mate', 'Prim')
        material = UsdShade.Material.Define(self.target_stage, '/mtl/matte')

        shader = UsdShade.Shader.Define(self.target_stage, '/mtl/matte/matteShader')

        target_attr = material.GetPrim().CreateAttribute('outputs:arnold:surface', Sdf.ValueTypeNames.Token)

        material_input = UsdShade.Output(target_attr)
        material_input.ConnectToSource(shader.ConnectableAPI(), "surface")
        print(self.target_stage.GetRootLayer().ExportToString())
        mat_bind_api = UsdShade.MaterialBindingAPI.Apply(asset_prim)
        print(collection)
        if isinstance(collection, list):
            for coll in collection:
                collection_item = self.collections[coll]
                mat_bind_api.Bind(collection_item, material, 'matte', bindingStrength=UsdShade.Tokens.strongerThanDescendants)
        else:
            collection_item = self.collections[collection]
            mat_bind_api.Bind(collection_item, material, 'matte', bindingStrength=UsdShade.Tokens.strongerThanDescendants)

        shader_prim = shader.GetPrim()

        id_arnold_attr = shader_prim.CreateAttribute('info:id', Sdf.ValueTypeNames.Token)
        id_arnold_attr.Set('arnold:matte')

        color_arnold_attr = shader_prim.CreateAttribute('inputs:color', Sdf.ValueTypeNames.Color4f)
        color_arnold_attr.Set(value=(0, 0, 0, 1))

        alpha_arnold_attr = shader_prim.CreateAttribute('inputs:opacity', Sdf.ValueTypeNames.Color3f)
        alpha_arnold_attr.Set(value=(0, 0, 0))


    def create_product(self, product_name, aov):
        product_path = '/Render/%s' % product_name
        product = UsdRender.Product.Define(self.target_stage, product_path)
        product_prim = self.target_stage.GetPrimAtPath(product_path)

        product_path = product.CreateProductNameAttr('%s/%s.0001.exr' % (self.images_folder, product_name))

        for frame in range(self.start_frame, self.end_frame + 1):
            frame_str = '%04d' % frame
            new_path = '%s/%s.%s.exr' % (self.images_folder, product_name, frame_str)
            # new_path = Sdf.AssetPath(new_path)
            product_path.Set(time=frame, value=new_path)


        camera_path = '/camera/camera1/cameraShape1'


        relationship = product_prim.CreateRelationship('camera')
        relationship.SetTargets([camera_path])

        aov_root_path = '/Render/Vars'
        aov_prim = self.target_stage.DefinePrim(aov_root_path, 'Scope')
        aov_path_list = []
        for aov_name, aov_expression in aov.items():
            aov_path = '%s/%s' % (aov_root_path, aov_name)
            aov_path_list.append(aov_path)
            var = UsdRender.Var.Define(self.target_stage, aov_path)
            var.CreateSourceNameAttr(aov_expression)
            var.CreateDataTypeAttr('color')

        vars_relationship = product_prim.CreateRelationship('orderedVars')
        vars_relationship.SetTargets(aov_path_list)
        return product_prim

    def create_render_settings(self, settings_name, product):
        render_settings_path = '/Render/%s' % settings_name
        if not isinstance(product, list):
            product = [product]

        render_settings_prim = self.target_stage.DefinePrim(render_settings_path, 'ArnoldOptions')
        products_rel = render_settings_prim.CreateRelationship('products')
        for prod in product:
            products_rel.AddTarget(prod.GetPath())

        render_settings_prim.CreateAttribute('arnold:AA_samples', Sdf.ValueTypeNames.Int).Set(value=5)
        #render_settings_prim.CreateAttribute('arnold:xres', Sdf.ValueTypeNames.Int).Set(value=512)
        #render_settings_prim.CreateAttribute('arnold:yres', Sdf.ValueTypeNames.Int).Set(value=512)

        return render_settings_prim


    def create_usd_render_settings(self, settings_name, product):
        render_settings_path = '/Render/%s' % settings_name
        if not isinstance(product, list):
            product = [product]

        render_settings = UsdRender.Settings.Define(self.target_stage, render_settings_path)
        render_settings.CreateRenderingColorSpaceAttr(self.layers_config['settings'].get('renderingColorSpace','raw'))

        resolution = Gf.Vec2i(self.layers_config['settings'].get('resolution', (768,576)))
        render_settings.CreateResolutionAttr(resolution)



        render_settings_prim = self.target_stage.GetPrimAtPath(render_settings_path)
        AA_samples = self.layers_config['settings'].get('AASamples', 4)
        render_settings_prim.CreateAttribute('arnold:AA_samples', Sdf.ValueTypeNames.Int).Set(value=AA_samples)

        products_rel = render_settings_prim.CreateRelationship('products')
        for prod in product:
            products_rel.AddTarget(prod.GetPath())



        return render_settings


    def create_render_layers(self):
        self.target_stage = Usd.Stage.CreateNew(self.output_render_layer)
        root_layer = self.target_stage.GetRootLayer()
        root_layer.subLayerPaths.append(self.source_scene)

        render_prim = self.target_stage.DefinePrim('/Render', 'Scope')

        collections_prim = self.target_stage.GetPrimAtPath('/World')
        self.collections = {}
        for collection_name, collection_settings in self.layers_config.get('collections', {}).items():
            included_paths = collection_settings.get('include', [])
            excluded_paths = collection_settings.get('exclude', [])

            new_collection = self.create_collection(collection_name,
                                                    collections_prim,
                                                    include_paths=included_paths,
                                                    exclude_path=excluded_paths)
            self.collections[collection_name] = new_collection

        for render_layer_name, render_layer_config in self.layers_config.get('layers').items():
            print('create_render layer', render_layer_name)
            collections = render_layer_config.get('collections', {})

            exclude_shading = render_layer_config.get('excluded_nodes', 'matte')
            #if exclude_shading == 'matte':
            #    self.create_mate_shader(collections)

            aov_list = render_layer_config['aov']
            render_product = self.create_product(render_layer_name, aov_list)


            render_settings = self.create_usd_render_settings('Settings_%s' % render_layer_name, render_product)


        pprint(self.collections)

        #render_pass = create_render_pass(self.target_stage,
        #                                 'Characters',
        #                                 self.collections['Animation'],
        #                                 render_prim,
        #                                 render_settings)

        self.target_stage.SetMetadata('renderSettingsPrimPath', str(render_settings.GetPath()))
        #self.target_stage.SetMetadata('renderPassPrimPath', str(render_pass.GetPath()))


        print(root_layer.ExportToString())
        root_layer.Save()













def create_render_pass(stage, layer_name, collection, render_prim, render_settings):
    child_path = render_prim.GetPath()
    passes_path = child_path.AppendPath('Passes')
    stage.DefinePrim(passes_path, 'Scope')
    render_pass_path = passes_path.AppendPath(layer_name)

    #render_layer_prim = stage.DefinePrim(render_pass_path, 'RenderPass ')
    render_pass = UsdRender.Pass.Define(stage, render_pass_path)
    render_prim = stage.GetPrimAtPath(render_pass_path)
    print(render_pass)
    render_pass.CreatePassTypeAttr('arnold')
    #render_pass.CreateFileNameAttr('C:/Users/larry/Documents/maya/projects/default/data/test_render_layers_characters.exr')
    render_pass.CreateCommandAttr(['kick', '-i'])

    relationship_include = render_prim.CreateRelationship('collection:renderVisibility')
    #relationship_include.SetTargets(['/Render/collections.collection:characters:includes'])
    relationship_include.SetTargets(['/World/sets'])
    relationship_include = render_prim.CreateRelationship('renderSource')
    relationship_include.SetTargets([render_settings.GetPath()])

    #relationship_include = render_prim.CreateRelationship('collection:renderVisibility:excludes')
    #relationship_include.SetTargets(['/Render/collections.collection:characters:excludes'])


    root_attr = render_prim.GetAttribute('collection:renderVisibility:includeRoot')
    root_attr.Set(False)
    #attribute = ren

    return render_prim

def create_render_layers(source, output):
    print(source)
    stage = Usd.Stage.CreateNew(output)
    root_layer = stage.GetRootLayer()
    root_layer.subLayerPaths.append(source)

    render_prim = stage.DefinePrim('/Render', 'Scope')

    collections_prim = stage.GetPrimAtPath('/World')
    #characters_collection = create_collection(stage, '/World/Characters', 'foreground', collections_prim)
    set_collection = create_collection(stage, '/World/set', 'background', collections_prim)

    create_mate_shader(stage, set_collection)
    aov = {'RGBA': 'RGBA',
           'a': 'a',
           'directDiffuse': "C<RD>[<L.>O]",
           'N': 'N'}

    foreground_product = create_product(stage, 'Foreground', aov)

    background_product = create_product(stage, 'Background', aov)

    render_settings = create_render_settings(stage,
                                             'HDSettings',
                                             foreground_product)


    render_pass = create_render_pass(stage, 'Characters', set_collection, render_prim, render_settings)

    stage.SetMetadata('renderSettingsPrimPath', str(render_settings.GetPath()))
    #stage.SetMetadata('renderPassPrimPath', str(render_pass.GetPath()))


    print(root_layer.ExportToString())
    root_layer.Save()

if __name__ == '__main__':
    source = 'V:/TPT/publish/usd/shots/s00/s00_ep01/s00_ep01_sq020/s00_ep01_sq020_sh010/breakdown/s00_ep01_sq020_sh010_breakdown.usda'

    output = 'V:/TPT/publish/usd/shots/s00/s00_ep01/s00_ep01_sq020/s00_ep01_sq020_sh010/render_layers/s00_ep01_sq020_sh010_settings.usda'
    builder = RenderLayerBuilder('TPT', source, output)
    builder.create_render_layers()
