// ***************************************************************************
//
// ReferenceCounting.cpp
//
// Example using C++ handle classes to implement reference counting
// 
// ***************************************************************************

#include <iostream>
#include <string>
#include <vector>

using namespace std;

#define PAUSE() (cin.get(),cin.get())

#define ASSERT( x )  \
((x) ? 0 : &(cout << #x ": Failed: " __FILE__ " " << __LINE__ << endl) \
&& PAUSE() ) 
#define SAY( x ) (cout << x)


template <class T> class ReferenceCountingPointer {
public:
    ReferenceCountingPointer() : pointer( 0 ) {}
    ReferenceCountingPointer( T* p ) 
        : pointer( p ) { IncrementCount(); }
    ReferenceCountingPointer( const ReferenceCountingPointer<T>& other ) 
        : pointer( other.pointer ) { IncrementCount(); }
    ~ReferenceCountingPointer() { DecrementCount(); }
    
    const ReferenceCountingPointer<T>&
    operator=( const ReferenceCountingPointer<T>& other ) {
        if (this != &other) {
            DecrementCount();
            pointer = other.pointer;
            IncrementCount();
        }
        return *this;
    }

    T* operator->() const { return pointer; }
    T& operator*() const { return *pointer; }

    // For Microsoft's STL Implementation:
    bool operator<( const ReferenceCountingPointer<T>& other ) const { 
        return pointer < other.pointer;
    }
    bool operator==( const ReferenceCountingPointer<T>& other ) const {
        return pointer == other.pointer;
    }

private:
    T* pointer;
    void IncrementCount() { if (pointer) pointer->IncrementCount(); }
    void DecrementCount() { if (pointer) pointer->DecrementCount(); }
};


class ReferenceCountedObject {
private:
    int referenceCount;
public:
    void IncrementCount() { referenceCount++; }
    void DecrementCount() { if (--referenceCount == 0) delete this; }
protected:
    ReferenceCountedObject() : referenceCount( 0 ) {}
    virtual ~ReferenceCountedObject() { }
};

typedef char Pixel;
enum { SCREEN_WIDTH = 240, SCREEN_HEIGHT = 480 };

class ScreenImage : public ReferenceCountedObject {

    Pixel pixels[SCREEN_WIDTH * SCREEN_HEIGHT];

public:
    typedef ReferenceCountingPointer<ScreenImage> Pointer;

    ScreenImage()  { totalInstances++; }
    ~ScreenImage() { totalInstances--; }
    void SetPixel( int i, Pixel p ) { pixels[i] = p; }
    Pixel GetPixel( int i ) { return pixels[i]; }

    static int totalInstances;
};

/*static*/
int ScreenImage::totalInstances = 0;


void test1();
void test2();

void main() {
    test1();
    ASSERT( ScreenImage::totalInstances == 0 );
    cout << "test 1 done...\n";
    test2();
    ASSERT( ScreenImage::totalInstances == 0 );
    cout << "test 2 done...\n";
    PAUSE();
}

void test1() {
    cout << "Basic pointer tests...\n";

    { 
        ScreenImage::Pointer image = new ScreenImage;
        image->SetPixel( 0, 0 );
    }
    ASSERT( ScreenImage::totalInstances == 0 );
        
    ScreenImage::Pointer a = new ScreenImage;
    ScreenImage::Pointer b = new ScreenImage;
    a->SetPixel( 0, 0 );
    b->SetPixel( 0, 1 );
    ASSERT( a->GetPixel( 0 ) == 0 );
    ASSERT( b->GetPixel( 0 ) == 1 );

    ScreenImage::Pointer c = b;
    ASSERT( c->GetPixel( 0 ) == 1 );

    c = a;
    ASSERT( a->GetPixel( 0 ) == 0 );
    ASSERT( b->GetPixel( 0 ) == 1 );
    ASSERT( c->GetPixel( 0 ) == 0 );
    ASSERT( (*a).GetPixel( 0 ) == 0 );
    b = a;
    ASSERT( a->GetPixel( 0 ) == 0 );
    ASSERT( b->GetPixel( 0 ) == 0 );
    ASSERT( c->GetPixel( 0 ) == 0 );

    ASSERT( ScreenImage::totalInstances == 1 );
}

void test2() { 
    cout << "test 2 starting...\n";

    vector<ScreenImage::Pointer> images;

    for (int i=0; i<10; i++)
        images.push_back( new ScreenImage );

    ASSERT( ScreenImage::totalInstances == 10 );

    for (i=0; i<10; i++)
        images.push_back( images[0] );

    ASSERT( ScreenImage::totalInstances == 10 );
}
