#include "mobileR_NMV.h"
#include "Arduino.h"

// ======================= Static instance =======================
mobileR_NMV* mobileR_NMV::instance = nullptr;

// ======================= Constructor =======================
mobileR_NMV::mobileR_NMV()
: motorL_pin(10), motorR_pin(9),
  encL_pin(2), encR_pin(3),
  wheelRadius(0.033), wheelBase(0.29),
  Kp_L(2.2), Ki_L(0.5), Kd_L(0.0),
  Kp_R(2.2), Ki_R(0.5), Kd_R(0.0),
  pidOutputL(0), filtered_vel_L(0), targetL(0),
  pidOutputR(0), filtered_vel_R(0), targetR(0),
  PID_L(&filtered_vel_L, &pidOutputL, &targetL, Kp_L, Ki_L, Kd_L, DIRECT),
  PID_R(&filtered_vel_R, &pidOutputR, &targetR, Kp_R, Ki_R, Kd_R, DIRECT),
  encoderCountL(0), encoderCountR(0), lastCountL(0), lastCountR(0),
  lastVelTimeL(0), lastVelTimeR(0),
  x(0), y(0), theta(0), lastOdomTime(0),
  recoveryActive(false), recoveryStart(0), recoveryStage(0),
  kP_omega(70)
{
    degPerPulseL = 0.8006;
    degPerPulseR = 0.8006;
    instance = this; // set static instance
}

// ======================= Attach motors =======================
void mobileR_NMV::attach_motors(char left, int pinL, char right, int pinR) {
    if(left == 'L') motorL_pin = pinL;
    if(right == 'R') motorR_pin = pinR;

    pinMode(motorL_pin, OUTPUT);
    pinMode(motorR_pin, OUTPUT);
}

// ======================= Attach encoders =======================
void mobileR_NMV::attach_encoders(char left, int pinL, char right, int pinR) {
    if(left == 'L') encL_pin = pinL;
    if(right == 'R') encR_pin = pinR;

    pinMode(encL_pin, INPUT);
    pinMode(encR_pin, INPUT);

    lastCountL = encoderCountL;
    lastCountR = encoderCountR;

    attachInterrupt(digitalPinToInterrupt(encL_pin), ISR_encoderL, CHANGE);
    attachInterrupt(digitalPinToInterrupt(encR_pin), ISR_encoderR, CHANGE);
}


// ======================= Static ISR =======================
void mobileR_NMV::ISR_encoderL() { if(instance) instance->handleEncoderL(); }
void mobileR_NMV::ISR_encoderR() { if(instance) instance->handleEncoderR(); }

void mobileR_NMV::handleEncoderL() { encoderCountL++; }
void mobileR_NMV::handleEncoderR() { encoderCountR++; }

//======================= Encoder calibration =======================
void mobileR_NMV::caliEncoder(char wheel, long counts, float revs) {
    double degPerRev = 360.0;
    double anglePerPulse = degPerRev * revs / counts;
    if (wheel == 'L') degPerPulseL = anglePerPulse;
    else if (wheel == 'R') degPerPulseR = anglePerPulse;
}

/* 
void mobileR_NMV::setPID_L(double Kp, double Ki, double Kd) {
    Kp_L = Kp;
    Ki_L = Ki;
    Kd_L = Kd;

    PID_L.SetMode(AUTOMATIC);
    PID_L.SetOutputLimits(minPulse, maxPulse);
    PID_L.SetSampleTime(5);
}

void mobileR_NMV::setPID_R(double Kp, double Ki, double Kd) {
    Kp_R = Kp;
    Ki_R = Ki;
    Kd_R = Kd;

    PID_R.SetMode(AUTOMATIC);
    PID_R.SetOutputLimits(minPulse, maxPulse);
    PID_R.SetSampleTime(5);
}

*/

