import importlib
import os
import logging
import importlib.resources as imp_resources

from PySide import QtWidgets
try:
    from shiboken6 import wrapInstance
except:
    from shiboken2 import wrapInstance

from maya import cmds
from maya import mel
from maya import OpenMayaUI
from maya.app.general.mayaMixin import MayaQWidgetDockableMixin

import library.ui.style_sheets as style_sheets
import library.ui.style_sheets.colors as stylesheet_colors

logger = logging.getLogger(__name__)

VARIANT_TO_STEP = {'setDressing': 'SetDressing',
                   'model': 'Model',
                   'shader': 'Shading'}
def application_main_window():
    return wrapInstance( int( OpenMayaUI.MQtUtil.mainWindow() ), QtWidgets.QMainWindow )

def file_name():
    return cmds.file( query = True, sceneName = True )

def file_new():
    cmds.file(new=True, force=True)
    cmds.file(type= 'mayaBinary')

def file_open(path):
    print(path)
    extension = os.path.splitext(path)[-1]

    if extension == '.ma': file_type = 'mayaAscii'
    elif extension == '.mb': file_type = 'mayaBinary'
    else: assert False

    cmds.file( path, f=1, options='v=0;', ignoreVersion=1,  o=1)
    mel.eval( 'addRecentFile "%s" "%s"' % (path, file_type))

def file_rename(path):
    folder = os.path.dirname(path)
    if not os.path.exists(folder):
        os.makedirs(folder)

    cmds.file( rename = path )
    #cmds.file( type= "mayaAscii" )  # Maya bug: need to force Ascii after rename

    #   Force refresh of Maya main window title
    #   TODO: report bug to Autodesk
    #
    selection = cmds.ls( selection = True )
    cmds.select( 'lambert1', replace = True )
    cmds.select( selection, replace = True )

def file_save(rename_to = None):

    if rename_to:
        file_rename( rename_to )

    current_name = file_name()
    extension = current_name.split('.')[-1]
    if extension == 'ma':
        maya_type = 'mayaAscii'
    else:
        maya_type = 'mayaBinary'
    path = cmds.file( save = True, type =maya_type, force = True )
    try:
        mel.eval( 'addRecentFile "%s" "%s"' % ( path, maya_type ) )
    except:
        pass

def playblast( workflowTaskType ):

    from modules.createPlayblast import createPlayblast

    return createPlayblast( task = workflowTaskType ).start()


def scan_rig_assemblies():
    all_inputs = {}
    for node in cmds.ls(type='RigAssembly'):
        asset_name = cmds.getAttr('%s.assetName' % node)
        alias = node.rsplit('_', 1)[0]
        path = cmds.getAttr('%s.rigPath' % node)
        variant = cmds.getAttr('%s.rigVariant' % node)
        version = cmds.getAttr('%s.rigVersion' % node)
        all_inputs[asset_name] = all_inputs.get(asset_name, {})
        all_inputs[asset_name][alias] = {}
        all_inputs[asset_name][alias]['Rig'] = {'variant': variant,
                                                'version_number': int(version),
                                                'path': path,
                                                'asset_name': asset_name,
                                                'code': asset_name,
                                                'loader_type': 'rig assembly',
                                                'type': 'Asset',
                                                'node': node

                                                }
        shading_path = cmds.getAttr('%s.shadingPath' % node)
        if shading_path and os.path.exists(shading_path):
            shading_variant = cmds.getAttr('%s.shadingVariant' % node)
            shading_version = cmds.getAttr('%s.shadingVersion' % node)
        else:
            shading_variant = 'Master'
            shading_version = 'latest'

        if not isinstance(shading_version, str) or shading_version.isdigit():
            shading_version = int(shading_version)

        all_inputs[asset_name][alias]['Shading'] = {'variant': shading_variant,
                                                    'version_number': shading_version,
                                                    'path': shading_path,
                                                    'asset_name': asset_name,
                                                    'code': asset_name,
                                                    'loader_type': 'rig assembly',
                                                    'type': 'Asset',
                                                    'node': node
                                                    }



    return all_inputs

