/*  -*-c++-*- 
 *  Copyright (C) 2008 Cedric Pinson <mornifle@plopbyte.net>
 *
 * This library is open source and may be redistributed and/or modified under  
 * the terms of the OpenSceneGraph Public License (OSGPL) version 0.0 or 
 * (at your option) any later version.  The full license is in LICENSE file
 * included with this distribution, and on the openscenegraph.org website.
 * 
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
 * OpenSceneGraph Public License for more details.
*/

#ifndef OSGANIMATION_SKINNING_H
#define OSGANIMATION_SKINNING_H

#include <iostream>
#include <vector>
#include <map>
#include <string>
#include <algorithm>
#include <osgAnimation/VertexInfluence>
#include <osgAnimation/Bone>
#include <osg/Matrix>
#include <osg/Vec3>
#include <osg/Quat>

namespace osgAnimation 
{

    /// This class manage format for software skinning
    /// it used the technic on this paper http://www.intel.com/cd/ids/developer/asmo-na/eng/172124.htm
    /// The idea is to prepare the data to do only v' = M x v  with M a combined matrix as below
    /// M = Mbone1 * w1 + Mbone2 * w2 + ...
    /// a M matrix is uniq for a set of vertex then to fully compute the skinned mesh
    /// you have to iterate on each UniqBoneSetVertexSet
    class TransformVertexFunctor
    {
    public:
        typedef osg::Matrix MatrixType;
        typedef osgAnimation::Bone BoneType;
        typedef Bone::BoneMap BoneMap;

        class BoneWeight 
        {
        public:
            BoneWeight(BoneType* bone, float weight) : _bone(bone), _weight(weight) {}
            const BoneType* getBone() const { return _bone.get(); }
            float getWeight() const { return _weight; }
            void setWeight(float w) { _weight = w; }
        protected:
            osg::observer_ptr<BoneType> _bone;
            float _weight;
        };

        typedef std::vector<BoneWeight> BoneWeightList;
        typedef std::vector<int> VertexList;

        class UniqBoneSetVertexSet 
        {
        public:
            BoneWeightList& getBones() { return _bones; }
            VertexList& getVertexes() { return _vertexes; }
            
            void resetMatrix() 
            {
                _result.set(0, 0, 0, 0,
                            0, 0, 0, 0,
                            0, 0, 0, 0,
                            0, 0, 0, 1);
            }
            void accummulateMatrix(const osg::Matrix& invBindMatrix, const osg::Matrix& matrix, osg::Matrix::value_type weight)
            {
                osg::Matrix m = invBindMatrix * matrix;
                osg::Matrix::value_type* ptr = m.ptr();
                osg::Matrix::value_type* ptrresult = _result.ptr();
                ptrresult[0] += ptr[0] * weight;
                ptrresult[1] += ptr[1] * weight;
                ptrresult[2] += ptr[2] * weight;

                ptrresult[4] += ptr[4] * weight;
                ptrresult[5] += ptr[5] * weight;
                ptrresult[6] += ptr[6] * weight;

                ptrresult[8] += ptr[8] * weight;
                ptrresult[9] += ptr[9] * weight;
                ptrresult[10] += ptr[10] * weight;

                ptrresult[12] += ptr[12] * weight;
                ptrresult[13] += ptr[13] * weight;
                ptrresult[14] += ptr[14] * weight;
            }
            void computeMatrixForVertexSet()
            {
                if (_bones.empty())
                {
                    osg::notify(osg::WARN) << "TransformVertexFunctor::UniqBoneSetVertexSet no bones found" << std::endl;
                    _result = MatrixType::identity();
                    return;
                }
                resetMatrix();

                int size = _bones.size();
                for (int i = 0; i < size; i++)
                {
                    const BoneType* bone = _bones[i].getBone();
                    const MatrixType& invBindMatrix = bone->getInvBindMatrixInSkeletonSpace();
                    const MatrixType& matrix = bone->getMatrixInSkeletonSpace();
                    osg::Matrix::value_type w = _bones[i].getWeight();
                    accummulateMatrix(invBindMatrix, matrix, w);
                }
            }
            const MatrixType& getMatrix() const { return _result;}
        protected:
            BoneWeightList _bones;
            VertexList _vertexes;
            MatrixType _result;
        };


