import os
import importlib

from pprint import pprint

from PySide import QtCore, QtGui, QtWidgets

from pxr import Usd

import usd.lib.usd_manager as usd_manager
importlib.reload(usd_manager)


PRIM_TYPES = {'geometry':['Mesh'],
              'light': ['DomeLight', 'RectLight', 'DistantLight', 'SphereLight', 'ArnoldSpotLight']}

class PrimTreeItem(QtGui.QStandardItem):
    def __init__(self, prim, parent=None, checkable=False):
        super().__init__(parent=parent)
        self.setColumnCount(4)
        self.prim = prim
        type_name = prim.GetTypeName()
        self.setCheckable(checkable)
        self.setText('%s - %s' % (prim.GetName(), type_name))


class UsdStageTree(QtWidgets.QTreeView):
    selected_prim = QtCore.Signal(PrimTreeItem)
    def __init__(self, project, usd_path='', maya_node='', parent=None, prim_type=None, manager=None, checkable=False):
        super(UsdStageTree, self).__init__()
        #self.setColumnCount(6)
        self.usd_path = usd_path
        self.maya_node = maya_node
        self.project = project
        self.prim_type = prim_type
        self.checkable = checkable
        self.setIndentation(10)
        if manager:
            self.manager = manager
        else:
            self.manager = usd_manager.UsdManager(self.project)

            if self.usd_path:
                self.manager.open(self.usd_path)
            elif self.maya_node:
                self.manager.get_stage_from_maya(self.maya_node)
            else:
                return

        self.usd_model = QtGui.QStandardItemModel()
        self.usd_model.setHorizontalHeaderLabels(['Prim'])
        self.setModel(self.usd_model)
        self.setUniformRowHeights(True)
        if self.prim_type:
            self.prim_list = self.get_prim_list()
        else:
            self.prim_list = []

        self.fill_tree()
        self.clicked.connect(self.emit_selected)

    def emit_selected(self, item):
        index = self.selectedIndexes()[0]
        item = self.usd_model.itemFromIndex(index)

        self.selected_prim.emit(item)

    def ____mousePressEvent(self, event):
        print('event')
        print(event.type())
        if event.type() == QtCore.QEvent.MouseButtonPress:
            pos = self.mapFromGlobal(QtGui.QCursor.pos())
            item_index = self.indexAt(pos)
            print(item_index)
            row = self.indexAt(pos).row()

            if event.button() == QtCore.Qt.RightButton:
                print("Right clicked on row %s" % row)

            else:
                print("Left clicked on row %s" % row)

            item = self.usd_model.itemFromIndex(item_index)
            if item:
                self.selected_prim.emit(item)

        super().mousePressEvent(event)

    def add_parents(self, prim):
        parent_prim_list = []
        prim_path = prim.GetPath().pathString
        if prim_path == '/':
            return []
        parent_prim_list.append(prim.GetPath().pathString)
        parent_prim_list += self.add_parents(prim.GetParent())
        return parent_prim_list

    def get_prim_list(self):
        all_prims = []
        for prim in self.manager.stage.Traverse():
            prim_type = prim.GetTypeName()
            if prim_type in PRIM_TYPES.get(self.prim_type, []):
                all_prims += self.add_parents(prim)
        return all_prims

    def fill_children(self, prim):
        all_child = []
        for child in prim.GetChildren():
            child_path = child.GetPath().pathString
            if self.prim_list and child_path not in self.prim_list:
                continue
            node = PrimTreeItem(child, checkable=self.checkable)
            children = self.fill_children(child)
            if children:
                for gchild in children:
                    node.appendRow(gchild)
            all_child.append(node)

        return all_child

    def fill_tree(self):
        pseudo_root = self.manager.stage.GetPseudoRoot()
        root_items = self.fill_children(pseudo_root)
        self.usd_model.appendRow(root_items)


    def get_selected_indexes(self, prim_list, parent, selection):
        if not parent:
            return
        ignored_types = ['Xform', 'Scope', 'Material', 'Shader']
        for row in range(parent.rowCount()):
            for column in range(parent.columnCount()):
                child = parent.child(row, column)
                if not child:
                    continue
                path = child.prim.GetPath().pathString
                prim_type = child.prim.GetTypeName()
                child.setCheckState(QtCore.Qt.Unchecked)
                if prim_type not in ignored_types:
                    if child.prim in prim_list:
                        #child.setSelected(True)
                        selection.append(child)
                        print('Selected:', path)

                if child.hasChildren():
                    self.get_selected_indexes(prim_list, child, selection)

    def select_prims(self, prim_list):
        if not self.checkable:
            return

        selection = []

        root = self.usd_model.invisibleRootItem()
        valid_index = self.get_selected_indexes(prim_list, root, selection)


        for item in selection:
            item.setCheckState(QtCore.Qt.Checked)

def test_prim_attribute():
    path = 'V:/TPT/publish/usd/assets/Main_Characters/LeoHero/asset_assembly.usda'
    print(os.path.exists(path))
    window = UsdStageTree('TPT', path)
    #window.show()


if __name__ == '__main__':

    test_prim_attribute()