/**
 * MIT License
 *
 * @brief Input-state model utilities, edge-detection helpers, and interop adapters.
 *
 * @file InputModel.h.h
 * @author Little Man Builds (Darren Osborne)
 * @date 2025-09-12
 * @copyright Copyright (c) 2025 Little Man Builds
 */

#pragma once

#include <bitset>
#include <cstddef>
#include <cstdint>
#include <type_traits>

namespace snapshot::input
{
    // ---- Index & time aliases ---- //

    /**
     * @brief Millisecond time type for input stamps.
     */
    using millis_u32 = std::uint32_t;

    /**
     * @brief Index helper that works with enums or integrals.
     * @tparam T Enum or integral type.
     * @param v Value to convert to index.
     * @return std::size_t Zero-based index.
     */
    template <class T>
    [[nodiscard]] constexpr std::size_t idx(T v) noexcept
    {
        return static_cast<std::size_t>(v);
    }

    // ---- Internal helpers (C++14 tag-dispatch, no if constexpr) ---- //

    /**
     * @brief Fast path: convert a small bitset (M ≤ 32) to a 64-bit mask via to_ulong().
     * @tparam M Bitset width.
     * @param b Source bitset (low M bits considered).
     * @return std::uint64_t 64-bit mask (bit i set iff b.test(i)).
     */
    template <std::size_t M>
    inline std::uint64_t bitset_mask64_impl_fast(const std::bitset<M> &b) noexcept
    {
        return static_cast<std::uint64_t>(b.to_ulong());
    }

    /**
     * @brief Slow path: convert an arbitrary bitset (M ≤ 64) to a 64-bit mask by iterating bits.
     * @tparam M Bitset width.
     * @param b Source bitset (low M bits considered).
     * @return std::uint64_t 64-bit mask (bit i set iff b.test(i)).
     */
    template <std::size_t M>
    inline std::uint64_t bitset_mask64_impl_slow(const std::bitset<M> &b) noexcept
    {
        std::uint64_t m = 0;
        for (std::size_t i = 0; i < M; ++i)
            if (b.test(i))
                m |= (std::uint64_t{1} << i);
        return m;
    }

    /**
     * @brief Dispatch to the fast mask conversion path (M ≤ 32).
     * @tparam M Bitset width.
     * @param b Source bitset.
     * @param std::true_type Tag selecting the fast path.
     * @return std::uint64_t 64-bit mask.
     */
    template <std::size_t M>
    inline std::uint64_t bitset_mask64_dispatch(const std::bitset<M> &b, std::true_type) noexcept
    {
        return bitset_mask64_impl_fast(b);
    }

    /**
     * @brief Dispatch to the slow mask conversion path (M > 32).
     * @tparam M Bitset width.
     * @param b Source bitset.
     * @param std::false_type Tag selecting the slow path.
     * @return std::uint64_t 64-bit mask.
     */
    template <std::size_t M>
    inline std::uint64_t bitset_mask64_dispatch(const std::bitset<M> &b, std::false_type) noexcept
    {
        return bitset_mask64_impl_slow(b);
    }

    // ---- State ---- //

    /**
     * @brief Snapshot of level-based inputs at a moment in time.
     * @tparam N Number of logical inputs (compile-time)
     * @details
     *   - Buttons are level-based (true == pressed/active).
     *   - stamp_ms uses milliseconds since boot (document the timebase).
     *   - Designed to be trivially small and fast to copy.
     */
    template <std::size_t N>
    struct State
    {
        /// @brief Logical button levels; true == pressed/active.
        std::bitset<N> buttons{};

        /// @brief Sample timestamp in milliseconds (ms since boot).
        millis_u32 stamp_ms{0};

        // ---- Convenience API ---- //

        /**
         * @brief Total number of logical inputs (compile-time constant).
         * @return std::size_t Count of inputs (N).
         */
        static constexpr std::size_t size() noexcept { return N; }

        /**
         * @brief Number of active bits (pressed buttons).
         * @return std::size_t Count of 1-bits.
         */
        [[nodiscard]] std::size_t count() const noexcept { return buttons.count(); }

        /**
         * @brief Whether any button is pressed.
         * @return true If any bit is 1.
         * @return false Otherwise.
         */
        [[nodiscard]] bool any() const noexcept { return buttons.any(); }

