import os

from PySide import QtWidgets, QtCore

import shotgrid_lib.database as database
import library.sandbox as sandbox_builder


entities_order = {'Asset': ['sg_asset_type', 'code'],
                  'Shot': ['sg_sequence.episode.code', 'sg_sequence.code', 'code']}

context_label = {'Asset': ['Asset Type', 'Asset Name'],
                 'Shot': ['Episode', 'Sequence', 'Shot']
                 }

column_width = {'Context': 70,
                'Asset Type': 120,
                'Asset Name': 110,
                'Episode': 80,
                'Sequence': 115,
                'Shot': 150,
                }

class ContextBitWidget(QtWidgets.QComboBox):

    def __init__(self,
                 tree_data=None,
                 parent_entity=None,
                 widget_layout=None,
                 parent=None,
                 depth=0):
        super(ContextBitWidget, self).__init__(parent=parent)
        self.parent_entity = parent_entity
        self.widget_layout = widget_layout
        self.depth = depth
        self.label = None
        self.child_widget = None
        self.set_data(tree_data)
        self.setSizeAdjustPolicy(QtWidgets.QComboBox.AdjustToContents)
        self.currentIndexChanged.connect(self.update_child_widget)

    def get_sandbox_path_by_env(self):
        project = os.environ['PROJECT']
        sandbox_solver = sandbox_builder.Sandbox(project=project)

        if 'PIPE_ASSET_NAME' in os.environ:
            sandbox_solver.set_context(asset_name=os.environ.get('PIPE_ASSET_NAME', ''),
                                       asset_type=os.environ.get('PIPE_ASSET_TYPE', ''))
        elif 'PIPE_SHOT' in os.environ:
            print('shot path')
            sandbox_solver.set_context(shot_name=os.environ.get('PIPE_SHOT', ''),
                                       sequence_name=os.environ.get('PIPE_SEQUENCE', ''),
                                       episode_name=os.environ.get('PIPE_EPISODE'))
        else:
            return
        path = sandbox_solver.generate_path()
        if os.path.exists(path):
            print('Set maya sandbox to: %s' % path)
            sandbox_solver.set_maya_sandbox(path)


    def find_entity(self, entity_name, children_tree):
        if not isinstance(children_tree, dict):
            return False
        if entity_name in children_tree.keys():
            return True
        for children_values in children_tree.values():
            found = self.find_entity(entity_name, children_values)
            if found:
                return True

        return False

    def get_contex_type(self):
        if self.depth == 0:
            text = str(self.currentText())
        else:
            text = self.parent_entity.get_contex_type()
        return text

    def setValue(self, entity_name):

        if self.label:
            self.remove_env_var()

        if entity_name in self.tree_data.keys():
            self.set_env_var(entity_name)
            self.setCurrentText(entity_name)
            return

        for key, children in self.tree_data.items():
            if self.find_entity(entity_name, children):
                self.setCurrentText(key)
                self.set_env_var(key)
                if self.child_widget:
                    self.child_widget.setValue(entity_name)


    def set_env_var(self, value):

        if self.label is None:
            return
        env_var = 'PIPE_%s' % self.label.upper().replace(' ', '_')
        if not value or not self.label or value == self.label :
            os.environ.pop(env_var, None)
            return

        os.environ[env_var] = value

    def remove_env_var(self):
        if self.label is None:
            return
        env_var = 'PIPE_%s' % self.label.upper().replace(' ', '_')
        os.environ.pop(env_var, None)

    def set_data(self, tree_data):
        self.remove_env_var()
        self.tree_data = tree_data
        self.bit_items = list(sorted(self.tree_data.keys()))

        self.clear()
        if self.depth == 0:
            self.label = 'Context'
        else:
            context_type = self.get_contex_type()
            label_list = context_label.get(context_type, [])
            if label_list:
                self.label = label_list[self.depth-1]
            else:
                self.label = '-----'
        self.bit_items.insert(0, self.label)
        self.addItems(self.bit_items)
        width  = column_width.get(self.label, 100)
        self.setMinimumWidth(width)

    def remove_child(self):
        if self.child_widget:
            self.child_widget.remove_env_var()
            self.child_widget.label = None
            self.child_widget.tree_data = {}
            self.child_widget.remove_child()
            self.child_widget.setHidden(True)

    def update_child_widget(self):
        text = str(self.currentText())

        child_data = self.tree_data.get(text)

        if not child_data:
            if self.child_widget:
                self.remove_child()
            self.set_env_var(text)
            self.get_sandbox_path_by_env()
            return

        if not self.child_widget:
            self.child_widget = ContextBitWidget(tree_data=child_data,
                                                 parent_entity=self,
                                                 widget_layout=self.widget_layout,
                                                 depth=self.depth+1)
            self.widget_layout.addWidget(self.child_widget)
        else:
            self.child_widget.setHidden(False)
            self.child_widget.set_data(child_data)
        self.set_env_var(text)


