import bpy
import io
import json
from mathutils import Quaternion, Matrix, Vector, Euler, Color
import logging
import math
import traceback
import sys
from s4studio.helpers import FNV32
from s4studio.model.geometry import Vertex, BoundingBox

AXIS_FIXER = Quaternion((0.7071067690849304, 0.7071067690849304, 0.0, 0.0))

def get_spec(base_bsdf):
    for input in base_bsdf.inputs: 
        if input.identifier in ('Specular', 'Specular IOR Level'):
            return input

def eprint(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)

def get_mat_nodes(material, node_type):
    return [node for node in material.node_tree.nodes if node.type == node_type]

def get_mat_node(material, node_type):
    return get_mat_nodes(material,node_type)[0]

def add_texture_to_material(material, map_name, image, texture_index, is_cas):
    bt = 'BaseTexture'
    if bpy.app.version < (2, 80, 0):
        # Add base texture
        if is_cas:
            if bt in bpy.data.textures:
                material.texture_slots.create(texture_index)
                texture_slot = material.texture_slots[texture_index]
                texture = bpy.data.textures[bt]
                texture_slot.texture = texture
                texture_slot.texture_coords = 'UV'
                texture.type = 'IMAGE'
                texture_index += 1
        material.texture_slots.create(texture_index)
        texture_slot = material.texture_slots[texture_index]
        texture_index += 1
        texture = bpy.data.textures.new(name=map_name, type='IMAGE')
        texture_slot.texture = texture
        texture_slot.texture_coords = 'UV'
        texture.type = 'IMAGE'
        texture.image = image
    else:
        material.use_nodes = True
        output_node = get_mat_node(material, 'OUTPUT_MATERIAL')
        material.blend_method = 'BLEND'
        material.show_transparent_back = False
        if is_cas:

            base_bsdf = material.node_tree.nodes.new("ShaderNodeBsdfPrincipled")
            base_bsdf.name = 'Mannequin Shader'
            get_spec(base_bsdf).default_value = 0
            base_texture = material.node_tree.nodes.new('ShaderNodeTexImage')
            base_texture.name = 'Mannequin Texture'
            base_texture.image = bpy.data.images[bt]
            material.node_tree.links.new(base_bsdf.inputs['Base Color'], base_texture.outputs['Color'])
            material.node_tree.links.new(base_bsdf.inputs['Alpha'], base_texture.outputs['Alpha'])

            diffuse_bsdf = get_mat_node(material, 'BSDF_PRINCIPLED')
            diffuse_bsdf.name = 'Diffuse Shader'
            get_spec(diffuse_bsdf).default_value = 0
            diffuse_texture = material.node_tree.nodes.new('ShaderNodeTexImage')
            diffuse_texture.name = 'Diffuse Texture'
            diffuse_texture.image = image
            material.node_tree.links.new(diffuse_bsdf.inputs['Base Color'], diffuse_texture.outputs['Color'])
            material.node_tree.links.new(diffuse_bsdf.inputs['Alpha'], diffuse_texture.outputs['Alpha'])

            mix_shader = material.node_tree.nodes.new('ShaderNodeMixShader')
            mix_shader.name = 'Skin Texture Blend'
            material.node_tree.links.new(mix_shader.inputs[0], diffuse_texture.outputs['Alpha'])
            material.node_tree.links.new(mix_shader.inputs[1], base_bsdf.outputs['BSDF'])
            material.node_tree.links.new(mix_shader.inputs[2], diffuse_bsdf.outputs['BSDF'])
            material.node_tree.links.new(output_node.inputs['Surface'], mix_shader.outputs['Shader'])
        else:
            diffuse_bsdf = get_mat_node(material,'BSDF_PRINCIPLED')
            get_spec(diffuse_bsdf).default_value = 0
            diffuse_texture = material.node_tree.nodes.new('ShaderNodeTexImage')
            diffuse_texture.image = image
            material.node_tree.links.new(diffuse_bsdf.inputs['Base Color'], diffuse_texture.outputs['Color'])
            material.node_tree.links.new(output_node.inputs['Surface'], diffuse_bsdf.outputs['BSDF'])
            material.node_tree.links.new(diffuse_bsdf.inputs['Alpha'], diffuse_texture.outputs['Alpha'])
    return texture_index