def scan_asset_assemblies():
    all_inputs = {}
    for node in cmds.ls(type='UsdAssembly'):
        asset_name = cmds.getAttr('%s.assetName' % node)
        instance_number = cmds.getAttr('%s.instanceNumber' % node)
        bits = node.rsplit('_', 1)
        if bits[-1].isdigit():
            alias = node
        else:
            alias = node.rsplit('_', 1)[0]
        path = cmds.getAttr('%s.usdPath' % node)
        if cmds.attributeQuery('model_variant', n=node, exists=True):
            model_variant = cmds.getAttr('%s.model_variant' % node, asString=True)
            model_version = cmds.getAttr('%s.model_version' % node, asString=True)
            if model_version.isdigit():
                model_version = int(model_version)


            print(model_variant, model_version)
            all_inputs[asset_name] = all_inputs.get(asset_name, {})
            all_inputs[asset_name][alias] = {}
            all_inputs[asset_name][alias]['Modeling'] = {'variant': model_variant,
                                                        'version_number': model_version,
                                                        'path': path,
                                                        'asset_name': asset_name,
                                                        'code': asset_name,
                                                        'loader_type': 'usd assembly',
                                                        'type': 'Asset',
                                                        'node': node

                                                        }

        if cmds.attributeQuery('shading_variant', n=node, exists=True):
            shading_variant = cmds.getAttr('%s.shading_variant' % node, asString=True)
            shading_version = cmds.getAttr('%s.shading_version' % node, asString=True)
            if shading_version.isdigit():
                shading_version = int(shading_version)

            all_inputs[asset_name][alias]['Shading'] = {'variant': shading_variant,
                                                         'version_number': shading_version,
                                                         'path': path,
                                                         'asset_name': asset_name,
                                                         'code': asset_name,
                                                         'loader_type': 'usd assembly',
                                                         'type': 'Asset',
                                                         'node': node

                                                         }
    return all_inputs


def scan_transforms():
    node = '|asset'
    all_inputs = {}
    if cmds.objExists(node) and cmds.attributeQuery('asset_name', n='|asset', exists=True):

        asset_name = cmds.getAttr('|asset.asset_name')
        asset_type = cmds.getAttr('|asset.asset_type')
        model_version = cmds.getAttr('|asset.model_version')
        if model_version:
            model_version = int(model_version)
        else:
            model_version = 1
        variant = cmds.getAttr('|asset.model_variant')
        all_inputs[asset_name] = {asset_name: {'Model': {'version_number': model_version,
                                                         'variant': variant,
                                                         'asset_name': asset_name,
                                                         'code': asset_name,
                                                         'asset_type': asset_type,
                                                         'loader_type': 'import',
                                                         'type': 'Asset',
                                                         'node': node

                                                         }
                                               }
                                  }
    return all_inputs

def check_children(stage, root_prim):
    for child in root_prim.GetChildren():
        pass


def get_parent_set(prim):

    parent_prim = prim.GetParent()
    if not parent_prim.IsValid():
        return None
    asset_name_attr = parent_prim.GetAttribute('atlantis:asset_name')

    if asset_name_attr.IsValid():
        asset_name = asset_name_attr.Get()
        print('Set name', asset_name)
        return asset_name

    return get_parent_set(parent_prim)

def scan_usd_shapes():
    import mayaUsd
    usd_nodes = cmds.ls(type='mayaUsdProxyShape', l=True)

    all_inputs = {}
    for node in usd_nodes:
        stage = mayaUsd.ufe.getStage(node)
        default_prim = stage.GetDefaultPrim()
        loadables = stage.FindLoadable()

        for prim_path in loadables:
            prim_node = stage.GetPrimAtPath(prim_path)
            asset_attr = prim_node.GetAttribute('asset_name')
            instance_name = prim_node.GetName()
            if asset_attr.IsValid():
                asset_name = asset_attr.Get()
            else:
                asset_name = prim_node.GetName()

            parent_asset = get_parent_set(prim_node)
            print(parent_asset)
            if parent_asset:
                continue
            asset_type_attr = prim_node.GetAttribute('asset_type')
            asset_type = asset_type_attr.Get()
            all_inputs[asset_name] = all_inputs.get(asset_name, {})
            all_inputs[asset_name][instance_name] = all_inputs[asset_name].get(instance_name, {})

            variant_sets = prim_node.GetVariantSets()
            for variant_set_name in variant_sets.GetNames():
                variant_set = variant_sets.GetVariantSet(variant_set_name)
                options = variant_set.GetVariantNames()
                selected = variant_set.GetVariantSelection()
                short_var_name, var_type = variant_set_name.split('_', 1)
                short_var_name = VARIANT_TO_STEP.get(short_var_name, short_var_name)

                if short_var_name not in all_inputs[asset_name][instance_name]:
                    all_inputs[asset_name][instance_name][short_var_name] = {'asset_name': asset_name,
                                                                             'code': asset_name,
                                                                             'loader_type': 'usdProxy',
                                                                             'asset_type': asset_type,
                                                                             'type': 'Asset',
                                                                             'node': (node, instance_name)

                                                                             }
                if var_type == 'variant':
                    print(asset_name, 'add variant', selected)
                    all_inputs[asset_name][instance_name][short_var_name]['variant'] = selected
                    all_inputs[asset_name][instance_name][short_var_name]['variants'] = options

                elif var_type == 'version':
                    print('add version')
                    if selected.isdigit():
                        all_inputs[asset_name][instance_name][short_var_name]['version_number'] = int(selected)
                    else:
                        all_inputs[asset_name][instance_name][short_var_name]['version_number'] = selected

    return all_inputs


