#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import importlib
from pprint import pprint

from PySide.QtWidgets import QApplication, QWidget, QGroupBox, QHBoxLayout, QVBoxLayout, QTreeView, QMenu, QAbstractItemView, QComboBox, QStyledItemDelegate
from PySide.QtGui import QPen

from PySide.QtCore import Qt, QAbstractItemModel, QModelIndex

import dependency_tracker.lib.dependency_tracker as dependency_tracker
import shotgrid_lib.database as database

importlib.reload(dependency_tracker)


STATUS_COLORS = {'apr': 'green',
                 'clsd': 'red',
                 'ip': 'orange',
                 'rev': 'orange',
                 'recd': 'orange'}
class GUIWindow(QWidget):
    def __init__(self, data,  parent=None):
        super(GUIWindow, self).__init__(parent)
        """ Setup UI """

        self.setWindowTitle("QTreeView from Dictionary")

        groupbox_model = QGroupBox('TreeView')  # Create a Group Box for the Model

        hbox_model = QHBoxLayout()  # Create a Horizontal layout for the Model
        vbox = QVBoxLayout()  # Create a Vertical layout for the Model Horizontal layout

        self.tree_view = DictionaryTreeView(data)  # Instantiate the View


        # Set the models


        hbox_model.addWidget(self.tree_view)  # Add the Widget to the Model Horizontal layout
        groupbox_model.setLayout(hbox_model)  # Add the hbox_model to layout of group box
        vbox.addWidget(groupbox_model)  # Add groupbox elements to vbox
        self.setLayout(vbox)


class DelegateStatus(QStyledItemDelegate):
    def createEditor(self, parent, option, index):
        print('create editor')
        item = index.internalPointer()
        if index.column() == 1 and isinstance(item, TreeNode):
            combo = QComboBox(parent)
            return combo


    def setEditorData(self, editor, index):
        try:
            tree_item = index.internalPointer()
            if index.column() == 1 and isinstance(tree_item, TreeNode):
                tree_item_view = tree_item.item_view
                database = tree_item_view._database
                step_view = tree_item_view.sg_step
                asset_view = tree_item_view.sg_asset

                published = asset_view.sg_published_elements.find_with_filters(sg_step=step_view)
                all_variants = published.sg_variant_name
                all_variants = list(set(all_variants))
                current_value = index.data()

                editor.clear()
                for i, variant_name in enumerate(all_variants):
                    editor.addItem(variant_name, variant_name)
                    if current_value == variant_name:
                        editor.setCurrentIndex(i)

        except (TypeError, AttributeError):
            print(f"No values in drop-down list for item at row: {index.row()} and column: {index.column()}")


    def setModelData(self, editor, model, index):
        tree_item = index.internalPointer()

        text = editor.currentText()
        if index.column() == 1:
            asset_view = tree_item.item_view.sg_asset
            if asset_view.empty:
                asset_view = tree_item.item_view.sg_context

            database = tree_item.item_view._database
            step_view = tree_item.item_view.sg_step
            variant_name = str(editor.currentText())
            variant_view = database['CustomEntity11'][variant_name]
            new_variant = self.get_step_variant_version(asset_view.sg_published_elements,
                                                        step_view,
                                                        variant_view)
            tree_item.setViewData(new_variant, modified=True)


    def get_step_variant_version(self, items, step_view, variant_view):
        selected = items.find_with_filters(sg_status_list='apr',
                                               sg_variant=variant_view,
                                               sg_step=step_view,
                                               single_item=True)

        if selected.empty:
            selected = items.find_with_filters(sg_variant=variant_view,
                                                   sg_step=step_view,
                                                   single_item=True)

        return selected

    def paint(self, painter, option, index):
        item = index.internalPointer()
        if not item.item_view:
            color = 'blue'
            text = index.data(Qt.DisplayRole)

            painter.setPen(QPen(color))
            painter.drawText(option.rect, Qt.AlignLeft | Qt.AlignVCenter, text)
            painter.restore()

            return

        status = item.status
        painter.save()
        color = STATUS_COLORS.get(status, 'grey')
        text = index.data(Qt.DisplayRole)

        painter.setPen(QPen(color))
        painter.drawText(option.rect, Qt.AlignLeft | Qt.AlignVCenter, text)
        painter.restore()