def load_material(map_name, map_resource, is_cas=True):
    material = bpy.data.materials.new(map_name)
    texture_index = 0

    material.specular_color = Color([.1, .1, .1])

    # Load image
    bpy.context.scene.render.image_settings.file_format = 'PNG'
    bpy.context.scene.render.image_settings.color_mode = 'RGBA'
    img = bpy.data.images.new(map_name, 512, 512)
    img.source = 'FILE'
    img.filepath = map_resource
    img.save_render(map_resource)
    img.filepath = map_resource

    texture_index = add_texture_to_material(material, map_name, img, texture_index, is_cas)

    # Configure texture  slots...
    if bpy.app.version < (2, 80, 0):
        texture_slot = material.texture_slots[texture_index - 1]
        if map_name in ('DiffuseMap', 'Multiplier'):
            texture_slot.use_stencil = True
        elif map_name == 'Mask':
            texture_slot.use = False
            texture_slot.blend_type = 'OVERLAY'
            texture_slot.use_rgb_to_intensity = True
        elif map_name in ('Specular', 'SpecularMap', 'Clothing Specular'):
            texture_slot.use_map_specular = True
            texture_slot.use_map_color_spec = True
            texture_slot.use_map_hardness = True
            texture_slot.use_map_color_diffuse = False
        elif map_name == 'NormalMap':
            texture_slot.use_map_normal = True
            texture_slot.normal_factor = 0.01
            texture_slot.use_map_color_diffuse = False
    return material


def s3_4x3_to_Matrix(s3_matrix):
    """
    Arranges a sequence of floats into a mathutils.Matrix
    """
    args = [tuple(m) for m in s3_matrix]
    args.append((0.0, 0.0, 0.0, 1.0))
    return Matrix(args)


def swizzle_uv(uv):
    return uv[0], 1 - uv[1]


def unswizzle_uv(uv):
    return [uv[0], 1 - uv[1]]


def swizzle_v3(v3):
    return Vector([v3[0], v3[1], v3[2]])


def quat_wxyz(quaternion):
    """
    Swap xyzw (order used by The Sims 3) to wxyz(order used by Blender).
    """
    return quaternion[3], quaternion[0], quaternion[1], quaternion[2]


def argb_rgb(argb):
    return argb[1:3]


def quat_xyzw(quaternion):
    return [quaternion[1], quaternion[2], quaternion[3], quaternion[0]]


def invalid_face(face):
    if not face: return True
    t = []
    for f in face:
        if f not in t:
            t.append(f)
    return len(t) != len(face)


def create_marker_node(name, rotate=False):
    set_context('OBJECT')
    bpy.ops.object.add(type='EMPTY')
    marker = bpy.context.active_object
    marker.name = name
    show_in_front(marker, True)
    set_empty_size(marker, .1)
    set_empty_type(marker, 'CUBE')
    if rotate:
        marker.rotation_euler = Euler([math.pi / 2, 0, 0])
        # bpy.ops.transform.rotate(value=(math.pi / 2), axis=(1, 0, 0))
    return marker


def apply_all_modifiers(ob):
    set_context('OBJECT', ob)
    bpy.ops.object.convert(target='MESH')


def matrix_mult(a, b):
    if bpy.app.version < (2, 80, 0):
        return a * b
    else:
        return a.__matmul__(b)


def rotate_obj(amount, axis):
    ov=bpy.context.copy()
    ov['area']=[a for a in bpy.context.screen.areas if a.type=="VIEW_3D"][0]
    if bpy.app.version < (2, 80, 0):
        a = (1, 0, 0)
        if axis == 'X':
            a = (1, 0, 0)
        if axis == '-X':
            a = (-1, 0, 0)
        elif axis == 'Y':
            a = (0, 1, 0)
        elif axis == 'Z':
            a = (0, 0, 1)
        bpy.ops.transform.rotate(ov,value=amount, axis=a)
    else:
        current_active = get_active()
        current_mode = bpy.context.object.mode if bpy.context.object else None
        print('ROTATING mode= %s active=%s'%(current_mode,current_active))
        if axis[0]== '-':
            axis = axis[1]
            amount = amount * -1
        if bpy.app.version >= (4,0,0):
            with bpy.context.temp_override(**ov):
                bpy.ops.transform.rotate(value=-amount, orient_axis=axis)
        else:
            bpy.ops.transform.rotate(ov,value=-amount, orient_axis=axis)
