/*
 * Copyright 2025 Bloomberg Finance LP
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef INCLUDED_BUILDBOXCOMMON_FUTUREGROUP_H
#define INCLUDED_BUILDBOXCOMMON_FUTUREGROUP_H

#include <atomic>
#include <functional>
#include <future>
#include <stdexcept>
#include <ThreadPool.h>
#include <tuple>
#include <utility>
#include <vector>

#include <buildboxcommon_exception.h>

/*
 * RAII class to manage a group of futures.
 * All owned futures are waited on when the FutureGroup is destroyed.
 */
template <typename R> class FutureGroup {
  public:
    explicit FutureGroup(ThreadPool *threadPool) : d_threadPool(threadPool) {}

    // Delete all other constructors and assignment operators
    FutureGroup() = delete;
    FutureGroup(const FutureGroup &) = delete;
    FutureGroup &operator=(const FutureGroup &) = delete;
    FutureGroup(FutureGroup &&) = delete;
    FutureGroup &operator=(FutureGroup &&) = delete;

    ~FutureGroup()
    {
        d_stopped.store(true);
        // Wait for all futures to complete
        for (auto &future : d_futures) {
            if (future.valid()) {
                future.wait();
            }
        }
    }
    template <typename F, typename... Args>
    std::shared_future<R> add(F &&func, Args &&...args)
    {
        using Ret = std::invoke_result_t<F, Args...>;
        static_assert(std::is_same_v<Ret, R>,
                      "All tasks must return the same type R");

        auto wrapped =
            [stop = std::ref(d_stopped), func = std::forward<F>(func),
             args = std::make_tuple(std::forward<Args>(args)...)]() mutable {
                if (stop.get().load()) {
                    throw std::runtime_error("FutureGroup has been stopped");
                }
                return std::apply(std::move(func), std::move(args));
            };

        auto future =
            d_threadPool == nullptr
                ? std::async(std::launch::deferred, std::move(wrapped)).share()
                : d_threadPool->enqueue(std::move(wrapped)).share();
        d_futures.push_back(future);
        return future;
    }

  private:
    ThreadPool *d_threadPool;
    std::atomic_bool d_stopped = false;
    std::vector<std::shared_future<R>> d_futures;
};

#endif
