#include "BleOtaUploader.h"
#include "BleOtaHeadCodes.h"
#include "ArduinoBleOTA.h"

namespace
{
#pragma pack(push, 1)
struct BeginResponse
{
    uint8_t head;
    uint32_t attributeSize;
    uint32_t bufferSize;
};
#pragma pack(pop)
}

BleOtaUploader::BleOtaUploader() :
    crc(),
    storage(nullptr),
#ifndef BLE_OTA_NO_BUFFER
    buffer(),
    withBuffer(true),
#endif
    enabled(false),
    uploading(false),
    installing(false),
    firmwareLength()
{}

void BleOtaUploader::begin(OTAStorage& storage)
{
    this->storage = &storage;
}

void BleOtaUploader::pull()
{
    if (installing)
        handleInstall();
}

void BleOtaUploader::setEnabling(bool enabling)
{
    enabled = enabling;
}

void BleOtaUploader::onData(const uint8_t* data, size_t length)
{
    if (installing)
        return;

    if (length == 0)
    {
        handleError(INCORRECT_FORMAT);
        return;
    }

    switch (data[0])
    {
    case BEGIN:
        handleBegin(data + 1, length - 1);
        break;
    case PACKAGE:
        handlePackage(data + 1, length - 1);
        break;
    case END:
        handleEnd(data + 1, length - 1);
        break;
    case SET_PIN_CODE:
        handleSetPinCode(data + 1, length - 1);
        break;
    case REMOVE_PIN_CODE:
        handleRemovePinCode(data + 1, length - 1);
        break;
    default:
        handleError(INCORRECT_FORMAT);
        break;
    }
}

void BleOtaUploader::handleBegin(const uint8_t* data, size_t length)
{
    if (uploading)
        terminateUpload();

    if (not enabled)
    {
        send(UPLOAD_DISABLED);
        return;
    }

    if (length != sizeof(uint32_t))
    {
        handleError(INCORRECT_FORMAT);
        return;
    }
    memcpy(&firmwareLength, data, length);

    if (storage == nullptr or not storage->open(firmwareLength))
    {
        firmwareLength = 0;
        handleError(INTERNAL_STORAGE_ERROR);
        return;
    }

    if (storage->maxSize() and firmwareLength > storage->maxSize())
    {
        terminateUpload();
        handleError(INCORRECT_FIRMWARE_SIZE);
        return;
    }

    uploading = true;
    crc.restart();

    #ifndef BLE_OTA_NO_BUFFER
    buffer.clear();
    uint32_t bufferSize = withBuffer ? BLE_OTA_BUFFER_SIZE : 0;
    #else
    uint32_t bufferSize = 0;
    #endif
    BeginResponse resp{OK, BLE_OTA_ATTRIBUTE_SIZE, bufferSize};
    send((const uint8_t*)(&resp), sizeof(BeginResponse));

    ArduinoBleOTA.uploadCallbacks->onBegin(firmwareLength);
}

void BleOtaUploader::handlePackage(const uint8_t* data, size_t length)
{
    if (not uploading)
        return;

    #ifndef BLE_OTA_NO_BUFFER
    const bool sendResponse = not withBuffer or buffer.size() + length > BLE_OTA_BUFFER_SIZE;
    #else
    const bool sendResponse = true;
    #endif

    crc.add(data, length);
    if (crc.count() > firmwareLength)
    {
        terminateUpload();
        if (sendResponse) handleError(INCORRECT_FIRMWARE_SIZE);
        return;
    }

    #ifndef BLE_OTA_NO_BUFFER
    if (sendResponse)
    {
        flushBuffer();
    }
    #endif

    fillData(data, length);
    if (sendResponse) send(OK);
}

void BleOtaUploader::handleEnd(const uint8_t* data, size_t length)
{
    if (not uploading)
    {
        handleError(NOK);
        return;
    }
    if (crc.count() != firmwareLength)
    {
        terminateUpload();
        handleError(INCORRECT_FIRMWARE_SIZE);
        return;
    }
    if (length != sizeof(uint32_t))
    {
        handleError(INCORRECT_FORMAT);
        return;
    }
    uint32_t firmwareCrc;
    memcpy(&firmwareCrc, data, length);

    if (crc.calc() != firmwareCrc)
    {
        terminateUpload();
        handleError(CHECKSUM_ERROR);
        return;
    }

    #ifndef BLE_OTA_NO_BUFFER
    flushBuffer();
    #endif

    send(OK);

    ArduinoBleOTA.uploadCallbacks->onEnd();
    installing = true;
}

void BleOtaUploader::handleSetPinCode(const uint8_t* data, size_t length)
{
    if (uploading)
    {
        handleError(NOK);
        return;
    }
    if (length != sizeof(uint32_t))
    {
        handleError(INCORRECT_FORMAT);
        return;
    }

    uint32_t pinCode;
    memcpy(&pinCode, data, length);
    send(ArduinoBleOTA.securityCallbacks->setPinCode(pinCode) ? OK : NOK);
}

void BleOtaUploader::handleRemovePinCode(const uint8_t* data, size_t length)
{
    if (uploading)
    {
        handleError(NOK);
        return;
    }
    if (length)
    {
        handleError(INCORRECT_FORMAT);
        return;
    }

    send(ArduinoBleOTA.securityCallbacks->removePinCode() ? OK : NOK);
}

void BleOtaUploader::handleInstall()
{
    delay(250);
    storage->close();
    delay(250);
    storage->apply();
    while (true);
}

void BleOtaUploader::handleError(uint8_t errorCode)
{
    send(errorCode);
    ArduinoBleOTA.uploadCallbacks->onError(errorCode);
}

void BleOtaUploader::send(uint8_t head)
{
    send(&head, 1);
}

void BleOtaUploader::send(const uint8_t* data, size_t length)
{
    ArduinoBleOTA.send(data, length);
}

void BleOtaUploader::terminateUpload()
{
    storage->clear();
    storage->close();
    uploading = false;
    firmwareLength = 0;

    #ifndef BLE_OTA_NO_BUFFER
    withBuffer = false;
    #endif
}

void BleOtaUploader::fillData(const uint8_t* data, size_t length)
{
    for (size_t i = 0; i < length; i++)
    {
        #ifndef BLE_OTA_NO_BUFFER
        withBuffer ? buffer.push(data[i]) : storage->write(data[i]);
        #else
        storage->write(data[i]);
        #endif
    }
}

#ifndef BLE_OTA_NO_BUFFER
void BleOtaUploader::flushBuffer()
{
    while (not buffer.isEmpty())
    {
        storage->write(buffer.shift());
    }
}
#endif

BleOtaUploader bleOtaUploader{};