def set_color(mesh_data, color_index, loop, color):
    if bpy.app.version >= (3,4):
        mesh_data.color_attributes [color_index].data[loop].color_srgb = color
    else: 
        mesh_data.vertex_colors [color_index].data[loop].color = color
def set_color_no_srgb(mesh_data, color_index, loop, color):
    if bpy.app.version >= (3,4):
        mesh_data.color_attributes [color_index].data[loop].color = color
    else: 
        mesh_data.vertex_colors [color_index].data[loop].color = color
def get_color(mesh_data, color_index, loop):
    if bpy.app.version >= (3,4):
        return mesh_data.color_attributes [color_index].data[loop].color_srgb
    else: 
        return mesh_data.vertex_colors[color_index].data[loop].color
def get_color_no_srgb(mesh_data, color_index, loop):
    if bpy.app.version >= (3,4):
        return mesh_data.color_attributes [color_index].data[loop].color
    else: 
        return mesh_data.vertex_colors[color_index].data[loop].color
def create_color(components):
    if bpy.app.version < (2, 80, 0):
        return Color(components[:3])
    else:
        if len(components) == 3:
            components = [components[0], components[1], components[2], 1.0]
        return components

def set_interpolation_mode(mode):
    if bpy.app.version < (2,80, 0):
        bpy.context.user_preferences.edit.keyframe_new_interpolation_type = mode
    else:
        bpy.context.preferences.edit.keyframe_new_interpolation_type = mode
def add_uv_layer(mesh, name):
    if bpy.app.version < (2, 80, 0):
        return mesh.uv_textures.new(name=name)
    else:
        return mesh.uv_layers.new(name=name)


def set_armature_display(armature_data, type):
    if bpy.app.version < (2, 80, 0):
        armature_data.draw_type = type
    else:
        armature_data.display_type = type


def show_in_front(obj, val):
    if bpy.app.version < (2, 80, 0):
        obj.show_x_ray = val
    else:
        obj.show_in_front = val


def get_objects():
    if bpy.app.version < (2, 80, 0):
        return bpy.context.scene.objects
    else:
        return set(bpy.context.collection.objects) | set(bpy.context.scene.objects)


def link_object(obj):
    if bpy.app.version < (2, 80, 0):
        bpy.context.scene.objects.link(obj)
    else:
        bpy.context.collection.objects.link(obj)


def unlink_object(obj):
    if bpy.app.version < (2, 80, 0):
        bpy.context.scene.objects.unlink(obj)
    else:
        bpy.context.collection.objects.unlink(obj)


def get_active():
    if bpy.app.version < (2, 80, 0):
        return bpy.context.scene.objects.active
    else:
        return bpy.context.view_layer.objects.active


def get_object_hide(obj):
    if bpy.app.version < (2, 80, 0):
        return obj.hide
    else:
        return obj.hide_get()


def set_object_hide(obj, val):
    if bpy.app.version < (2, 80, 0):
        obj.hide = val
    else:
        obj.hide_set(val)


def set_active(value):
    if bpy.app.version < (2, 80, 0):
        bpy.context.scene.objects.active = value
    else:
        bpy.context.view_layer.objects.active = value


def set_selection_state(obj, is_selected):
    if bpy.app.version < (2, 80, 0):
        obj.select = is_selected
    else:
        obj.select_set(is_selected)


def get_selection_state(obj):
    if bpy.app.version < (2, 80, 0):
        return obj.select
    else:
        obj.select_get()


def set_empty_type(empty, type):
    if bpy.app.version < (2, 80, 0):
        empty.empty_draw_type = type
    else:
        empty.empty_display_type = type


def set_empty_size(empty, sz):
    if bpy.app.version < (2, 80, 0):
        empty.empty_draw_size = sz
    else:
        empty.empty_display_size = sz