        /**
         * @brief Whether all buttons are pressed.
         * @return true If all bits are 1.
         * @return false Otherwise.
         */
        [[nodiscard]] bool all() const noexcept { return buttons.all(); }

        /**
         * @brief Whether no buttons are pressed.
         * @return true If all bits are 0.
         * @return false Otherwise.
         */
        [[nodiscard]] bool none() const noexcept { return buttons.none(); }

        /**
         * @brief Get current level for a logical id.
         * @tparam Id Enum or integral id type.
         * @param id Logical input id.
         * @return true If pressed/active.
         * @return false Otherwise.
         */
        template <class Id>
        [[nodiscard]] bool is_pressed(Id id) const noexcept
        {
            return buttons.test(idx(id));
        }

        /**
         * @brief Get inverted level (released) for a logical id.
         * @tparam Id Enum or integral id type.
         * @param id Logical input id.
         * @return true If released/inactive.
         * @return false Otherwise.
         */
        template <class Id>
        [[nodiscard]] bool is_released(Id id) const noexcept
        {
            return !buttons.test(idx(id));
        }

        /**
         * @brief Set/clear a logical button bit.
         * @tparam Id Enum or integral id type.
         * @param id Logical input id.
         * @param pressed New level to set.
         */
        template <class Id>
        void set_button(Id id, bool pressed) noexcept
        {
            buttons.set(idx(id), pressed);
        }

        /**
         * @brief Clear all bits (set to released).
         */
        void clear() noexcept { buttons.reset(); }

        // ---- Mask interop (SwitchBank-friendly) ---- //

        /**
         * @brief Convert to a 32-bit mask (low N bits). Requires N <= 32.
         * @tparam M Compile-time copy of N (defaults to N).
         * @return std::uint32_t Bit mask suitable for 32-bit “switch bank”.
         */
        template <std::size_t M = N>
        [[nodiscard]] std::uint32_t mask32() const noexcept
        {
            static_assert(M <= 32, "mask32() requires N <= 32");
            return static_cast<std::uint32_t>(buttons.to_ulong());
        }

        /**
         * @brief Convert to a 64-bit mask (low N bits). Requires N <= 64.
         * @tparam M Compile-time copy of N (defaults to N).
         * @return std::uint64_t Bit mask suitable for 64-bit “switch bank”.
         */
        template <std::size_t M = N>
        [[nodiscard]] std::uint64_t mask64() const noexcept
        {
            static_assert(M <= 64, "mask64() requires N <= 64");
            return bitset_mask64_dispatch(buttons, std::integral_constant<bool, (M <= 32)>{});
        }

        /**
         * @brief Assign from a 32-bit mask (low N bits). Requires N <= 32.
         * @tparam M Compile-time copy of N (defaults to N).
         * @param m 32-bit mask (bit i -> button i).
         */
        template <std::size_t M = N>
        void from_mask32(std::uint32_t m) noexcept
        {
            static_assert(M <= 32, "from_mask32() requires N <= 32");
            for (std::size_t i = 0; i < M; ++i)
                buttons.set(i, (m >> i) & 0x1u);
        }

        /**
         * @brief Assign from a 64-bit mask (low N bits). Requires N <= 64.
         * @tparam M Compile-time copy of N (defaults to N).
         * @param m 64-bit mask (bit i -> button i).
         */
        template <std::size_t M = N>
        void from_mask64(std::uint64_t m) noexcept
        {
            static_assert(M <= 64, "from_mask64() requires N <= 64");
            for (std::size_t i = 0; i < M; ++i)
                buttons.set(i, (m >> i) & 0x1ull);
        }
    };

    // ---- Snapshot helpers (pure, no I/O) ---- //

    /**
     * @brief Compute 0→1 transitions between two snapshots.
     * @tparam N Number of inputs.
     * @param prev Previous snapshot.
     * @param cur Current snapshot.
     * @return std::bitset<N> Mask of rising edges.
     */
    template <std::size_t N>
    [[nodiscard]] inline std::bitset<N>
    rising_edges(const State<N> &prev, const State<N> &cur) noexcept
    {
        return (~prev.buttons) & cur.buttons;
    }

