import os
import ast
import typing
import importlib
import maya.cmds as cmds
from pxr import Usd, UsdGeom, Sdf, Gf

import shotgrid_lib.shotgrid_helpers as helpers
import shotgrid_lib.database as database
import usd.lib.usd_rig as usd_rig
importlib.reload(usd_rig)


def get_rig_path(asset_name, variant, project):

    sg_database = database.DataBase()
    sg_database.fill(project, precatch=False)

    filters = [['code', 'is', asset_name]]
    sg_database.query_sg_database('Asset', filters=filters)
    sg_database.query_sg_database('CustomEntity11', as_precache=True)
    sg_database.query_sg_database('Step', as_precache=True)

    asset_view = sg_database['Asset'][asset_name]
    variant_view = sg_database['CustomEntity11'][variant]
    step_view = sg_database['Step']['Rig']

    version = asset_view.sg_published_elements.find_with_filters(sg_variant=variant_view,
                                                                 sg_step=step_view,
                                                                 sg_complete=True,
                                                                 sg_delete=False,
                                                                 sg_status_list='cmpt',
                                                                 single_item=True)
    if version.empty:
        version = asset_view.sg_published_elements.find_with_filters(sg_variant=variant_view,
                                                                     sg_step=step_view,
                                                                     sg_complete=True,
                                                                     sg_delete=False,
                                                                     single_item=True)

    if version.empty:
        return None, None

    path = '%s/%s' % (version.sg_published_folder, version.sg_files.get('rig'))
    if os.path.exists(path):

        return path, version
    return None, None

def get_parent_asset(connection, node_name, project):
    bits = node_name.split('|')
    for node_subname in bits:
        if not node_subname:
            continue

        if cmds.nodeType(node_subname) == 'Assembly' and cmds.getAttr('%s.assemblyType' % node_subname) == 'Set':
            sg_code = cmds.getAttr('%s.sg_id' % node_subname)
            if sg_code.find(':') > -1:
                sg_entity, sg_id = sg_code.split(':', 1)
                publish_data = helpers.get_publish_data_by_id(connection, int(sg_id), project=project)
                return publish_data['sg_asset']

    return {}



def create_breakdown(task_view, asset_data):


    set_data = asset_data.copy()
    set_data['asset_name'] = asset_data['set_name']
    set_data['instance_number'] = 1

    record = database.Record('CustomEntity12')
    sg_database = task_view._database
    sg_database.query_sg_database('CustomEntity11')
    variant_view = sg_database['CustomEntity11'][asset_data['geometry_variant']]
    shading_variant_view = sg_database['CustomEntity11'][asset_data['shading_variant']]
    asset_filters = [['code', 'in', [asset_data['asset_name'], asset_data['set_name']]]]
    sg_database.query_sg_database('Asset', filters=asset_filters)
    asset_view = sg_database['Asset'][asset_data['asset_name']]

    record.sg_link = task_view.entity
    record.project = task_view.project
    record.sg_geometry_variant = variant_view
    record.sg_shading_variant = shading_variant_view
    record.sg_instance = int(asset_data['instance_number'])
    record.sg_alias = asset_data['alias']
    record.sg_name_in_parent = asset_data['path_in_parent']
    record.sg_variant = variant_view
    record.sg_asset = asset_view
    record.code = asset_data['alias']

    if asset_data['set_name']:
        parent_view = get_breakdown(task_view, set_data)
        record.sg_parent_asset = parent_view

    sg_database.append(record)



def get_breakdown(task_view, asset_data):
    sg_database = task_view._database
    shot_view = task_view.entity
    asset_name = asset_data['asset_name']
    asset_instance = asset_data['instance_number']

    asset_filters = [['code', 'is', asset_name]]
    sg_database.query_sg_database('Asset', filters=asset_filters)
    asset_view = sg_database['Asset'][asset_name]

    breakdown_filter = [['sg_link', 'name_is', shot_view.code],
                        ['sg_asset', 'name_is', asset_name],
                        ['sg_instance','is', asset_instance]
                        ]
    sg_database.query_sg_database('CustomEntity12', filters=breakdown_filter)

    this_breakdown = sg_database['CustomEntity12'].find_with_filters(sg_asset=asset_view,
                                                                      sg_link=shot_view,
                                                                      sg_instance=asset_instance,
                                                                      single_item=True)

    return this_breakdown

def get_world_transform_xform(prim: Usd.Prim) -> typing.Tuple[Gf.Vec3d, Gf.Rotation, Gf.Vec3d]:
    """
    Get the local transformation of a prim using Xformable.
    See https://openusd.org/release/api/class_usd_geom_xformable.html
    Args:
        prim: The prim to calculate the world transformation.
    Returns:
        A tuple of:
        - Translation vector.
        - Rotation quaternion, i.e. 3d vector plus angle.
        - Scale vector.
    """
    xform = UsdGeom.Xformable(prim)
    time = Usd.TimeCode.Default() # The time at which we compute the bounding box
    world_transform: Gf.Matrix4d = xform.ComputeLocalToWorldTransform(time)
    translation: Gf.Vec3d = world_transform.ExtractTranslation()
    rotation: Gf.Rotation = world_transform.ExtractRotation()
    scale: Gf.Vec3d = Gf.Vec3d(*(v.GetLength() for v in world_transform.ExtractRotationMatrix()))
    return translation, rotation, scale


