#ifndef _HWBREAK_H_
#define _HWBREAK_H_

#include <windows.h>

// Hardware Breakpoints, by Andrew Birkett <andy@nobugs.org>
class HWBreak
{
public:
    typedef enum { ReadWrite, Write } BreakPointType;
    
    // Set an available hardware breakpoint to halt on reads/writes to the dword at [address]

    //
    // For example:
    //  int i = 0;
    //  HWBreak(&i);
    //  int x = i + 2; (break on read)
    //  int i++;       (break on write)
    // 
    // NB: The debugger will probably stop on the line /after/ the breakpoint.

    HWBreak(void *address,  BreakPointType events = ReadWrite) 
    { 
        m_break_address = (DWORD) address;
        m_events = events;
        
        // Get a real handle to the current thread
        ::DuplicateHandle( 
            ::GetCurrentProcess(), ::GetCurrentThread(),
            ::GetCurrentProcess(),  &m_main_thread_id,
            0, FALSE, DUPLICATE_SAME_ACCESS);


        m_enable = true;
        RunHelperThread();
    }

    ~HWBreak()
    {
        if (m_bp_num != -1) {
            m_enable = false;
            RunHelperThread();
        }

        ::CloseHandle(m_main_thread_id);
    }

private:

        
    HANDLE m_main_thread_id;
    DWORD m_break_address;
    BreakPointType m_events;    
    DWORD m_bp_num; // Which hw breakpoint are we using? (0 - 3, or -1 if we didn't set a bp)
    
    bool m_enable; // Current worker thread action

    enum { THREAD_FAILED, THREAD_SUCCESS };
    
    // We can only safely alter the thread context of a suspended thread,
    // so we create a worker thread which suspends us, alters our context, and
    // restarts us.  It update m_bp_num.
    void RunHelperThread()
    { 
        // Start the helper thread
        DWORD thread_id = 0;
        HANDLE hThread = ::CreateThread( NULL, 0, threadFn, this, 0, &thread_id );

        // Wait until the thread has exited (ie. has updated our registers)
        DWORD status = ::WaitForSingleObject(hThread, INFINITE);
        
        if (status != WAIT_OBJECT_0) { 
            ::MessageBoxA(0, "Helper thread didn't exit cleanly", "HWBreakpoint", MB_OK);
        }

        DWORD exit_code = 0;
        if (!::GetExitCodeThread(hThread, &exit_code)) {
            ::MessageBoxA(0, "Failed to get worked thread exit code", "HWBreakPoint", MB_OK);
        }
        
        if (exit_code == THREAD_FAILED) {
            ::MessageBoxA(0, "Worker thread returned failure code", "HWBreakPoint", MB_OK);
        }


        if (m_bp_num == -1) {
            ::MessageBoxA(0, "All four hardware breakpoints are in use - breakpoint not set.", "HWBreakPoint", MB_OK);
        } 
        
        ::CloseHandle(hThread);
    }
 
    static int FindFreeBreakpoint(const CONTEXT &context)
    {
        for (int bp = 0; bp <= 3; bp++) {
            int mask = 1 << (bp * 2);
            if ((context.Dr7 & mask) == 0) return bp;
        }

        return -1;
    }

    // The thread gets the required action from that->m_enable.
    // If we're enabling a breakpoint, it'll write the breakpoint number (0-3) into that->m_bp_num (or -1 if there's none free)
    // If we're disabling a breakpoint, it finds the breakpoint number from that->m_bp_num
    static DWORD WINAPI threadFn(void *data)
    {

        HWBreak *that = (HWBreak *) data;
       
        // Suspend the main thread
        DWORD suspend_count = ::SuspendThread(that->m_main_thread_id);

        if (suspend_count ) return THREAD_FAILED;
        
        CONTEXT context;
        context.ContextFlags = CONTEXT_DEBUG_REGISTERS;

        DWORD ok = ::GetThreadContext(that->m_main_thread_id, &context);

        if (ok)
        {
            if (that->m_enable) {
                int bp = FindFreeBreakpoint(context);
                
                if (bp != -1)
                {            
                    DWORD enable = 0x1 << (bp *2);
                    DWORD rw = (that->m_events == ReadWrite ? 0x000f0000 : 0x000d0000) << (bp * 4);
                    DWORD mask = enable | rw;
                    

                    switch (bp)
                    {
                    case 0: context.Dr0 = that->m_break_address;  break;
                    case 1: context.Dr1 = that->m_break_address;  break;
                    case 2: context.Dr2 = that->m_break_address;  break;
                    case 3: context.Dr3 = that->m_break_address;  break;
                    }
                    
                    context.Dr7 |= mask;

                    ok = ::SetThreadContext(that->m_main_thread_id, &context);
                }

                that->m_bp_num = bp;
                
            } else {
                DWORD enable = 0x1 << (that->m_bp_num *2);
                context.Dr7 &= ~enable;
                ok = ::SetThreadContext(that->m_main_thread_id, &context);
            }
        }
        
        // Resume the main thread
        suspend_count = ::ResumeThread(that->m_main_thread_id);
        if (suspend_count == 0xFFFFFFFF) return THREAD_FAILED;

        return ok ? THREAD_SUCCESS : THREAD_FAILED;
    }   
};

#endif // _HWBREAK_H_