def get_empty_size(empty):
    if bpy.app.version < (2, 80, 0):
        return empty.empty_draw_size
    else:
        return empty.empty_display_size


def set_context(mode=None, select=None,printInfo=True):
    current_active = get_active()
    current_mode = bpy.context.object.mode if bpy.context.object else None
    if(printInfo):
        print('changing mode from %s to %s select=%s'%(current_mode,mode,select))
    for obj in bpy.data.objects:
        set_selection_state(obj, obj == select)

    if current_active == select and mode == current_mode:
        return
    if current_active != select:
        if current_active:
            if current_mode != 'OBJECT':
                set_object_hide(current_active, False)
                current_active.hide_select = False
                bpy.ops.object.mode_set(mode='OBJECT')
        set_active(select)
    if bpy.context.object and bpy.context.object.mode != mode:
        bpy.ops.object.mode_set(mode=mode)
    if printInfo:
        print('Mode is now %s'%(bpy.context.object.mode if bpy.context.object else None))


def approximate_vector(v, precision=2):
    return str([round(x, precision) for x in v])
    pass


def equals_float_array(a, b, precision=2):
    assert len(a) == len(b)
    for i in range(len(a)):
        if round(a[i], precision) != round(b[i], precision):
            return False
    return True


def equals_vector(a, b, precision=.25):
    v1 = Vector(a)
    v2 = Vector(b)
    v3 = v1 - v2
    l = math.fabs(v3.length)
    return l < precision


blend_index_map = [0, 1, 2, 3]


class SimMeshData(object):
    def __init__(self):
        self.vertices = []
        self.indices = []
        self.bones = []
        self.bone_names = []
        self.bounds = BoundingBox()
        self.bounds.init_values()


