import re
import logging

import maya.cmds as cmds
from pxr import Usd, Sdf, UsdGeom, Kind, UsdShade, Vt

logger = logging.getLogger(__name__)

IGNORE_SETS = ['defaultLightSet', 'defaultObjectSet', 
               'initialParticleSE', 'initialShadingGroup', 
               'lambert1SG']
MAYA_ROOT_SETS = {'Modeling': 'Model_sets'}


def maya_set_to_usd_collection(manager):
    logger.info('== Converting Maya sets to USD collections ==')

    def add_geom_subSets(subset_data,
                         collection_data):
        for set_name, node_data in subset_data.items():
            for node, subset_data in node_data.items():
                if set_name in collection_data.keys():
                    collection_data[set_name].append(subset_data['rel'])
                
                prim_path = node.replace('|', '/')
                prim = manager.stage.GetPrimAtPath(prim_path)
                if not prim.IsValid():
                    continue

                index = subset_data['index']
                subset_type = subset_data['type']
                geom = UsdGeom.Imageable(prim)
                face_index = Vt.IntArray(index)
                UsdGeom.Subset.CreateGeomSubset(geom, 
                                                set_name, 
                                                subset_type, 
                                                face_index, 
                                                familyName=set_name)


    def create_collections(collection_data):
        for set_name, obj_list in collection_data.items():
            collection_api = Usd.CollectionAPI.Apply(default_prim, set_name)
            for maya_node in obj_list:
                asset_path = maya_node.replace('|', '/')
                prim_node = manager.stage.GetPrimAtPath(asset_path)
                if prim_node.IsValid():
                    collection_api.GetIncludesRel().AddTarget(asset_path)
            collection_api.GetExpansionRuleAttr().Set(Usd.Tokens.expandPrims)


    default_prim = manager.stage.GetDefaultPrim()
    all_sets = cmds.ls(type='objectSet')
    collection_data = {}
    subset_data = {}
    
    for set_name in all_sets:
        if set_name in IGNORE_SETS or cmds.sets(set_name, q=True, renderable=True):
            continue
        if cmds.nodeType(set_name) == 'creaseSet':
            continue

        objects = cmds.sets(set_name, q=True, l=True)
        full_objects = cmds.ls(objects, l=True, fl=True)

        for element in full_objects:
            if element.find('.') == -1:
                collection_data[set_name] = collection_data.get(set_name, [])
                collection_data[set_name].append(element)

            else:
                if not set_name in collection_data.keys():
                    collection_data[set_name] = []

                obj_name, attribute = element.split('.', 1)

                if attribute[0] == 'f':
                    data_type = 'face'
                elif attribute[0] == 'v':
                    data_type = 'point'
                else:
                    data_type = 'face'

                index = re.findall('\[([0-9]+)\]', element)

                if index:
                    subset_data[set_name] = subset_data.get(set_name, {})
                    subset_data[set_name][obj_name] = subset_data[set_name].get(obj_name, {'type': data_type,
                                                                                           'rel': f'{element.split(".")[0]}|{set_name}',
                                                                                           'index': []})
                    subset_data[set_name][obj_name]['index'].append(int(index[0]))

    add_geom_subSets(subset_data, collection_data)
    create_collections(collection_data)    

    return collection_data


def usd_collection_to_maya_set(manager,
                               set_names=None,
                               loader_type='geometry'):
    logger.info('== Converting USD collections to Maya sets ==')

    if set_names is None:
        set_names = [a for a in cmds.ls(type='objectSet') 
                     if a not in IGNORE_SETS and 
                     cmds.objectType(a) == 'objectSet']
        
    root_prim = manager.get_default_prim()
    if not root_prim.HasAPI(Usd.CollectionAPI):
        logger.info('No collections detected in this usd file')
        return
    
    usd_sets_attr = root_prim.GetAttribute('maya_inputs')
    if usd_sets_attr:
        usd_sets = usd_sets_attr.Get()
        if usd_sets and not set_names:
            set_names = usd_sets
    
    if loader_type == 'geometry':
        if not cmds.objExists('Model_sets'):
            root_set = cmds.sets(name='Model_sets')
        else:
            root_set = 'Model_sets'
    
    for set_name in set_names:
        collention_attr = root_prim.GetAttribute(f'collection:{set_name}')
        if not collention_attr:
            continue
        
        relationship = root_prim.GetRelationship(f'collection:{set_name}:includes')
        rel_prims = relationship.GetTargets()
        if not rel_prims:
            continue

        nodes_to_add = list()

        for prim_path in rel_prims:
            prim = manager.stage.GetPrimAtPath(prim_path)

            if not prim.IsValid():
                continue
            elif prim.GetTypeName() == 'GeomSubset':
                parent_path = manager.stage.GetPrimAtPath(prim.GetPath().GetParentPath()).GetPath()
                aprent_path = str(parent_path).replace('/', '|')
                element_type = prim.GetAttribute('elementType').Get()
                if element_type == 'face':
                    identifier = 'f'
                
                face_index = prim.GetAttribute('indices').Get()
                for index in face_index:
                    nodes_to_add.append(f'{aprent_path}.{identifier}[{index}]')
            else:
                maya_dag_path = str(prim_path).replace('/', '|')
                if cmds.objExists(maya_dag_path):
                    nodes_to_add.append(maya_dag_path)

        logger.info(f'Adding {str(len(nodes_to_add))} nodes to set {set_name}')

        if not cmds.objExists(set_name):
            new_set = cmds.sets(nodes_to_add, name=set_name)
            if loader_type == 'geometry':
                cmds.sets(new_set, e=True, forceElement=root_set)
        else:
            for node in nodes_to_add:
                cmds.sets(node, e=True, forceElement=set_name)