import os
import shutil
from pprint import pprint


import shotgrid_lib.database as database
from pxr import Usd, UsdGeom, Sdf


import library.core.config_manager as config_manager



class UsdShot():
    def __init__(self, shot_name, project=None, root_node='World'):
        self.shot_name = shot_name
        self.project = project
        self.root_node = root_node
        self.init()

    def get_config(self):
        self.config_solver = config_manager.ConfigSolver(project=self.project)
        self.project_data = self.config_solver.get_config('project')
        self.usd_shot_data = self.config_solver.get_config('usd_shot', module='usd')


    def get_publish_path(self):
        self.root_path = self.project_data['paths']['publish_server']

        bits = self.shot_name.split('_')
        self.season = bits[0]
        self.episode = '_'.join(bits[:2])
        self.sequence = '_'.join(bits[:3])

        self.publish_path = '%s/publish/usd/shots/%s/%s/%s/%s' % (self.root_path,self.season,self.episode,self.sequence, self.shot_name)
        self.publish_asset_path = '%s/publish/usd/assets' % self.root_path

        return self.publish_path


    def init(self):
        self.database = database.DataBase()
        self.database.fill(self.project, precatch=False)
        shot_filters =[['code', 'is', self.shot_name]]
        self.database.query_sg_database('Shot', filters=shot_filters)
        #self.database.query_sg_database('Asset')
        self.shot_view = self.database['Shot'][self.shot_name]

        self.shot_view.precache_dependencies(fields=['sg_breakdowns', 'sg_published_elements'])
        asset_list = self.shot_view.sg_breakdowns.sg_asset
        asset_list.precache_dependencies(fields=['sg_published_elements'])
        self.get_config()
        self.get_publish_path()


    def create_shot_scene(self, shot_layers=None, force_update=False, filename=None):

        if os.path.exists(self.publish_path) and force_update:
            shutil.rmtree(self.publish_path)

        if not shot_layers:
            shot_layers = self.usd_shot_data['layers']

        if filename:
            assembly_path = '%s/%s.usda' % (self.publish_path, filename)
        else:
            assembly_path = '%s/shot_assembly.usda' % self.publish_path
        if not force_update:
            full_path = os.path.abspath(assembly_path)
            if os.path.exists(full_path):
                print('Shot assembly for %s already exists' % shot_name)
                return

        stage = get_stage(assembly_path, self.root_node)

        asset_prim = stage.GetPrimAtPath('/%s' % self.root_node)
        if not asset_prim.IsValid():
            asset_prim = stage.DefinePrim('/%s' % self.root_node, 'Xform')

        references = asset_prim.GetReferences()
        references.ClearReferences()

        for layer in shot_layers:
            relative_path = '%s/%s_layer.usda' % (self.publish_path, layer)
            get_stage(relative_path, self.root_node, create=True)

            references.AddReference(relative_path)

        stage.GetRootLayer().Save()


    def build_breakdown(self):
        correct_version = self.get_breakdown_versions()

        for item in self.shot_view.sg_breakdowns:

            asset_name = item.sg_asset.code
            asset_type =item.sg_asset.sg_asset_type.split(' ')[-1]
            print(asset_name, asset_type)
            if asset_type == 'Cameras':
                continue
            alias = item.sg_alias if item.sg_alias else '%s_%03d' % (asset_name, item.sg_instance)

            variants = {'Geometry_variant': item.sg_geometry_variant.code,
                        'Shading_variant': item.sg_shading_variant.code
                        }
            if asset_name in correct_version:
                asset_versions = correct_version[asset_name]
                if 'Model' in asset_versions:
                    variants['Geometry_version'] = asset_versions['Model'].get('approved_version',
                                                                                    asset_versions['Model'][
                                                                                        'highest_version'])
                if 'SetDressing' in asset_versions:
                    variants['Geometry_version'] = asset_versions['SetDressing'].get('approved_version',
                                                                                     asset_versions['SetDressing'][
                                                                                         'highest_version'])
                if 'Shading' in asset_versions:
                    variants['Shading_version'] = asset_versions['Shading'].get('approved_version',
                                                                                asset_versions['Shading'][
                                                                                    'highest_version'])
                if 'Lighsets' in asset_versions:
                    print('add lightsets')

                    variants['Light_variant'] = item.sg_geometry_variant.code
                    variants['Light_version'] = asset_versions['Lighsets'].get('approved_version',
                                                                               asset_versions['Lighsets'][
                                                                                   'highest_version'])

            self.add_asset_breakdown(asset_name, alias, variants=variants, asset_type=asset_type)

    def get_breakdown_versions(self):
        correct_versions = {}
        for asset_view in self.shot_view.sg_breakdowns.sg_asset:
            for version in asset_view.sg_published_elements:
                pipeline_step = version.sg_step.code
                version_number = version.sg_version_number
                status = version.sg_status_list
