#include <string.h>
#include <stddef.h>
#include <stdlib.h>
#include "MQTTClient.h"
#include <ERa/ERaDefine.hpp>
#include <Utility/ERaUtility.hpp>

inline char* lwmqtt_strdup(const char* str) {
  if (str == nullptr) {
    return nullptr;
  }

  size_t length {0};
  char* copy = nullptr;

  length = strlen(str) + sizeof("");
  copy = (char*)malloc(length);
  if (copy == nullptr) {
    return nullptr;
  }
  memcpy(copy, str, length);

  return copy;
}

inline void lwmqtt_arduino_timer_set(void *ref, uint32_t timeout) {
  // cast timer reference
  auto t = (lwmqtt_arduino_timer_t *)ref;

  // set timeout
  t->timeout = timeout;

  // set start
  if (t->millis != nullptr) {
    t->start = t->millis();
  } else {
    t->start = ERaMillis();
  }
}

inline int32_t lwmqtt_arduino_timer_get(void *ref) {
  // cast timer reference
  auto t = (lwmqtt_arduino_timer_t *)ref;

  // get now
  uint32_t now;
  if (t->millis != nullptr) {
    now = t->millis();
  } else {
    now = ERaMillis();
  }

  // get difference (account for roll-overs)
  uint32_t diff;
  if (now < t->start) {
    diff = UINT32_MAX - t->start + now;
  } else {
    diff = now - t->start;
  }

  // return relative time
  if (diff > t->timeout) {
    return -(diff - t->timeout);
  } else {
    return t->timeout - diff;
  }
}

inline lwmqtt_err_t lwmqtt_arduino_network_read(void *ref, uint8_t *buffer, size_t len, size_t *read,
                                                uint32_t timeout) {
  // cast network reference
  auto n = (lwmqtt_arduino_network_t *)ref;

  // set timeout
  uint32_t start = ERaMillis();

  // reset counter
  *read = 0;

  // read until all bytes have been read or timeout has been reached
  while (len > 0 && (ERaMillis() - start < timeout)) {
    // read from connection
    int r = n->client->read(buffer, len);

    // handle read data
    if (r > 0) {
      buffer += r;
      *read += r;
      len -= r;
      continue;
    }

    // wait/unblock for some time (RTOS based boards may otherwise fail since
    // the wifi task cannot provide the data)
    mqtt_yield_fix();

    // otherwise check status
    if (!n->client->connected()) {
      return LWMQTT_NETWORK_FAILED_READ;
    }
  }

  // check counter
  if (*read == 0) {
    return LWMQTT_NETWORK_TIMEOUT;
  }

  return LWMQTT_SUCCESS;
}

inline lwmqtt_err_t lwmqtt_arduino_network_write(void *ref, uint8_t *buffer, size_t len, size_t *sent,
                                                 uint32_t /*timeout*/) {
  // cast network reference
  auto n = (lwmqtt_arduino_network_t *)ref;

  // reset counter
  *sent = 0;

  // write bytes
  size_t written = 0;
  size_t to_write = 0;
  while (len) {
    to_write = n->client->write(buffer, ERaMin(len, (size_t)1024));
    if (to_write == 0) {
      return LWMQTT_NETWORK_FAILED_WRITE;
    }
    len -= to_write;
    buffer += to_write;
    written += to_write;
  }

  // set counter
  *sent = written;

  return LWMQTT_SUCCESS;
}

static void MQTTClientHandler(lwmqtt_client_t * /*client*/, void *ref, lwmqtt_string_t topic,
                              lwmqtt_message_t message) {
  // get callback
  auto cb = (MQTTClientCallback *)ref;

  // null terminate topic
  char terminated_topic[topic.len + 1];
  memcpy(terminated_topic, topic.data, topic.len);
  terminated_topic[topic.len] = '\0';

  // null terminate payload if available
  if (message.payload != nullptr) {
    message.payload[message.payload_len] = '\0';
  }

  // call the advanced callback and return if available
  if (cb->advanced != nullptr) {
    cb->advanced(cb->client, terminated_topic, (char *)message.payload, (int)message.payload_len);
    return;
  }
#if MQTT_HAS_FUNCTIONAL
  if (cb->functionAdvanced != nullptr) {
    cb->functionAdvanced(cb->client, terminated_topic, (char *)message.payload, (int)message.payload_len);
    return;
  }
#endif

  // return if simple callback is not set
#if MQTT_HAS_FUNCTIONAL
  if (cb->simple == nullptr && cb->functionSimple == nullptr) {
    return;
  }
#else
  if (cb->simple == nullptr) {
    return;
  }
#endif

  // create topic string
  const char *str_topic = (const char *)terminated_topic;

  // create payload string
  const char *str_payload = (const char *)message.payload;

  // call simple callback
#if MQTT_HAS_FUNCTIONAL
  if (cb->functionSimple != nullptr) {
    cb->functionSimple(str_topic, str_payload);
  } else {
    cb->simple(str_topic, str_payload);
  }
#else
  cb->simple(str_topic, str_payload);
#endif
}

MQTTClient::MQTTClient(int readBufSize, int writeBufSize) {
  this->init(readBufSize, writeBufSize);
}

