import os
import yaml

from pxr import Usd, UsdGeom, Sdf

import shotgrid_lib.database as database
import library.core.config_manager as config_manager

from pprint import  pprint

class UsdAssetLayers():
    def __init__(self, project=None):
        self.project = project

        self.database = database.DataBase()
        self.database.fill(self.project, precatch=False)
        self.database.query_sg_database('Step', as_precache=True)
        self.database.query_sg_database('CustomEntity11', as_precache=True)
        self.database.query_sg_database('Asset')
        self.get_config()

    def get_layer_path(self, layer_name=None, variant=None):
        if not layer_name:
            usd_layer_folder = '%s/%s/%s' % (self.asset_config['paths']['assets'],
                                             self.asset_type,
                                             self.asset_name)

            usd_file_path = '%s/asset_assembly.usda' % usd_layer_folder
            return usd_file_path

        usd_layer_folder = '%s/%s/%s/%s' % (self.asset_config['paths']['assets'],
                                            self.asset_type,
                                            self.asset_name,
                                            layer_name)
        if not variant:
            usd_file_path = '%s/%s_%s.usda' % (usd_layer_folder,
                                               self.asset_name,
                                               layer_name)
        else:
            usd_file_path = '%s/%s_%s_%s.usda' % (usd_layer_folder,
                                                  self.asset_name,
                                                  layer_name,
                                                  variant)
        return usd_file_path

    def set_variant(self, asset_prim, variant_set_name, variant_name, variant_value):
        prim_sets = asset_prim.GetVariantSets()

        version_set = prim_sets.AddVariantSet(variant_set_name)
        selected = version_set.GetVariantSelection()
        if not selected:
            selected = 'latest'

        version_set.AddVariant(variant_name)
        version_set.SetVariantSelection(variant_name)

        print(variant_value)
        with version_set.GetVariantEditContext():
            payloads = asset_prim.GetPayloads()
            payloads.ClearPayloads()
            payloads.AddPayload(variant_value)

        version_set.SetVariantSelection(selected)

    def update_usd_file(self, usd_file, variant_set_name, variant_name, publish_path, is_version=False):
        update_parent_layer = False

        if os.path.exists(usd_file):
            stage = Usd.Stage.Open(usd_file)
        else:
            stage = Usd.Stage.CreateNew(usd_file)
            update_parent_layer = True

        asset_prim = stage.GetPrimAtPath('/asset')
        if not asset_prim.IsValid():
            asset_prim = stage.OverridePrim('/asset')
            stage.SetDefaultPrim(asset_prim)

        self.set_variant(asset_prim, variant_set_name, variant_name, publish_path)
        prim_sets = asset_prim.GetVariantSets()

        if is_version:
            version_set = prim_sets.AddVariantSet(variant_set_name)
            versions = version_set.GetVariantNames()

            numeric_versions = [int(version_str) for version_str in versions if version_str.isdigit()]

            if version == max(numeric_versions):
                print('latest versions')
                self.set_variant(asset_prim, variant_set_name, 'latest', publish_path)
                if not 'approved' in versions:
                    self.set_variant(asset_prim, variant_set_name, 'recommended', publish_path)
            else:
                print('Not latest version')

        stage.GetRootLayer().Save()

        return update_parent_layer


    def add_file(self, publish_path, step, variant_name, version):
        layer_data = self.get_layer_data(step, self.asset_layers_config)
        layer_name = layer_data['name']
        versions_path = self.get_layer_path(layer_name=layer_name, variant=variant_name)
        print(versions_path, os.path.exists(versions_path))

        update_layer = self.update_usd_file(versions_path,
                                            '%s_version' % layer_name,
                                            '%03d' % version,
                                            publish_path,
                                            is_version=True)
        layer_path = self.get_layer_path(layer_name=layer_name)
        print('layer path', layer_path)
        if not os.path.exists(layer_path):
            update_layer = True

        if update_layer:

            print('Updating the main sublayer')
            update_layer = self.update_usd_file(layer_path,
                                                '%s_variant' % layer_name,
                                                variant_name,
                                                versions_path)

        assembly_path = self.get_layer_path()
        print(assembly_path, os.path.exists(assembly_path))
        if not os.path.exists(assembly_path):
            self.create_layer_files()
        else:

            print('check root', assembly_path)
            main_stage = Usd.Stage.Open(assembly_path)
            root_layer = main_stage.GetRootLayer()
            layer_stack = main_stage.GetLayerStack(includeSessionLayers=False)
            layers_path = [layer.identifier for layer in layer_stack]

            if assembly_path in layers_path:
                return assembly_path

            root_layer.subLayerPaths = []
            for item in self.asset_layers_config:
                for sublayer_name in item.keys():
                    usd_sublayer_path = self.get_layer_path(layer_name=sublayer_name)
                    if os.path.exists(usd_sublayer_path):
                        root_layer.subLayerPaths.append(usd_sublayer_path)

            root_layer.Save()

    def get_config(self):
        self.config_solver = config_manager.ConfigSolver(project=self.project)
        self.project_config = self.config_solver.get_config('project')
        self.asset_config = self.config_solver.get_config('asset_usd_scheme', module='usd')


    def set_asset(self, asset_name):
        self.asset_name = asset_name

        self.asset_view = self.database['Asset'][self.asset_name]
        self.asset_type = self.asset_view.sg_asset_type.replace(' ', '_')
        self.asset_view.precache_dependencies(fields=['sg_published_elements'])

        self.asset_layers_config = self.asset_config['layers'].get(self.asset_type)
        if not self.asset_layers_config:
            self.asset_layers_config = self.asset_config['layers']['default']

        publish_root_folder = self.project_config['paths']['publish']['root']
        self.publish_path = '%s/usd/assets/%s/%s' % (publish_root_folder, self.asset_type, self.asset_name)
        print(self.publish_path)

        self.get_asset_data()

    def get_asset_data(self):
        published_elements = self.asset_view.sg_published_elements
        if published_elements.empty:
            self.asset_publish_data = {}
            return {}
        steps = published_elements.single_values('sg_step')
        asset_data = {}
        for step in steps:
            asset_data[step] = {}
            all_step_publish = published_elements.find_with_filters(sg_step=self.database['Step'][step])

            for variant in all_step_publish.single_values('sg_variant'):
                variant_publish = all_step_publish.find_with_filters(sg_variant=self.database['CustomEntity11'][variant])
                asset_data[step][variant] = variant_publish

        self.asset_publish_data = asset_data

    def get_layer_data(self,step,  layers_data):
        for item in layers_data:

            for key, value in item.items():
                if isinstance(value, list):
                    data = self.get_layer_data(step, value)
                    if data:
                        return data

                else:
                    if key == step:
                        value['name'] = step
                        return value
        return None

    def create_step_files(self, step):
        layer_data = self.get_layer_data(step, self.asset_layers_config)
        layer_name = layer_data['name']
        publish_type = layer_data['publish_type']



        usd_file_path = '%s/%s/%s_%s.usda' % (self.publish_path,
                                              layer_name,
                                              self.asset_name,
                                              layer_name)

        print('usd_path:', usd_file_path)
        usd_layer_folder = os.path.dirname(usd_file_path)

        if not os.path.exists(usd_layer_folder):
            os.makedirs(usd_layer_folder)

        variant_paths = {}
        for variant_name, layer_view in self.asset_publish_data.get(publish_type, {}).items():
            print('variant:', variant_name)
            variant_file_path = '%s/%s_%s_%s.usda' % (usd_layer_folder,
                                                      self.asset_name,
                                                      layer_name,
                                                      variant_name)
            self.create_versions_file(variant_file_path, layer_view, layer_name)
            variant_paths[variant_name] = variant_file_path

        self.create_variant_file(usd_file_path, variant_paths, layer_name)

    def create_layer_files(self, layer_name='', layers_data=None, force_build=False):
        if layers_data is None:
            layers_data = self.asset_layers_config

        if not layer_name:
            usd_layer_path = '%s/asset_assembly.usda' % (self.publish_path)
        else:
            usd_layer_path = '%s/%s/%s_%s.usda' % (self.publish_path,
                                                   layer_name,
                                                   self.asset_name,
                                                   layer_name)

        usd_layer_folder = os.path.dirname(usd_layer_path)
        if not os.path.exists(usd_layer_folder):
            os.makedirs(usd_layer_folder)

        layers_paths = []
        for item in layers_data:
            for sublayer_name, sublayer_data in item.items():
                if isinstance(sublayer_data, list):
                    layer_file = self.create_layer_files(layer_name=sublayer_name, layers_data=sublayer_data, force_build=force_build)
                    layers_paths.append(layer_file)
                else:
                    usd_sublayer_path = '%s/%s/%s_%s.usda' % (self.publish_path,
                                                           sublayer_name,
                                                           self.asset_name,
                                                           sublayer_name)
                    if force_build:
                        self.create_step_files(sublayer_name)

                    if os.path.exists(usd_sublayer_path):
                        layers_paths.append(usd_sublayer_path)

        if os.path.exists(usd_layer_path):
            os.remove(usd_layer_path)

        stage = Usd.Stage.CreateNew(usd_layer_path)
        root_layer = stage.GetRootLayer()
        for path in layers_paths:
            root_layer.subLayerPaths.append(path)

        root_prim = stage.DefinePrim('/asset', 'Xform')
        stage.SetDefaultPrim(root_prim)


        if not layer_name:
            render_prim = stage.OverridePrim('/asset/render')
            imageable_API = UsdGeom.Imageable(render_prim)
            purpose_attr = imageable_API.CreatePurposeAttr()
            purpose_attr.Set(UsdGeom.Tokens.render)

            proxy_prim = stage.OverridePrim('/asset/proxy')

            imageable_API = UsdGeom.Imageable(proxy_prim)
            purpose_attr = imageable_API.CreatePurposeAttr()
            purpose_attr.Set(UsdGeom.Tokens.proxy)


        stage.GetRootLayer().Save()

        return usd_layer_path

    def create_variant_file(self, asset_layer_path, variant_paths, layer_name):

        print('layer_name', layer_name)
        print(variant_paths)
        stage = Usd.Stage.CreateNew(asset_layer_path)
        if not variant_paths:
            stage.GetRootLayer().Save()
            return

        asset_prim = stage.OverridePrim('/asset' )
        stage.SetDefaultPrim(asset_prim)
        prim_sets = asset_prim.GetVariantSets()

        version_set = prim_sets.AddVariantSet('%s_variant' % layer_name)

        for variant_name, path in variant_paths.items():
            version_set.AddVariant(variant_name)
            version_set.SetVariantSelection(variant_name)

            with version_set.GetVariantEditContext():
                payloads = asset_prim.GetPayloads()
                payloads.AddPayload(path)

        if 'Master' in variant_paths:
            version_set.SetVariantSelection('Master')
        stage.GetRootLayer().Save()

    def create_versions_file(self, sublayer_path, versions_view, layer_name):
        print('create version file')
        print('sublayer path', sublayer_path)
        print('layer name', layer_name)
        versions_data = {}
        approved_version = 0
        latest_version = 0
        for published in versions_view:
            if published.sg_delete or not published.sg_complete:
                continue
            if not published.sg_files or not published.sg_files.get('usd', ''):
                continue
            published_path = '%s/%s' % (published.sg_published_folder, published.sg_files.get('usd', ''))
            #print(published_path, os.path.exists(published_path))
            if not os.path.exists(published_path):
                continue

            if published.sg_status_list == 'cmpt':
                if published.sg_version_number > approved_version:
                    approved_version = published.sg_version_number
            if published.sg_version_number > latest_version:
                latest_version = published.sg_version_number

            versions_data[published.sg_version_number] = published_path

        if os.path.exists(sublayer_path):
            os.remove(sublayer_path)

        stage = Usd.Stage.CreateNew(sublayer_path)

        asset_prim = stage.OverridePrim('/asset' )
        stage.SetDefaultPrim(asset_prim)
        prim_sets = asset_prim.GetVariantSets()

        version_set = prim_sets.AddVariantSet('%s_version' % layer_name)

        for version, path in versions_data.items():

            version_str = '%03d' % version
            version_set.AddVariant(version_str)
            version_set.SetVariantSelection(version_str)

            with version_set.GetVariantEditContext():
                payloads = asset_prim.GetPayloads()
                payloads.AddPayload(path)

        if latest_version:
            path = versions_data[latest_version]
            version_set.AddVariant('latest')
            version_set.SetVariantSelection('latest')

            with version_set.GetVariantEditContext():
                payloads = asset_prim.GetPayloads()
                payloads.AddPayload(path)

            if not approved_version:
                version_set.AddVariant('recommended')
                version_set.SetVariantSelection('recommended')

                with version_set.GetVariantEditContext():
                    payloads = asset_prim.GetPayloads()
                    payloads.AddPayload(path)

        if approved_version:
            path = versions_data[approved_version]
            version_set.AddVariant('approved')
            version_set.SetVariantSelection('approved')

            with version_set.GetVariantEditContext():
                payloads = asset_prim.GetPayloads()
                payloads.AddPayload(path)

            version_set.AddVariant('recommended')
            version_set.SetVariantSelection('recommended')

            with version_set.GetVariantEditContext():
                payloads = asset_prim.GetPayloads()
                payloads.AddPayload(path)

        stage.GetRootLayer().Save()

if __name__ == '__main__':
    asset = 'LeoMesi'
    #asset = ''
    asset_manager = UsdAssetLayers('TPT')

    #asset_manager.set_asset(asset)

    path = 'V:/SGD/publish/shading/Main_Characters/LeoMesi/Master/LeoMesi_Master_f9bd3b6b797cba5a/v109/usd/LeoMesi.usda'
    step = 'Shading'
    variant = 'Master'
    variant = 'Test'
    version = 1

    #asset_manager.add_file(path, step, variant, version)

    asset_name = 'streetMarinette'

    for asset_view in asset_manager.database['Asset']:
        asset_manager.set_asset(asset_view.code)
        asset_manager.create_layer_files(force_build=True)

