import copy
import math
import os
import s4studio, bpy, bmesh
from mathutils import Vector, Color
from s4studio.animation.blender import load_rig
from s4studio.animation.rig import SkeletonRig
from s4studio.blender import swizzle_uv, invalid_face, create_marker_node, collect_mesh_data, SimMeshData, set_context, \
    Sims4StudioException, apply_all_modifiers, set_active, set_selection_state, link_object, get_objects, rotate_obj, \
    add_uv_layer, create_color, set_object_hide, is_geonode
from s4studio.buybuild import ObjectComponentModel
from s4studio.buybuild.catalog import CatalogProductObject, ProductBase, CatalogProductStairs, CatalogProductRailing, \
    CatalogProductFireplace, CatalogProductFence
from s4studio.buybuild.geometry import Model, ModelLod, VertexFormat, ObjectMesh, VertexBuffer, IndexBuffer, \
    ObjectSkinController, VertexBufferShadow, IndexBufferShadow, calculate_pos_scales, calculate_uv_scales
from s4studio.core import ResourceKey
from s4studio.helpers import first, FNV32
from s4studio.material import PackedPreset, Preset
from s4studio.material.blender import MaterialLoader
from s4studio.model import VisualProxy
import s4studio.material.blender
from s4studio.model.geometry import SkinController, Vertex, BoundingBox
from s4studio.model.material import MaterialDefinition, MaterialSet