def scan_file():
    all_inputs = {}
    transform_data = scan_transforms()
    all_inputs.update(transform_data)
    rig_data = scan_rig_assemblies()
    all_inputs.update(rig_data)

    assets_data = scan_asset_assemblies()
    all_inputs.update(assets_data)

    usd_data = scan_usd_shapes()
    all_inputs.update(usd_data)

    return all_inputs


def set_scene_settings(entity_data):

    logger.info('Set scene config ')
    logger.debug('Set fps to %s %s' % (entity_data['fps'], entity_data['fps_name']))
    cmds.currentUnit(time=entity_data['fps_name'])

    cmds.setAttr('defaultResolution.width', entity_data['width'])
    cmds.setAttr('defaultResolution.height', entity_data['height'])
    print('set pixel aspect ratio')
    if 'deviceAspectRatio' in entity_data:
        cmds.setAttr('defaultResolution.deviceAspectRatio', entity_data.get('deviceAspectRatio'))
    cmds.setAttr('defaultResolution.pixelAspect', entity_data.get('pixelAspect', 1.0))

    if entity_data['context_type'] == 'Shot':
        start_frame = entity_data['start_frame']
        start_animation = start_frame - 10

        end_frame = entity_data['end_frame']
        end_animation = end_frame + 10

        logger.info('Set frame range to %s %s' % (start_frame, end_frame))

        cmds.playbackOptions(ast=start_animation, e=True)
        cmds.playbackOptions(aet=end_animation, e=True)
        cmds.playbackOptions(minTime=start_frame, e=True)
        cmds.playbackOptions(maxTime=end_frame, e=True)
        cmds.currentTime(start_frame, e=True)

def set_cameras_data(resolution_data):
    aperture = resolution_data.get('aperture')
    if not aperture:
        return

    for camera in cmds.ls(type='camera'):
        width = aperture[0]/ 25.4
        height = aperture[1] / 25.4
        cmds.setAttr('%s.horizontalFilmAperture' % camera, width)
        cmds.setAttr('%s.verticalFilmAperture' % camera, height)

def get_maya_widget(control_name):
    from maya import OpenMayaUI

    try:
        import shiboken2 as shiboken
    except:
        import shiboken6 as shiboken

    qctrl = OpenMayaUI.MQtUtil.findControl(control_name)
    maya_widget = shiboken.wrapInstance(int(qctrl), QtWidgets.QWidget)
    return maya_widget

def add_context_chooser(context_chooser):
    ctrl = 'flowLayout1'
    maya_widget = get_maya_widget(ctrl)
    layout = maya_widget.layout()
    layout.addWidget(context_chooser)
    layout.update()

class MainWindow(MayaQWidgetDockableMixin, QtWidgets.QMainWindow):
    def __init__( self, title, style_sheet='square_style', parent=None):
        host = application_main_window()
        super().__init__(parent=host)
        self.style_sheet = style_sheet
        self.title = title
        self.setWindowTitle(self.title)
        self.setContentsMargins(0, 0, 0, 0)

    def apply_style(self, central_widget):
        style_sheet_resource = imp_resources.path(style_sheets, '%s.qss' % self.style_sheet)

        colors = stylesheet_colors.stylesheet_color()
        try:
            with style_sheet_resource as style_sheet_path:
                with open(style_sheet_path, "r") as f:
                    _style = f.read()

                    for key, value in colors.items():
                        _style = _style.replace('<%s>' % key, value)
                    central_widget.setStyleSheet(_style)
        except:
            pass