def get_alias(stagePath, sdfPath):
    import mayaUsd

    asset_data = {}
    stage = mayaUsd.ufe.getStage(stagePath)

    root_prim = stage.GetDefaultPrim()

    set_name_attr = root_prim.GetAttribute('atlantis:asset_name')
    set_name = set_name_attr.Get()

    asset_prim = stage.GetPrimAtPath(sdfPath)

    asset_attribute = asset_prim.GetAttribute('asset_name')
    instance_number_attribute = asset_prim.GetAttribute('instance_number')

    asset_name = asset_attribute.Get()
    instance_number = instance_number_attribute.Get()

    prim_sets = asset_prim.GetVariantSets()

    version_set = prim_sets.AddVariantSet('model_variant')
    geometry_variant = version_set.GetVariantSelection()

    shading_set = prim_sets.AddVariantSet('shader_variant')
    shading_variant = shading_set.GetVariantSelection()

    assembly_name = sdfPath.split('/')[-1]
    xformable = UsdGeom.Xformable(asset_prim)

    time = Usd.TimeCode.Default()  # The time at which we compute the bounding box
    world_transform = xformable.ComputeLocalToWorldTransform(time)
    translation = world_transform.ExtractTranslation()
    rotation = world_transform.ExtractRotation()

    rotation = rotation.Decompose((1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0 ,0.0, 1.0))
    scale = Gf.Vec3d(*(v.GetLength() for v in world_transform.ExtractRotationMatrix()))

    asset_data['asset_name'] = asset_name
    asset_data['instance_number'] = instance_number
    asset_data['assembly_name'] = '%s_PROP' % assembly_name
    asset_data['alias'] = assembly_name
    asset_data['path_in_parent'] = sdfPath

    asset_data['geometry_variant'] = geometry_variant
    asset_data['shading_variant'] = shading_variant
    asset_data['set_name'] = set_name
    asset_data['transform'] = {'translation': translation, 'rotation': rotation, 'scale': scale}

    return asset_prim, asset_data


def load_rig(asset_data, path, rig_version):
    root_control = 'x_world_rig_ctl'


    asset_name = rig_version.sg_asset.code
    asset_type = rig_version.sg_asset.sg_asset_type

    if not cmds.objExists('Props'):
        group = cmds.createNode('transform', name='Props')
    else:
        group = 'Props'

    node_name = cmds.createNode('RigAssembly', name=asset_data['assembly_name'], parent=group)
    cmds.setAttr('%s.assetName' % node_name, asset_name, type='string')
    cmds.setAttr('%s.assetName' % node_name, lock=1)
    cmds.setAttr('%s.assetType' % node_name, asset_type, type='string')
    cmds.setAttr('%s.assetType' % node_name, lock=1)
    cmds.setAttr('%s.instanceNumber' % node_name, asset_data['instance_number'])
    cmds.setAttr('%s.instanceNumber' % node_name, lock=1)

    cmds.setAttr('%s.rigPath' % node_name, path, type='string')
    cmds.setAttr('%s.rigPath' % node_name, lock=1)

    cmds.setAttr('%s.rigVariant' % node_name, asset_data['geometry_variant'], type='string')
    cmds.setAttr('%s.rigVariant' % node_name, lock=1)
    cmds.setAttr('%s.rigVersion' % node_name, str(rig_version.sg_version_number), type='string')
    cmds.setAttr('%s.rigVersion' % node_name, lock=1)
    cmds.setAttr('%s.rigHash' % node_name, rig_version.sg_hash, type='string')
    cmds.setAttr('%s.rigHash' % node_name, lock=1)
    namespace = cmds.getAttr('%s.representationNamespace' % node_name)
    root_node = '%s:%s' % (namespace, root_control)
    geo_node = '%s:geo' % (namespace)

    translation = asset_data['transform']['translation']
    rotation = asset_data['transform']['rotation']
    scale = asset_data['transform']['scale']

    cmds.setAttr('%s.translateX' % root_node, translation[0])
    cmds.setAttr('%s.translateY' % root_node, translation[1])
    cmds.setAttr('%s.translateZ' % root_node, translation[2])

    cmds.setAttr('%s.rotateX' % root_node, rotation[0])
    cmds.setAttr('%s.rotateY' % root_node, rotation[1])
    cmds.setAttr('%s.rotateZ' % root_node, rotation[2])

    cmds.setAttr('%s.scaleX' % root_node, scale[0])
    cmds.setAttr('%s.scaleY' % root_node, scale[1])
    cmds.setAttr('%s.scaleZ' % root_node, scale[2])


def swap_to_rig(task=None, project=None):
    from mayaUsd.lib import proxyAccessor as pa

    stagePath, sdfPath = pa.getSelectedDagAndPrim()

    asset_prim, asset_data = get_alias(stagePath, sdfPath)

    asset_breakdown = get_breakdown(task, asset_data)
    path, rig_version = get_rig_path(asset_data['asset_name'], asset_data['geometry_variant'], project)

    if not rig_version:
        print('Cant find a rig version of %s' % asset_data['asset_name'])
        cmds.confirmDialog(title='Can\'t find rig',
                           message='Can\'t find a valid rig version for that asset',
                           button=['OK'],
                           defaultButton='OK',
                           cancelButton='OK',
                           dismissString='OK')

        return

    print(asset_data['assembly_name'], cmds.objExists(asset_data['assembly_name']))
    if cmds.objExists(asset_data['assembly_name']) and not asset_breakdown.empty:
        cmds.confirmDialog(title='Item already loaded',
                           message='This asset is already loaded in the scene',
                           button=['OK'],
                           defaultButton='OK',
                           cancelButton='OK',
                           dismissString='OK')

        print('Item already loaded as rig')
        return
    load_rig(asset_data, path, rig_version)

    asset_prim.SetActive(False)


    create_breakdown(task, asset_data)