def collect_mesh_data(mesh_object, existing_bones=None):
    print('Collecting mesh data: existing_bones= %s'%(existing_bones))
    bvmajor, bvminor, bvrevision = bpy.app.version
    print('collecting mesh data from %s'%mesh_object.name)
    mesh_data = mesh_object.data
    vertex_group_map = {}
    vertex_groups = []
    if existing_bones:
        vertex_groups = existing_bones
    bone_names = []
    dta = SimMeshData()
    if len(mesh_data.loops) == 0:
        dta.bones = vertex_groups
        dta.bone_names = bone_names
        print('no mesh loops found')
        return dta
    set_context('OBJECT', mesh_object)
    set_context('EDIT', mesh_object)
    bpy.ops.mesh.reveal()
    bpy.ops.mesh.select_all(action='SELECT')
    bpy.ops.mesh.quads_convert_to_tris()
    bpy.ops.mesh.select_all(action='DESELECT')
    set_context('OBJECT', mesh_object)
    bpy.ops.object.transform_apply(location=True, rotation=True, scale=True)
    rotate_obj(-math.pi / 2.0, 'X')
    bpy.ops.object.transform_apply(rotation=True)
    has_uv = len(mesh_data.uv_layers) > 0
    if has_uv:
        mesh_data.calc_tangents()
    # Establish bone groups
    for vertex_group_index in range(len(mesh_object.vertex_groups)):
        vertex_group = mesh_object.vertex_groups[vertex_group_index]
        hash = int(vertex_group.name[2:],16) if len(vertex_group.name) == 10 and str(vertex_group.name[:2]).upper()== '0X' else  FNV32.hash(vertex_group.name)
        if not hash in vertex_groups:
            vertex_groups.append(hash)
        if not vertex_group.name in bone_names:
            bone_names.append(vertex_group.name)
        vertex_group_map[vertex_group.name] = vertex_groups.index(hash)
    pass

    sim_vertices = []
    indices = []

    # Vertex Index matched to a list of Sim Vertices split by UV coordinates
    vertex_map = {}

    # Enumerate blender loops to split vertices if their loop is different
    for loop_index, loop in enumerate(mesh_data.loops):
        face_point_index = loop_index % 3
        face_index = int((loop_index - face_point_index) / 3)

        # Add a triangle
        if face_point_index == 0:
            indices.append([])

        # Initialize vertex map for this index
        if not loop.vertex_index in vertex_map:
            vertex_map[loop.vertex_index] = []

        # Final vertex for face
        sim_vertex = None

        # Collect UV coordinates for all layers
        loop_uv = []
        if has_uv:
            for uv_layer_index, uv_layer in enumerate(mesh_data.uv_layers):
                uv = uv_layer.data[loop_index].uv
                loop_uv.append(swizzle_uv(uv))

        # Collect tangent vector for loop
        loop_tangent = [1, 0, 0]
        if has_uv:
            loop_tangent = [loop.tangent.x, loop.tangent.y, loop.tangent.z, loop.bitangent_sign]

        use_custom_normals = False

        if (bvmajor == 2 and bvminor >= 74) or bvmajor > 2  and use_custom_normals:
            loop_normal = [loop.normal.x, loop.normal.y, loop.normal.z]
        else:
            vertex_normal = mesh_data.vertices[loop.vertex_index].normal
            loop_normal = [vertex_normal.x, vertex_normal.y, vertex_normal.z]

        # Check for existing matching vertex
        for v in vertex_map[loop.vertex_index]:
            assert isinstance(v, Vertex)
            # Compare loop's UV to existing vertex
            uv_equal = True
            for i in range(len(v.uv)):
                if not equals_float_array(loop_uv[i], v.uv[i], 4):
                    uv_equal = False
                    break
            # Compare loop's normal to existing vertex
            normal_equal = False
            tangent_equal = not has_uv
            if (bvmajor == 2 and bvminor >= 74) or bvmajor > 2  and use_custom_normals:
                normal_equal = equals_vector(loop_normal, v.normal)
                if has_uv:
                    tangent_equal = equals_vector(loop_tangent, v.tangent)
            else:
                normal_equal = True
                tangent_equal = True

            if uv_equal and normal_equal:  # and tangent_equal:
                sim_vertex = v
                break
        # Create a new vertex if no matches are found
        if not sim_vertex:
            # Collect blender vertex
            vertex = mesh_data.vertices[loop.vertex_index]
            sim_vertex = Vertex()
            sim_vertex.uv = []
            # Set position
            sim_vertex.position = [vertex.co.x, vertex.co.y, vertex.co.z]
            dta.bounds.add(Vector(sim_vertex.position))
            # Set normal vector
            sim_vertex.normal = loop_normal
            # Set bone weights
            blend_index = [0] * 4
            weight_list = [1.0, 0, 0, 0]
            for i in range(min(len(vertex.groups), 4)):
                bone_idx = vertex_group_map[mesh_object.vertex_groups[int(vertex.groups[i].group)].name]
                if i == 0:
                    blend_index = [bone_idx] * 4
                weight_list[blend_index_map[i]] = min(vertex.groups[i].weight, 1.0)
                blend_index[i] = bone_idx
            sim_vertex.blend_indices = blend_index
            sim_vertex.blend_weights = weight_list

            # Set vertex color
            for i in range(len(mesh_data.vertex_colors)):
                if sim_vertex.colour is None:
                    sim_vertex.colour = []
                blender_color = mesh_data.vertex_colors[i].data[loop_index].color[:]
                sims_color = [c for c in blender_color]
                sims_color.append(0)
                sim_vertex.colour.append(sims_color)

            if has_uv:
                # Set loop data
                sim_vertex.uv = loop_uv
                sim_vertex.tangent = loop_tangent

            # Add new vertex to the main list
            sim_vertices.append(sim_vertex)
            vertex_map[loop.vertex_index].append(sim_vertex)

        indices[face_index].append(sim_vertices.index(sim_vertex))

    split_count = 0
    for i in vertex_map.keys():
        verts = vertex_map[i]
        split_count += 1
    print('%s vertices split.' % split_count)

    final_indices = []
    for face in indices:
        a = [face[2], face[0], face[1]]
        for i in a:
            final_indices.append(i)

    dta.indices = final_indices
    dta.vertices = sim_vertices
    dta.bones = vertex_groups
    dta.bone_names = bone_names
    return dta

def is_geonode(obj):
    for m in obj.modifiers:
        if m.type == 'NODES':
            return True
    return False

class Sims4StudioException(Exception):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