class DictionaryTreeView(QTreeView):
    def __init__(self,data, parent=None):
        super().__init__( parent=parent)
        self.data = data
        self.parent = parent

        self.headers = ["Item", "Variant", 'Version', 'Statuts', 'Modified']
        self.setContextMenuPolicy(Qt.CustomContextMenu)
        self.customContextMenuRequested.connect(self._open_menu)
        self.setSelectionBehavior(QAbstractItemView.SelectRows)

        self.setItemDelegate(DelegateStatus())

        model = TreeModel(self.headers, self.data)

        self.setModel(model)
        self.expandAll()
        self.resizeColumnToContents(0)

    def flags(self, index):
        return


    def open_variant_menu(self, position, tree_item, current_option):

        tree_item_view = tree_item.item_view
        database = tree_item_view._database
        step_view = tree_item_view.sg_step
        asset_view = tree_item_view.sg_asset

        published = asset_view.sg_published_elements.find_with_filters(sg_step=step_view)
        all_variants = published.sg_variant_name
        all_variants = list(set(all_variants))

        menu = QMenu()

        for variant in all_variants:
            menu.addAction(variant)

        menu_click = menu.exec_(position)

        if not menu_click or menu_click.text == current_option:
            return

        variant_name = str(menu_click.text())
        variant_view = database['CustomEntity11'][variant_name]
        new_variant = self.get_step_variant_version(asset_view.sg_published_elements,
                                                    step_view,
                                                    variant_view)
        tree_item.setViewData(new_variant, modified=True)


    def open_version_menu(self, position, tree_item, current_option):
        tree_item_view = tree_item.item_view
        database = tree_item_view._database

        hash = tree_item_view.sg_hash
        all_versions = database['CustomEntity09'].find_with_filters(sg_hash=hash,
                                                                    sg_complete=True,
                                                                    sg_delete=False)
        if all_versions.empty:
            print('cant find versions')
            return

        all_versions = ['%03d' % version_number for version_number in all_versions.sg_version_number]
        menu = QMenu()

        for version in sorted(all_versions):
            menu.addAction(version)

        menu_click = menu.exec_(position)

        if not menu_click or menu_click.text == current_option:
            return

        version_number = int(str(menu_click.text()))

        new_version = database['CustomEntity09'].find_with_filters(sg_hash=hash,
                                                                   sg_version_number=version_number,
                                                                   single_item=True)
        tree_item.setViewData(new_version, modified=True)


    def _open_menu(self, position):
        indexes = self.selectedIndexes()
        if len(indexes) == 0:
            return

        row_index = self.columnAt(position.x())

        tree_item = indexes[row_index].internalPointer()

        tree_item_view = tree_item.item_view
        if tree_item_view is None or tree_item_view.type != 'CustomEntity09':
            return

        current_option = indexes[row_index].data()
        position = self.viewport().mapToGlobal(position)

        if row_index == 1:
            self.open_variant_menu(position,  tree_item, current_option)

        elif row_index == 2:
            self.open_version_menu(position,  tree_item, current_option)



    def get_step_variant_version(self, items, step_view, variant_view):
        selected = items.find_with_filters(sg_status_list='apr',
                                               sg_variant=variant_view,
                                               sg_step=step_view,
                                               single_item=True)

        if selected.empty:
            selected = items.find_with_filters(sg_variant=variant_view,
                                                   sg_step=step_view,
                                                   single_item=True)

        return selected

