#pragma once
#include <assert.h>

/**
 * A smart pointer to an object which implements AddRef/Release.
 */
template<typename ReferencedType>
class ref_count_ptr {
    typedef ReferencedType* reference_type;

public:
    ref_count_ptr(): reference_(nullptr) {
    }

    explicit ref_count_ptr(ReferencedType* in_reference, bool bAddRef = true) {
        reference_ = in_reference;
        if (reference_ && bAddRef) {
            reference_->AddRef();
        }
    }

    ref_count_ptr(const ref_count_ptr& copy) {
        reference_ = copy.reference_;
        if (reference_) {
            reference_->AddRef();
        }
    }

    template<typename CopyReferencedType>
    explicit ref_count_ptr(const ref_count_ptr<CopyReferencedType>& copy) {
        reference_ = static_cast<ReferencedType*>(copy.get_reference());
        if (reference_) {
            reference_->AddRef();
        }
    }

    ref_count_ptr(ref_count_ptr&& move) noexcept {
        reference_ = move.reference_;
        move.reference_ = nullptr;
    }

    template<typename MoveReferencedType>
    explicit ref_count_ptr(ref_count_ptr<MoveReferencedType>&& move) {
        reference_ = static_cast<ReferencedType*>(move.get_reference());
        move.reference_ = nullptr;
    }

    ~ref_count_ptr() {
        if (reference_) {
            reference_->Release();
        }
    }

    ref_count_ptr& operator=(ReferencedType* in_reference) {
        if (reference_ != in_reference) {
            // Call AddRef before Release, in case the new reference is the same as the old reference.
            ReferencedType* old_reference = reference_;
            reference_ = in_reference;
            if (reference_) {
                reference_->AddRef();
            }
            if (old_reference) {
                old_reference->Release();
            }
        }
        return *this;
    }

    ref_count_ptr& operator=(const ref_count_ptr& in_ptr) {
        return *this = in_ptr.reference_;
    }

    template<typename CopyReferencedType>
    ref_count_ptr& operator=(const ref_count_ptr<CopyReferencedType>& in_ptr) {
        return *this = in_ptr.GetReference();
    }

    ref_count_ptr& operator=(ref_count_ptr&& in_ptr) noexcept {
        if (this != &in_ptr) {
            ReferencedType* old_reference = reference_;
            reference_ = in_ptr.reference_;
            in_ptr.reference_ = nullptr;
            if (old_reference) {
                old_reference->Release();
            }
        }
        return *this;
    }

    template<typename MoveReferencedType>
    ref_count_ptr& operator=(ref_count_ptr<MoveReferencedType>&& in_ptr) {
        // InPtr is a different type (or we would have called the other operator), so we need not test &InPtr != this
        ReferencedType* old_reference = reference_;
        reference_ = in_ptr.reference_;
        in_ptr.reference_ = nullptr;
        if (old_reference) {
            old_reference->Release();
        }
        return *this;
    }

    ReferencedType* operator->() const {
        return reference_;
    }

    operator reference_type() const {
        return reference_;
    }

    ReferencedType** get_init_reference() {
        *this = nullptr;
        return &reference_;
    }

    ReferencedType* get_reference() const {
        return reference_;
    }

    friend bool is_valid_ref(const ref_count_ptr& in_reference) {
        return in_reference.reference_ != nullptr;
    }

    bool is_valid() const {
        return reference_ != nullptr;
    }

    void safe_release() {
        *this = nullptr;
    }

    unsigned int get_ref_count() {
        unsigned int result = 0;
        if (reference_) {
            result = reference_->GetRefCount();
            assert(result > 0);
            // you should never have a zero ref count if there is a live ref counted pointer (*this is live)
        }
        return result;
    }

    void swap(ref_count_ptr& in_ptr) noexcept // this does not change the reference count, and so is faster
    {
        ReferencedType* old_reference = reference_;
        reference_ = in_ptr.reference_;
        in_ptr.reference_ = old_reference;
    }

    // void Serialize(FArchive& Ar)
    // {
    // 	reference_type PtrReference = Reference;
    // 	Ar << PtrReference;
    // 	if(Ar.IsLoading())
    // 	{
    // 		*this = PtrReference;
    // 	}
    // }

private:
    ReferencedType* reference_;

    template<typename OtherType>
    friend class ref_count_ptr;

public:
    bool operator==(const ref_count_ptr& b) const {
        return get_reference() == b.get_reference();
    }

    bool operator==(ReferencedType* b) const {
        return get_reference() == b;
    }
};