def load_mesh(armature_rig, model_mesh, mesh_name, materials, state_hash=None):
    bvmajor, bvminor, bvrevision = bpy.app.version

    assert isinstance(model_mesh, s4studio.buybuild.geometry.ObjectMesh)
    bone_name_map = {}
    if armature_rig and armature_rig.type == 'ARMATURE':
        for bone_hash in armature_rig.data.bones:
            bone_name_map[FNV32.hash(bone_hash.name)] = bone_hash.name

    vertices = model_mesh.get_vertices(state_hash)
    faces = model_mesh.get_triangles(state_hash)

    mesh = bpy.data.meshes.new(mesh_name)

    mesh_obj = bpy.data.objects.new(mesh_name, mesh)
    link_object(mesh_obj)
    set_active( mesh_obj)

    mesh_obj.show_transparent = True

    if os.path.exists(materials):
        mesh_material = s4studio.blender.load_material('DIFFUSE', materials, False)
        mesh_obj.data.materials.append(mesh_material)
        pass
    skin_controller = model_mesh.skin_controller
    vrtf = model_mesh.get_vertex_format()
    vertex_groups = []
    vertex_group_map = {}

    if model_mesh.skin_controller:
        for bone_hash in model_mesh.bone_references:
            bone_name = '0x%08X' % bone_hash
            if bone_hash in bone_name_map:
                bone_name = bone_name_map[bone_hash]
            vertex_group = mesh_obj.vertex_groups.new(name=bone_name)
            vertex_groups.append(vertex_group)
            vertex_group_map[bone_hash] = vertex_group
    if armature_rig:
        mesh_skin = mesh_obj.modifiers.new(type='ARMATURE', name="%s_skin" % mesh_name)
        mesh_skin.use_bone_envelopes = False
        mesh_skin.object = armature_rig

    bm = bmesh.new()
    bm.from_mesh(mesh)

    for vertex in vertices:
        bm.verts.new(vertex.position)

    if (bvmajor == 2 and bvminor >= 74) or bvmajor > 2 :
        bm.verts.ensure_lookup_table()

    faces_skipped = []
    for face_index, face in enumerate(faces):
        if invalid_face(face):
            print('[%s]Face[%04i] %s has duplicate points, skipped' % (mesh_name, face_index, face))
            faces_skipped.append(face_index)
            continue
        f = [bm.verts[face_point] for face_point in face]
        try:
            bm.faces.new([f[1], f[2], f[0]])
        except ValueError as ve:
            faces_skipped.append(face_index)
            print(ve)

    for vertex_index, vertex in enumerate(vertices):
        if vertex.normal:
            bm.verts[vertex_index].normal = Vector(vertex.normal[:3])

    bm.to_mesh(mesh)

    color_channels = 0
    for declaration in model_mesh.get_vertex_format().declarations:
        if declaration.usage == VertexFormat.USAGE.COLOR:
            color_channels += 1

    matd = model_mesh.get_material_definition()
    color_channel_names = []

    if matd.shader_name == FNV32.hash('AnimatedTree'):
        color_decls = len(list(filter(lambda d: d.usage == VertexFormat.USAGE.COLOR,vrtf.declarations)))
        if color_decls == 2: color_channel_names.append('Surface Color Tint')
        color_channel_names.append('Rustle X')
        color_channel_names.append('Rustle Y')
        color_channel_names.append('Rustle Z')
        color_channel_names.append('Rustle Phase')
        color_channels = len(color_channel_names)
        for v in vertices:
            tint = v.colour[0] if color_decls == 2 else None
            rustle = v.colour[-1]
            v.colour = []
            if tint:
                v.colour.append(tint);
            for i in range(4):
                v.colour.append([rustle[i]]*4)
            

    for ci in range(color_channels):
        cn = color_channel_names[ci] if ci <  len(color_channel_names) else 'color'
        if bpy.app.version < (3, 4):
            bpy.ops.mesh.vertex_color_add()
            mesh.vertex_colors[ci].name  = cn
        else:
            mesh.color_attributes.new(name=cn, type='BYTE_COLOR',domain='CORNER')
    if color_channels:
        for loop_index, loop in enumerate(mesh.loops):
            vertex = vertices[loop.vertex_index]
            for ci in range(color_channels):
                mesh.vertex_colors[ci].data[loop_index].color = create_color([c for c in vertex.colour[ci]])

    if model_mesh.skin_controller:
        assert isinstance(skin_controller, ObjectSkinController)
        for vertex_index, vertex in enumerate(vertices):
            if vertex.blend_indices and vertex.blend_weights:
                for blend_index, blend_bone_index in enumerate(vertex.blend_indices):
                    if blend_bone_index >= 0:
                        weight = vertex.blend_weights[blend_index]
                        if weight > 0.0:
                            ix = int(blend_bone_index)
                            while ix >= len(vertex_groups):
                                bone_name = 'missing_bone_%s' % (len(vertex_groups))
                                bone_hash = FNV32.hash(bone_name)
                                vertex_group = mesh_obj.vertex_groups.new(name=bone_name)
                                vertex_groups.append(vertex_group)
                                vertex_group_map[bone_hash] = vertex_group
                            blend_vertex_group = vertex_groups[int(blend_bone_index)]
                            blend_vertex_group.add((vertex_index,), weight, 'ADD')


    for declaration in model_mesh.get_vertex_format().declarations:
        if declaration.usage == VertexFormat.USAGE.UV:
            add_uv_layer(mesh, 'uv_%i' %declaration.usage_index)


    normals = []
    uv_face_skipped = 0
    for face_index, face in enumerate(faces):
        if invalid_face(face):
            print('skipped face %s: %s ' % (face_index, face))
            uv_face_skipped += 1
            continue
        if face_index in faces_skipped:
            uv_face_skipped += 1
            continue
        for face_point_index, face_point_vertex_index in enumerate([face[1], face[2], face[0]]):
            vertex = vertices[face_point_vertex_index]
            if vertex.normal:
                normals.append(vertex.normal[:3])
            if vertex.uv:
                for uv_channel_index, uv_coord in enumerate(vertex.uv):
                    mesh.uv_layers[uv_channel_index].data[
                        face_point_index + ((face_index - uv_face_skipped) * 3)].uv = swizzle_uv(uv_coord)

    remove_doubles = True
    use_custom_normals = False
    set_context('OBJECT', None)
    set_context('OBJECT', mesh_obj)

    if ((bvmajor == 2 and bvminor >= 74) or bvmajor > 2)  and any(normals) and use_custom_normals:
        mesh.use_auto_smooth = True
        mesh.show_edge_sharp = True
        mesh.normals_split_custom_set(normals)
        if remove_doubles:
            set_context('EDIT', mesh_obj)
            bpy.ops.mesh.select_all(action='SELECT')
            bpy.ops.mesh.remove_doubles()
        set_context('OBJECT', mesh_obj)
    else:
        bpy.ops.object.shade_smooth()
    try:
        rotate_obj(math.pi / 2.0, 'X')
        bpy.ops.object.transform_apply(rotation=True)
    except Exception as e:
        print('unable to rotate the mesh')
        print(e)
        pass
    set_selection_state(mesh_obj, False)
    mesh_obj.active_shape_key_index = 0
    return mesh_obj