        void init(const BoneMap& map, const osgAnimation::VertexInfluenceSet::UniqVertexSetToBoneSetList& influence)
        {
            _boneSetVertexSet.clear();

            int size = influence.size();
            _boneSetVertexSet.resize(size);
            for (int i = 0; i < size; i++) 
            {
                const osgAnimation::VertexInfluenceSet::UniqVertexSetToBoneSet& inf = influence[i];
                int nbBones = inf.getBones().size();
                BoneWeightList& boneList = _boneSetVertexSet[i].getBones();

                double sumOfWeight = 0;
                for (int b = 0; b < nbBones; b++) 
                {
                    const std::string& bname = inf.getBones()[b].getBoneName();
                    float weight = inf.getBones()[b].getWeight();
                    BoneMap::const_iterator it = map.find(bname);
                    if (it == map.end()) 
                    {
                        osg::notify(osg::WARN) << "TransformVertexFunctor Bone " << bname << " not found, skip the influence group " <<bname  << std::endl;
                        continue;
                    }
                    BoneType* bone = it->second.get();
                    boneList.push_back(BoneWeight(bone, weight));
                    sumOfWeight += weight;
                }
                // if a bone referenced by a vertexinfluence is missed it can make the sum less than 1.0
                // so we check it and renormalize the all weight bone
                const double threshold = 1e-4;
                if (!_boneSetVertexSet[i].getBones().empty() && 
                    (sumOfWeight < 1.0 - threshold ||  sumOfWeight > 1.0 + threshold))
                {
                    for (int b = 0; b < (int)boneList.size(); b++)
                        boneList[b].setWeight(boneList[b].getWeight() / sumOfWeight);
                }
                _boneSetVertexSet[i].getVertexes() = inf.getVertexes();
            }
        }


        template <class V> void compute(const V* src, V* dst)
        {
            int size = _boneSetVertexSet.size();
            for (int i = 0; i < size; i++) 
            {
                UniqBoneSetVertexSet& uniq = _boneSetVertexSet[i];
                uniq.computeMatrixForVertexSet();
                const MatrixType& matrix = uniq.getMatrix();

                const VertexList& vertexes = uniq.getVertexes();
                int vertexSize = vertexes.size();
                for (int j = 0; j < vertexSize; j++) 
                {
                    int idx = vertexes[j];
                    dst[idx] = src[idx] * matrix;
                }
            }
        }

        template <class V> void compute(const MatrixType& transform, const MatrixType& invTransform, const V* src, V* dst)
        {
            // the result of matrix mult should be cached to be used for vertexes transform and normal transform and maybe other computation
            int size = _boneSetVertexSet.size();
            for (int i = 0; i < size; i++) 
            {
                UniqBoneSetVertexSet& uniq = _boneSetVertexSet[i];
                uniq.computeMatrixForVertexSet();
                MatrixType matrix =  transform * uniq.getMatrix() * invTransform;

                const VertexList& vertexes = uniq.getVertexes();
                int vertexSize = vertexes.size();
                for (int j = 0; j < vertexSize; j++)
                {
                    int idx = vertexes[j];
                    dst[idx] = src[idx] * matrix;
                }
            }
        }


        template <class V> void computeNormal(const MatrixType& transform, const MatrixType& invTransform, const V* src, V* dst)
        {
            int size = _boneSetVertexSet.size();
            for (int i = 0; i < size; i++) 
            {
                UniqBoneSetVertexSet& uniq = _boneSetVertexSet[i];
                uniq.computeMatrixForVertexSet();
                MatrixType matrix =  transform * uniq.getMatrix() * invTransform;

                const VertexList& vertexes = uniq.getVertexes();
                int vertexSize = vertexes.size();
                for (int j = 0; j < vertexSize; j++) 
                {
                    int idx = vertexes[j];
                    dst[idx] = MatrixType::transform3x3(src[idx],matrix);
                }
            }
        }

    protected:

        std::vector<UniqBoneSetVertexSet> _boneSetVertexSet;
    };
}

#endif
