import os
import shutil
from pprint import pprint


import shotgrid_lib.database as database
from pxr import Usd, UsdGeom, Sdf


import library.core.config_manager as config_manager


class BreakdownBuilder():
    def __init__(self, shot_name, project=None, root_node='World'):
        self.shot_name = shot_name
        self.project = project
        self.root_node = root_node
        self.init()

    def get_config(self):
        self.config_solver = config_manager.ConfigSolver(project=self.project)
        self.project_data = self.config_solver.get_config('project')
        self.render_layers_config = self.config_solver.get_config('render_layers_config', module='usd')


    def get_publish_path(self):
        self.root_path = self.project_data['paths']['publish_server']

        bits = self.shot_name.split('_')
        self.season = bits[0]
        self.episode = '_'.join(bits[:2])
        self.sequence = '_'.join(bits[:3])

        self.publish_path = '%s/publish/usd/shots/%s/%s/%s/%s' % (self.root_path,self.season,self.episode,self.sequence, self.shot_name)
        self.publish_asset_path = '%s/publish/usd/assets' % self.root_path

        return self.publish_path


    def init(self):
        self.database = database.DataBase()
        self.database.fill(self.project, precatch=False)
        shot_filters =[['code', 'is', self.shot_name]]
        self.database.query_sg_database('Shot', filters=shot_filters)
        self.shot_view = self.database['Shot'][self.shot_name]

        self.shot_view.precache_dependencies(fields=['sg_breakdowns', 'sg_published_elements'])
        asset_list = self.shot_view.sg_breakdowns.sg_asset
        if not asset_list.empty:

            asset_list.precache_dependencies(fields=['sg_published_elements'])
        self.get_config()
        self.get_publish_path()

    def build_breakdown(self):
        correct_version = self.get_breakdown_versions()

        print(correct_version)
        filename = '%s/breakdown/%s_breakdown.usda' % (self.publish_path, self.shot_name)

        stage = get_stage(filename, self.root_node, clean=True)

        for item in self.shot_view.sg_breakdowns:
            if item.sg_parent_asset:
                continue
            asset_name = item.sg_asset.code
            asset_type = item.sg_asset.sg_asset_type.replace(' ', '_')

            print(asset_name, asset_type)
            if asset_type == 'Cameras':
                continue
            alias = item.sg_alias if item.sg_alias else '%s_%03d' % (asset_name, item.sg_instance)

            variants = {'Geometry_variant': item.sg_geometry_variant.code,
                        'Shading_variant': item.sg_shading_variant.code
                        }
            if asset_name in correct_version:
                asset_versions = correct_version[asset_name]
                if 'Model' in asset_versions:
                    variants['Geometry_version'] = asset_versions['Model'].get('approved_version',
                                                                                    asset_versions['Model'][
                                                                                        'highest_version'])
                if 'SetDressing' in asset_versions:
                    variants['Geometry_version'] = asset_versions['SetDressing'].get('approved_version',
                                                                                     asset_versions['SetDressing'][
                                                                                         'highest_version'])
                if 'Shading' in asset_versions:
                    variants['Shading_version'] = asset_versions['Shading'].get('approved_version',
                                                                                asset_versions['Shading'][
                                                                                    'highest_version'])
                if 'Lighsets' in asset_versions:
                    print('add lightsets')

                    variants['Light_variant'] = item.sg_geometry_variant.code
                    variants['Light_version'] = asset_versions['Lighsets'].get('approved_version',
                                                                               asset_versions['Lighsets'][
                                                                                   'highest_version'])

            self.add_asset_breakdown(self.shot_view.code,
                                     asset_name,
                                     alias,
                                     variants=variants,
                                     asset_type=asset_type,
                                     stage=stage)

        self.build_collections(stage)
        stage.GetRootLayer().Save()
        return filename

    def get_full_paths(self, scene_manager, include_paths, exclude_paths):

        full_include_paths = []
        full_exclude_paths = []

        for item in self.shot_view.sg_breakdowns:
            if item.sg_parent_asset:
                continue
            asset_name = item.sg_asset.code
            alias = item.sg_alias if item.sg_alias else '%s_%03d' % (asset_name, item.sg_instance)

            for include_pattern in include_paths:
                check_path = include_pattern.replace('<alias>', alias)
                check_prim = scene_manager.GetPrimAtPath(check_path)
                if check_prim.IsValid():
                    full_include_paths.append(check_path)

            for exclude_pattern in exclude_paths:
                check_path = exclude_pattern.replace('<alias>', alias)
                check_prim = scene_manager.GetPrimAtPath(check_path)
                if check_prim.IsValid():
                    full_exclude_paths.append(check_path)

        return full_include_paths, full_exclude_paths

    def create_collection(self, collection_name, render_prim, scene_manager, include_paths=None, exclude_path=None):

        collection_api = Usd.CollectionAPI.Apply(render_prim, collection_name)
        full_include_paths, full_exclude_paths = self.get_full_paths(scene_manager, include_paths, exclude_path)

        for inc_path in full_include_paths:
            collection_api.GetIncludesRel().AddTarget(inc_path)
        for exc_path in full_exclude_paths:
            collection_api.GetExcludesRel().AddTarget(exc_path)

        collection_api.GetExpansionRuleAttr().Set(Usd.Tokens.expandPrims)

        return collection_api

    def create_layer(self, layer_name, layer_config, scene_manager):
        default_prim = scene_manager.GetDefaultPrim()
        add_paths = layer_config.get('added', [])
        removed_paths = layer_config.get('removed', [])

        collection_api = self.create_collection(layer_name,
                                                default_prim,
                                                scene_manager,
                                                include_paths=add_paths,
                                                exclude_path=removed_paths)

        return collection_api

    def build_collections(self, stage):

        for layer_name, layer_data in self.render_layers_config.get('layers', {}).items():
            print('create layer', layer_name)
            collection_api = self.create_layer(layer_name, layer_data, stage)

            collection_query = collection_api.ComputeMembershipQuery()
            #print(collection_api.ComputeIncludedPaths(collection_query, stage))

    def get_breakdown_versions(self):
        correct_versions = {}
        for asset_view in self.shot_view.sg_breakdowns.sg_asset:
            for version in asset_view.sg_published_elements:
                pipeline_step = version.sg_step.code
                version_number = version.sg_version_number
                status = version.sg_status_list