def load_lod(armature_rig, lod, material_path, state_hash):
    for mesh_index, model_mesh in enumerate(lod.meshes):
        diffuse_tex = os.path.join(material_path, '%s.png' % mesh_index)
        mesh_name = 's4studio_mesh_%s' % mesh_index
        mesh = load_mesh(armature_rig, model_mesh, mesh_name, diffuse_tex, state_hash)
        mesh.data.s4studio.cut = str(mesh_index)


def preserve_mesh(s4mesh_orig, s4mesh, state_hash=None):
    assert isinstance(s4mesh, ObjectMesh)
    assert isinstance(s4mesh_orig, ObjectMesh)
    vbuf = s4mesh.vertex_buffer
    vrtf = s4mesh.get_vertex_format()
    ibuf = s4mesh.index_buffer
    assert isinstance(vrtf, VertexFormat)
    assert isinstance(vbuf, VertexBuffer)
    assert isinstance(ibuf, IndexBuffer)
    vertices = s4mesh_orig.get_vertices(state_hash)
    triangles = s4mesh_orig.get_triangles(state_hash)
    uv_scales = s4mesh_orig.get_uv_scales()
    pos_scales = s4mesh_orig.get_pos_scale()

    start_index = len(ibuf.buffer)
    for tri in triangles:
        for index in tri:
            ibuf.buffer.append(index)

    offset = vbuf.buffer.write_vertices(vrtf, vertices, uv_scales, pos_scales, s4mesh.get_shader())

    s4mesh.start_index = start_index
    s4mesh.start_vertex = 0
    s4mesh.stream_offset = offset
    s4mesh.primitive_count = int(len(triangles))
    s4mesh.vertex_count = int(len(vertices))


