#pragma once
#include <memory>
#include <vector>
#include <bits/ranges_algo.h>

class memory_pool {
public:
    memory_pool(const size_t chunk_size, const size_t chunks_per_block)
        : chunk_size_(chunk_size), chunks_per_block_(chunks_per_block), current_block_(nullptr), free_list_(nullptr) {
        allocate_block();
    }
    memory_pool(memory_pool&& other) noexcept
        : chunk_size_(other.chunk_size_), chunks_per_block_(other.chunks_per_block_), blocks_(std::move(other.blocks_)),
          current_block_(other.current_block_), free_list_(other.free_list_) {
        other.current_block_ = nullptr;
        other.free_list_ = nullptr;
    }

    ~memory_pool() {
        for (void* block : blocks_) {
            delete[] static_cast<char*>(block);
        }
    }

    auto allocate() -> void* {
        if (!free_list_) {
            allocate_block();
        }

        std::lock_guard lock(mutex_);
        void* chunk = free_list_;
        free_list_ = *static_cast<void**>(free_list_);
        return chunk;
    }

    void deallocate(void* chunk) {
        std::lock_guard lock(mutex_);
        *static_cast<void**>(chunk) = free_list_;
        free_list_ = chunk;
    }

private:
    void allocate_block() {
        const size_t block_size = chunk_size_ * chunks_per_block_;
        const auto new_block = new char[block_size];
        {
            std::lock_guard lock(mutex_);
            blocks_.push_back(new_block);
            current_block_ = new_block;
        }

        for (size_t i = 0; i < chunks_per_block_; ++i) {
            void* chunk = new_block + i * chunk_size_;
            deallocate(chunk);
        }
    }

    size_t chunk_size_;
    size_t chunks_per_block_;
    std::vector<void*> blocks_;
    void* current_block_;
    void* free_list_;
    std::mutex mutex_;
};

template<class T>
class obj_mempool {
public:
    template<typename ...Args>
    static auto construct(Args&&... args) -> T* {
        auto obj = alloc();
        new (obj) T(std::forward<Args>(args)...);
        return obj;
    }
    static auto construct() -> T* {
        auto obj = alloc();
        new (obj) T();
        return obj;
    }
    static void free(T* p) {
        p->~T();
        deallocate(p);
    }
    static void free_all() {
        for (auto obj : objs_) {
            obj->~T();
            pool_.deallocate(obj);
        }
        objs_.clear();
    }
    static auto objs() -> const std::vector<T*>& {
        return objs_;
    }
    static auto has_obj(T* p) -> bool {
        return std::find(objs_.begin(), objs_.end(), p) != objs_.end();
    }
    static auto safe_free(T* p) -> bool {
        if (has_obj(p)) {
            free(p);
            return true;
        }
        return false;
    }

    static auto alloc() -> T* {
        T* p = static_cast<T*>(pool_.allocate());
        objs_.push_back(p);
        return p;
    }
    static void deallocate(T* p) {
        pool_.deallocate(p);
        auto e = std::ranges::remove(objs_, p);
        objs_.erase(e.begin(), e.end());
    }
private:
    inline static auto pool_ = memory_pool(sizeof(T), 64);
    inline static std::vector<T*> objs_;
};

template<class T>
class pool_obj {
public:
    auto operator new(size_t size) -> void* {
        return obj_mempool<T>::alloc();
    }
    void operator delete(void* p) {
        obj_mempool<T>::deallocate(static_cast<T*>(p));
    }
};