import os
import importlib
import time

from collections import OrderedDict
from pprint import pprint

from pxr import Usd, Sdf, UsdGeom, Kind, UsdShade, Vt

import library.core.config_manager as config_manager

IGNORE_SETS = ['defaultLightSet', 'defaultObjectSet']

class UsdManager:

    def __init__(self, project):
        self.project = project
        self.entity_type = 'Asset'

    def set_entity(self, entity_name, entity_type, asset_type=None):
        self.entity_name = entity_name
        self.entity_type = entity_type
        self.asset_type = asset_type

        self.config_solver = config_manager.ConfigSolver(project=self.project)
        self.project_config = self.config_solver.get_config('project')
        self.asset_config = self.config_solver.get_config('entity_usd_scheme', module='usd')

        self.asset_layers_config = self.asset_config['layers'][self.entity_type].get(self.asset_type)
        if not self.asset_layers_config:
            self.asset_layers_config = self.asset_config['layers'][self.entity_type]['default']


        self.assembly_folder = self.get_assembly_root()
        self.filename = self.get_assembly_filename()

    def get_assembly_root(self):
        root = self.project_config['paths']['usd_files']
        if self.entity_type == 'Asset':
            asset_type = self.asset_type.replace(' ', '_')
            print('root', root)
            print('asset_type', asset_type)
            print('root', self.entity_name)
            assembly_folder = '%s/assets/%s/%s' % (root, asset_type, self.entity_name)
        else:
            bits = self.entity_name.split('_')
            season = bits[0]
            episode = '_'.join(bits[:2])
            sequence = '_'.join(bits[:3])
            assembly_folder = '%s/shots/%s/%s/%s/%s' % (root, season, episode, sequence,  self.entity_name)

        return assembly_folder

    def get_assembly_filename(self):

        if self.entity_type == 'Asset':
            filename = '%s/asset_assembly.usda' % self.assembly_folder
        else:
            filename = '%s/shot_assembly.usda' % self.assembly_folder

        return filename

    def get_layer_filename(self, step, variant=None):
        print(step, self.entity_name, variant)
        print('assembly folder', self.assembly_folder)
        if variant is None:
            layer_name = '%s/%s/%s_%s.usda' % (self.assembly_folder, step, self.entity_name, step)
        else:
            layer_name = '%s/%s/%s_%s_%s.usda' % (self.assembly_folder, step, self.entity_name, step, variant)

        return layer_name

    def exists_root_layer(self):
        path = self.get_assembly_filename()
        return os.path.exists(path)

    def open(self, filename):
        self.filename = filename
        self.assembly_folder = os.path.dirname(self.filename)

        if os.path.exists(self.filename):
            print('The file do exists, opening: %s' % filename)
            self.stage = Usd.Stage.Open(self.filename)

        else:
            if not os.path.exists(self.assembly_folder):
                os.makedirs(self.assembly_folder)
            print('The file do not exists, creating: %s' % filename)
            self.stage = Usd.Stage.CreateNew(self.filename)

    def get_stage_from_maya(self, node):
        import maya.cmds as cmds
        import mayaUsd

        shape = cmds.ls(node, long=True)[0]  # make sure to use long full paths
        self.stage = mayaUsd.ufe.getStage(shape)
        self.get_default_prim()

    def get_default_prim(self):
        default_prim = self.stage.GetDefaultPrim()
        if self.entity_type == 'Asset':
            if not default_prim:
                default_prim = self.stage.OverridePrim('/asset')
                self.stage.SetDefaultPrim(default_prim)
        elif self.entity_type == 'Shot':
            if not default_prim:
                default_prim = self.stage.OverridePrim('/World')
                self.stage.SetDefaultPrim(default_prim)

        return default_prim

    def set_default_prim(self, prim_name):
        print('set')
        asset_prim = self.stage.GetPrimAtPath(prim_name)
        self.stage.SetDefaultPrim(asset_prim)
        return asset_prim

    default_prim = property(fget=get_default_prim, fset=set_default_prim)

    @property
    def root_layer(self):
        root_layer = self.stage.GetRootLayer()
        return root_layer

    def print(self):
        print(self.root_layer.ExportToString())

    def save_stage(self):
        self.stage.GetRootLayer().Save()

    def save_as(self, new_path):
        self.root_layer.Export(new_path)

    def get_asset_info(self):
        asset_info = self.default_prim.GetAssetInfo()
        return asset_info


    def set_prim_purpose(self, prim_path, purpose):
        prim = self.stage.GetPrimAtPath(prim_path)
        render_purpose = prim.GetAttribute('purpose')
        render_purpose.Set(purpose)

    def override_prim(self, prim_path):
        print('override', prim_path)
        prim = self.stage.GetPrimAtPath(prim_path)
        prim.SetSpecifier(Sdf.SpecifierOver)

    def get_attribute(self, prim, attribute_name):
        attr = prim.GetAttribute(attribute_name)
        return attr.Get()

    def set_attribute(self, prim, attribute_name, value):
        attr = prim.GetAttribute(attribute_name)
        return attr.Set(value)

    def unique_name(self, prim_path: Sdf.Path) -> Sdf.Path:
        """Return Sdf.Path that is unique under the current composed stage.

        Note that this technically does not ensure that the Sdf.Path does not
        exist in any of the layers, e.g. it could be defined within a currently
        unselected variant or a muted layer.

        """
        src = prim_path.pathString.rstrip("123456789")
        i = 1
        while self.stage.GetPrimAtPath(prim_path):
            prim_path = Sdf.Path(f"{src}{i}")
            i += 1
        return prim_path
    def duplicate_prim(self, prim: Usd.Prim, new_path: str) -> Sdf.Path:
        """Duplicate prim"""
        path = prim.GetPath()
        for spec in prim.GetPrimStack():
            layer = spec.layer
            Sdf.CopySpec(layer, path, layer, new_path)
        return new_path

    def create_attribute(self, prim, attribute_name, attribute_type=str, value=None):

        if attribute_type is str:
            type_token = Sdf.ValueTypeNames.String
        elif attribute_type is int:
            type_token = Sdf.ValueTypeNames.Int
        elif attribute_type is float:
            type_token = Sdf.ValueTypeNames.Double
        elif attribute_type is bool:
            type_token = Sdf.ValueTypeNames.Bool
        elif attribute_type == 'string_array':
            type_token = Sdf.ValueTypeNames.StringArray
        elif attribute_type == 'path':
            type_token = Sdf.ValueTypeNames.Asset
            if value and isinstance(value, str):
                value = Sdf.AssetPath(value)

        elif attribute_type == 'uchar':
            type_token = Sdf.ValueTypeNames.UChar

        elif attribute_type == 'token':
            type_token = Sdf.ValueTypeNames.Token

        else:
            type_token = attribute_type

        print(attribute_name, type_token)
        attribute = prim.CreateAttribute(attribute_name, type_token)
        if value is not None:
            attribute.Set(value)
        return attribute

    def set_prim_typename(self, prim_path, typename):
        prim = self.stage.GetPrimAtPath(prim_path)
        prim.SetTypeName(typename)

    def set_prim_kind(self, prim_path, kind_type):

        prim = self.stage.GetPrimAtPath(prim_path)
        model_API = Usd.ModelAPI(prim)
        if kind_type == 'assembly':
            model_API.SetKind(Kind.Tokens.assembly)
        elif kind_type == 'component':
            model_API.SetKind(Kind.Tokens.component)
        elif kind_type == 'group':
            model_API.SetKind(Kind.Tokens.group)

    def delete_prim_at(self, prim_path):
        self.stage.RemovePrim(prim_path)

    def delete_prim(self, prim):
        self.stage.RemovePrim(prim.GetPath())


    def set_asset_info(self, value):
        self.default_prim.SetAssetInfo(value)

    asset_info = property(fget=get_asset_info,fset=set_asset_info)

    def add_layer_metadata(self):
        pass

    def set_assembly(self):
        pass

    def get_payloads(self, prim):
        payloads = prim.GetPayloads()

        return payloads

    def get_layers(self):
        layer_stack = self.stage.GetLayerStack(includeSessionLayers=False)
        layers_data = OrderedDict()
        for layer in layer_stack:
            basename = os.path.basename(layer.realPath)
            name = basename.split('.')[0]
            layer_name = name.split('_')[-1].lower()
            layers_data[layer_name] = layer

        return layers_data


    def approve_layer_version(self, step, variant, version):
        print('=== approve_layer_version ===')
        print('Entity: %s' % self.entity_name)
        print('Entity type: %s' % self.entity_type)
        print('Variant: %s' % variant)
        print('Pipeline step: %s' % step)

        layers, _ = self.check_layer_in_assembly(step)

        if not layers:
            print('cant find layers')
            return
        layer_name = layers[-1]

        version_filename = self.get_layer_filename(step, variant=variant)

        version_manager = UsdManager(self.project)
        version_manager.open(version_filename)
        variant_set_name = '%s_version' % layer_name
        variant_sets = version_manager.default_prim.GetVariantSets()
        version_set = variant_sets.AddVariantSet(variant_set_name)
        version_set.SetVariantSelection('%03d' % version)

        with version_set.GetVariantEditContext():
            prim_metadata  = version_manager.default_prim.GetMetadata("payload")
            if not prim_metadata:
                return
            version_path = prim_metadata.ApplyOperations([])[0].assetPath

        version_manager.add_variant('%s_version' % layer_name, 'approved', version_path)
        version_manager.add_variant('%s_version' % layer_name, 'recommended', version_path)

        version_set.SetVariantSelection('recommended')
        version_manager.save_stage()


    def remove_layer_version(self, step, variant, version):
        layers, _ = self.check_layer_in_assembly(step)

        if not layers:
            print('cat find layers')
            return
        layer_name = layers[-1]
        variant_name = '%s_version' % layer_name
        variant_filename = self.get_layer_filename(layer_name)
        version_filename = self.get_layer_filename(layer_name, variant=variant)
        version_manager = UsdManager(self.project)
        version_manager.set_entity(self.entity_name, self.entity_type, asset_type=self.asset_type)
        version_manager.open(version_filename)
        default_prim = version_manager.default_prim

        if isinstance(version, int):
            version_str = '%03d' % version
        else:
            version_str = version

        variants = version_manager.get_layer_variants()
        if version_str in variants.get(variant_name, {}):
            v_set = default_prim.GetVariantSet(variant_name)

    def check_assembly_exists(self):
        print('= Check entity assembly =')
        assembly_manager = UsdManager(self.project)
        assembly_manager.set_entity(self.entity_name, self.entity_type, asset_type=self.asset_type)
        filename = assembly_manager.filename
        if not os.path.exists(filename):
            print('Don\'t exist, generating...')
            assembly_manager.open(filename)
            assembly_manager.create_entity_assembly()
        else:
            print('Exists...')


    def build_asset(self, versions_data):

        self.check_assembly_exists()
        step_files_data = {}

        for step , steps_data in versions_data.items():

            layers, _ = self.check_layer_in_assembly(step)

            if not layers:
                print('Can\'t find layers definition for the pipeline step %s' % step)
                continue

            layer_name = layers[-1]

            step_filename = self.get_layer_filename(layer_name)
            step_manager = UsdManager(self.project)
            step_manager.set_entity(self.entity_name, self.entity_type, asset_type=self.asset_type)
            step_manager.open(step_filename)

            for variant, variant_data  in steps_data.items():
                variant_filename = self.get_layer_filename(layer_name, variant=variant)

                step_manager.add_variant('%s_variant' % layer_name, variant, variant_filename)
                variant_manager = UsdManager(self.project)
                variant_manager.set_entity(self.entity_name, self.entity_type, asset_type=self.asset_type)
                variant_manager.open(variant_filename)

                for version, version_path in variant_data.items():
                    if isinstance(version, int):
                        version_str = '%03d' % version
                    else:
                        version_str = version
                    variant_manager.add_variant('%s_version' % layer_name, version_str, version_path)

                variant_manager.set_selected_variant(variant_manager.default_prim, '%s_version' % layer_name, 'recommended')
                variant_manager.save_stage()

            selected_variant = 'Master' if 'Master' in steps_data.keys() else list(steps_data.keys())[0]
            step_manager.set_selected_variant(step_manager.default_prim, '%s_variant' % layer_name, selected_variant)
            step_manager.save_stage()



    def add_asset_version(self, step, variant, version, version_path, add_empty_layer=False):

        print('=== approve_layer_version ===')
        print('Entity: %s' % self.entity_name)
        print('Entity type: %s' % self.entity_type)
        print('Pipeline step: %s' % step)
        print('Variant: %s' % variant)
        print('Version: %s' % version)

        self.check_assembly_exists()

        layers, _ = self.check_layer_in_assembly(step)

        if not layers:
            print('Can\'t find layers definition for the pipeline step %s' % step)
            return

        layer_name = layers[-1]

        variant_filename = self.get_layer_filename(layer_name)
        version_filename = self.get_layer_filename(layer_name, variant=variant)

        version_manager = UsdManager(self.project)
        version_manager.set_entity(self.entity_name, self.entity_type, asset_type=self.asset_type)

        version_manager.open(version_filename)

        if isinstance(version, int):
            version_str = '%03d' % version
        else:
            version_str = version

        versions = version_manager.add_variant('%s_version' % layer_name, version_str, version_path)
        numeric_versions = [int(ver) for ver in versions if ver.isdigit()]

        if numeric_versions and version >= max(numeric_versions):
            version_manager.add_variant('%s_version' % layer_name, 'latest', version_path)

            if 'approved' not in versions:
                version_manager.add_variant('%s_version' % layer_name, 'recommended', version_path)
                version_manager.set_selected_variant(version_manager.default_prim, '%s_version' % layer_name, 'recommended')
            else:

                version_manager.set_selected_variant(version_manager.default_prim, '%s_version' % layer_name, 'approved')

        print('Saving version layer: %s' % version_manager.filename)
        try:
            version_manager.save_stage()
        except:
            time.sleep(2)
            try:
                version_manager.save_stage()
            except:
                pass


        variant_manager = UsdManager(self.project)
        variant_manager.set_entity(self.entity_name, self.entity_type, asset_type=self.asset_type)

        variant_manager.open(variant_filename)
        if add_empty_layer:
            variant_manager.add_variant('%s_variant' % layer_name, 'None', None)

        variant_manager.add_variant('%s_variant' % layer_name, variant, version_filename)
        variant_manager.set_selected_variant(variant_manager.default_prim, '%s_variant' % layer_name, 'Master')

        print('Saving variant layer: %s' % variant_manager.filename)
        try:
            variant_manager.save_stage()
        except:
            time.sleep(2)
            try:
                variant_manager.save_stage()
            except:
                pass

        self.check_layer_in_assembly(step)

    def get_layer_info(self, step, layers_data):
        for item in layers_data:
            for key, value in item.items():
                return_steps = [key]

                if isinstance(value, list):
                    next_step, layer_info = self.get_layer_info(step, value)
                    if layer_info:
                        next_step.insert(0, key)
                        return next_step, layer_info
                elif value.get('publish_type') == step:

                    return return_steps, value

        return None, None

    def check_layer_in_assembly(self, step):
        all_steps, layer_data = self.get_layer_info(step, self.asset_layers_config)
        return all_steps, layer_data

    def get_layer_step(self, layer_name, data=None):
        if data is None:
            data = self.asset_layers_config

        for layer_config in data:
            for key, value in layer_config.items():
                if isinstance(value, list):
                    step = self.get_layer_step(layer_name, value)
                    if step:
                        return step
                else:
                    if key == layer_name:
                        return value['publish_type']
        return None


    def add_variant(self, variant_set_name, variant_name, variant_value, as_selected=False):
        print('= add_variant =')
        print('Variant set name: %s' % variant_set_name)
        print('Variant name: %s' % variant_name)
        print('Variant value: %s' % variant_value)

        variant_sets = self.default_prim.GetVariantSets()
        variant_set = variant_sets.AddVariantSet(variant_set_name)
        selected = variant_set.GetVariantSelection()

        if variant_value:
            try:
                variant_set.SetVariantSelection(variant_name)
            except:
                return []

            with variant_set.GetVariantEditContext():
                payloads = self.default_prim.GetPayloads()
                payloads.ClearPayloads()
                try:
                    payloads.AddPayload(variant_value)
                except:
                    print('Can\'t add layer: %s' % variant_value)
        else:
            variant_set.SetVariantSelection(selected)

        if not as_selected and selected:
            variant_set.SetVariantSelection(selected)

        all_names = variant_set.GetVariantNames()

        return all_names

    def create_entity_assembly(self, layers_config=None):
        print('=== create_entity_assembly ===')
        print('Entity: %s' % self.entity_name)
        print('Entity type: %s' % self.entity_type)

        if layers_config is None:
            entity_layers_config = self.asset_layers_config
            for config in entity_layers_config:
                print(config)
                if self.asset_type in config:
                    layers_config = config[self.asset_type]
                    print('>>', layers_config)

            if layers_config is None:
                layers_config = entity_layers_config
            #layers_config = entity_layers_config.get(self.asset_type, entity_layers_config['default'])
        else:
            print('Adding children layers')

        self.clear_sublayers()

        for layer_config in layers_config:
            for layer_name, layer_data in layer_config.items():
                print('Layer: %s' % layer_name)
                if layer_name == 'breakdown':
                    print('Generating breakdown')
                    import usd.lib.build_breakdown as build_breakdown
                    importlib.reload(build_breakdown)
                    builder = build_breakdown.BreakdownBuilder(self.entity_name, project=self.project)
                    breakdown_path = builder.build_breakdown()
                    self.add_sublayer(breakdown_path)
                    continue

                new_layer_file = UsdManager(self.project)
                new_layer_file.set_entity(self.entity_name, self.entity_type, asset_type=self.asset_type)
                layer_path = new_layer_file.get_layer_filename(layer_name)

                if os.path.exists(layer_path):
                    os.remove(layer_path)

                print('Layer path', layer_path)
                new_layer_file.open(layer_path)

                if isinstance(layer_data, list):
                    print('>>> create_entity_assembly')
                    new_layer_file.create_entity_assembly(layers_config=layer_data)
                new_layer_file.save_stage()
                self.add_sublayer(layer_path)

        print('adding prims', self.entity_type)
        if self.entity_type == 'Asset':
            asset_prim = self.stage.DefinePrim('/asset', 'Xform')
            self.default_prim = '/asset'
            if self.asset_type != 'Sets' or self.asset_type != 'Collections':
                render_prim = self.stage.DefinePrim('/asset/render', 'Xform')
                proxy_prim = self.stage.DefinePrim('/asset/proxy', 'Xform')

                self.set_prim_purpose('/asset/render', 'render')
                self.set_prim_purpose('/asset/proxy', 'proxy')

        else:
            world_prim = self.stage.DefinePrim('/World', 'Xform')
            self.default_prim = '/World'

        self.save_stage()

    def set_mute_layer_status(self, layer_name, status):
        all_layers = self.get_layers()
        layer_name = layer_name.lower()
        if layer_name in all_layers:
            layer_path = all_layers[layer_name].identifier
            if status:
                print('mute', layer_path)
                self.stage.MuteLayer(layer_path)
            else:
                print('unmute', layer_path)

                self.stage.UnmuteLayer(layer_path)

    def clear_sublayers(self):
        self.root_layer.subLayerPaths = []

    def get_sublayers(self, layer=None):
        if layer:
            return layer.subLayerPaths
        else:
            return self.root_layer.subLayerPaths

    def add_sublayer(self, sublayer_path, index=None):
        if not sublayer_path or not os.path.exists(sublayer_path):
            return
        print('add sublayer', sublayer_path)
        if isinstance(sublayer_path, str):
            sublayer_path = sublayer_path.replace('\\', '/')
        if sublayer_path in self.root_layer.subLayerPaths:
            return

        if index is None:
            print('add next', sublayer_path)
            self.root_layer.subLayerPaths.append(sublayer_path)

        elif isinstance(index, int):
            self.root_layer.subLayerPaths.insert(index, sublayer_path)

    def repath_prim(self, source_path, new_path):
        import usd.lib.reparent_prim as reparent_prim
        importlib.reload(reparent_prim)
        reparent_prim.reparent_prim(self.root_layer, source_path, new_path)

    def remove_attributes(self, parser_config, flatten_meshes=False):
        print('remove_attributes')
        cleaner_config = self.config_solver.get_config('clean_attributes_config', module='usd')

        node_types = cleaner_config[parser_config].get('node_types', [])
        attributes = cleaner_config[parser_config].get('attributes', [])
        prefixes = cleaner_config[parser_config].get('prefix', [])

        to_delete = []
        for prim in self.stage.TraverseAll():
            prim_type = prim.GetTypeName()
            parent_prim = prim.GetParent()

            if prim_type in node_types:
                continue

            if prim_type == 'Mesh' and flatten_meshes:
                parent_prim.SetTypeName("Mesh")
                to_delete.append(prim)

            for prim_property in prim.GetAuthoredProperties():
                if prim.GetTypeName() == 'GeomSubset':
                    continue
                property_name = prim_property.GetName()
                delete_attribute = True

                if property_name in attributes:
                    delete_attribute = False

                if delete_attribute:
                    for prefix in prefixes:
                        if property_name.startswith(prefix):
                            print(property_name, prefix)
                            delete_attribute = False

                if delete_attribute:
                    prim.RemoveProperty(prim_property.GetName())


                elif flatten_meshes and prim_type == 'Mesh':

                    if property_name == 'material:binding':
                        parent_prim.SetTypeName("Mesh")
                        mat_bind_api = UsdShade.MaterialBindingAPI.Apply(parent_prim)
                        to_delete.append(prim)
                        target = prim_property.GetTargets()[0]
                        material_prim = self.stage.GetPrimAtPath(target)
                        material = UsdShade.Material(material_prim)
                        mat_bind_api.Bind(material)
                    elif not hasattr(prim_property, 'AddTarget'):

                        property_type = prim_property.GetTypeName()

                        value = prim_property.Get()
                        self.create_attribute(parent_prim, property_name, attribute_type=property_type, value=value)

        for prim in to_delete:
            self.stage.RemovePrim(prim.GetPath())


    def get_selected_variant(self, prim, variant_set_name):
        prim_sets = prim.GetVariantSets()
        version_set = prim_sets.AddVariantSet(variant_set_name)
        selected = version_set.GetVariantSelection()
        return selected

    def set_selected_variant(self, prim, variant_set_name, variant_value):
        prim_sets = prim.GetVariantSets()
        version_set = prim_sets.AddVariantSet(variant_set_name)
        selected = version_set.SetVariantSelection(variant_value)
        return selected

    def get_variant_versions(self, layer_name, variant):
        if not self.default_prim.IsValid():
            print('not valid prim')
            return {}

        all_variant_sets = self.default_prim.GetVariantSets()
        versions = []
        variant_set_name = '%s_variant' % layer_name
        version_set_name = '%s_version' % layer_name

        variant_set = all_variant_sets.GetVariantSet(variant_set_name)
        variant_set.SetVariantSelection(variant)

        with variant_set.GetVariantEditContext():
            version_set = all_variant_sets.GetVariantSet(version_set_name)
            return version_set.GetVariantNames()


    def get_layer_variants(self):
        if not self.default_prim.IsValid():
            return {}

        all_variant_sets = self.default_prim.GetVariantSets()
        all_variant_data = {}
        for variant_set_name in all_variant_sets.GetNames():
            all_variant_data[variant_set_name] = {}
            variant_set = all_variant_sets.GetVariantSet(variant_set_name)
            current_selected = variant_set.GetVariantSelection()

            for variant_value in variant_set.GetVariantNames():
                all_variant_data[variant_set_name][variant_value] = {}

            if current_selected:
                variant_set.SetVariantSelection(current_selected)


        return all_variant_data


    def build_breakdown(self):
        correct_version = self.get_breakdown_versions()

        for item in self.shot_view.sg_breakdowns:

            asset_name = item.sg_asset.code
            asset_type =item.sg_asset.sg_asset_type.split(' ')[-1]
            print(asset_name, asset_type)
            if asset_type == 'Cameras':
                continue
            alias = item.sg_alias if item.sg_alias else '%s_%03d' % (asset_name, item.sg_instance)

            variants = {'Geometry_variant': item.sg_geometry_variant.code,
                        'Shading_variant': item.sg_shading_variant.code
                        }
            if asset_name in correct_version:
                asset_versions = correct_version[asset_name]
                if 'Model' in asset_versions:
                    variants['Geometry_version'] = asset_versions['Model'].get('approved_version',
                                                                                    asset_versions['Model'][
                                                                                        'highest_version'])
                if 'SetDressing' in asset_versions:
                    variants['Geometry_version'] = asset_versions['SetDressing'].get('approved_version',
                                                                                     asset_versions['SetDressing'][
                                                                                         'highest_version'])
                if 'Shading' in asset_versions:
                    variants['Shading_version'] = asset_versions['Shading'].get('approved_version',
                                                                                asset_versions['Shading'][
                                                                                    'highest_version'])
                if 'Lighsets' in asset_versions:
                    print('add lightsets')

                    variants['Light_variant'] = item.sg_geometry_variant.code
                    variants['Light_version'] = asset_versions['Lighsets'].get('approved_version',
                                                                               asset_versions['Lighsets'][
                                                                                   'highest_version'])

            self.add_asset_breakdown(asset_name, alias, variants=variants, asset_type=asset_type)

    def get_breakdown_versions(self):
        correct_versions = {}
        for asset_view in self.shot_view.sg_breakdowns.sg_asset:
            for version in asset_view.sg_published_elements:
                pipeline_step = version.sg_step.code
                version_number = version.sg_version_number
                status = version.sg_status_list
