#include <Arduino.h>
#include <unity.h>
#include <avr/sleep.h>
#include "avr-fast-multiply.h"
#include "lambda_timer.hpp"
#include "unity_print_timers.hpp"

template <typename Ta>
static void assert_fast_multiply(Ta a, uint16_t b)
{
    char szMsg[128];
    snprintf(szMsg, sizeof(szMsg), "(a=%lu, b=%u)", (uint32_t)a, b);
    TEST_ASSERT_EQUAL_UINT32_MESSAGE((uint32_t)a * b, fast_multiply(a, b), szMsg);
}

static void test_fast_multiply16x16(void)
{
    assert_fast_multiply<uint16_t>(0, 0);
    assert_fast_multiply<uint16_t>(0, 1);
    assert_fast_multiply<uint16_t>(1, 0);
    assert_fast_multiply<uint16_t>(1, 1);

    assert_fast_multiply<uint16_t>((uint16_t)UINT8_MAX-1U, (uint16_t)UINT8_MAX-1U);
    assert_fast_multiply<uint16_t>((uint16_t)UINT8_MAX, (uint16_t)UINT8_MAX-1U);
    assert_fast_multiply<uint16_t>((uint16_t)UINT8_MAX-1U, (uint16_t)UINT8_MAX);
    assert_fast_multiply<uint16_t>((uint16_t)UINT8_MAX, (uint16_t)UINT8_MAX);
    assert_fast_multiply<uint16_t>((uint16_t)UINT8_MAX+1U, (uint16_t)UINT8_MAX+1U);
    assert_fast_multiply<uint16_t>((uint16_t)UINT8_MAX, (uint16_t)UINT8_MAX+1U);
    assert_fast_multiply<uint16_t>((uint16_t)UINT8_MAX+1U, (uint16_t)UINT8_MAX);

    assert_fast_multiply<uint16_t>(UINT16_MAX-1U, UINT16_MAX);
    assert_fast_multiply<uint16_t>(UINT16_MAX, UINT16_MAX);

    for (uint32_t a = 0; a < (uint32_t)UINT16_MAX; a+=443U)
    {
        for (uint32_t b = 0; b < (uint32_t)UINT16_MAX; b+=499U)
        {
            assert_fast_multiply<uint16_t>((uint16_t)a, (uint16_t)b);
        }    
    }    

}

static void test_fast_multiply32x16(void)
{
    assert_fast_multiply<uint32_t>(0, 0);
    assert_fast_multiply<uint32_t>(0, 1);
    assert_fast_multiply<uint32_t>(1, 0);
    assert_fast_multiply<uint32_t>(1, 1);
    assert_fast_multiply<uint32_t>(11, 7);

    assert_fast_multiply<uint32_t>((uint32_t)UINT8_MAX-1U, (uint16_t)UINT8_MAX-1U);
    assert_fast_multiply<uint32_t>((uint32_t)UINT8_MAX, (uint16_t)UINT8_MAX-1U);
    assert_fast_multiply<uint32_t>((uint32_t)UINT8_MAX-1U, (uint16_t)UINT8_MAX);
    assert_fast_multiply<uint32_t>((uint32_t)UINT8_MAX, (uint16_t)UINT8_MAX);
    assert_fast_multiply<uint32_t>((uint32_t)UINT8_MAX+1U, (uint16_t)UINT8_MAX+1U);
    assert_fast_multiply<uint32_t>((uint32_t)UINT8_MAX, (uint16_t)UINT8_MAX+1U);
    assert_fast_multiply<uint32_t>((uint32_t)UINT8_MAX+1U, (uint16_t)UINT8_MAX);

    assert_fast_multiply<uint32_t>((uint32_t)UINT16_MAX-1U, UINT16_MAX);
    assert_fast_multiply<uint32_t>((uint32_t)UINT16_MAX, UINT16_MAX);
    assert_fast_multiply<uint32_t>((uint32_t)UINT16_MAX+1U, UINT16_MAX);

    assert_fast_multiply<uint32_t>(UINT16_MAX/2, 3);
    assert_fast_multiply<uint32_t>(UINT16_MAX*3U, UINT16_MAX);

    assert_fast_multiply<uint32_t>(UINT32_MAX/4UL, 3U);
    assert_fast_multiply<uint32_t>(UINT32_MAX/3UL, 2U);
    assert_fast_multiply<uint32_t>(UINT32_MAX/2UL, 2U);
    assert_fast_multiply<uint32_t>(UINT32_MAX-1UL, 1U);
}

// The macros below are used to ensure the performance test functions
// have the same number of operations
//
// So we are comparing apples-to-apples

#define PERF_NATIVE_MUL(a, b) (uint32_t)(a) * (b);
#define PERF_OPTIMIZED_MUL(a, b) fast_multiply((a), (b));