def save_mesh(blender_mesh, s4mesh_orig, s4mesh_new, geometry_state):
    print('saving mesh geostate: %s' % geometry_state)
    apply_all_modifiers(blender_mesh)
    assert isinstance(s4mesh_new, ObjectMesh)
    assert isinstance(s4mesh_orig, ObjectMesh)
    sim_mesh_data = collect_mesh_data(blender_mesh,s4mesh_orig.bone_references if len(s4mesh_orig.states) else [])
    print('BONES: %s'%sim_mesh_data.bones)
    assert isinstance(sim_mesh_data, SimMeshData)
    vbuf = s4mesh_new.vertex_buffer
    vrtf = s4mesh_new.get_vertex_format()
    ibuf = s4mesh_new.index_buffer
    original_uv_scales = s4mesh_new.get_uv_scales()
    original_pos_scales = s4mesh_new.get_pos_scale()
    has_uv_scales = s4mesh_new.has_uv_scales()
    assert isinstance(vrtf, VertexFormat)
    assert isinstance(vbuf, VertexBuffer)
    assert isinstance(ibuf, IndexBuffer)
    if len(sim_mesh_data.bones) == 0:
        sim_mesh_data.bone_names.append('transformBone')
        sim_mesh_data.bones.append(FNV32.hash('transformBone'))
    s4mesh_new.bone_references = sim_mesh_data.bones
    if s4mesh_orig.flags & ObjectMesh.Flags.EXTRA_BOUNDS:
        print('Preserving bone bounding boxes %s' %(len(s4mesh_orig.extra_bounds),))
        s4mesh_new.extra_bounds = [
            s4mesh_orig.extra_bounds[s4mesh_orig.bone_references.index(bh)] if bh in s4mesh_orig.bone_references else BoundingBox() in s4mesh_new.bone_references for bh in s4mesh_new.bone_references]
        
        print('Preservied bounds %s' %(len(s4mesh_new.extra_bounds),))
        print(s4mesh_new.extra_bounds)
    else:
        print('Skipping bone bounds %s'%(hex(s4mesh_orig.flags)))
    expected_uv_count = 0
    for i in filter(lambda x: x.usage == VertexFormat.USAGE.UV, vrtf.declarations):
        expected_uv_count += 1

    matd = s4mesh_orig.get_material_definition()
    if matd.shader_name == FNV32.hash('AnimatedTree'):
        color_decls = list(filter(lambda d:d.usage == VertexFormat.USAGE.COLOR, vrtf.declarations))

        for vertex in sim_mesh_data.vertices:
            tint = [0] * 4 if len(color_decls)== 2 else None
            if tint and len(vertex.colour) == 5:
                tint = vertex.colour[0]
                tint[3] = 1.0
                vertex.colour.pop(0)
            rustle = [0] * 4
            for i in range(min(4, len(vertex.colour))):
                rustle[i] = vertex.colour[i][0]
            vertex.colour = []
            if tint:
                vertex.colour.append(tint)
            vertex.colour.append(rustle)
    print('%s vertices in mesh data'%len(sim_mesh_data.vertices))
    for vertex in sim_mesh_data.vertices:
        assert (isinstance(vertex, Vertex))
        diff_uv_count = expected_uv_count - len(vertex.uv)
        if vertex.blend_indices:
            m = max(vertex.blend_indices)
            if m > 0 and m >= len(s4mesh_new.bone_references):
                raise Sims4StudioException('Mesh Import Error','Your mesh group %s has more vertex groups than the original mesh.'%(blender_mesh.name))
        if diff_uv_count > 0:
            if len(vertex.uv) == 0:
                raise Sims4StudioException('Mesh Import Error','Your mesh group %s requires a UV map but does not have one.'%(blender_mesh.name))
            for i in range(diff_uv_count):
                vertex.uv.append(vertex.uv[0])



    total_mesh_bounds = BoundingBox()
    all_vertices = []
    s4mesh_new.primitive_count = 0
    if not len(s4mesh_orig.states):
        all_vertices.extend(sim_mesh_data.vertices)
        total_mesh_bounds.add(sim_mesh_data.bounds)
        s4mesh_new.primitive_count += int(len(sim_mesh_data.indices) / 3)
    else:
        for old_state, new_state in zip(s4mesh_orig.states, s4mesh_new.states):
            assert isinstance(new_state, ObjectMesh.State)
            assert isinstance(old_state, ObjectMesh.State)
            if int(new_state.state) == int(geometry_state):
                all_vertices.extend(sim_mesh_data.vertices)
                if len(sim_mesh_data.vertices):
                    total_mesh_bounds.add(sim_mesh_data.bounds)
                s4mesh_new.primitive_count += int(len(sim_mesh_data.indices) / 3)
            else:
                vertices = s4mesh_orig.get_vertices(old_state.state)
                all_vertices.extend(vertices)
                total_mesh_bounds.add(s4mesh_orig.bounds)
                s4mesh_new.primitive_count += old_state.primitive_count
    print('Calculated bounds %s'%total_mesh_bounds)
    s4mesh_new.bounds = total_mesh_bounds
    s4mesh_new.vertex_count = len(all_vertices)

    max_pos_size = vrtf.max_size_for_usage(VertexFormat.USAGE.POSITION)
    max_uv_size = vrtf.max_size_for_usage(VertexFormat.USAGE.UV)

    pos_scales = calculate_pos_scales(total_mesh_bounds, max_pos_size)
    s4mesh_new.set_pos_scales(pos_scales)
    uv_scales = original_uv_scales
    if has_uv_scales:
        uv_scales = calculate_uv_scales(all_vertices, max_uv_size)
        if not uv_scales and has_uv_scales:
            uv_scales = original_uv_scales
        # preserve original uv scales for any that were not calculated
        for i in range(len(uv_scales)):
            s = uv_scales[i]
            if s == 0 and len(original_uv_scales) >= i+1:
                print('restoring uv scale %s at index %s' %(original_uv_scales[i],i))
                uv_scales[i] = original_uv_scales[i]
        s4mesh_new.set_uv_scales(uv_scales)
    start_index = len(ibuf.buffer)
    s4mesh_new.start_index = start_index
    if not geometry_state:
        offset = vbuf.buffer.write_vertices(vrtf, sim_mesh_data.vertices, uv_scales, pos_scales,s4mesh_new.get_shader())

        ibuf.buffer.extend(sim_mesh_data.indices)
        s4mesh_new.start_vertex = 0
        s4mesh_new.stream_offset = offset
        s4mesh_new.primitive_count = int(len(sim_mesh_data.indices) / 3)
        s4mesh_new.vertex_count = int(len(sim_mesh_data.vertices))
        s4mesh_new.bounds = sim_mesh_data.bounds
    else:
        start_vertex = 0
        s4mesh_new.vertex_count = 0
        s4mesh_new.primitive_count = 0
        s4mesh_new.start_vertex = 0
        s4mesh_new.stream_offset = vbuf.buffer.stream.tell()
        min_vertex_index = 0
        for old_state, new_state in zip(s4mesh_orig.states, s4mesh_new.states):
            assert isinstance(new_state, ObjectMesh.State)
            assert isinstance(old_state, ObjectMesh.State)
            if int(new_state.state) == int(geometry_state):
                print('replacing state %s' % new_state.state)
                offset = vbuf.buffer.write_vertices(vrtf, sim_mesh_data.vertices, uv_scales, pos_scales, s4mesh_new.get_shader())
                start_index = len(ibuf.buffer)
                ibuf.buffer.extend([ix + min_vertex_index for ix in sim_mesh_data.indices])
                new_state.start_index = start_index
                new_state.min_vertex_index = min_vertex_index
                new_state.stream_offset = offset
                new_state.primitive_count = int(len(sim_mesh_data.indices) / 3)
                new_state.vertex_count = int(len(sim_mesh_data.vertices))
                start_vertex += int(len(sim_mesh_data.vertices))
                s4mesh_new.vertex_count += new_state.vertex_count
                s4mesh_new.primitive_count += new_state.primitive_count
                min_vertex_index += new_state.vertex_count
                pass
            else:
                print('skipping state %s' % new_state.state)
                vertices = s4mesh_orig.get_vertices(old_state.state)
                triangles = s4mesh_orig.get_triangles(old_state.state)
                start_index = len(ibuf.buffer)
                for tri in triangles:
                    for index in tri:
                        ibuf.buffer.append(index + min_vertex_index)

                offset = vbuf.buffer.write_vertices(vrtf, vertices, uv_scales, pos_scales,s4mesh_new.get_shader())

                new_state.start_index = start_index
                new_state.min_vertex_index = min_vertex_index
                new_state.stream_offset = offset
                new_state.primitive_count = int(len(triangles))
                new_state.vertex_count = int(len(vertices))
                start_vertex += new_state.vertex_count
                s4mesh_new.vertex_count += new_state.vertex_count
                s4mesh_new.primitive_count += new_state.primitive_count
                min_vertex_index += new_state.vertex_count






