/* math3d.c
 *
 * Math 3D Library
 *
 * 
 */

#include <stdio.h>
#include "math3d.h"

void m3dFindNormalf(M3DVector3f result, const M3DVector3f point1, 
                    const M3DVector3f point2, const M3DVector3f point3)
{
    M3DVector3f v1, v2;

    /* calculate two vectors from the three points, assumes 
     * counter clockwise winding
     */
    v1[0] = point1[0] - point2[0];
    v1[1] = point1[1] - point2[1];
    v1[2] = point1[2] - point2[2];

    v2[0] = point2[0] - point3[0];
    v2[1] = point2[1] - point3[1];
    v2[2] = point2[2] - point3[2];

    /* take the cross product of the two vectors to get the normal vector */
    m3dCrossProductf(result, v1, v2);
}

void m3dLoadIdentity33f(M3DMatrix33f m)
{
    /* don't be fooled, this is still column major */
    static M3DMatrix33f identity = { 1.0f, 0.0f, 0.0f,
                                     0.0f, 1.0f, 0.0f,
                                     0.0f, 0.0f, 1.0f };

    memcpy(m, identity, sizeof(M3DMatrix33f));
}

void m3dLoadIdentity44f(M3DMatrix44f m) /* 4x4 float */
{
    /* don't be fooled, this is still column major */
    static M3DMatrix44f identity = { 1.0f, 0.0f, 0.0f, 0.0f,
                                     0.0f, 1.0f, 0.0f, 0.0f,
                                     0.0f, 0.0f, 1.0f, 0.0f,
                                     0.0f, 0.0f, 0.0f, 1.0f };

    memcpy(m, identity, sizeof(M3DMatrix44f));
}

/* calculate the plane equation of the plane that the three specified points lay in
 * the points are given in clockwise winding order, with normal pointing out of clockwise face
 * planeEq contains the a, b, c, and d of the plane equation coefficients
 */ 
void m3dGetPlaneEquationf(M3DVector4f planeEq, const M3DVector3f p1,
                          const M3DVector3f p2, const M3DVector3f p3)
{
    /* get two vectors... do the cross product */
    M3DVector3f v1, v2;

    /* v1 = p3 - p1 */
    v1[0] = p3[0] - p1[0];
    v1[1] = p3[1] - p1[1];
    v1[2] = p3[2] - p1[2];

    /* v2 = p2 - p1 */
    v2[0] = p2[0] - p1[0];
    v2[1] = p2[1] - p1[1];
    v2[2] = p2[2] - p1[2];

    /* unit normal to plane - Not sure which is the best way here */
    m3dCrossProductf(planeEq, v1, v2);
    m3dNormalizeVectorf(planeEq);
    /* back substitute to get d */
    planeEq[3] = -(planeEq[0] * p3[0] + planeEq[1] * p3[1] + planeEq[2] * p3[2]);
}

/* creates a 4x4 rotation matrix, takes radians not degrees */
void m3dRotationMatrix44f(M3DMatrix44f m, float angle, float x, float y, float z)
{
    float mag, s, c;
    float xx, yy, zz, xy, yz, zx, xs, ys, zs, one_c;

    s = (float) (sin(angle));
    c = (float) (cos(angle));

    mag = (float) (sqrt(x*x + y*y + z*z));

    /* identity matrix */
    if (mag == 0.0f)
    {
        m3dLoadIdentity44f(m);
        return;
    }

    /* rotation matrix is normalized */
    x /= mag;
    y /= mag;
    z /= mag;

    #define M(row,col) m[col*4+row]

    xx = x * x;
    yy = y * y;
    zz = z * z;
    xy = x * y;
    yz = y * z;
    zx = z * x;

    xs = x * s;
    ys = y * s;
    zs = z * s;
    one_c = 1.0f - c;

    M(0,0) = (one_c * xx) + c;
    M(0,1) = (one_c * xy) - zs;
    M(0,2) = (one_c * zx) + ys;
    M(0,3) = 0.0f;

    M(1,0) = (one_c * xy) + zs;
    M(1,1) = (one_c * yy) + c;
    M(1,2) = (one_c * yz) - xs;
    M(1,3) = 0.0f;

    M(2,0) = (one_c * zx) - ys;
    M(2,1) = (one_c * yz) + xs;
    M(2,2) = (one_c * zz) + c;
    M(2,3) = 0.0f;

    M(3,0) = 0.0f;
    M(3,1) = 0.0f;
    M(3,2) = 0.0f;
    M(3,3) = 1.0f;

    #undef M
}

