import os
import importlib

from pprint import pprint

import maya.cmds as cmds
from pxr import Usd, UsdGeom, Sdf, Kind, Gf

import usd.lib.usd_manager as usd_manager
importlib.reload(usd_manager)


def check_visible(node_path):
    if not node_path:
        return True
    if cmds.nodeType(node_path) == 'SetAssembly':
        return True
    if cmds.nodeType(node_path) == 'GeometryAssembly' and cmds.assembly(node_path, a=True, q=True) == 'Disabled':
        return False

    if not cmds.getAttr('%s.visibility' % node_path):
        return False

    parent = node_path.rsplit('|', 1)[0]
    if parent == node_path:
        return True
    visibility = check_visible(parent)
    return visibility


def get_attribute_difference(attribute, current_prim, reference_prim):
    current_properties = current_prim.GetPropertyNames()
    reference_properties = reference_prim.GetPropertyNames()

    current_value = None
    reference_value = None

    if attribute in current_properties:

        query_attr = current_prim.GetAttribute(attribute)
        current_value = query_attr.Get()
        if attribute in reference_properties:
            query_reference_attr = reference_prim.GetAttribute(attribute)
            reference_value = query_reference_attr.Get()

    if current_value == reference_value:
        return None
    return current_value

def decompose_prim(prim):
    xformable = UsdGeom.Xformable(prim)

    local_transformation = xformable.GetLocalTransformation()
    translation= local_transformation.ExtractTranslation()
    rotation = local_transformation.ExtractRotation()
    scale = Gf.Vec3d(*(v.GetLength() for v in local_transformation.ExtractRotationMatrix()))

    return xformable, translation, rotation, scale

def compare_transform(current_prim, reference_prim, attribute_differences):
    current_xformable, current_translation, current_rotation, current_scale = decompose_prim(current_prim)
    reference_xformable, reference_translation, reference_rotation, reference_scale = decompose_prim(reference_prim)

    if current_translation != reference_translation:
        attribute_differences['translate'] = current_translation

        print('set translate', current_translation)
        print('reference:', reference_translation)

    if current_rotation != reference_rotation:
        rotation = current_rotation.Decompose((1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0))

        attribute_differences['rotate'] = rotation

    if current_scale != reference_scale:
        attribute_differences['scale'] = current_scale


def compare_prims(current_prim, reference_prim):
    this_differences = {}
    for reference_child in reference_prim.GetChildren():
        child_name = reference_child.GetName()

        current_child = current_prim.GetChild(child_name)

        print('current child', current_child)
        if not current_child.IsValid():
            print('hidden')
            this_differences[child_name] = {'attributes': {'visible': False}}
            continue

        current_properties = current_child.GetPropertyNames()

        asset_name_attr = current_child.GetAttribute('atlantis:asset_name')
        if asset_name_attr.IsValid():
            asset_name = asset_name_attr.Get()
        else:
            asset_name = None

        asset_type_attr = current_child.GetAttribute('atlantis:asset_type')
        if asset_type_attr.IsValid():
            asset_type = asset_type_attr.Get()
        else:
            asset_type = None


        if asset_name is None or asset_type == 'Collections':
            child_differences = compare_prims(current_child, reference_child)
            if child_differences:
                this_differences[child_name] = {'children': child_differences}

        current_loaded = current_child.IsLoaded()

        reference_loaded = reference_child.IsLoaded()

        current_visible = current_child.IsActive()

        reference_visible = reference_child.IsActive()

        print(current_loaded, current_visible, reference_visible)
        attribute_differences = {}
        compare_transform(current_child, reference_child, attribute_differences)

        if current_visible != reference_visible :
            attribute_differences['active'] = current_visible

        if current_loaded != reference_loaded:
            attribute_differences['active'] = current_loaded


        if attribute_differences:
            this_differences[child_name] = this_differences.get(child_name, {})

            this_differences[child_name]['attributes'] = attribute_differences

    return this_differences


def add_attribute(prim, attribute, value=None, attr_type=Sdf.ValueTypeNames.String):
    new_attr = prim.CreateAttribute(attribute, attr_type)
    if value is not None:
        new_attr.Set(value)
    return new_attr


def create_override_child(stage, prim_name, value, parent_prim):
    if not value:
        return
    new_name = '%s/%s' % (parent_prim, prim_name)
    children = value.get('children', {})

    attributes = value.get('attributes', {})

    asset_prim = stage.OverridePrim(new_name)

    added_transform = False
    active = True
    for key, value in attributes.items():
        if key == 'active':
            active = value
            asset_prim.SetActive(value)

        elif key in ['translate', 'rotate', 'scale']:
            if not added_transform:
                transformable = UsdGeom.Xformable(asset_prim)
                added_transform = True
            if key == 'translate':
                transformable.AddTranslateOp().Set(value=(value[0], value[1], value[2]))
            if key == 'rotate':
                transformable.AddRotateXYZOp().Set(value=(value[0], value[1], value[2]))
            if key == 'scale':
                transformable.AddScaleOp().Set(value=(value[0], value[1], value[2]))
    if not active:
        return asset_prim
    for key, value in children.items():
        if not value:
            continue
        root_prim = create_override_child(stage, key, value, new_name)
    return asset_prim


def compare_set(usd_shape,
                asset_name,
                output_usd,
                project,
                sets_roots='/World/Sets',
                instance_name=None,
                shotgrid_id=0,
                setview_variant='',
                setview_version=''):

    manager = usd_manager.UsdManager(project)
    manager.get_stage_from_maya(usd_shape)
    stage = manager.stage

    if not sets_roots:
        set_root = '/asset'
    else:
        set_root = sets_roots

    set_prim = stage.GetPrimAtPath(set_root)
    print('Set prim:', set_prim)

    variant = manager.get_selected_variant(set_prim, 'setDressing_variant')
    version = manager.get_selected_variant(set_prim, 'setDressing_version')

    reference_manager = usd_manager.UsdManager(project)
    reference_manager.set_entity(asset_name, 'Asset', 'Sets')

    layer_path = reference_manager.get_layer_filename('SetDressing')
    print(layer_path)
    reference_manager.open(layer_path)
    reference_manager.set_selected_variant(reference_manager.default_prim, 'setDressing_variant', variant)
    reference_manager.set_selected_variant(reference_manager.default_prim, 'setDressing_version', version)

    stage_ref = reference_manager.stage

    reference_prim = stage_ref.GetPrimAtPath('/asset')
    print('-' * 10)
    print(reference_prim)
    print(set_prim)
    differences = compare_prims(set_prim, reference_prim)
    print('-' * 10)
    pprint(differences)
    edit_stage = Usd.Stage.CreateNew(output_usd)

    asset_prim = edit_stage.OverridePrim('/asset')
    edit_stage.SetDefaultPrim(asset_prim)

    for key, value in differences.items():
        root_prim = create_override_child(edit_stage, key, value, set_root)
        print(root_prim)

    manager.asset_info = {'asset_name': asset_name,
                          'asset_type': 'Set',
                          'setview_variant': setview_variant,
                          'setview_version': setview_version,
                          'setview_shotgrid_id': shotgrid_id,
                          'setview_publish_path': output_usd
                          }
    edit_stage.GetRootLayer().Save()
