#include "MPU6050_VibrationRMS.h"

#define MPU6050_I2C_ADDRESS 0x68
#define ACCEL_SENSITIVITY 16384.0
#define SAMPLING_RATE 100
#define DT (1.0 / SAMPLING_RATE)
#define NUM_SAMPLES 128
#define MA_WINDOW_SIZE 7
#define HISTORY_SIZE 6

MPU6050_VibrationRMS::MPU6050_VibrationRMS()
  : kalmanX(2,2,0.01), kalmanY(2,2,0.01), kalmanZ(2,2,0.01),
    fftIndex(0), ax_offset(0), ay_offset(0), az_offset(0),
    ax_prev(0), ay_prev(0), az_prev(0),
    ax_hpf(0), ay_hpf(0), az_hpf(0),
    alpha(0.98), prevVelocityRMS(0), historyIndex(0) {}

void MPU6050_VibrationRMS::begin() {
  Wire.begin();
  FFT = arduinoFFT();
  MPU6050_Init();
  calibrateAccelerometer();
}

float MPU6050_VibrationRMS::readVRMS() {
  int16_t ax, ay, az;
  readAccelerometer(ax, ay, az);

  float ax_mps2 = kalmanX.updateEstimate((ax - ax_offset) * (9.81 / ACCEL_SENSITIVITY));
  float ay_mps2 = kalmanY.updateEstimate((ay - ay_offset) * (9.81 / ACCEL_SENSITIVITY));
  float az_mps2 = kalmanZ.updateEstimate((az - az_offset) * (9.81 / ACCEL_SENSITIVITY));

  ax_hpf = alpha * (ax_hpf + ax_mps2 - ax_prev);
  ay_hpf = alpha * (ay_hpf + ay_mps2 - ay_prev);
  az_hpf = alpha * (az_hpf + az_mps2 - az_prev);
  ax_prev = ax_mps2;
  ay_prev = ay_mps2;
  az_prev = az_mps2;

  float vx_new = ax_hpf * DT;
  float vy_new = ay_hpf * DT;
  float vz_new = az_hpf * DT;

  float velocityMagnitude = sqrt(vx_new * vx_new + vy_new * vy_new + vz_new * vz_new) * 1000;
  float velocityFiltered = movingAverage(velocityMagnitude);

  vReal[fftIndex] = velocityFiltered;
  vImag[fftIndex] = 0;
  fftIndex = (fftIndex + 1) % NUM_SAMPLES;

  if (fftIndex == 0) performFFT();

  return calculateVelocityRMS();
}

void MPU6050_VibrationRMS::MPU6050_Init() {
  Wire.beginTransmission(MPU6050_I2C_ADDRESS);
  Wire.write(0x6B);
  Wire.write(0x00);
  Wire.endTransmission();
  delay(10);
}

void MPU6050_VibrationRMS::readAccelerometer(int16_t &ax, int16_t &ay, int16_t &az) {
  Wire.beginTransmission(MPU6050_I2C_ADDRESS);
  Wire.write(0x3B);
  Wire.endTransmission(false);
  Wire.requestFrom(MPU6050_I2C_ADDRESS, 6, true);
  ax = Wire.read() << 8 | Wire.read();
  ay = Wire.read() << 8 | Wire.read();
  az = Wire.read() << 8 | Wire.read();
}

void MPU6050_VibrationRMS::calibrateAccelerometer() {
  delay(1000);
  int32_t ax_sum = 0, ay_sum = 0, az_sum = 0;
  const int samples = 1000;
  for (int i = 0; i < samples; i++) {
    int16_t ax_temp, ay_temp, az_temp;
    readAccelerometer(ax_temp, ay_temp, az_temp);
    ax_sum += ax_temp;
    ay_sum += ay_temp;
    az_sum += az_temp;
    delay(1);
  }
  ax_offset = ax_sum / samples;
  ay_offset = ay_sum / samples;
  az_offset = az_sum / samples;
}

void MPU6050_VibrationRMS::performFFT() {
  FFT.Windowing(vReal, NUM_SAMPLES, FFT_WIN_TYP_HAMMING, FFT_FORWARD);
  FFT.Compute(vReal, vImag, NUM_SAMPLES, FFT_FORWARD);
  FFT.ComplexToMagnitude(vReal, vImag, NUM_SAMPLES);
}

float MPU6050_VibrationRMS::movingAverage(float newValue) {
  static float buffer[MA_WINDOW_SIZE] = {0};
  static int index = 0;
  static float sum = 0;
  sum -= buffer[index];
  buffer[index] = newValue;
  sum += newValue;
  index = (index + 1) % MA_WINDOW_SIZE;
  return sum / MA_WINDOW_SIZE;
}

float MPU6050_VibrationRMS::calculateVelocityRMS() {
  float sumSquares = 0;
  for (int i = 0; i < NUM_SAMPLES; i++) {
    sumSquares += vReal[i] * vReal[i];
  }
  return sqrt(sumSquares / NUM_SAMPLES);
}