    /**
     * @brief Compute 1→0 transitions between two snapshots.
     * @tparam N Number of inputs.
     * @param prev Previous snapshot.
     * @param cur Current snapshot.
     * @return std::bitset<N> Mask of falling edges.
     */
    template <std::size_t N>
    [[nodiscard]] inline std::bitset<N>
    falling_edges(const State<N> &prev, const State<N> &cur) noexcept
    {
        return prev.buttons & (~cur.buttons);
    }

    /**
     * @brief Compute any level change between two snapshots.
     * @tparam N Number of inputs.
     * @param prev Previous snapshot.
     * @param cur Current snapshot.
     * @return std::bitset<N> Mask of changed bits.
     */
    template <std::size_t N>
    [[nodiscard]] inline std::bitset<N>
    changed_edges(const State<N> &prev, const State<N> &cur) noexcept
    {
        return prev.buttons ^ cur.buttons;
    }

    // ---- Compact masks (telemetry / SwitchBank interop) ---- //

    /**
     * @brief Rising-edge mask as 32-bit value (N <= 32).
     *
     * @tparam N Number of inputs.
     * @tparam M Compile-time copy of N (defaults to N).
     * @param p Previous snapshot.
     * @param c Current snapshot.
     * @return std::uint32_t Rising edges as bit mask.
     */
    template <std::size_t N, std::size_t M = N>
    [[nodiscard]] inline std::uint32_t
    rising_mask32(const State<N> &p, const State<N> &c) noexcept
    {
        static_assert(M <= 32, "rising_mask32() requires N <= 32");
        return ((~p.buttons) & c.buttons).to_ulong();
    }

    /**
     * @brief Falling-edge mask as 32-bit value (N <= 32).
     */
    template <std::size_t N, std::size_t M = N>
    [[nodiscard]] inline std::uint32_t
    falling_mask32(const State<N> &p, const State<N> &c) noexcept
    {
        static_assert(M <= 32, "falling_mask32() requires N <= 32");
        return (p.buttons & (~c.buttons)).to_ulong();
    }

    /**
     * @brief Changed-bit mask as 32-bit value (N <= 32).
     */
    template <std::size_t N, std::size_t M = N>
    [[nodiscard]] inline std::uint32_t
    changed_mask32(const State<N> &p, const State<N> &c) noexcept
    {
        static_assert(M <= 32, "changed_mask32() requires N <= 32");
        return (p.buttons ^ c.buttons).to_ulong();
    }

    /**
     * @brief Rising-edge mask as 64-bit value (N <= 64).
     */
    template <std::size_t N, std::size_t M = N>
    [[nodiscard]] inline std::uint64_t
    rising_mask64(const State<N> &p, const State<N> &c) noexcept
    {
        static_assert(M <= 64, "rising_mask64() requires N <= 64");
        const auto b = (~p.buttons) & c.buttons;
        return bitset_mask64_dispatch(b, std::integral_constant<bool, (M <= 32)>{});
    }

    /**
     * @brief Falling-edge mask as 64-bit value (N <= 64).
     */
    template <std::size_t N, std::size_t M = N>
    [[nodiscard]] inline std::uint64_t
    falling_mask64(const State<N> &p, const State<N> &c) noexcept
    {
        static_assert(M <= 64, "falling_mask64() requires N <= 64");
        const auto b = p.buttons & (~c.buttons);
        return bitset_mask64_dispatch(b, std::integral_constant<bool, (M <= 32)>{});
    }

    /**
     * @brief Changed-bit mask as 64-bit value (N <= 64).
     */
    template <std::size_t N, std::size_t M = N>
    [[nodiscard]] inline std::uint64_t
    changed_mask64(const State<N> &p, const State<N> &c) noexcept
    {
        static_assert(M <= 64, "changed_mask64() requires N <= 64");
        const auto b = p.buttons ^ c.buttons;
        return bitset_mask64_dispatch(b, std::integral_constant<bool, (M <= 32)>{});
    }

    // ---- Edge iteration ---- //

