import maya.cmds as cmds
from pprint import pprint
import shotgrid_lib.database as database

import mayaUsd
from pxr import Usd, UsdGeom


class SceneManagerCore():
    def __init__(self, context_type, entity_name, project, sg_database=None):

        self.context_type = context_type
        self.entity_name = entity_name
        self.project = project
        self.loaded_elements = {}
        if sg_database or True:
            self.database = sg_database
        else:
            self.database = database.DataBase()

            self.database.fill(project=self.project, precatch=True)

        self.get_loaded_elements()
    def get_loaded_elements(self):
        all_assemblies = cmds.ls(type='UsdAssembly', l=True)
        for assembly in all_assemblies:
            if self.is_child(assembly):
                continue
            self.loaded_elements[assembly] = self.get_assembly_data(assembly)


    def get_assembly_data(self, assembly_node):
        node_data = {}
        asset_attributes = ['usdPath',
                            'assetName',
                            'assetType',
                            ]

        all_attributes = cmds.listAttr(assembly_node)
        for attribute in all_attributes:
            if attribute in asset_attributes:
                node_data[attribute] = cmds.getAttr('%s.%s' % (assembly_node, attribute))
            elif attribute.endswith('_variant'):
                values = cmds.attributeQuery(attribute, node=assembly_node, le=True)
                values = values[0].split(':')
                node_data[attribute] = {'name': attribute,
                                        'value': cmds.getAttr('%s.%s' % (assembly_node, attribute), asString=True),
                                        'values':values}



        return node_data

    def get_scene_hash_list(self):
        hash_list = []
        for publish_data in self.loaded_elements.values():
            if publish_data.get('hash', ''):
                hash_list.append(publish_data['hash'])
        return hash_list
    def is_child(self, assembly):
        bits = assembly.split('|')
        for index in range(len(bits)-1):
            node = '|'.join(bits[:index])
            if not node:
                continue
            if cmds.nodeType(node) == 'Assembly':
                return True
        return False

    def get_valid_version(self, view):
        status_order = ['cmpt', 'rev', 'clsd']

        for status in status_order:
            correct_version = view.find_with_value('sg_status_list', status)
            if not correct_version.empty:
                return max(correct_version)
        return max(view)

    def check_update_node(self, maya_node, publish_view):
        maya_shotgrid_id = cmds.getAttr('%s.shotgrid_id' % maya_node).split(':')[-1]
        update = False
        if not maya_shotgrid_id.isdigit():
            update = True
        if int(maya_shotgrid_id) != int(publish_view.id):
            update = True

        if not update:
            return

        file_tag = cmds.getAttr('%s.fileTag' % maya_node)
        full_asset_path = '%s/%s' % (publish_view.sg_published_folder, publish_view.sg_files[file_tag])
        cmds.setAttr('%s.path' % maya_node, full_asset_path, type='string')
        cmds.setAttr('%s.variant' % maya_node, publish_view.sg_variant_name, type='string')
        cmds.setAttr('%s.version' % maya_node, str(publish_view.sg_version_number), type='string')
        cmds.setAttr('%s.shotgrid_id' % maya_node, str(publish_view.id), type='string')
        cmds.setAttr('%s.hash' % maya_node, str(publish_view.sg_hash), type='string')


    def get_prim_variants(self, prim):
        variant_sets = prim.GetVariantSets()
        v_sets_names = variant_sets.GetNames()
        variants = {'selected': {}, 'variants': {}}
        for variant_name in v_sets_names:
            variant_set = variant_sets.GetVariantSet(variant_name)
            variant_selected = variant_set.GetVariantSelection()
            variant_list = variant_set.GetVariantNames()
            variants['selected'][variant_name] = variant_selected
            variants['variants'][variant_name] = variant_list

        return variants

    def get_children_data(self, parent_prim, assets_data):

        for child in parent_prim.GetChildren():
            asset_name_attr = child.GetAttribute('asset_name')
            if asset_name_attr.IsValid():
                asset_name = asset_name_attr.Get()
                prim_name = child.GetName()
                variants = self.get_prim_variants(child)

                if asset_name not in assets_data:
                    assets_data[asset_name] = {'asset_variants': variants['variants'],
                                              'instances': {prim_name: variants['selected']}
                                              }
                else:
                    assets_data[asset_name]['instances'][prim_name] = variants['selected']
            else:
                assets_data = self.get_children_data(child, assets_data)

        return assets_data

    def inspect_usd_proxies(self):
        all_proxies = cmds.ls(type='mayaUsdProxyShape', long=True)

        all_entities = {}
        for proxy_shape in all_proxies:
            stage = mayaUsd.ufe.getStage(proxy_shape)
            default_prim = stage.GetDefaultPrim()
            prim_variants = self.get_prim_variants(default_prim)
            all_entities['proxy_shape'] = {'root': {'variants': prim_variants}}

            assets_data =self.get_children_data(default_prim, all_entities)

        all_assemblies = cmds.ls(type='UsdAssembly')
        for assembly in all_assemblies:
            asset_name = cmds.getAttr('%s.assetName' % assembly)
            asset_data = all_entities.get(asset_name, {'asset_variants': {}, 'instances': {}})
            all_attributes = cmds.listAttr(assembly)
            asset_data['instances'][assembly] = {}
            if not asset_data['asset_variants']:
                for attr in all_attributes:
                    if not attr.endswith('_variant') and not attr.endswith('_version'):
                        continue
                    variant_values = cmds.attributeQuery(attr, n=assembly, listEnum=True)
                    asset_data['asset_variants'][attr] = variant_values[0].split(':')

            for attr in all_attributes:
                if not attr.endswith('_variant') and not attr.endswith('_version'):
                    continue
                variant_value = cmds.getAttr('%s.%s' % (assembly, attr), asString=True)
                asset_data['instances'][assembly][attr] = variant_value

            all_entities[asset_name] = asset_data


        return all_entities
    def inspect_scene(self):
        scene_versions = self.loaded_elements
        scene_versions = self.inspect_usd_proxies()


        return scene_versions
