#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import importlib

from PySide import QtWidgets, QtGui, QtCore

import library.ui.base_tree_view as base_tree_view
import dependency_tracker.lib.dependency_tracker as dependency_tracker
import shotgrid_lib.database as database

importlib.reload(dependency_tracker)

importlib.reload(base_tree_view)

STATUS_COLORS = {'apr': 'green',
                 'clsd': 'red',
                 'ip': 'orange',
                 'rev': 'orange',
                 'recd': 'orange'}

class GUIWindow(QtWidgets.QWidget):
    def __init__(self, parent=None):
        super(GUIWindow, self).__init__(parent)
        """ Setup UI """

        self.setWindowTitle("QTreeView from Dictionary")
        self.setMinimumSize(1000, 600)
        groupbox_model = QtWidgets.QGroupBox('TreeView')  # Create a Group Box for the Model

        hbox_model = QtWidgets.QHBoxLayout()  # Create a Horizontal layout for the Model
        vbox = QtWidgets.QVBoxLayout()  # Create a Vertical layout for the Model Horizontal layout

        self.tree_view = DependenciesTreeView()  # Instantiate the View

        hbox_model.addWidget(self.tree_view)  # Add the Widget to the Model Horizontal layout
        groupbox_model.setLayout(hbox_model)  # Add the hbox_model to layout of group box
        vbox.addWidget(groupbox_model)  # Add groupbox elements to vbox
        self.setLayout(vbox)

    def update_model(self, data):
        self.tree_view.update_model(data)

    def getValue(self):
        return self.tree_view.getValue()


class DelegateStatus(QtWidgets.QStyledItemDelegate):

    def paint(self, painter, option, index):
        item = index.internalPointer()
        if not item.item_view :
            color = 'yellow'
            text = index.data(QtCore.Qt.DisplayRole)
            painter.setPen(QtGui.QPen(color))
            painter.drawText(option.rect, QtCore.Qt.AlignLeft | QtCore.Qt.AlignVCenter, text)
            painter.restore()
            return

        status = item.status
        painter.save()
        color = 'white'
        text = ''
        if index.column() == 0:
            text = item.item_view.sg_step_name
        elif index.column() == 1:
            text = item.item_view.sg_variant_name
            if len(item.get_variant_list()) == 1:
                color = 'grey'

        elif index.column() == 2:
            text = '%03d' % item.item_view.sg_version_number
            color = STATUS_COLORS.get(status, 'grey')

        elif index.column() == 3:
            text = item.item_view.sg_status_list
            color = STATUS_COLORS.get(status, 'grey')

        elif index.column() == 4 and item.modified and item.enabled:
            color = 'red'
            text = 'Modified'

        if not item.enabled:
            color = 'grey'

        painter.setPen(QtGui.QPen(color))

        painter.drawText(option.rect, QtCore.Qt.AlignLeft | QtCore.Qt.AlignVCenter, text)
        painter.restore()


class DependenciesTreeView(QtWidgets.QTreeView):
    def __init__(self, parent=None):
        super().__init__( parent=parent)
        self.parent = parent

        self.headers = ["Item", "Variant", 'Version', 'Statuts', 'Modified']
        self.setContextMenuPolicy(QtCore.Qt.CustomContextMenu)
        self.customContextMenuRequested.connect(self._open_menu)
        self.setSelectionBehavior(QtWidgets.QAbstractItemView.SelectRows)

        self.setItemDelegate(DelegateStatus())


    def update_model(self, data):
        self.data = data

        self.model = TreeModel(self.headers, self.data)

        self.setModel(self.model)
        self.expandAll()
        self.resizeColumnToContents(0)

    def open_enable_menu(self, position, tree_item, current_option):
        menu = QtWidgets.QMenu()

        menu.addAction('Enabled')
        menu.addAction('Disabled')

        menu_click = menu.exec_(position)
        if not menu_click:
            return
        tree_item_view = tree_item.item_view
        clicked = menu_click.text()
        if tree_item_view is None:
            for child in tree_item.children:
                child.enabled = clicked == 'Enabled'
        else:
            tree_item.enabled = clicked == 'Enabled'


    def open_variant_menu(self, position, tree_item, current_option):
        all_variants = tree_item.get_variant_list()

        menu = QtWidgets.QMenu()

        for variant in all_variants:
            menu.addAction(variant)

        menu_click = menu.exec_(position)

        if not menu_click or menu_click.text() == current_option:
            return

        variant_name = str(menu_click.text())
        new_variant = tree_item.get_default_version(variant_name)
        tree_item.setViewData(new_variant)

    def open_version_menu(self, position, tree_item, current_option):
        versions_data = tree_item.get_version_list()
        all_versions = list(versions_data.keys())
        menu = QtWidgets.QMenu()
        default_action = None
        action_list = []
        for version in reversed(sorted(all_versions)):
            new_action = QtWidgets.QAction(version)
            action_list.append(new_action)
            if versions_data[version] == 'apr' and default_action is None:
                default_action = new_action
            menu.addAction(new_action)

        menu.setDefaultAction(default_action)
        menu.setStyleSheet("QMenu::item:default { color: green; }")

        menu_click = menu.exec_(position)

        if not menu_click or menu_click.text() == current_option:
            return

        version_number = int(str(menu_click.text()))
        versions_hash = tree_item.item_view.sg_hash
        database = tree_item.item_view._database
        new_version = database['CustomEntity09'].find_with_filters(sg_hash=versions_hash,
                                                                   sg_version_number=version_number,
                                                                   single_item=True)
        tree_item.setViewData(new_version)

    def _open_menu(self, position):
        indexes = self.selectedIndexes()
        if len(indexes) == 0:
            return

        row_index = self.columnAt(position.x())

        tree_item = indexes[row_index].internalPointer()

        tree_item_view = tree_item.item_view

        position = self.viewport().mapToGlobal(position)
        current_option = indexes[row_index].data()

        if row_index == 0:
            self.open_enable_menu(position,  tree_item, current_option)

        if tree_item_view is None or tree_item_view.type != 'CustomEntity09':
            return

        if row_index == 1:
            self.open_variant_menu(position,  tree_item, current_option)

        elif row_index == 2:
            self.open_version_menu(position,  tree_item, current_option)


    def getValue(self):

        values = {}
        for row in range(self.model.rowCount()):

            item = self.model.index(row, 0)
            asset_name = item.data()
            values[asset_name] = {}
            pointer_to_item = item.internalPointer()
            if pointer_to_item is None:
                continue
            for child in pointer_to_item.children:
                step_name = child.item_view.sg_step_name
                values[asset_name][step_name] = {'publish': child.item_view, 'enabled': child.enabled}

        return values

