import os
import importlib

from pprint import pprint

from pxr import Usd, Sdf
import library.core.config_manager as config_manager
import shotgrid_lib.database as database
import usd.lib.usd_manager as usd_manager

importlib.reload(usd_manager)


class RenderConfig():
    def __init__(self,
                 project,
                 entity,
                 entity_type,
                 render_type,
                 output_file,
                 output_images_path,
                 start_frame=1001,
                 end_frame=1010,
                 renderpass=None,
                 layer_paths=None
                 ):
        self.project = project
        self.entity = entity
        self.entity_type = entity_type
        self.render_type = render_type
        self.output_file = output_file

        self.output_images_path = output_images_path
        self.start_frame = start_frame
        self.end_frame = end_frame
        self.renderpass = renderpass
        self.layer_paths = layer_paths

        self.database = database.DataBase()
        self.database.fill(self.project, precatch=False)

        self.resolution = self.get_resolution()
        self.manager = usd_manager.UsdManager(self.project)
        self.manager.set_entity(self.entity, self.entity_type)

        if not self.layer_paths:
            self.layer_paths = [self.manager.filename]

        self.manager.open(self.output_file)
        self.manager.stage.SetMetadata('timeCodesPerSecond', 24)
        for layer_path in layer_paths:
            self.manager.add_sublayer(layer_path)

        self.read_render_config()

        self.set_metadata()

        self.create_option()
        self.manager.save_stage()

        self.build_collections()

        self.create_renderpass()

        for check_prim in self.manager.stage.Traverse():
            prim_type = check_prim.GetTypeName()
            if prim_type != 'Mesh':
                continue
            self.set_subdivisions(check_prim)

        self.manager.save_stage()


    def get_full_paths(self, include_paths, exclude_paths):

        full_include_paths = []
        full_exclude_paths = []

        for prim in self.manager.stage.Traverse():
            prim_name = prim.GetName()
            prim_path = prim.GetPath().pathString
            if prim_name in include_paths or prim_path in include_paths:
                full_include_paths.append(prim_path)

            if prim_name in exclude_paths or prim_path in exclude_paths:
                full_exclude_paths.append(prim_path)

        return full_include_paths, full_exclude_paths

    def create_collection(self, collection_name, render_prim, include_paths=None, exclude_path=None):

        collection_api = Usd.CollectionAPI.Apply(render_prim, collection_name)
        full_include_paths, full_exclude_paths = self.get_full_paths(include_paths, exclude_path)

        for inc_path in full_include_paths:
            collection_api.GetIncludesRel().AddTarget(inc_path)
        for exc_path in full_exclude_paths:
            collection_api.GetExcludesRel().AddTarget(exc_path)

        collection_api.GetExpansionRuleAttr().Set(Usd.Tokens.expandPrims)

        return collection_api

    def create_layer(self, layer_name, layer_config):
        default_prim = self.manager.default_prim
        add_paths = layer_config.get('added', [])
        removed_paths = layer_config.get('removed', [])

        collection_api = self.create_collection(layer_name,
                                                default_prim,
                                                include_paths=add_paths,
                                                exclude_path=removed_paths)

        return collection_api

    def build_collections(self):

        for layer_name, layer_data in self.render_layers_config.get('layers', {}).items():
            print('create layer', layer_name)
            collection_api = self.create_layer(layer_name, layer_data)

            collection_query = collection_api.ComputeMembershipQuery()
            #print(collection_api.ComputeIncludedPaths(collection_query, stage))


    def set_subdivisions(self, prim):
        subdiv_type = self.manager.get_attribute(prim, 'primvars:arnold:subdiv_type')
        if subdiv_type == 'none':
            return
        subdiv_iterations = self.manager.get_attribute(prim, 'primvars:arnold:subdiv_iterations')
        if subdiv_iterations is None:
            subdiv_iterations = 1

        path = prim.GetPath().pathString
        bits = path.split('/')
        alias = bits[3]
        breakdown = self.entity_view.sg_breakdowns.find_with_filters(sg_alias=alias, single_item=True)
        if breakdown.empty:
            if alias.find('_') == -1:
                return
            asset_name, instance = alias.rsplit('_', 1)
            instance = int(instance)
            for breakdown_check in self.entity_view.sg_breakdowns:
                if breakdown_check.sg_asset_name == asset_name and breakdown_check.sg_instance == instance:
                    breakdown = breakdown_check
                    break

        if breakdown.empty:
            return

        breakdown_shot_type = breakdown.sg_shot_type

        if breakdown_shot_type == 'Very long':
            self.manager.create_attribute(prim,
                                          'primvars:arnold:subdiv_type',
                                          value='none',
                                          attribute_type=str)

        elif breakdown_shot_type == 'Medium':
            subdiv_iterations = subdiv_iterations + 1
            self.manager.create_attribute(prim,
                                          'primvars:arnold:subdiv_iterations',
                                          value=subdiv_iterations,
                                          attribute_type=int)
        elif breakdown_shot_type == 'Close':
            subdiv_iterations = subdiv_iterations + 2
            self.manager.create_attribute(prim,
                                          'primvars:arnold:subdiv_iterations',
                                          value=subdiv_iterations,
                                          attribute_type=int)


    def create_renderpass(self):

        print('=== create render passes == ')
        renderpass_config = self.render_type_data.get('renderpass', {}).get(self.renderpass, {})
        if not renderpass_config:
            return
        default_prim = self.manager.default_prim

        for layer, layer_data in renderpass_config['layers'].items():
            print(layer)
            if not layer_data:
                continue
            collection_api = Usd.CollectionAPI.Apply(default_prim, layer)

            collection_query = collection_api.ComputeMembershipQuery()
            for path in collection_api.ComputeIncludedPaths(collection_query, self.manager.stage):
                check_prim = self.manager.stage.GetPrimAtPath(path)
                prim_type = check_prim.GetTypeName()
                if prim_type not in layer_data:
                    continue
                for attr, value in layer_data[prim_type].items():
                    if attr == 'active':
                        check_prim.SetActive(value)
                    else:
                        self.manager.create_attribute(check_prim, attr, value=value, attribute_type=int)

    def set_metadata(self):
        self.manager.stage.SetMetadata('framesPerSecond', self.resolution_data['fps'])
        self.manager.stage.SetMetadata('timeCodesPerSecond', 24)
        self.manager.stage.SetMetadata('startFrame', self.start_frame)
        self.manager.stage.SetMetadata('endFrame', self.end_frame)

    def get_resolution(self):
        entity_filters = [['code', 'is', self.entity]]
        self.database.query_sg_database(self.entity_type, filters=entity_filters)
        self.entity_view = self.database[self.entity_type][self.entity]
        self.entity_view.precache_dependencies(fields=['sg_breakdowns'])
        self.entity_view.sg_breakdowns.precache_dependencies(fields=['sg_assets'])
        if self.entity_type == 'Shot':
            shot_resolution = self.entity_view.sg_resolution
            if not shot_resolution:
                shot_resolution = self.entity_view.sg_sequence.sg_resolution
                if not shot_resolution:
                    shot_resolution = self.entity_view.sg_sequence.episode.sg_resolution
                    if not shot_resolution:
                        shot_resolution = 'Default'
        else:
            shot_resolution = 'Default'
        return shot_resolution

    def create_color_manager(self, node_name='defaultColorMgtGlobals'):
        color_manager_prim = self.manager.stage.DefinePrim('/%s' % node_name, 'ArnoldColorManagerOcio')
        self.manager.create_attribute(color_manager_prim, 'arnold:name', value=node_name)
        self.manager.create_attribute(color_manager_prim, 'arnold:color_space_linear', value='ACEScg')
        self.manager.create_attribute(color_manager_prim,
                                      'arnold:config',
                                      attribute_type='path',
                                      value='V:/company/tools/ocio/0.1.0/config/config.ocio')
        return node_name


    def add_attributes(self, prim, attribute_data):
        ignore_attributes = ['type', 'buffer', 'short_name']
        for attribute, value in attribute_data.items():
            if attribute in ignore_attributes:
                continue
            if attribute == 'name':
                value = '%s/%s' % (prim.GetParent().GetPath(), value)
                value = value[1:]
            full_attr_name = 'arnold:%s' % attribute
            attr_type = str
            if isinstance(value, float):
                attr_type = float
            elif isinstance(value, int):
                    attr_type = int


            self.manager.create_attribute(prim, full_attr_name, attribute_type=attr_type, value=value)

    def generate_outputs(self):
        outputs = []
        filters_prim = self.manager.stage.DefinePrim('/defaultArnoldFilter')
        filter_name = ''
        for filter_name, filter_data in self.render_type_data.get('defaultArnoldFilter', {}).items():
            filter_prim = self.manager.stage.DefinePrim('/defaultArnoldFilter/%s' % filter_name, 'ArnoldGaussianFilter')
            self.add_attributes(filter_prim, filter_data)
            filter_name = filter_data['name']

        drivers_prim = self.manager.stage.DefinePrim('/defaultArnoldDriver')

        aov_data = self.render_type_data.get('renderpass', {}).get(self.renderpass, {}).get('aov')
        for aov_name, driver_data in aov_data.items():
            driver_name = driver_data.get('name', aov_name).replace('.', '_')
            driver_data.get('type', 'ArnoldDriverExr')
            driver_prim = self.manager.stage.DefinePrim('/defaultArnoldDriver/%s' % driver_name,
                                                        driver_data.get('type', 'ArnoldDriverExr'))

            aov_short_name = driver_data['short_name']
            if aov_short_name == 'RGBA':
                aov_short_name = 'beauty'

            self.add_attributes(driver_prim, driver_data)
            output = []
            output.append(driver_data['short_name'])
            output.append(driver_data['buffer'])
            output.append('defaultArnoldFilter/%s' % filter_name)
            full_driver_name = 'defaultArnoldDriver/%s' % driver_data['name']
            output.append(full_driver_name)
            output_str = ' '.join(output)
            outputs.append(output_str)

            attr = driver_prim.CreateAttribute('arnold:filename', Sdf.ValueTypeNames.Asset)
            output_pattern = self.output_images_path.replace('<aov>', aov_short_name)
            output_folder = os.path.dirname(output_pattern)
            if not os.path.exists(output_folder):
                os.makedirs(output_folder)
            for frame in range(self.start_frame, self.end_frame+1):

                full_path = output_pattern.replace('<frame>', '%04d' % frame)
                value = Sdf.AssetPath(full_path)
                attr.Set(time=frame, value=value)
        #outputs = [' '.join(outputs)]
        return outputs

    def create_option(self):
        options_prim = self.manager.stage.DefinePrim('/Options', 'ArnoldOptions')
        self.manager.create_attribute(options_prim, 'arnold:name', value='options')

        self.manager.create_attribute(options_prim, 'arnold:color_manager', value=self.create_color_manager())

        fps = self.resolution_data['fps']
        self.manager.create_attribute(options_prim, 'arnold:fps', attribute_type=int, value=fps)
        width = self.resolution_data['width']
        height = self.resolution_data['height']
        self.manager.create_attribute(options_prim, 'arnold:xres', attribute_type=int, value=width)
        self.manager.create_attribute(options_prim, 'arnold:yres', attribute_type=int, value=height)

        self.add_attributes(options_prim, self.render_type_data.get('ArnoldOptions', {}))
        outputs = self.generate_outputs()
        self.manager.create_attribute(options_prim, 'arnold:outputs', attribute_type='string_array', value=outputs)

        volume_shaders = ['arnold:atmosphere_volume']
        for prim in self.manager.stage.Traverse():
            prim_type = prim.GetTypeName()
            if prim_type != 'Shader':
                continue
            info_id = self.manager.get_attribute(prim, 'info:id')
            if info_id in volume_shaders:
                self.manager.create_attribute(options_prim, 'arnold:atmosphere', value=prim.GetPath().pathString)

    def read_render_config(self):
        self.config_solver = config_manager.ConfigSolver(project=self.project)
        self.project_data = self.config_solver.get_config('project')
        self.render_type_data = self.config_solver.get_config(self.render_type, module='render_config')
        self.render_layers_config = self.config_solver.get_config('render_layers_config', module='usd')

        self.all_resolutions_data = self.config_solver.get_config('show_resolutions')
        self.resolution_data = self.all_resolutions_data.get(self.resolution)
        if not self.resolution_data:
            self.resolution_data = self.all_resolutions_data.get('Default')






def test_create_config():
    path = 'C:/Users/fernando.vizoso/Documents/maya/projects/default/datad/render_config.usda'
    output_images = 'C:/Users/fernando.vizoso/Documents/maya/projects/default/images/<aov>/image.<frame>.exr'
    if os.path.exists(path):
        os.remove(path)
    entity = 's00_ep01_sq020_sh010'
    entity_type = 'Shot'
    #render_config = RenderConfig('TpT', entity, entity_type, 'rendering', path, output_images, renderpass='Characters')
    render_config = RenderConfig('TpT', entity, entity_type, 'lighting', path, output_images, renderpass='Preview')

if __name__ == '__main__':
    test_create_config()