static constexpr uint16_t ARRAY_SIZE = 512;

static uint16_t u16_1[ARRAY_SIZE];
static uint16_t u16_2[ARRAY_SIZE];

#define PERF_TEST_FUN_BODY_16_16(mul_op) \
    for (uint16_t loop=0; loop<ARRAY_SIZE; ++loop) { \
        checkSum += mul_op(u16_1[loop], u16_2[loop]); \
    } 

static void test_fast_multiply_perf_16x16(void)
{
    // Randomness here is all about ensuring that the compiler doesn't optimize away the shifts
    // (which it won't do in normal operaton when the multipliers unknown at compile time.)
    randomSeed(rand());
    for (uint16_t loop=0; loop<ARRAY_SIZE; ++loop) { \
        u16_1[loop] = (uint16_t)random(UINT16_MAX/2U, UINT16_MAX); \
        u16_2[loop] = (uint16_t)random(4, UINT16_MAX/2U); \
    }
    const uint16_t iters = 500;
    const uint16_t inMin = 0;
    const uint16_t inMax = 1;
    const uint16_t step = 1;

    auto nativeTest =    [] (uint16_t, uint32_t &checkSum) { PERF_TEST_FUN_BODY_16_16(PERF_NATIVE_MUL); };
    auto optimizedTest = [] (uint16_t, uint32_t &checkSum) { PERF_TEST_FUN_BODY_16_16(PERF_OPTIMIZED_MUL); };
    auto comparison = compare_executiontime<uint16_t, uint32_t>(iters, inMin, inMax, step, nativeTest, optimizedTest);

    MESSAGE_TIMERS(comparison.timeA.timer, comparison.timeB.timer);
    TEST_ASSERT_EQUAL_UINT32(comparison.timeA.result, comparison.timeB.result);

#if defined(__AVR__) // We only expect a speed improvement on AVR
    TEST_ASSERT_LESS_THAN(comparison.timeA.timer.duration_micros(), comparison.timeB.timer.duration_micros());
#endif
}

static uint32_t u32_1[ARRAY_SIZE];

#define PERF_TEST_FUN_BODY_32_16(mul_op) \
    for (uint16_t loop=0; loop<ARRAY_SIZE; ++loop) { \
        checkSum += mul_op(u32_1[loop], u16_2[loop]); \
    } 

static void test_fast_multiply_perf_32x16(void)
{
    // Randomness here is all about ensuring that the compiler doesn't optimize away the shifts
    // (which it won't do in normal operaton when the multipliers unknown at compile time.)
    randomSeed(rand());
    for (uint16_t loop=0; loop<ARRAY_SIZE; ++loop) { \
        u16_1[loop] = (uint16_t)random(UINT16_MAX/2U, UINT16_MAX); \
        u32_1[loop] = (uint32_t)random(4, UINT16_MAX*2UL); \
    }
    const uint16_t iters = 500;
    const uint16_t inMin = 0;
    const uint16_t inMax = 1;
    const uint16_t step = 1;

    auto nativeTest =    [] (uint16_t, uint32_t &checkSum) { PERF_TEST_FUN_BODY_32_16(PERF_NATIVE_MUL); };
    auto optimizedTest = [] (uint16_t, uint32_t &checkSum) { PERF_TEST_FUN_BODY_32_16(PERF_OPTIMIZED_MUL); };
    auto comparison = compare_executiontime<uint16_t, uint32_t>(iters, inMin, inMax, step, nativeTest, optimizedTest);

    MESSAGE_TIMERS(comparison.timeA.timer, comparison.timeB.timer);
    TEST_ASSERT_EQUAL_UINT32(comparison.timeA.result, comparison.timeB.result);

#if defined(__AVR__) // We only expect a speed improvement on AVR
    TEST_ASSERT_LESS_THAN(comparison.timeA.timer.duration_micros(), comparison.timeB.timer.duration_micros());
#endif
}

void setup()
{
    pinMode(LED_BUILTIN, OUTPUT);

    // NOTE!!! Wait for >2 secs
    // if board doesn't support software reset via Serial.DTR/RTS
#if !defined(SIMULATOR)
    delay(2000);
#endif

    UNITY_BEGIN(); 
    RUN_TEST(test_fast_multiply16x16);
    RUN_TEST(test_fast_multiply32x16);
    RUN_TEST(test_fast_multiply_perf_16x16);
    RUN_TEST(test_fast_multiply_perf_32x16);
    UNITY_END(); 

    // Tell SimAVR we are done
#if defined(SIMULATOR)    
    cli();
    sleep_enable();
    sleep_cpu();
#endif    
}

void loop()
{
    // Blink to indicate end of test
    digitalWrite(LED_BUILTIN, HIGH);
    delay(250);
    digitalWrite(LED_BUILTIN, LOW);
    delay(250);
}