import os
import logging
import math

from pprint import pprint

from pxr import Usd, UsdGeom, Sdf


logger = logging.getLogger(__name__)
# logger.setLevel(logging.DEBUG)

TURNTABLE_BIN_FOLDER = os.path.dirname(os.path.abspath(__file__))
TURNTABLE_ROOT_FOLDER = os.path.dirname(TURNTABLE_BIN_FOLDER)
TURNTABLE_RESOURCES_FOLDER = os.path.join(TURNTABLE_ROOT_FOLDER, 'resources')




class UsdTurntable():
    def __init__(self,
                 turntable_type,
                 asset_name,
                 project='sgd',
                 asset_type='',
                 geometry_path='',
                 uv_path='',
                 shading_path='',
                 lights_path='',
                 fx_path='',
                 camera_path='',
                 no_turn_camera=False,
                 images_path='',
                 output_path='',
                 start_frame=101,
                 end_frame=148):

        self.project = project

        self.turntable_type = turntable_type
        self.asset_name = asset_name

        self.asset_type = asset_type
        self.geometry_path = geometry_path
        self.uv_path = uv_path
        self.shading_path = shading_path
        self.lights_path = lights_path
        self.fx_path = fx_path
        self.camera_path = camera_path
        self.no_turn_camera = no_turn_camera
        self.images_path = images_path
        self.output_path = output_path

        self.start_frame = start_frame
        self.end_frame = end_frame

        self.stage = None
        self.loop_frames = (self.end_frame - self.start_frame +1) / 2
        self.loops = 1

    def AddSpin(self, rotate_prim, loops):
        if not rotate_prim.IsValid():
            return

        rotate_prim.SetTypeName('Xform')

        transformable = UsdGeom.Xformable(rotate_prim)
        end_frame = self.start_frame + self.loop_frames * loops 
        angle = 360 * loops

        spin = transformable.AddRotateYOp(opSuffix='spin')
        spin.Set(time=self.start_frame, value=0)
        spin.Set(time=end_frame, value=angle)


    def set_user_camera(self, camera, user_camera):
        transformable_camera = UsdGeom.Xformable(camera)

        translate = camera.GetAttribute('xformOp:translate')
        if not translate.IsValid():
            translate = transformable_camera.AddTranslateOp()
        rotate = camera.GetAttribute('xformOp:rotateXYZ')
        if not rotate.IsValid():
            rotate = transformable_camera.AddRotateXYZOp()

        scale = camera.GetAttribute('xformOp:scale')
        if not scale.IsValid():

            scale = transformable_camera.AddScaleOp()

        clipping_range = camera.GetAttribute('clippingRange')
        if not clipping_range.IsValid():
            clipping_range = camera.CreateAttribute('clippingRange', Sdf.ValueTypeNames.Float2)

        focus_distance = camera.GetAttribute('focusDistance')

        if not focus_distance.IsValid():
            focus_distance = camera.CreateAttribute('focusDistance', Sdf.ValueTypeNames.Float)

        horizontal_aperture = camera.GetAttribute('horizontalAperture')
        if not horizontal_aperture.IsValid():
            horizontal_aperture = camera.CreateAttribute('horizontalAperture', Sdf.ValueTypeNames.Float)

        vertical_aperture = camera.GetAttribute('verticalAperture')
        if not vertical_aperture.IsValid():
            vertical_aperture = camera.CreateAttribute('verticalAperture', Sdf.ValueTypeNames.Float)

        for current_frame, frame_data in user_camera.items():
            translate.Set(value=(frame_data['translate'][0], frame_data['translate'][1],frame_data['translate'][2] ),
                          time=int(current_frame))
            rotate.Set(value=(frame_data['rotate'][0], frame_data['rotate'][1], frame_data['rotate'][2] ),
                       time=current_frame)
            scale.Set(value=(frame_data['scale'][0], frame_data['scale'][1], frame_data['scale'][2] ),
                      time=current_frame)

            clipping_range.Set(value=frame_data['clipping_range'], time=current_frame)

            focus_distance.Set(value=frame_data['focus_distance'], time=current_frame)

            horizontal_aperture.Set(value=frame_data['horizontal_aperture'], time=current_frame)

            vertical_aperture.Set(value=frame_data['vertical_aperture'], time=current_frame)

    def set_turntable_base(self,
                           base_prim,
                           chart_prim,
                           Y_position,
                           scale,
                           camera_height,
                           camera_z):

        transformable_base = UsdGeom.Xformable(base_prim)
        base_translate = base_prim.GetAttribute('xformOp:translate')
        if not base_translate.IsValid():
            base_translate = transformable_base.AddTranslateOp()
        base_translate.Set(value=(0.0, Y_position, 0.0))

        base_scale = base_prim.GetAttribute('xformOp:scale')
        if not base_scale.IsValid():
            base_scale = transformable_base.AddScaleOp()
        base_scale.Set(value=(scale, scale, scale))
        # bottom = bottom/max_base

        # if chart_prim.IsValid():
        transformable_chart = UsdGeom.Xformable(chart_prim)

        chart_translate = chart_prim.GetAttribute('xformOp:translate')
        if not chart_translate.IsValid():
            chart_translate = transformable_chart.AddTranslateOp()

        chart_translate.Set(value=(1.5, camera_height - 1, camera_z - 10))

        chart_scale = chart_prim.GetAttribute('xformOp:scale')
        if not chart_scale.IsValid():
            chart_scale = transformable_chart.AddScaleOp()
        chart_scale.Set(value=(0.33, 0.33, 0.33))

    def fix_flat_geo(self, geo_prim, bb_depth):
        transformable_geo = UsdGeom.Xformable(geo_prim)
        translate_geo = transformable_geo.AddTranslateOp()
        translate_geo.Set(value=(0.0, bb_depth / 2.0, 0.0))
        rotate_geo = transformable_geo.AddRotateXYZOp(opSuffix="reposition")
        rotate_geo.Set(value=(90.0, 0.0, 0.0))


    def get_cliping_planes(self, camera_z, longest_boundingbox_axis):
        far_clip = camera_z + longest_boundingbox_axis

        log10 = math.log10(far_clip)
        if log10 <= 5:
            near_clip = 1
        else:
            near_clip = pow(10, log10 - 5)

        return near_clip, far_clip


    def compute_from_bounding_box(self, geo_prim):
        vertical_aperture = 20.955
        focal_distance = 50.0

        ratio = focal_distance / (vertical_aperture / 2.0)

        bbox_cache = UsdGeom.BBoxCache(0, ['render', 'default'])
        bbox = bbox_cache.ComputeWorldBound(geo_prim)
        bbox = bbox.ComputeAlignedBox()

        bb_height = bbox.GetCorner(2)[1] - bbox.GetCorner(0)[1]
        bb_width = bbox.GetCorner(1)[0] - bbox.GetCorner(0)[0]
        bb_depth = bbox.GetCorner(4)[2] - bbox.GetCorner(0)[2]

        logger.debug('Height bounding box length: %s' % bb_height)
        logger.debug('Width bounding box length: %s' % bb_width)
        logger.debug('Depth bounding box length: %s' % bb_depth)

        if bb_height < bb_width/5.0 and bb_height < bb_width/5.0:
            self.fix_flat_geo(geo_prim, bb_depth)
            camera_y = bb_depth / 2.0

        else:
            camera_y = (bbox.GetCorner(2)[1] + bbox.GetCorner(0)[1]) / 2.0

        max_base = max(bb_width, bb_depth) * .25
        longest_boundingbox_axis = max(bb_height, bb_width, bb_depth)
        boundingbox_bottom = bbox.GetCorner(0)[1]

        camera_z = longest_boundingbox_axis * ratio

        logger.info('Camera Z position: %s' % camera_z)

        clipping_range = self.get_cliping_planes(camera_z, longest_boundingbox_axis)
        logger.info('Max bounding box length: %s' % longest_boundingbox_axis)

        return camera_y, camera_z, max_base, boundingbox_bottom, clipping_range

    def create_camera(self):
        camera_scope = self.stage.DefinePrim('/Cameras', 'Scope')
        camera = self.stage.DefinePrim('/Cameras/Camera', 'Camera')

        return camera

    def set_camera_values(self, camera, clipping_range, camera_y, camera_z):
        clipping_attr = camera.GetAttribute('clippingRange')
        clipping_attr.Set(value=clipping_range)

        transformable = UsdGeom.Xformable(camera)
        camera_position = camera.GetAttribute('xformOp:translate')

        if not camera_position.IsValid():
            camera_position = transformable.AddTranslateOp()

        camera_position.Set(value=(0, camera_y, camera_z), time=self.start_frame)
        camera_position.Set(value=(0, camera_y, camera_z), time=self.end_frame)


        camera_rotate = camera.GetAttribute('xformOp:rotateXYZ')
        if not camera_rotate.IsValid():
            camera_rotate = transformable.AddRotateXYZOp()

        camera_rotate.Set(value=(0.0,0.0,0.0), time=self.start_frame)
        camera_rotate.Set(value=(0.0, 0.0, 0.0), time=self.end_frame)

        camera_scale = camera.GetAttribute('xformOp:scale')
        if not camera_scale.IsValid():
            camera_scale = transformable.AddScaleOp()

        camera_scale.Set(value=(1.0, 1.0, 1.0), time=self.start_frame)
        camera_scale.Set(value=(1.0, 1.0, 1.0), time=self.end_frame)

    def add_camera(self, geo_prim, base_prim, chart_prim, camera, detail=False):

        camera_y, camera_z,  max_base, boundingbox_bottom, clipping_range = self.compute_from_bounding_box(geo_prim)

        if detail:
            camera_y *= 1.8
            camera_z /= 5

        self.set_camera_values(camera, clipping_range, camera_y, camera_z)

        self.set_turntable_base(
                               base_prim,
                               chart_prim,
                               boundingbox_bottom,
                               max_base,
                               camera_y,
                               camera_z)
        #if base_prim.IsValid():

        #self.add_macbeth_chart(camera)

    def set_render_options(self):

        self.start_frame = int(self.stage.GetStartTimeCode())
        self.end_frame = int(self.stage.GetEndTimeCode())
        options_prim = self.stage.GetPrimAtPath('/options')
        attr = options_prim.GetAttribute('arnold:camera')
        attr.Set('/Cameras/camera')

        drivers_prim = self.stage.GetPrimAtPath('/defaultArnoldDriver')

        for child in drivers_prim.GetChildren():
            attr = child.GetAttribute('arnold:filename')
            filename = attr.Get()

            name_attr = child.GetAttribute('arnold:name')
            aov_name = name_attr.Get()
            if not aov_name: continue

            aov_name = aov_name.split('.')[-1]
            if aov_name == 'RGBA':
                aov_name = 'beauty'

            aov_folder = '%s/%s' % (self.images_path, aov_name)
            if not os.path.exists(aov_folder):
                os.makedirs(aov_folder)

            for frame in range(self.start_frame, self.end_frame+1):
                frame_str = '%04d' % frame
                new_path = '%s/%s.%s.exr' % (aov_folder, self.asset_name, frame_str)
                attr.Set(time=frame, value=new_path)


    def get_camera_transforms(self):

        if not self.camera_path:
            return {}
        camera_stage = Usd.Stage.Open(self.camera_path)
        camera_prims = [x for x in camera_stage.Traverse() if x.IsA('Camera')]
        camera_data = {}

        if not camera_prims:
            return {}

        first_frame = int(camera_stage.GetStartTimeCode())
        last_frame = int(camera_stage.GetEndTimeCode()) +1

        for camera_prim in camera_prims:
            parent_transform = camera_prim.GetParent()
            print(parent_transform)
            translate_attr = parent_transform.GetAttribute('xformOp:translate')
            rotate_attr = parent_transform.GetAttribute('xformOp:rotateXYZ')
            scale_attr = parent_transform.GetAttribute('xformOp:scale')

            focal_length_attr = camera_prim.GetAttribute('focalLength')
            focus_distance_attr = camera_prim.GetAttribute('focusDistance')
            horizontal_aperture_attr = camera_prim.GetAttribute('horizontalAperture')
            vertical_aperture_attr = camera_prim.GetAttribute('verticalAperture')

            clipping_range_attr = camera_prim.GetAttribute('clippingRange')

            for frame in range(first_frame, last_frame):
                translate = translate_attr.Get(time=frame)
                rotate = rotate_attr.Get(time=frame)
                scale = scale_attr.Get(time=frame)
                focal_length = focal_length_attr.Get(time=frame)
                focus_distance = focus_distance_attr.Get(time=frame)
                horizontal_aperture = horizontal_aperture_attr.Get(time=frame)
                vertical_aperture = vertical_aperture_attr.Get(time=frame)
                clipping_range = clipping_range_attr.Get(time=frame)
                camera_data[frame] = {'translate': list(translate),
                                       'rotate': list(rotate),
                                       'scale': list(scale),
                                       'focal_length': focal_length,
                                       'focus_distance': focus_distance,
                                       'horizontal_aperture': horizontal_aperture,
                                       'vertical_aperture': vertical_aperture,
                                       'clipping_range': clipping_range,
                                       }

        return camera_data


    def create_stage(self, usd_path):

        self.stage = Usd.Stage.CreateNew(usd_path)
        self.root_layer = self.stage.GetRootLayer()

        if self.geometry_path:
            self.root_layer.subLayerPaths.append(self.geometry_path)

        if self.uv_path:
            self.root_layer.subLayerPaths.append(self.uv_path)

        if self.shading_path and os.path.exists(self.shading_path):
            self.root_layer.subLayerPaths.append(self.shading_path)

        if self.lights_path and os.path.exists(self.lights_path):
            self.root_layer.subLayerPaths.append(self.lights_path)

        self.root_layer.subLayerPaths.append(self.options_path)

        UsdGeom.SetStageUpAxis(self.stage, UsdGeom.Tokens.y)


    def add_prim_rotation(self):
        base_prim = self.stage.DefinePrim('/asset/light_rig/base')
        chart_prim = self.stage.DefinePrim('/asset/light_rig/chart')
        render_prim = self.stage.DefinePrim('/asset/render')
        proxy_prim = self.stage.DefinePrim('/asset/proxy')
        light_prim = self.stage.DefinePrim('/asset/light_rig/all')

        self.AddSpin(light_prim, self.loops * 2)
        self.AddSpin(render_prim, self.loops)
        self.AddSpin(proxy_prim, self.loops)

        return base_prim, chart_prim, render_prim, proxy_prim


    def create_turntable(self):
        logger.info('=== starts create turntable ===')
        logger.info(f'Output path: {self.output_path}')

        if self.output_path:
            dirname = os.path.dirname(self.output_path)
            if not os.path.exists(dirname):
                os.makedirs(dirname)

        logger.info(f'Images path: {self.images_path}')

        self.options_path = '%s/%s_turntable_settings.usd' % (TURNTABLE_RESOURCES_FOLDER, self.turntable_type)

        self.options_path = str(self.options_path).replace('\\', '/')

        self.create_stage(self.output_path)

        root_layer = self.stage.GetRootLayer()
        self.stage.SetStartTimeCode(self.start_frame)
        self.stage.SetEndTimeCode(self.end_frame)


        base_prim, chart_prim, render_prim, proxy_prim = self.add_prim_rotation()

        camera = self.create_camera()

        if self.camera_path and os.path.exists(self.camera_path):
            user_camera = self.get_camera_transforms()
            self.set_user_camera(camera, user_camera)

            end_frame = max(list(user_camera.keys()))
            self.stage.SetEndTimeCode(end_frame)
        else:
            self.add_camera(render_prim, base_prim, chart_prim, camera)

        self.set_render_options()


        logger.info('Saving turntable file: %s' % self.output_path)
        root_layer.Save()