MQTTClient::~MQTTClient() {
  // free will
  this->clearWill();

  // free hostname
  if (this->hostname != nullptr) {
    free((void *)this->hostname);
  }

  // free buffers
  free(this->readBuf);
  free(this->writeBuf);
}

void MQTTClient::begin(Client &_client) {
  // set client
  this->netClient = &_client;

  // initialize client
  lwmqtt_init(&this->client, this->writeBuf, this->writeBufSize, this->readBuf, this->readBufSize);

  // set timers
  lwmqtt_set_timers(&this->client, &this->timer1, &this->timer2, lwmqtt_arduino_timer_set, lwmqtt_arduino_timer_get);

  // set network
  lwmqtt_set_network(&this->client, &this->network, lwmqtt_arduino_network_read, lwmqtt_arduino_network_write);

  // set callback
  lwmqtt_set_callback(&this->client, (void *)&this->callback, MQTTClientHandler);
}

void MQTTClient::init(int readBufSize, int writeBufSize) {
  // allocate buffers
  if ((readBufSize != 0) && (this->readBuf == nullptr)) {
    this->readBufSize = (size_t)readBufSize;
    this->readBuf = (uint8_t *)ERA_MALLOC((size_t)readBufSize + 1);
  }
  if ((writeBufSize != 0) && (this->writeBuf == nullptr)) {
    this->writeBufSize = (size_t)writeBufSize;
    this->writeBuf = (uint8_t *)ERA_MALLOC((size_t)writeBufSize);
  }
}

void MQTTClient::onMessage(MQTTClientCallbackSimple cb) {
  // set callback
  this->callback.client = this;
  this->callback.simple = cb;
  this->callback.advanced = nullptr;
#if MQTT_HAS_FUNCTIONAL
  this->callback.functionSimple = nullptr;
  this->callback.functionAdvanced = nullptr;
#endif
}

void MQTTClient::onMessageAdvanced(MQTTClientCallbackAdvanced cb) {
  // set callback
  this->callback.client = this;
  this->callback.simple = nullptr;
  this->callback.advanced = cb;
#if MQTT_HAS_FUNCTIONAL
  this->callback.functionSimple = nullptr;
  this->callback.functionAdvanced = nullptr;
#endif
}

#if MQTT_HAS_FUNCTIONAL
void MQTTClient::onMessage(MQTTClientCallbackSimpleFunction cb) {
  // set callback
  this->callback.client = this;
  this->callback.simple = nullptr;
  this->callback.functionSimple = cb;
  this->callback.advanced = nullptr;
  this->callback.functionAdvanced = nullptr;
}

void MQTTClient::onMessageAdvanced(MQTTClientCallbackAdvancedFunction cb) {
  // set callback
  this->callback.client = this;
  this->callback.simple = nullptr;
  this->callback.functionSimple = nullptr;
  this->callback.advanced = nullptr;
  this->callback.functionAdvanced = cb;
}
#endif

void MQTTClient::setClockSource(MQTTClientClockSource cb) {
  this->timer1.millis = cb;
  this->timer2.millis = cb;
}

void MQTTClient::setSkipACK(bool skip) {
  this->skipACK = skip;
}

void MQTTClient::setHost(IPAddress _address, int _port) {
  // set address and port
  this->address = _address;
  this->port = _port;
}

void MQTTClient::setHost(const char _hostname[], int _port) {
  // free hostname if set
  if (this->hostname != nullptr) {
    free((void *)this->hostname);
  }

  // set hostname and port
  this->hostname = lwmqtt_strdup(_hostname);
  this->port = _port;
}

void MQTTClient::setWill(const char topic[], const char payload[], bool retained, int qos) {
  // return if topic is missing
  if (topic == nullptr || strlen(topic) == 0) {
    return;
  }

  // clear existing will
  this->clearWill();

  // allocate will
  this->will = (lwmqtt_will_t *)ERA_MALLOC(sizeof(lwmqtt_will_t));
  memset(this->will, 0, sizeof(lwmqtt_will_t));

  // set topic
  this->will->topic = lwmqtt_string(lwmqtt_strdup(topic));

  // set payload if available
  if (payload != nullptr && strlen(payload) > 0) {
    this->will->payload = lwmqtt_string(lwmqtt_strdup(payload));
  }

  // set flags
  this->will->retained = retained;
  this->will->qos = (lwmqtt_qos_t)qos;
}

void MQTTClient::clearWill() {
  // return if not set
  if (this->will == nullptr) {
    return;
  }

  // free payload if set
  if (this->will->payload.len > 0) {
    free(this->will->payload.data);
  }

  // free topic if set
  if (this->will->topic.len > 0) {
    free(this->will->topic.data);
  }

  // free will
  free(this->will);
  this->will = nullptr;
}

void MQTTClient::setKeepAlive(int _keepAlive) { this->keepAlive = _keepAlive; }

void MQTTClient::setCleanSession(bool _cleanSession) { this->cleanSession = _cleanSession; }

void MQTTClient::setTimeout(int _timeout) { this->timeout = _timeout; }