def save_lod(model_lod, geometry_state):
    s4studio_meshes = {}
    for o in get_objects():
        if o.type == 'MESH':
            name = o.data.s4studio.cut
            name = str('' if not name else name).strip()
            if name:
                if is_geonode(o):
                    raise Sims4StudioException('Mesh Import Error', 'Geometry Nodes are not supported at this time. Please convert ''%s'' to a standard mesh before importing.'%o.name)
                if not o.data.s4studio.cut in s4studio_meshes:
                    set_object_hide(o, False)
                    s4studio_meshes[name] = o
    if len(s4studio_meshes) == 0:
        raise Sims4StudioException('Mesh Import Error','No meshes found to import. Make sure you set the "Cut Number" in the S4Studio Mesh Tools panel for each mesh you are importing.')
    assert isinstance(model_lod, ModelLod)
    s4meshes = []
    shared_vbuf = None
    shared_ibuf = None
    for mesh_index, m in enumerate(model_lod.meshes):
        assert isinstance(m, ObjectMesh)
        name = str(mesh_index)
        new_mesh = copy.deepcopy(m)
        assert isinstance(new_mesh, ObjectMesh)
        if not isinstance(m.vertex_buffer, VertexBufferShadow):
            if not shared_vbuf:
                shared_vbuf = new_mesh.vertex_buffer
                shared_vbuf.buffer.clear()
            else:
                new_mesh.vertex_buffer = shared_vbuf
        else:
            new_mesh.vertex_buffer.buffer.clear()
        if not isinstance(m.index_buffer, IndexBufferShadow):
            if not shared_ibuf:
                shared_ibuf = new_mesh.index_buffer
                assert isinstance(shared_ibuf, IndexBuffer)
                shared_ibuf.buffer.clear()
            else:
                new_mesh.index_buffer = shared_ibuf
        else:
            new_mesh.index_buffer.buffer.clear()
        if name in s4studio_meshes:
            blender_mesh = s4studio_meshes[name]
            save_mesh(blender_mesh, m, new_mesh, geometry_state)
            print('new_mesh:')
            print(new_mesh.extra_bounds)

        else:
            preserve_mesh(m, new_mesh)
        s4meshes.append(new_mesh)
    model_lod.meshes = s4meshes