class PublishedElementNode(base_tree_view.TreeNode):
    def __init__(self, data, parent=None):
        super(PublishedElementNode, self).__init__(data, parent=parent)

        self.item_view = None
        self.original_view = None
        self.status = ''
        self.modified = False
        self.enabled = True

        self.variant = None
        self.version = None

        self.favorite_variant = None
        self.favorite_version = None
        self.loaded_variant = None
        self.loaded_version = None

    def get_version_list(self):

        all_versions = self.published_versions[self.variant]['versions']
        versions_dict = {}
        for version in all_versions:
            versions_dict['%03d' % version.sg_version_number] = version.sg_status_list

        all_versions = ['%03d' % version_number for version_number in reversed(sorted(all_versions.sg_version_number))]
        return versions_dict

    def get_variant_list(self):
        all_variants = list(self.published_versions.keys())
        return all_variants

    def get_default_version(self, variant_name):
        versions = self.published_versions[variant_name]
        if versions['approved']:
            return versions['approved']
        elif versions['latest']:
            return versions['latest']
        return None

    def prefilter_versions(self, item_view):
        entity_view = item_view.sg_asset

        if entity_view.empty:
            entity_view = item_view.sg_context

        step_view = item_view.sg_step
        published_elements = entity_view.sg_published_elements
        if published_elements.empty:
            return
        step_publish = published_elements.find_with_filters(sg_step=step_view,
                                                            sg_complete=True,
                                                            sg_delete=False)
        variants = step_publish.sg_variant

        variants_data = {}
        for variant_view in variants:
            if variant_view.code in variants_data:
                continue
            variant_publish = step_publish.find_with_filters(sg_variant=variant_view)
            approved_version = None
            latest_version = None
            recommended_version = None
            if not variant_publish.empty:

                approved_version = variant_publish.find_with_filters(sg_status_list='apr',
                                                             single_item=True)
                latest_version = max(variant_publish)

                if approved_version.empty:
                    recommended_version = latest_version
                else:
                    recommended_version = approved_version
            variants_data[variant_view.code] = {'versions': variant_publish,
                                                'approved': approved_version,
                                                'latest': latest_version,
                                                'recommended': recommended_version
                                                }
        self.published_versions = variants_data


    def setViewData(self, item_view):
        if item_view.empty:
            return

        if self.original_view is None:
            self.original_view = item_view
            self.prefilter_versions(item_view)

        self.item_view = item_view

        self.status = item_view.sg_status_list
        self.variant = item_view.sg_variant_name
        self.version = item_view.sg_version_number

        if self.original_view.sg_variant_name == self.item_view.sg_variant_name and \
                self.original_view.sg_version_number == self.item_view.sg_version_number:
            self.modified = False
        else:
            self.modified = True

class TreeModel(base_tree_view.TreeModel):
    def __init__(self, headers, data, parent=None):
        data = data.dependencies_data
        self.shot_view = data
        super(TreeModel, self).__init__(headers, data, parent=parent, tree_node_class=PublishedElementNode)

    def createData(self, data, indent):

        indent += 1
        position = 4 * indent

        for key, value in data.items():
            parent = self.get_parent(position)
            parent.insertChildren(parent.childCount(), 1, parent.columnCount())

            if 'publish' in value:
                parent.child(parent.childCount() - 1).setViewData(value['publish'])
            else:
                parent.child(parent.childCount() - 1).setData(0, key)
                self.createData(value, indent)


if __name__ == '__main__':
    sg_database = database.DataBase()
    project = 'TPT'
    pipeline_step = 'Animation'
    entity_name = 's00_ep01_sq010_sh010'
    dependencies_solver = dependency_tracker.DependencyTracker(project, pipeline_step)
    variant = 'Master'
    entity_view, _ = dependencies_solver.get_dependencies(entity_name)
    app = QtWidgets.QApplication(sys.argv)
    app.setStyle("plastique")  # ("cleanlooks")
    form = GUIWindow()
    form.show()
    form.update_model(dependencies_solver)
    return_data = form.getValue()
    print(return_data)
    for asset_name, asset_data in return_data.items():
        for step_name, step_data in asset_data.items():
            print(step_data)
    sys.exit(app.exec_())