#                if pipeline_step not in ['ModelingHD', 'Shading', 'SetDressing', 'Lighsets']:
#                    continue

                asset_name = version.sg_asset.code

                correct_versions[asset_name] = correct_versions.get(asset_name, {})
                correct_versions[asset_name][pipeline_step] = correct_versions[asset_name].get(pipeline_step, {})

                last_higest_version = correct_versions[asset_name][pipeline_step].get('highest_version', 0)
                approved_version = correct_versions[asset_name][pipeline_step].get('approved_version', 0)

                if version_number > last_higest_version:
                    correct_versions[asset_name][pipeline_step]['highest_version'] = version_number

                if status == 'cmpt' and version_number > approved_version:
                    correct_versions[asset_name][pipeline_step]['approved_version'] = version_number

        return correct_versions

    def get_all_prims_by_type(self, prim_type):
        all_prims = []
        for prim in self.stage.Traverse():
            if prim.GetTypeName() == prim_type:
                all_prims.append(prim)
        return all_prims


    def add_asset_breakdown(self,asset_name, alias, variants=None, asset_type='Char'):
        filename = '%s/breakdown_layer.usda' % self.publish_path

        stage = get_stage(filename, self.root_node)

        root_prim = stage.DefinePrim('/%s' % self.root_node, 'Xform')
        #root_prim.SetKind('assembly')
        print('/%s/%s' % (self.root_node, asset_type))
        asset_type_prim = stage.DefinePrim('/%s/%s' % (self.root_node, asset_type), 'Scope')
        #asset_type_prim.SetKind('group')
        asset_prim_path = '/%s/%s/%s' % (self.root_node, asset_type, alias)

        asset_prim = stage.DefinePrim(asset_prim_path, 'Xform')
        #asset_prim.SetKind('component')

        payloads = asset_prim.GetPayloads()
        payloads.ClearPayloads()
        asset_path = '%s/%s/asset_assembly.usda' % (self.publish_asset_path, asset_name)

        payloads.AddPayload(asset_path)

        if variants:
            for variant_set_name, variant_value in variants.items():
                variant_value = str(variant_value)
                variant_sets = asset_prim.GetVariantSets()
                v_sets_names = variant_sets.GetNames()

                if not variant_set_name in v_sets_names:
                    continue
                variant_set = variant_sets.GetVariantSet(variant_set_name)
                print(variant_set_name, variant_value)
                variant_set.SetVariantSelection(variant_value)

        print(stage.GetRootLayer().ExportToString())

        stage.GetRootLayer().Save()


    def set_prim_variant(self, prim, variant_set_name, variant_value):
        variant_sets = prim.GetVariantSets()
        v_sets_names = variant_sets.GetNames()

        if not variant_set_name in v_sets_names:
            return False

        variant_set = variant_sets.GetVariantSet(variant_set_name)
        variant_set.SetVariantSelection(variant_value)

        return True
    def maya_set_to_collection(self):
        import maya.cmds as cmds
        collections = {}
        asset_prim = self.stage.GetPrimAtPath('/asset')
        self.stage.SetDefaultPrim(asset_prim)
        all_sets = cmds.ls(type='objectSet')

        default_prim = self.stage.GetDefaultPrim()
        print(default_prim)

        for set_name in all_sets:
            if set_name in IGNORE_SETS or cmds.sets(set_name, q=True, renderable=True):
                continue
            if cmds.nodeType(set_name) == 'creaseSet':
                continue
            collection_name = set_name
            collection_api = Usd.CollectionAPI.Apply(default_prim, collection_name)
            objects = cmds.sets(set_name, q=True, l=True)
            full_objects = cmds.ls(objects, l=True)
            print(full_objects)
            for maya_node in full_objects:
                asset_path = maya_node.replace('|', '/')
                prim_node = self.stage.GetPrimAtPath(asset_path)
                if prim_node.IsValid():
                    collection_api.GetIncludesRel().AddTarget(asset_path)
            collection_api.GetExpansionRuleAttr().Set(Usd.Tokens.expandPrims)

            collections[set_name] = collection_api

        return collections

    def maya_set(self):


        import re
        import maya.cmds as cmds

        default_prim = self.stage.GetDefaultPrim()

        all_sets = cmds.ls(type='objectSet')
        collection_data = {}
        subset_data = {}
        pprint(all_sets)
        for set_name in all_sets:
            print('check', set_name)
            if set_name in IGNORE_SETS or cmds.sets(set_name, q=True, renderable=True):
                continue
            if cmds.nodeType(set_name) == 'creaseSet':
                continue

            objects = cmds.sets(set_name, q=True, l=True)
            full_objects = cmds.ls(objects, l=True, fl=True)

            for element in full_objects:
                if element.find('.') == -1:
                    print('only set')
                    collection_data[set_name] = collection_data.get(set_name, [])
                    collection_data[set_name].append(element)

                else:
                    obj_name, attribute = element.split('.', 1)
                    if attribute[0] == 'f':
                        data_type = 'face'
                    elif attribute[0] == 'v':
                        data_type = 'point'
                    else:
                        data_type = 'face'
                    index = re.findall('\[([0-9]+)\]', element)
                    if index:
                        subset_data[set_name] = subset_data.get(set_name, {})

                        subset_data[set_name][obj_name] = subset_data[set_name].get(obj_name, {'type': data_type, 'index': []})

                        subset_data[set_name][obj_name]['index'].append(int(index[0]))

        #pprint(subset_data)
        pprint(collection_data)
        for set_name, node_data in subset_data.items():
            for node, subset_data in node_data.items():
                print(set_name, node)
                prim_path = node.replace('|', '/')
                prim = self.stage.GetPrimAtPath(prim_path)
                if not prim.IsValid():
                    continue
                index = subset_data['index']
                subset_type = subset_data['type']
                geom = UsdGeom.Imageable(prim)
                print(index)
                print(subset_type)

                face_index = Vt.IntArray(index)
                UsdGeom.Subset.CreateGeomSubset(geom, set_name, subset_type, face_index, familyName=set_name)


        for set_name, obj_list in collection_data.items():
            collection_api = Usd.CollectionAPI.Apply(default_prim, set_name)
            for maya_node in obj_list:
                asset_path = maya_node.replace('|', '/')
                prim_node = self.stage.GetPrimAtPath(asset_path)
                if prim_node.IsValid():
                    collection_api.GetIncludesRel().AddTarget(asset_path)
            collection_api.GetExpansionRuleAttr().Set(Usd.Tokens.expandPrims)
    def add_light_link(self, prim, light_path):

        light_link_attribute = prim.GetAttribute('primvars:arnold:light_group')
        if not light_link_attribute.IsValid():
            light_link_attribute = self.create_attribute(prim,
                                                         'primvars:arnold:light_group',
                                                         attribute_type='string_array',
                                                         value=[light_path])
            self.create_attribute(prim,
                                  'primvars:arnold:use_light_group',
                                  attribute_type=bool,
                                  value=True)
        else:
            value = light_link_attribute.Get()
            #size = value.size()
            #value.resize(size+1)
            #value[size+1] =
            print(value)
            print(list(value))
            value = list(value)
            value.append(light_path)
            light_link_attribute.Set(value=value)