void mobileR_NMV::setPID_L(double Kp, double Ki, double Kd, double minOut, double maxOut, int sampleTime) {
    //Kp_L = Kp; Ki_L = Ki; Kd_L = Kd;
    PID_L.SetTunings(Kp, Ki, Kd);
    //Serial.println(Kp_L);
    PID_L.SetMode(AUTOMATIC);
    //PID_L.SetOutputLimits(minPulse, maxPulse);
    PID_L.SetSampleTime(sampleTime);
}

void mobileR_NMV::setPID_R(double Kp, double Ki, double Kd, double minOut, double maxOut, int sampleTime) {
    //Kp_R = Kp; Ki_R = Ki; Kd_R = Kd;
    PID_R.SetMode(AUTOMATIC);
    PID_R.SetTunings(Kp, Ki, Kd);
    //PID_R.SetOutputLimits(minPulse, maxPulse);
    PID_R.SetSampleTime(sampleTime);
}


void mobileR_NMV::pulseLimit(const char* type1, int value1, const char* type2, int value2) {
    // type1
    if (strcmp(type1, "max") == 0) maxPulse = value1;
    else if (strcmp(type1, "min") == 0) minPulse = value1;

    // type2
    if (strcmp(type2, "max") == 0) maxPulse = value2;
    else if (strcmp(type2, "min") == 0) minPulse = value2;

    PID_L.SetOutputLimits(minPulse, maxPulse);
    PID_R.SetOutputLimits(minPulse, maxPulse);
}



// ======================= Odometer calibration =======================
void mobileR_NMV::caliODOM(float wheelR, float wheelB) {
    wheelRadius = wheelR;
    wheelBase = wheelB;
}

// ======================= Obstacle recovery =======================
void mobileR_NMV::navi_obst_recovery(unsigned long stage1_ms, unsigned long stage2_ms, unsigned long stage3_ms) {
    stageTimes[0] = stage1_ms;
    stageTimes[1] = stage2_ms;
    stageTimes[2] = stage3_ms;
}

void mobileR_NMV::navi_caliP(float k_val) {
    kP_omega = k_val;
}

// ======================= Update velocities =======================
double mobileR_NMV::updateVelocityL() {
    unsigned long now = millis();
    if (now - lastVelTimeL >= 200) {
        long diff = encoderCountL - lastCountL;
        double vel_meas = (degPerPulseL * 1000 * diff) / (now - lastVelTimeL);
        filtered_vel_L = 0.1 * vel_meas + 0.9 * filtered_vel_L;
        lastVelTimeL = now;
        lastCountL = encoderCountL;
    }
    return filtered_vel_L;
}

double mobileR_NMV::updateVelocityR() {
    unsigned long now = millis();
    if (now - lastVelTimeR >= 200) {
        long diff = encoderCountR - lastCountR;
        double vel_meas = (degPerPulseR * 1000 * diff) / (now - lastVelTimeR);
        filtered_vel_R = 0.1 * vel_meas + 0.9 * filtered_vel_R;
        lastVelTimeR = now;
        lastCountR = encoderCountR;
    }
    return filtered_vel_R;
}

// ======================= PID compute =======================
double mobileR_NMV::computePID_L(double currentVel, double targetVel) {
    filtered_vel_L = currentVel;
    targetL = targetVel;
    PID_L.Compute();
    return pidOutputL;
}

double mobileR_NMV::computePID_R(double currentVel, double targetVel) {
    filtered_vel_R = currentVel;
    targetR = targetVel;
    PID_R.Compute();
    return pidOutputR;
}