#                if pipeline_step not in ['ModelingHD', 'Shading', 'SetDressing', 'Lighsets']:
#                    continue

                asset_name = version.sg_asset.code

                correct_versions[asset_name] = correct_versions.get(asset_name, {})
                correct_versions[asset_name][pipeline_step] = correct_versions[asset_name].get(pipeline_step, {})

                last_higest_version = correct_versions[asset_name][pipeline_step].get('highest_version', 0)
                approved_version = correct_versions[asset_name][pipeline_step].get('approved_version', 0)

                if version_number > last_higest_version:
                    correct_versions[asset_name][pipeline_step]['highest_version'] = version_number

                if status == 'cmpt' and version_number > approved_version:
                    correct_versions[asset_name][pipeline_step]['approved_version'] = version_number

        return correct_versions



    def add_asset_breakdown(self, shot_name, asset_name, alias, variants=None, asset_type='Char', stage=None):
        short_asset_type = asset_type.split('_')[-1]


        filename = '%s/breakdown/%s_breakdown.usda' % (self.publish_path, self.shot_name)

        if stage is None:
            stage = get_stage(filename, self.root_node)

        root_prim = stage.DefinePrim('/%s' % self.root_node, 'Xform')
        #root_prim.SetKind('assembly')
        print('/%s/%s' % (self.root_node, short_asset_type))
        asset_type_prim = stage.DefinePrim('/%s/%s' % (self.root_node, short_asset_type), 'Scope')
        #asset_type_prim.SetKind('group')
        asset_prim_path = '/%s/%s/%s' % (self.root_node, short_asset_type, alias)

        asset_prim = stage.DefinePrim(asset_prim_path, 'Xform')
        #asset_prim.SetKind('component')

        payloads = asset_prim.GetPayloads()
        payloads.ClearPayloads()
        asset_path = '%s/%s/%s/asset_assembly.usda' % (self.publish_asset_path, asset_type, asset_name)

        payloads.AddPayload(asset_path)

        if variants:
            for variant_set_name, variant_value in variants.items():
                if not isinstance(variant_value, str):
                    continue
                print(variant_set_name, variant_value)
                if variant_set_name.endswith('_version'):
                    continue

                variant_value = str(variant_value)
                variant_sets = asset_prim.GetVariantSets()

                variant_set = variant_sets.GetVariantSet(variant_set_name)
                variant_set.SetVariantSelection(variant_value)

        #print(stage.GetRootLayer().ExportToString())


        print(filename)