def test_asset_layer():
    path = 'V:/SGD/publish/usd/assets/Main_Characters/LeoMesi/asset_assembly.usda'
    usd_manager = UsdManager('SGD')
    usd_manager.set_entity('LeoMesi', 'Asset', asset_type='Main_Characters')
    usd_manager.create_entity_assembly()
    usd_manager.open()
    usd_manager.default_prim = 'asset'
    usd_manager.asset_info = {'asset_name': 'LeoMesi', 'asset_type': 'Main Characters'}
    usd_manager.get_layers()

    # usd_manager.print()
    # usd_manager.add_asset_version('model', 'Master', 129, 'V:/SGD/publish/model/Main_Characters/LeoMesi/Low/LeoMesi_Low_4f1bb89bcf9fc3dd/v007/usd/LeoMesi_geometry.usda')
    # usd_manager.approve_layer_version()_version('model', 'Master', 107)
    # usd_manager.check_layer_in_assembly('model')


def test_shots():
    usd_manager = UsdManager('TPT')
    shot_name = 's00_ep01_sq020_sh010'
    usd_manager.set_entity(shot_name, 'Shot')
    path = usd_manager.filename
    print(path)
    usd_manager.open(path)
    usd_manager.create_entity_assembly()

    # usd_path = 'V:/TPT/publish/layout/s00/s00_ep01/s00_ep01_sq010/s00_ep01_sq010_sh010/s00_ep01_sq010_sh010_Master_2aea712513585bf9/v008/usd/cam_s00_ep01_sq010_sh010.usd'
    # usd_manager.add_asset_version('layout', 'Master', 8, usd_path)
#update shot usd PostAnim s00_ep01_sq020_sh010 3 V:/TPT/publish/postanim/s00/s00_ep01/s00_ep01_sq020/s00_ep01_sq020_sh010/s00_ep01_sq020_sh010_Master_80de7776839bf18c/v003/usd/s00_ep01_sq020_sh010_cache.usd
    version = 7
    path = 'V:/TPT/publish/postanim/s00/s00_ep01/s00_ep01_sq020/s00_ep01_sq020_sh010/s00_ep01_sq020_sh010_Master_80de7776839bf18c/v%03d/usd/s00_ep01_sq020_sh010_cache.usd' % version
    usd_manager.add_asset_version('PostAnim',
                                  'Master',
                                  version,
                                  path,
                                  #add_empty_layer=True
                                  )


def test_delete_version():
    usd_manager = UsdManager('TPT')
    asset_name = 'Book'
    usd_manager.set_entity(asset_name, 'Asset', asset_type='Props')
    path = usd_manager.filename
    usd_manager.open(path)
    usd_manager.remove_layer_version('Shading', 'Master', 2)



if __name__ == '__main__':
    test_delete_version()