#                if pipeline_step not in ['ModelingHD', 'Shading', 'SetDressing', 'Lighsets']:
#                    continue

                asset_name = version.sg_asset.code

                correct_versions[asset_name] = correct_versions.get(asset_name, {})
                correct_versions[asset_name][pipeline_step] = correct_versions[asset_name].get(pipeline_step, {})

                last_higest_version = correct_versions[asset_name][pipeline_step].get('highest_version', 0)
                approved_version = correct_versions[asset_name][pipeline_step].get('approved_version', 0)

                if version_number > last_higest_version:
                    correct_versions[asset_name][pipeline_step]['highest_version'] = version_number

                if status == 'cmpt' and version_number > approved_version:
                    correct_versions[asset_name][pipeline_step]['approved_version'] = version_number

        return correct_versions



    def add_asset_breakdown(self,asset_name, alias, variants=None, asset_type='Char'):
        filename = '%s/breakdown_layer.usda' % self.publish_path

        stage = get_stage(filename, self.root_node)

        root_prim = stage.DefinePrim('/%s' % self.root_node, 'Xform')
        #root_prim.SetKind('assembly')
        print('/%s/%s' % (self.root_node, asset_type))
        asset_type_prim = stage.DefinePrim('/%s/%s' % (self.root_node, asset_type), 'Scope')
        #asset_type_prim.SetKind('group')
        asset_prim_path = '/%s/%s/%s' % (self.root_node, asset_type, alias)

        asset_prim = stage.DefinePrim(asset_prim_path, 'Xform')
        #asset_prim.SetKind('component')

        payloads = asset_prim.GetPayloads()
        payloads.ClearPayloads()
        asset_path = '%s/%s/asset_assembly.usda' % (self.publish_asset_path, asset_name)

        payloads.AddPayload(asset_path)

        if variants:
            for variant_set_name, variant_value in variants.items():
                variant_value = str(variant_value)
                variant_sets = asset_prim.GetVariantSets()
                v_sets_names = variant_sets.GetNames()

                if not variant_set_name in v_sets_names:
                    continue
                variant_set = variant_sets.GetVariantSet(variant_set_name)
                print(variant_set_name, variant_value)
                variant_set.SetVariantSelection(variant_value)

        print(stage.GetRootLayer().ExportToString())

        stage.GetRootLayer().Save()








def get_breakdown_data(shot_name):
    server = 'https://tds.shotgunstudio.com'
    script_name = 'Interface'
    api_key = 'dxfvstbke8zfzLmoacdlno@va'

    connection = shotgun_api3.Shotgun(server, script_name, api_key)

    filters = [['project.Project.code', 'is', 'SGD'],
               ['sg_link.Shot.code', 'is', shot_name]
               ]

    fields = ['code', 'sg_asset', 'sg_alias', 'sg_instance', 'sg_geometry_variant', 'sg_shading_variant',
              'sg_asset.Asset.sg_asset_type']
    breakdown = connection.find('CustomEntity12', filters=filters, fields=fields)
    return breakdown




def get_versions_data(shot_name, assets):
    server = 'https://tds.shotgunstudio.com'
    script_name = 'Interface'
    api_key = 'dxfvstbke8zfzLmoacdlno@va'

    connection = shotgun_api3.Shotgun(server, script_name, api_key)

    filters = [['project.Project.code', 'is', 'SGD'],
               ['sg_asset', 'in', assets],
               ['sg_delete', 'is', False],
               ['sg_complete', 'is', True],
               ['sg_complete', 'is', True],
               ]

    fields = ['code',
              'sg_asset',
              'sg_alias',
              'sg_instance',
              'sg_geometry_variant',
              'sg_shading_variant',
              'sg_asset.Asset.sg_asset_type',
              'sg_step',
              'sg_version_number',
              'sg_status_list']
    assets_versions = connection.find('CustomEntity09', filters=filters, fields=fields)
    return assets_versions


def get_stage(filename, root, create=False):
    print(filename)
    if os.path.exists(filename):
        print('The file do exists, opening')
        stage = Usd.Stage.Open(filename)

    else:
        abs_path = os.path.abspath(filename)
        dirname = os.path.dirname(abs_path)
        if not os.path.exists(dirname):
            os.makedirs(dirname)
        print('The file do not exists, creating')
        stage = Usd.Stage.CreateNew(filename)
        asset_prim = stage.OverridePrim('/%s' % root)
        stage.SetDefaultPrim(asset_prim)
        if create:
            stage.GetRootLayer().Save()

    return stage







if __name__ == '__main__':
    shot_name = 'sq010_sh010'
    alias = 'Lionel_Messi'
    asset_name = 'LeoMessi'
    variants = {}

    variants = {'Shading_variant': 'Laminate',
                'Shading_version': '002',
                }
    # add_asset_breakdown(shot_name, alias, asset_name, variants=variants, asset_type='Char')

    alias = 'Room_Messi'
    asset_name = 'room'
    variants = {}
    asset_type = 'Set'

    variants = {'Shading_variant': 'Metal',
                'Shading_version': '002',
                }

    # add_asset_breakdown(shot_name, alias, asset_name, variants=variants, asset_type=asset_type)


    shot_name = 's00_ep00_sq010_sh010'

    build_breakdown(shot_name)

    # create_shot_scene(shot_name, force_update=True)