def load_model(package, modl, armature_rig, material):
    lod_entry = modl.lods[0]
    lod = lod_entry.model if isinstance(lod_entry.model, ModelLod) else package.find_key(lod_entry.model.key).fetch(
        ModelLod)
    load_lod(armature_rig, lod, material)


def load_product(vpxy, package, resource_name, presets=None):
    armature_rig = None
    rig = None
    rig_entry = first(vpxy.entries,
                      lambda e: isinstance(e, VisualProxy.MiscEntry) and e.resource.key.t == SkeletonRig.ID)
    if rig_entry:
        rig = package.find_key(rig_entry.resource.key)
    if rig:
        try:
            rig = rig.fetch(SkeletonRig)
            armature_rig = load_rig(rig)
        except:
            print('Unable to load rig, please patch your game...')
    if not armature_rig:
        print('No rig found')
        armature_rig = create_marker_node(resource_name, True)
    modl = package.find_key(first(vpxy.entries, lambda e: isinstance(e,
                                                                     VisualProxy.MiscEntry) and e.resource.key.t == Model.ID).resource.key).fetch(
        Model)
    if presets and any(presets):
        preset = presets[0]
    else:
        preset = package.find_key(ResourceKey(t=PackedPreset.ID, g=1, i=modl.key.i))
        if preset:
            preset = preset.fetch(PackedPreset)
        else:
            preset = Preset()
    ml = MaterialLoader(package, preset)
    load_model(package, modl, armature_rig, ml)
    return armature_rig


def load_all(package):
    for vpxy_index in package.find_all_type(VisualProxy.ID):
        print('loading vpxy : %s' % vpxy_index.key)
        vpxy = vpxy_index.fetch(VisualProxy)
        try:
            load_product(vpxy, package, str(vpxy_index.key), None)
        except:
            print('unable to load %s' % vpxy_index.key)


def load_object(objd, package):
    """

    @param objd:
    @param package:
    @return:
    """
    print('Loading object...')
    assert isinstance(objd, CatalogProductObject)
    objk = package.find_key(objd.object_component).fetch(ObjectComponentModel)
    assert isinstance(objk, ObjectComponentModel)
    vpxy = package.find_key(objk.component_data['modelKey'][1]).fetch(VisualProxy)
    return load_product(vpxy, package, objd.resource_name, objd.presets)


def load_stairs(cstr, package):
    print('Loading stairs...')
    assert isinstance(cstr, CatalogProductStairs)
    vpxy = package.find_key(cstr.steps_4x_model).fetch(VisualProxy)
    return load_product(vpxy, package, cstr.resource_name, cstr.presets)


def load_railing(cral, package):
    print('Loading railing!')
    assert isinstance(cral, CatalogProductRailing)
    vpxy = package.find_key(cral.railing_4x_model).fetch(VisualProxy)
    return load_product(vpxy, package, cral.resource_name, cral.presets)


def load_fireplace(cfire, package):
    print('Loading fireplace...')
    assert isinstance(cfire, CatalogProductFireplace)
    objd = package.find_key(cfire.mantle).fetch(CatalogProductObject)
    return load_object(objd, package)


def load_fence(cfen, package):
    print('Loading fence...')
    assert isinstance(cfen, CatalogProductFence)
    vpxy = package.find_key(cfen.model).fetch(VisualProxy)
    return load_product(vpxy, package, cfen.resource_name, cfen.presets)