/* create a projection to "squish" an object into the plane.
 * use m3dGetPlaneEquationf(planeEq, point1, point2, point3); to get a plane equation
 */
void m3dMakePlanarShadowMatrixf(M3DMatrix44f proj, const M3DVector4f planeEq,
                                const M3DVector3f vLightPos)
{
    /* these just make the code below easier to read */
    float a = planeEq[0];
    float b = planeEq[1];
    float c = planeEq[2];
    float d = planeEq[3];

    float dx = -vLightPos[0];
    float dy = -vLightPos[1];
    float dz = -vLightPos[2];

    /* now build the projection matrix */
    proj[0] = b * dy + c * dz;
    proj[1] = -a * dy;
    proj[2] = -a * dz;
    proj[3] = 0.0;

    proj[4] = -b * dx;
    proj[5] = a * dx + c * dz;
    proj[6] = -b * dz;
    proj[7] = 0.0;

    proj[8] = -c * dx;
    proj[9] = -c * dy;
    proj[10] = a * dx + b * dy;
    proj[11] = 0.0;

    proj[12] = -d * dx;
    proj[13] = -d * dy;
    proj[14] = -d * dz;
    proj[15] = a * dx + b * dy + c * dz;
    /* shadow matrix ready */
}

/* hmm not sure if defines could reside outside of a function as they did */
void m3dMatrixMultiply44f(M3DMatrix44f product, const M3DMatrix44f a, 
                          const M3DMatrix44f b)
{
    #define A(row,col)  a[(col<<2)+row]
    #define B(row,col)  b[(col<<2)+row]
    #define P(row,col)  product[(col<<2)+row]

    int i;
    for (i = 0; i < 4; i++)
    {
        float ai0 = A(i, 0), ai1 = A(i, 1), ai2 = A(i, 2), ai3 = A(i, 3);
        P(i, 0) = ai0 * B(0, 0) + ai1 * B(1, 0) + ai2 * B(2, 0) + ai3 * B(3, 0);
        P(i, 1) = ai0 * B(0, 1) + ai1 * B(1, 1) + ai2 * B(2, 1) + ai3 * B(3, 1);
        P(i, 2) = ai0 * B(0, 2) + ai1 * B(1, 2) + ai2 * B(2, 2) + ai3 * B(3, 2);
        P(i, 3) = ai0 * B(0, 3) + ai1 * B(1, 3) + ai2 * B(2, 3) + ai3 * B(3, 3);
    }

    #undef A
    #undef B
    #undef P
}

void m3dPrintMatrix44f(const M3DMatrix44f m)
{
    int i, j;

    #define M(row,col) m[col*4+row]

    for (i = 0; i < 4; i++)
    {
        for (j = 0; j < 4; j++)
        {
            printf("%5.2f ", M(i, j));
        }
        puts("");
    }

    puts("");

    #undef M
}