def get_breakdown_data(shot_name):
    server = 'https://tds.shotgunstudio.com'
    script_name = 'Interface'
    api_key = 'dxfvstbke8zfzLmoacdlno@va'

    connection = shotgun_api3.Shotgun(server, script_name, api_key)

    filters = [['project.Project.code', 'is', 'SGD'],
               ['sg_link.Shot.code', 'is', shot_name]
               ]

    fields = ['code', 'sg_asset', 'sg_alias', 'sg_instance', 'sg_geometry_variant', 'sg_shading_variant',
              'sg_asset.Asset.sg_asset_type']
    breakdown = connection.find('CustomEntity12', filters=filters, fields=fields)
    return breakdown




def get_versions_data(shot_name, assets):
    server = 'https://tds.shotgunstudio.com'
    script_name = 'Interface'
    api_key = 'dxfvstbke8zfzLmoacdlno@va'

    connection = shotgun_api3.Shotgun(server, script_name, api_key)

    filters = [['project.Project.code', 'is', 'SGD'],
               ['sg_asset', 'in', assets],
               ['sg_delete', 'is', False],
               ['sg_complete', 'is', True],
               ['sg_complete', 'is', True],
               ]

    fields = ['code',
              'sg_asset',
              'sg_alias',
              'sg_instance',
              'sg_geometry_variant',
              'sg_shading_variant',
              'sg_asset.Asset.sg_asset_type',
              'sg_step',
              'sg_version_number',
              'sg_status_list']
    assets_versions = connection.find('CustomEntity09', filters=filters, fields=fields)
    return assets_versions


def get_stage(filename, root, create=False, clean=False):
    print(filename)
    if os.path.exists(filename):
        print('The file do exists, opening')
        if clean:
            os.remove(filename)
        else:
            stage = Usd.Stage.Open(filename)
            return stage

    abs_path = os.path.abspath(filename)
    dirname = os.path.dirname(abs_path)
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    print('The file do not exists, creating')
    stage = Usd.Stage.CreateNew(filename)
    asset_prim = stage.OverridePrim('/%s' % root)
    stage.SetDefaultPrim(asset_prim)
    if create:
        stage.GetRootLayer().Save()

    return stage







if __name__ == '__main__':
    shot_name = 'sq010_sh010'
    alias = 'Lionel_Messi'
    asset_name = 'LeoMessi'
    variants = {}

    variants = {'Shading_variant': 'Laminate',
                'Shading_version': '002',
                }
    # add_asset_breakdown(shot_name, alias, asset_name, variants=variants, asset_type='Char')

    alias = 'Room_Messi'
    asset_name = 'room'
    variants = {}
    asset_type = 'Set'

    variants = {'Shading_variant': 'Metal',
                'Shading_version': '002',
                }

    # add_asset_breakdown(shot_name, alias, asset_name, variants=variants, asset_type=asset_type)


    shot_name = 's00_ep01_sq030_sh060'

    builder = BreakdownBuilder(shot_name, project='TPT')
    builder.build_breakdown()

    # create_shot_scene(shot_name, force_update=True)