// ======================= Odometry =======================
Pose mobileR_NMV::odometer(float velL_deg, float velR_deg, bool obstacleMode, int sensorValue, int threshold) {
    unsigned long now = millis();
    if (lastOdomTime == 0) lastOdomTime = now;
    float dt = (now - lastOdomTime)/1000.0;

    // Obstacle recovery
    if (obstacleMode && sensorValue > threshold && !recoveryActive) {
        recoveryActive = true;
        recoveryStage = 1;
        recoveryStart = now;
    }

    if (recoveryActive) {
        switch(recoveryStage){
            case 1:
                velL_deg = 0; velR_deg = 0;
                analogWrite(motorL_pin, 30);
                analogWrite(motorR_pin, 30);
                if(now - recoveryStart >= stageTimes[0]) { recoveryStage = 2; recoveryStart = now; }
                break;
            case 2:
                velL_deg = 0; velR_deg = 120;
                analogWrite(motorL_pin, 0);
                analogWrite(motorR_pin, 60);
                if(now - recoveryStart >= stageTimes[1]) { recoveryStage = 3; recoveryStart = now; }
                break;
            case 3:
                velL_deg = 120; velR_deg = 120;
                analogWrite(motorL_pin, 60);
                analogWrite(motorR_pin, 60);
                if(now - recoveryStart >= stageTimes[2]) { recoveryActive=false; recoveryStage=0; }
                break;
        }
    }

    // Normal odometry
    float velL_rad = velL_deg * (PI/180.0);
    float velR_rad = velR_deg * (PI/180.0);
    float v = wheelRadius*(velR_rad + velL_rad)/2.0;
    float omega = wheelRadius*(velR_rad - velL_rad)/wheelBase;

    x += v*cos(theta)*dt;
    y += v*sin(theta)*dt;
    theta += omega*dt;

    if(theta>PI) theta -= 2*PI;
    if(theta<-PI) theta += 2*PI;

    lastOdomTime = now;
    return {x, y, theta};
}

// ======================= Navigation =======================
WheelVelocities mobileR_NMV::navigate(Pose rPose, float goalX, float goalY){
    WheelVelocities w = {0,0,false};
    float dx = goalX - rPose.x;
    float dy = goalY - rPose.y;
    float dist = sqrt(dx*dx + dy*dy);

    float desiredTheta = atan2(dy, dx);
    float thetaErr = desiredTheta - rPose.heading;

    if(dist < 0.1){ w.velL_deg=0; w.velR_deg=0; w.goalReached=true; return w; }

    float vmax = 80.0;
    float vForward = vmax + 40*dist;
    if(vForward < 0) vForward = 0;

    w.velL_deg = vForward - thetaErr*kP_omega;
    w.velR_deg = vForward + thetaErr*kP_omega;
    return w;
}



PathData mobileR_NMV::handlePath(const String& dataIn) {

    String data = dataIn;   // local copy

    // ---- REMOVE HEADER IF PRESENT ----
    if (data.startsWith("PATH,")) {
        data = data.substring(5); // remove "PATH,"
    }

    PathData path;
    path.length = 0;
    int start = 0;

    while (start < data.length() && path.length < MAX_POINTS) {

        int comma = data.indexOf(',', start);
        int semi  = data.indexOf(';', start);

        if (comma == -1) break;

        float x = data.substring(start, comma).toFloat();

        float y;
        if (semi == -1) {
            y = data.substring(comma + 1).toFloat();
            start = data.length();
        } else {
            y = data.substring(comma + 1, semi).toFloat();
            start = semi + 1;
        }

        path.goalX[path.length] = x;
        path.goalY[path.length] = y;
        path.length++;
    }

    return path;
}


AccAvg mobileR_NMV::handleACC(const String& data) {
    AccAvg out;
    out.x = 0; out.y = 0; out.z = 0; out.samples = 0;

    int start = data.indexOf("ACC,");
    if (start == -1) return out;

    String str = data.substring(start + 4);

    while (str.length() > 0) {
        int c1 = str.indexOf(',');
        int c2 = str.indexOf(',', c1 + 1);
        int sep = str.indexOf(';');

        if (c1 == -1 || c2 == -1) break;

        float ax = str.substring(0, c1).toFloat();
        float ay = str.substring(c1 + 1, c2).toFloat();

        float az;
        if (sep == -1) {
            az = str.substring(c2 + 1).toFloat();
            str = "";
        } else {
            az = str.substring(c2 + 1, sep).toFloat();
            str = str.substring(sep + 1);
        }

        out.x += ax;
        out.y += ay;
        out.z += az;
        out.samples++;
    }

    if (out.samples > 0) {
        out.x /= out.samples;
        out.y /= out.samples;
        out.z /= out.samples;
    }

    return out;
}


