/* -*-c++-*- OpenThreads - Copyright (C) 1998-2007 Robert Osfield 
 *
 * This library is open source and may be redistributed and/or modified under  
 * the terms of the OpenSceneGraph Public License (OSGPL) version 0.0 or 
 * (at your option) any later version.  The full license is in LICENSE file
 * included with this distribution, and on the openscenegraph.org website.
 * 
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
 * OpenSceneGraph Public License for more details.
*/

#ifndef _OPENTHREADS_REENTRANTMUTEX_
#define _OPENTHREADS_REENTRANTMUTEX_

#include <OpenThreads/Thread>
#include <OpenThreads/Mutex>
#include <OpenThreads/ScopedLock>

namespace OpenThreads {

class ReentrantMutex : public OpenThreads::Mutex
{
    public:

        ReentrantMutex():
            _threadHoldingMutex(0),
            _lockCount(0) {}

        virtual ~ReentrantMutex() {}

        virtual int lock()
        {
            {
                OpenThreads::ScopedLock<OpenThreads::Mutex> lock(_lockCountMutex);
                if (_threadHoldingMutex==OpenThreads::Thread::CurrentThread() && _lockCount>0)
                {
                    ++_lockCount;
                    return 0;
                }
            }

            int result = Mutex::lock();
            if (result==0)
            {
                OpenThreads::ScopedLock<OpenThreads::Mutex> lock(_lockCountMutex);

                _threadHoldingMutex = OpenThreads::Thread::CurrentThread();
                _lockCount = 1;
            }
            return result;
        }


        virtual int unlock()
        {
            OpenThreads::ScopedLock<OpenThreads::Mutex> lock(_lockCountMutex);
        #if 0
            if (_threadHoldingMutex==OpenThreads::Thread::CurrentThread() && _lockCount>0)
            {
                --_lockCount;
                if (_lockCount<=0)
                {
                    _threadHoldingMutex = 0;
                    return Mutex::unlock();
                }
            }
            else
            {
                osg::notify(osg::NOTICE)<<"Error: ReentrantMutex::unlock() - unlocking from the wrong thread."<<std::endl;
            }
        #else
            if (_lockCount>0)
            {
                --_lockCount;
                if (_lockCount<=0)
                {
                    _threadHoldingMutex = 0;
                    return Mutex::unlock();
                }
            }
        #endif    
            return 0;
        }
        

        virtual int trylock()
        {
            {
                OpenThreads::ScopedLock<OpenThreads::Mutex> lock(_lockCountMutex);
                if (_threadHoldingMutex==OpenThreads::Thread::CurrentThread() && _lockCount>0)
                {
                    ++_lockCount;
                    return 0;
                }
            }

            int result = Mutex::trylock();
            if (result==0)
            {
                OpenThreads::ScopedLock<OpenThreads::Mutex> lock(_lockCountMutex);

                _threadHoldingMutex = OpenThreads::Thread::CurrentThread();
                _lockCount = 1;
            }
            return result;
        }

    private:

        ReentrantMutex(const ReentrantMutex&):OpenThreads::Mutex() {}

        ReentrantMutex& operator =(const ReentrantMutex&) { return *(this); }
        
        OpenThreads::Thread* _threadHoldingMutex;
        
        OpenThreads::Mutex  _lockCountMutex;
        unsigned int        _lockCount;

};

}

#endif