class TreeModel(QAbstractItemModel):
    def __init__(self, headers, data, parent=None):
        super(TreeModel, self).__init__(parent)
        """ subclassing the standard interface item models must use and 
                implementing index(), parent(), rowCount(), columnCount(), and data()."""

        rootData = [header for header in headers]
        self.rootItem = TreeNode(rootData)

        self.shot_view = data.dependencies_data

        indent = -1
        self.parents = [self.rootItem]
        self.indentations = [0]

        self.createData(data.dependencies_data, indent)

    def flags(self, index):
        return Qt.ItemIsEditable | Qt.ItemIsEnabled

    def createData(self, data, indent):

        indent += 1
        position = 4 * indent

        for key, value in data.items():

            if position > self.indentations[-1]:
                if self.parents[-1].childCount() > 0:
                    self.parents.append(self.parents[-1].child(self.parents[-1].childCount() - 1))
                    self.indentations.append(position)
            else:
                while position < self.indentations[-1] and len(self.parents) > 0:
                    self.parents.pop()
                    self.indentations.pop()

            parent = self.parents[-1]
            parent.insertChildren(parent.childCount(), 1, parent.columnCount())

            if 'publish' in value:
                parent.child(parent.childCount() - 1).setViewData(value['publish'])

            else:
                parent.child(parent.childCount() - 1).setData(0, key)
                self.createData(value, indent)

    def index(self, row, column, index=QModelIndex()):
        """ Returns the index of the item in the model specified by the given row, column and parent index """

        if not self.hasIndex(row, column, index):
            return QModelIndex()
        if not index.isValid():
            item = self.rootItem
        else:
            item = index.internalPointer()

        child = item.child(row)
        if child:
            return self.createIndex(row, column, child)
        return QModelIndex()

    def parent(self, index):
        """ Returns the parent of the model item with the given index
                If the item has no parent, an invalid QModelIndex is returned """

        if not index.isValid():
            return QModelIndex()
        item = index.internalPointer()
        if not item:
            return QModelIndex()

        parent = item.parentItem
        if parent == self.rootItem:
            return QModelIndex()
        else:
            return self.createIndex(parent.childNumber(), 0, parent)

    def rowCount(self, index=QModelIndex()):
        """ Returns the number of rows under the given parent
                When the parent is valid it means that rowCount is returning the number of children of parent """

        if index.isValid():
            parent = index.internalPointer()
        else:
            parent = self.rootItem
        return parent.childCount()

    def columnCount(self, index=QModelIndex()):
        """ Returns the number of columns for the children of the given parent """

        return self.rootItem.columnCount()

    def data(self, index, role=Qt.DisplayRole):
        """ Returns the data stored under the given role for the item referred to by the index """

        if index.isValid() and role == Qt.DisplayRole:
            return index.internalPointer().data(index.column())
        elif not index.isValid():
            return self.rootItem.data(index.column())

    def headerData(self, section, orientation, role=Qt.DisplayRole):
        """ Returns the data for the given role and section in the header with the specified orientation """

        if orientation == Qt.Horizontal and role == Qt.DisplayRole:
            return self.rootItem.data(section)


class TreeNode(object):
    def __init__(self, data, parent=None):
        self.parentItem = parent
        self.itemData = data
        self.item_view = None
        self.original_view = None
        self.children = []
        self.status = ''
        self.modified = False

    def child(self, row):
        return self.children[row]

    def childCount(self):
        return len(self.children)

    def childNumber(self):
        if self.parentItem is not None:
            return self.parentItem.children.index(self)

    def columnCount(self):
        return len(self.itemData)

    def data(self, column):
        return self.itemData[column]

    def insertChildren(self, position, count, columns):
        if position < 0 or position > len(self.children):
            return False

        for row in range(count):
            data = ['' for v in range(columns)]
            item = TreeNode(data, self)
            self.children.insert(position, item)

    def parent(self):
        return self.parentItem

    def setViewData(self, item_view, modified=False):
        if self.original_view is None:
            self.original_view = item_view

        self.item_view = item_view
        if item_view.empty:
            return

        self.status = item_view.sg_status_list

        self.itemData[0] = item_view.sg_step_name
        self.itemData[1] = item_view.sg_variant_name
        self.itemData[2] = '%03d' % item_view.sg_version_number
        self.itemData[3] = item_view.sg_status_list

        if self.original_view.sg_variant_name == self.item_view.sg_variant_name and \
                self.original_view.sg_version_number == self.item_view.sg_version_number:
            self.modified = False
            self.itemData[4] = ''

        else:
            self.modified = True
            self.itemData[4] = 'Modified'


    def setData(self, column, value):
        if column < 0 or column >= len(self.itemData):
            return False

        self.itemData[column] = value



if __name__ == '__main__':
    sg_database = database.DataBase()
    project = 'TPT'
    pipeline_step = 'Animation'
    entity_name = 's00_ep01_sq010_sh010'
    dependencies_solver = dependency_tracker.DependencyTracker(project, pipeline_step)
    variant = 'Master'
    entity_view, _ = dependencies_solver.get_dependencies(entity_name)
    app = QApplication(sys.argv)
    app.setStyle("plastique")  # ("cleanlooks")
    form = GUIWindow(dependencies_solver)
    form.show()
    sys.exit(app.exec_())