def launch_turntable(turntable_type='set_model',
                     project='',
                     asset_name='',
                     asset_type='',

                     geometry_path='',
                     uv_path='',
                     shading_path='',
                     lights_path='',
                     fx_path='',
                     camera_path='',
                     no_turn_camera=False,
                     images_path='',
                     usd_path='',
                     start_frame=101,
                     end_frame=148):

    turntable_builder = UsdTurntable(turntable_type,
                                     asset_name,
                                     project=project,
                                     asset_type=asset_type,

                                     geometry_path=geometry_path,
                                     uv_path=uv_path,
                                     shading_path=shading_path,
                                     lights_path=lights_path,
                                     fx_path=fx_path,
                                     camera_path=camera_path,
                                     no_turn_camera=no_turn_camera,
                                     images_path=images_path,
                                     output_path=usd_path,
                                     start_frame=start_frame,
                                     end_frame=end_frame)

    turntable_builder.create_turntable()

    output = {'turntable_path': turntable_builder.output_path,
               'starts_frame': turntable_builder.start_frame,
               'end_frame': turntable_builder.end_frame}
    print(turntable_builder.output_path)
    return output

def test_turntable():
    turntable_type = 'set_model'
    project='TPT'
    asset_name = 'Maina'
    geometry_path = r'\\columbus\SGD\TPT\publish\model\Main_Characters\Maina\Master\Maina_Master_9f3caf1ce71ddbd7\v018\usd/Maina_geometry.usd'


    geometry_path = r'\\columbus\SGD\TPT\publish\set\Sets\BedroomLeo\Master\BedroomLeo_Master_cd99cb9b1eef8c2e\v011\usd/bedroomleo.usda'

    usd_path = r'\\columbus\SGD\TPT\publish\model\Main_Characters\Maina\Master\Maina_Master_9f3caf1ce71ddbd7\v018\turntable/test.usda'
    images_path = r'\\columbus\SGD\TPT\publish\model\Main_Characters\Maina\Master\Maina_Master_9f3caf1ce71ddbd7\v018\turntable'
    shading_path = r'\\columbus\SGD\TPT\publish\shading\Main_Characters\Maina\Master\Maina_Master_ec8f7a9de98dab63\v003\usd/Maina.usda'
    shading_path = ''
    lights_path = r'\\columbus\SGD\TPT\publish\asset_lighting\Others\Template\model_turntable\Template_model_turntable_eef31bbb8084d27b\v024\usd/Template.usd'
    camera_path = r'\\columbus\SGD\TPT\publish\set\Sets\BedroomLeo\Master\BedroomLeo_Master_cd99cb9b1eef8c2e\v011\camera_cache\cam_bedroomLeo.usd'
    #scamera_path = None

    print('launch turntable test')
    launch_turntable(turntable_type=turntable_type,
                     project=project,
                     asset_name=asset_name,
                     geometry_path=geometry_path,
                     images_path=images_path,
                     lights_path=lights_path,
                     shading_path=shading_path,
                     camera_path=camera_path,
                     no_turn_camera=True,
                     usd_path=usd_path)

if __name__ == '__main__':
    test_turntable()