#pragma once

/** @file
 * @brief Optimized multiplication *for AVR-GCC only*. See @ref group-opt-mul
*/

#include <stdint.h>

/// @cond
#define ADC_ZERO(reg) \
    "clr __zero_reg__ \n\t" \
    "adc  " #reg ", __zero_reg__ \n\t"

#define ADD_MUL_RESULT(low_reg, high_reg) \
    "add  " #low_reg ", r0 \n\t" \
    "adc  " #high_reg ", r1 \n\t"

#define ASM_U16xU16_32() \
    "mul  %A[a], %A[b] \n\t" \
    "movw %A[result], r0 \n\t" \
    \
    "mul  %B[a], %B[b] \n\t" \
    "movw %C[result], r0 \n\t" \
    \
    "mul  %B[b], %A[a] \n\t" \
    ADD_MUL_RESULT(%B[result], %C[result]) \
    ADC_ZERO(%D[result]) \
    \
    "mul  %B[a], %A[b] \n\t" \
    ADD_MUL_RESULT(%B[result], %C[result]) \
    ADC_ZERO(%D[result])


/// @endcond

/// @defgroup group-opt-mul Optimised multiplication
///
/// @brief Widening multiply for unsigned 16-bit integers
///
/// As of AVR-GCC 15.2.0, the compiler does not implement a widening multiply for 16-bit operands. I.e. u16 x u16 => u32.
/// The built in u16 multiplier does not widen (I.e. u16 x u16 => **u16**) and therefore doesn't handle overflow. To handle
/// overflow, we have to cast one of the operands to uint32_t - but this forces 32-bit multiplication. I.e. u32 x u32 => u32.
/// This is very inefficient.
///
/// Usage:
/// @code
///      uint16_t a, b;
///      uint32_t result = fast_multiply(a, b);
/// @endcode
///
/// @note Code is usable on all architectures, but the optimization only applies to AVR-GCC.
/// Other compilers will see a standard multply operation.
/// @{
static inline uint32_t fast_multiply(uint16_t a, uint16_t b)
{
#if defined(__AVR__)
    uint32_t result = 0U;

    asm (
        ASM_U16xU16_32()

        : [result] "=&w" (result)
        : [a] "d" (a), [b] "d" (b)
        : "r0", "r1"
    );
    
    return result;
#else
    return (uint32_t)a * b;
#endif
}

static inline uint32_t fast_multiply(uint32_t a, uint16_t b)
{
#if defined(__AVR__)
   uint32_t result;
    
    asm (     
        // a_low * b
        ASM_U16xU16_32()
		
        // a_high * b
            // a_high_low * b_low
            "mul  %C[a], %A[b] \n\t"
            ADD_MUL_RESULT(%C[result], %D[result])
            // a_high_high * b_high
            // removed - goes beyond 32 bits
            // a_high_high * b_low
            "mul  %D[a], %A[b] \n\t"
            "add  %D[result], r0 \n\t"
            // a_high_low * b_high
            "mul  %C[a], %B[b] \n\t"
            "add  %D[result], r0 \n\t"
		
        : [result] "=&w" (result)
        : [a] "d" (a), [b] "d" (b)
        : "r0", "r1"
    );
    
    return result;
#else
    return a * b;
#endif
}

static inline uint32_t fast_multiply(uint16_t a, uint32_t b)
{
    return fast_multiply(b, a);
}
/// @}