void MQTTClient::dropOverflow(bool enabled) {
  // configure drop overflow
  lwmqtt_drop_overflow(&this->client, enabled, &this->_droppedMessages);
}

bool MQTTClient::connect(const char clientID[], const char username[], const char password[], bool skip) {
  // close left open connection if still connected
  if (!skip && this->connected()) {
    this->close();
  }

  // save client
  this->network.client = this->netClient;

  // connect to host
  if (!skip) {
    int ret;
    if (this->hostname != nullptr) {
      ret = this->netClient->connect(this->hostname, (uint16_t)this->port);
    } else {
      ret = this->netClient->connect(this->address, (uint16_t)this->port);
    }
    if (ret <= 0) {
      this->_lastError = LWMQTT_NETWORK_FAILED_CONNECT;
      return false;
    }
  }

  // prepare options
  lwmqtt_connect_options_t options = lwmqtt_default_connect_options;
  options.keep_alive = this->keepAlive;
  options.clean_session = this->cleanSession;
  options.client_id = lwmqtt_string(clientID);

  // set username and password if available
  if (username != nullptr) {
    options.username = lwmqtt_string(username);
  }
  if (password != nullptr) {
    options.password = lwmqtt_string(password);
  }

  // connect to broker
  this->_lastError = lwmqtt_connect(&this->client, &options, this->will, this->timeout);

  // copy return code
  this->_returnCode = options.return_code;

  // handle error
  if (this->_lastError != LWMQTT_SUCCESS) {
    // close connection
    this->close();

    return false;
  }

  // copy session present flag
  this->_sessionPresent = options.session_present;

  // set flag
  this->_connected = true;

  return true;
}

bool MQTTClient::publish(const char topic[], const char payload[], int length, bool retained, int qos) {
  // return immediately if not connected
  if (!this->connected()) {
    return false;
  }

  // prepare message
  lwmqtt_message_t message = lwmqtt_default_message;
  message.payload = (uint8_t *)payload;
  message.payload_len = (size_t)length;
  message.retained = retained;
  message.qos = lwmqtt_qos_t(qos);

  // prepare options
  lwmqtt_publish_options_t options = lwmqtt_default_publish_options;

  // set duplicate packet id if available
  if (this->nextDupPacketID > 0) {
    options.dup_id = &this->nextDupPacketID;
    this->nextDupPacketID = 0;
  }
  options.skip_ack = this->skipACK;

  // publish message
  this->_lastError = lwmqtt_publish(&this->client, &options, lwmqtt_string(topic), message, this->timeout);
  if (this->_lastError != LWMQTT_SUCCESS) {
    // close connection
    this->close();

    return false;
  }

  return true;
}

uint16_t MQTTClient::lastPacketID() {
  // get last packet id from client
  return this->client.last_packet_id;
}

void MQTTClient::prepareDuplicate(uint16_t packetID) {
  // set next duplicate packet id
  this->nextDupPacketID = packetID;
}

bool MQTTClient::subscribe(const char topic[], int qos) {
  // return immediately if not connected
  if (!this->connected()) {
    return false;
  }

  // subscribe to topic
  this->_lastError = lwmqtt_subscribe_one(&this->client, lwmqtt_string(topic), (lwmqtt_qos_t)qos, this->timeout);
  if (this->_lastError != LWMQTT_SUCCESS) {
    // close connection
    this->close();

    return false;
  }

  return true;
}

bool MQTTClient::unsubscribe(const char topic[]) {
  // return immediately if not connected
  if (!this->connected()) {
    return false;
  }

  // unsubscribe from topic
  this->_lastError = lwmqtt_unsubscribe_one(&this->client, lwmqtt_string(topic), this->timeout);
  if (this->_lastError != LWMQTT_SUCCESS) {
    // close connection
    this->close();

    return false;
  }

  return true;
}

bool MQTTClient::loop() {
  // return immediately if not connected
  if (!this->connected()) {
    return false;
  }

  // get available bytes on the network
  int available = this->netClient->available();

  // yield if data is available
  if (available > 0) {
    this->_lastError = lwmqtt_yield(&this->client, available, this->timeout);
    if (this->_lastError != LWMQTT_SUCCESS) {
      // close connection
      this->close();

      return false;
    }
  }

  // keep the connection alive
  this->_lastError = lwmqtt_keep_alive(&this->client, this->timeout);
  if (this->_lastError != LWMQTT_SUCCESS) {
    // close connection
    this->close();

    return false;
  }

  return true;
}

bool MQTTClient::connected() {
  // a client is connected if the network is connected, a client is available and
  // the connection has been properly initiated
  return this->_connected && (this->netClient != nullptr) && (this->netClient->connected() == 1);
}

bool MQTTClient::disconnect() {
  // return immediately if not connected anymore
  if (!this->connected()) {
    return false;
  }

  // cleanly disconnect
  this->_lastError = lwmqtt_disconnect(&this->client, this->timeout);

  // close
  this->close();

  return this->_lastError == LWMQTT_SUCCESS;
}

void MQTTClient::close() {
  // set flag
  this->_connected = false;

  // close network
  this->netClient->stop();
}
