import os
import shutil
from collections import OrderedDict
from pprint import pprint

import maya.cmds as cmds

import shotgrid_lib.database as database
from pxr import Usd, UsdGeom, Sdf


import library.core.config_manager as config_manager

step_to_layers = OrderedDict()
def get_stage(filename, root, create=False):
    if os.path.exists(filename):
        stage = Usd.Stage.Open(filename)

    else:
        abs_path = os.path.abspath(filename)
        dirname = os.path.dirname(abs_path)
        if not os.path.exists(dirname):
            os.makedirs(dirname)
        stage = Usd.Stage.CreateNew(filename)
        if create:
            stage.GetRootLayer().Save()

    return stage

class UsdAsset():
    def __init__(self, asset_name, project=None, root_node='Geometry'):
        self.asset_name = asset_name
        self.project = project
        self.root_node = root_node
        self.error = False
        self.init()


    def get_config(self):
        self.config_solver = config_manager.ConfigSolver(project=self.project)
        self.project_data = self.config_solver.get_config('project')
        self.usd_asset_data = self.config_solver.get_config('usd_asset', module='usd')


    def get_publish_path(self):
        self.root_path = self.project_data['paths']['publish_server']
        asset_type = self.asset_view.sg_asset_type.replace(' ', '_')

        self.publish_path = '%s/publish/usd/assets/%s/%s' % (self.root_path, asset_type, self.asset_name)

        return self.publish_path


    def init(self):
        self.database = database.DataBase()
        self.database.fill(self.project, precatch=False)
        asset_filters =[['code', 'is', self.asset_name]]

        self.database.query_sg_database('Asset', filters=asset_filters)
        self.database.query_sg_database('Step', as_precache=True)

        if self.database['Asset'].empty:
            self.error = True
            return
        self.asset_view = self.database['Asset'][self.asset_name]

        self.asset_view.precache_dependencies(fields=['sg_published_elements'])

        self.get_config()
        self.get_publish_path()

    def add_sub_layer(self, sub_layer_path, root_layer):
        sub_layer = Sdf.Layer.CreateNew(sub_layer_path)
        # You can use standard python list.insert to add the subLayer to any position in the list
        root_layer.subLayerPaths.append(sub_layer.identifier)
        return sub_layer


    def build_asset_assembly(self, output_path=None, force_update=True, valid_layers=None, prim_list=None, layer_config=None):
        if output_path:
            assembly_path = output_path
        else:
            assembly_path = '%s/asset_assembly.usda' % self.publish_path

        if force_update and os.path.exists(assembly_path):
            os.remove(assembly_path)

        stage = get_stage(assembly_path, self.asset_name)

        root_layer = stage.GetRootLayer()
        if layer_config is None:
            layer_config = self.usd_asset_data['layers']

        for layer_data in layer_config:
            layer = list(layer_data.keys())[0]
            if valid_layers is not None and layer not in valid_layers:
                continue

            pipeline_step = list(layer_data.values())[0]
            if isinstance(pipeline_step, list):
                sub_layers = []
                layer_config = {}
                for step in pipeline_step:
                    layer_config[list(step.keys())[0]] = list(step.values())[0]
                    sub_layers.append(list(step.keys())[0])

                sublayer_path = '%s/%s/%s.usda' % (self.publish_path, layer, layer)
                relative_path = self.build_asset_assembly(output_path=sublayer_path, layer_config=pipeline_step)
                root_layer.subLayerPaths.append(relative_path)

            else:
                relative_path = '%s/%s/%s_variant.usda' % (self.publish_path, layer, layer)
                get_stage(relative_path, self.asset_name, create=True)


                new_layer = self.add_sub_layer(relative_path, root_layer)


        stage.GetRootLayer().Save()

        return assembly_path


    def create_prims(self, stage, prim_list):
        for prim_path, prim_data in prim_list.items():
            prim_type = prim_data.get('type', 'Xform')
            current_prim = stage.GetPrimAtPath(prim_path)
            if not current_prim.IsValid():
                current_prim = stage.DefinePrim(prim_path, prim_type)
            else:
                current_prim.SetTypeName(prim_type)

            if 'purpose' in prim_data:
                render_purpose = current_prim.GetAttribute('purpose')
                render_purpose.Set(prim_data['purpose'])

    def get_sandbox_path(self, asset_name, asset_type, pipeline_step):
        import library.sandbox as sandbox_builder

        sandbox_solver = sandbox_builder.Sandbox(project=self.project)

        sandbox_solver.set_context(asset_name=asset_name, asset_type=asset_type)
        path = sandbox_solver.generate_path()
        full_path_pattern = '%s/maya/data/%s_%s_v<version_number>.usda' % (path, asset_name, pipeline_step)
        full_path = sandbox_builder.get_version_path(full_path_pattern)
        return full_path


    def build_usd_scene(self, pipeline_step):

        cmds.file(new=True, f=True)
        asset_type = self.asset_view.sg_asset_type
        sandbox_path = self.get_sandbox_path(self.asset_name, asset_type, pipeline_step)
        step_data = self.usd_asset_data['pipeline_steps'].get(pipeline_step, {})

        valid_layers = step_data.get('valid_layers')
        prim_list =  step_data.get('added_prims', [])
        self.build_asset_assembly(output_path=sandbox_path,valid_layers=valid_layers, prim_list=prim_list)
        shape_node = cmds.createNode('mayaUsdProxyShape', name='%sShape' % self.asset_name)
        parent_node = cmds.listRelatives(shape_node, p=True)
        cmds.rename(parent_node, self.asset_name)
        cmds.setAttr('%s.filePath' % shape_node, sandbox_path, type='string')


    def update_step(self, layer_name, pipeline_step, all_publish):
        variant_file = ''
        step_view = self.database['Step'][pipeline_step]
        if step_view.empty:
            return
        step_versions_view = all_publish.find_with_filters(sg_step=step_view, sg_delete=False, sg_complete=True)
        if step_versions_view.empty:
            return
        for version in step_versions_view:
            full_path = '%s/%s' % (version.sg_published_folder, version.sg_files.get('usd'))

            if not os.path.exists(full_path):
                continue

            variant_name = version.sg_variant_name
            version = str(version.sg_version_number)

            version_file = self.add_file_variant(layer_name, 'version', full_path, version, name=variant_name)

            relative_path = './%s' % version_file.split('/')[-1]

            variant_file = self.add_file_variant(layer_name,
                                                 'variant',
                                                 relative_path,
                                                 variant_name,
                                                 select_variant='Master')
        return variant_file


    def update_asset(self, publish_step=None):
        if self.error:
            return
        for element in os.listdir(self.publish_path):
            full_path = '%s/%s' %(self.publish_path, element)
            if os.path.isdir(full_path):
                shutil.rmtree(full_path)
            else:
                os.remove(full_path)

        self.build_asset_assembly()
        all_publish = self.asset_view.sg_published_elements
        variant_file = None

        for layer_data in self.usd_asset_data['layers']:
            layer_name = list(layer_data.keys())[0]
            pipeline_step = list(layer_data.values())[0]

            if isinstance(pipeline_step, list):
                sub_layers = []
                for step in pipeline_step:
                    sub_layer = list(step.keys())[0]
                    sub_step = list(step.values())[0]
                    sub_layers.append(sub_layer)
                    variant_file = self.update_step(sub_layer, sub_step, all_publish)

                #print(sub_layers)
                #sublayer_path = '%s/%s/%s.usda' % (self.publish_path, layer_name, layer_name)
                #variant_file = self.build_asset_assembly(output_path=sublayer_path, valid_layers=sub_layers)

            else:
                if publish_step and publish_step != pipeline_step:
                    continue

                variant_file = self.update_step(layer_name, pipeline_step, all_publish)

        return variant_file
    def add_file_variant(self, layer_type, variant_name, path, variant_value, select_variant='', name=''):
        print('updating file: %s %s %s ' % (self.asset_name, layer_type, variant_name))
        if name:
            filename = '%s/%s/%s_%s_%s.usda' % (self.publish_path, layer_type, layer_type, name, variant_name)
        else:
            filename = '%s/%s/%s_%s.usda' % (self.publish_path, layer_type, layer_type, variant_name)

        variant_set_name = '%s_%s' % (layer_type.title(), variant_name)

        stage = get_stage(filename, self.asset_name)

        asset_prim = stage.GetPrimAtPath('/asset')
        if not asset_prim.IsValid():
            asset_prim = stage.OverridePrim('/asset')

            stage.SetDefaultPrim(asset_prim)

        variant_sets = asset_prim.GetVariantSets()
        v_sets_names = variant_sets.GetNames()

        if variant_set_name in v_sets_names:
            variant_set = variant_sets.GetVariantSet(variant_set_name)

        else:
            variant_set = variant_sets.AddVariantSet(variant_set_name)

        variant_name_list = variant_set.GetVariantNames()
        variant_name = variant_name.title()

        if variant_name in variant_name_list:
            return

        variant_set.AddVariant(str(variant_value))

        variant_set.SetVariantSelection(str(variant_value))

        with variant_set.GetVariantEditContext():
            payloads = asset_prim.GetPayloads()
            payloads.AddPayload(path)

        if select_variant:
            variant_set.SetVariantSelection(str(select_variant))

        # print(stage.GetRootLayer().ExportToString())
        stage.GetRootLayer().Save()

        return filename