class DccContextWidget(QtWidgets.QFrame):
    def __new__(cls, *args, **kargs):
        if not hasattr(cls, 'instance'):
            cls.instance = super(DccContextWidget, cls).__new__(cls)

        return cls.instance

    def __init__(self, entity_types=None, project='TPT', parent=None):
        super(DccContextWidget, self).__init__(parent=parent)
        self.setObjectName('ContextChooserWidget')
        self.entity_types = entity_types
        self.project = project
        self.widget_list = []
        self.database = database.DataBase()
        self.database.fill(self.project, precatch=False)

        if self.entity_types != 'Shot':
            self.database.query_sg_database('Asset', as_precache=True)

        if self.entity_types != 'Asset':
            self.database.query_sg_database('Shot', as_precache=True)
            #self.database.query_sg_database('Sequence', as_precache=True)
            #self.database.query_sg_database('Episode', as_precache=True)

        self.config_ui()


    def config_ui(self):

        self.get_entities_tree()

        self.main_layout = QtWidgets.QHBoxLayout()
        self.main_layout.setSpacing(0)
        self.main_layout.setContentsMargins(0, 0, 0, 0)
        if self.entity_types is None:
            self.entity_type_widget = ContextBitWidget(tree_data=self.entities_tree, widget_layout=self.main_layout)
            self.main_layout.addWidget(self.entity_type_widget)
            self.widget_list.append(self.entity_type_widget)
        else:
            self.entity_type_widget = ContextBitWidget(tree_data=self.entities_tree[self.entity_types], widget_layout=self.main_layout, depth=1)
            self.main_layout.addWidget(self.entity_type_widget)
            self.widget_list.append(self.entity_type_widget)

        self.setLayout(self.main_layout)

    def setValue(self, entity_name):
        print('set value', entity_name)
        self.entity_type_widget.setValue(entity_name)

    def get_field_value(self, entity, attribute):
        if attribute.find('.') > -1:
            variable, rest = attribute.split('.', 1)
        else:
            variable = attribute
            rest = ''
        value = entity.get_field_value(variable)
        if rest and isinstance(value, database.View):
            value = self.get_field_value(value, rest)
        return value


    def get_entities_tree(self):
        entities_tree = {}
        for entity_type, entity_hierarchy in entities_order.items():
            if self.entity_types is not None and entity_type != self.entity_types:
                continue
            entities_tree[entity_type] = {}
            for entity in self.database[entity_type]:
                this_tree = entities_tree[entity_type]
                for attribute in entity_hierarchy:
                    value = self.get_field_value(entity, attribute)
                    this_tree[value] = this_tree.get(value, {})
                    this_tree = this_tree[value]
        self.entities_tree = entities_tree

    def create_menu(self):
        import packages_io
        print('create context chooser')
        packages_io.add_context_chooser(self)

def test():
    project = 'TPT'
    QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_ShareOpenGLContexts)
    app = QtWidgets.QApplication()
    window = DccContextWidget(project=project)
    window.show()
    #window.setValue('s00_ep02_sq020_sh060')
    #window.setValue('LeoHero')
    window.setValue('Goldy')

    app.exec()



if __name__ == '__main__':
    test()