/* invert 4x4 matrix, contributed by David Moore (See Mesa bug #6748) */
int m3dInvertMatrix44f(M3DMatrix44f dst, const M3DMatrix44f m)
{
    M3DMatrix44f inv;
    double det;
    int i;

    inv[0] =   m[5]*m[10]*m[15] - m[5]*m[11]*m[14] - m[9]*m[6]*m[15]
             + m[9]*m[7]*m[14]  + m[13]*m[6]*m[11] - m[13]*m[7]*m[10];
    inv[4] =  -m[4]*m[10]*m[15] + m[4]*m[11]*m[14] + m[8]*m[6]*m[15]
             - m[8]*m[7]*m[14]  - m[12]*m[6]*m[11] + m[12]*m[7]*m[10];
    inv[8] =   m[4]*m[9]*m[15]  - m[4]*m[11]*m[13] - m[8]*m[5]*m[15]
             + m[8]*m[7]*m[13]  + m[12]*m[5]*m[11] - m[12]*m[7]*m[9];
    inv[12] = -m[4]*m[9]*m[14]  + m[4]*m[10]*m[13] + m[8]*m[5]*m[14]
             - m[8]*m[6]*m[13]  - m[12]*m[5]*m[10] + m[12]*m[6]*m[9];
    inv[1] =  -m[1]*m[10]*m[15] + m[1]*m[11]*m[14] + m[9]*m[2]*m[15]
             - m[9]*m[3]*m[14]  - m[13]*m[2]*m[11] + m[13]*m[3]*m[10];
    inv[5] =   m[0]*m[10]*m[15] - m[0]*m[11]*m[14] - m[8]*m[2]*m[15]
             + m[8]*m[3]*m[14]  + m[12]*m[2]*m[11] - m[12]*m[3]*m[10];
    inv[9] =  -m[0]*m[9]*m[15]  + m[0]*m[11]*m[13] + m[8]*m[1]*m[15]
             - m[8]*m[3]*m[13]  - m[12]*m[1]*m[11] + m[12]*m[3]*m[9];
    inv[13] =  m[0]*m[9]*m[14]  - m[0]*m[10]*m[13] - m[8]*m[1]*m[14]
             + m[8]*m[2]*m[13]  + m[12]*m[1]*m[10] - m[12]*m[2]*m[9];
    inv[2] =   m[1]*m[6]*m[15]  - m[1]*m[7]*m[14]  - m[5]*m[2]*m[15]
             + m[5]*m[3]*m[14]  + m[13]*m[2]*m[7]  - m[13]*m[3]*m[6];
    inv[6] =  -m[0]*m[6]*m[15]  + m[0]*m[7]*m[14]  + m[4]*m[2]*m[15]
             - m[4]*m[3]*m[14]  - m[12]*m[2]*m[7]  + m[12]*m[3]*m[6];
    inv[10] =  m[0]*m[5]*m[15]  - m[0]*m[7]*m[13]  - m[4]*m[1]*m[15]
             + m[4]*m[3]*m[13]  + m[12]*m[1]*m[7]  - m[12]*m[3]*m[5];
    inv[14] = -m[0]*m[5]*m[14]  + m[0]*m[6]*m[13]  + m[4]*m[1]*m[14]
             - m[4]*m[2]*m[13]  - m[12]*m[1]*m[6]  + m[12]*m[2]*m[5];
    inv[3] =  -m[1]*m[6]*m[11]  + m[1]*m[7]*m[10]  + m[5]*m[2]*m[11]
             - m[5]*m[3]*m[10]  - m[9]*m[2]*m[7]   + m[9]*m[3]*m[6];
    inv[7] =   m[0]*m[6]*m[11]  - m[0]*m[7]*m[10]  - m[4]*m[2]*m[11]
             + m[4]*m[3]*m[10]  + m[8]*m[2]*m[7]   - m[8]*m[3]*m[6];
    inv[11] = -m[0]*m[5]*m[11]  + m[0]*m[7]*m[9]   + m[4]*m[1]*m[11]
             - m[4]*m[3]*m[9]   - m[8]*m[1]*m[7]   + m[8]*m[3]*m[5];
    inv[15] =  m[0]*m[5]*m[10]  - m[0]*m[6]*m[9]   - m[4]*m[1]*m[10]
             + m[4]*m[2]*m[9]   + m[8]*m[1]*m[6]   - m[8]*m[2]*m[5];

    det = m[0]*inv[0] + m[1]*inv[4] + m[2]*inv[8] + m[3]*inv[12];
    if (det == 0)
        return -1;

    det = 1.0 / det;

    for (i = 0; i < 16; i++)
        dst[i] = inv[i] * det;

    return 0;
}

unsigned int m3dPowerOfTwo(unsigned int n)
{
    unsigned int ret = 1;

    while (ret < n)
        ret <<= 1;

    return ret;
}