    /**
     * @brief Visit each bit that changed between prev and cur.
     * @tparam N Number of logical inputs.
     * @tparam F Callable: void(std::size_t index, bool pressed, std::uint32_t stamp_ms)
     * @param prev Previous snapshot.
     * @param cur Current snapshot.
     * @param on_edge Invoked for each changed bit.
     */
    template <std::size_t N, class F>
    inline void for_each_edge(const State<N> &prev,
                              const State<N> &cur,
                              F &&on_edge) noexcept
    {
        const auto changed = changed_edges(prev, cur);
        if (!changed.any())
            return;

        for (std::size_t i = 0; i < N; ++i)
        {
            if (changed.test(i))
            {
                const bool pressed = cur.buttons.test(i);
                on_edge(i, pressed, cur.stamp_ms);
            }
        }
    }

    // ---- Adapters: Universal_Button & SwitchBank ---- //

    /**
     * @brief Assign from any “bits-like” object that supports test(i)->bool.
     * @tparam N Number of logical inputs.
     * @tparam BitsLike Any type exposing bool test(std::size_t) const.
     * @param s Destination state to modify.
     * @param bits Source bits-like object (e.g., Universal_Button bitset).
     * @param stamp_ms Timestamp in milliseconds.
     */
    template <std::size_t N, class BitsLike>
    inline void assign_from_bits(State<N> &s,
                                 const BitsLike &bits,
                                 millis_u32 stamp_ms) noexcept
    {
        for (std::size_t i = 0; i < N; ++i)
            s.buttons.set(i, static_cast<bool>(bits.test(i)));
        s.stamp_ms = stamp_ms;
    }

    /**
     * @brief Build a std::bitset<N> from a 32-bit mask.
     * @tparam N Number of logical inputs.
     * @tparam M Compile-time copy of N (defaults to N).
     * @param mask 32-bit bit mask (bit i -> index i).
     * @return std::bitset<N> Bitset with low-N bits populated.
     */
    template <std::size_t N, std::size_t M = N>
    [[nodiscard]] inline std::bitset<N>
    to_bitset(std::uint32_t mask) noexcept
    {
        static_assert(M <= 32, "to_bitset(uint32_t) requires N <= 32");
        std::bitset<N> b;
        for (std::size_t i = 0; i < N; ++i)
            b.set(i, (mask >> i) & 0x1u);
        return b;
    }

    /**
     * @brief Build a std::bitset<N> from a 64-bit mask.
     * @tparam N Number of logical inputs.
     * @tparam M Compile-time copy of N (defaults to N).
     * @param mask 64-bit bit mask (bit i -> index i).
     * @return std::bitset<N> Bitset with low-N bits populated.
     */
    template <std::size_t N, std::size_t M = N>
    [[nodiscard]] inline std::bitset<N>
    to_bitset(std::uint64_t mask) noexcept
    {
        static_assert(M <= 64, "to_bitset(uint64_t) requires N <= 64");
        std::bitset<N> b;
        for (std::size_t i = 0; i < N; ++i)
            b.set(i, (mask >> i) & 0x1ull);
        return b;
    }

    /**
     * @brief Convert std::bitset<N> to a 32-bit mask (N <= 32).
     * @tparam N Number of logical inputs.
     * @tparam M Compile-time copy of N (defaults to N).
     * @param b Source bitset.
     * @return std::uint32_t 32-bit mask.
     */
    template <std::size_t N, std::size_t M = N>
    [[nodiscard]] inline std::uint32_t
    to_mask32(const std::bitset<N> &b) noexcept
    {
        static_assert(M <= 32, "to_mask32() requires N <= 32");
        return static_cast<std::uint32_t>(b.to_ulong());
    }

    /**
     * @brief Convert std::bitset<N> to a 64-bit mask (N <= 64).
     * @tparam N Number of logical inputs.
     * @tparam M Compile-time copy of N (defaults to N).
     * @param b Source bitset.
     * @return std::uint64_t 64-bit mask.
     */
    template <std::size_t N, std::size_t M = N>
    [[nodiscard]] inline std::uint64_t
    to_mask64(const std::bitset<N> &b) noexcept
    {
        static_assert(M <= 64, "to_mask64() requires N <= 64");
        return bitset_mask64_dispatch(b, std::integral_constant<bool, (M <= 32)>{});
    }

} ///< Namespace snapshot::input.