import pandas as pd
import traci
from datetime import datetime
from scipy.optimize import fsolve
import numpy as np
from scipy.stats import truncnorm
import pickle
from collections import deque
import math

data = []

# Main simulation loop
def run():
    start = datetime.now()
    print(str(datetime.now()))

    # begin and end times of the simulation in seconds from 00:00
    # begin = 0
    # end = 50000
    seed = 111
    step = 0.25

    # start Traci
    traci.start(["sumo", "-n", "osm_reserveCopy.net.xml",
                 "-r", "route_res_PS.rou.xml",
                 "--step-length", f"{step}",
                 # "-b", f"{begin}", "-e", f"{end}",
                 "--seed", f"{seed}",
                 "--no-warnings"])

    # step = 0
    prev_position = {}

    # while step < (end - begin):
    while traci.simulation.getMinExpectedNumber() > 0: # Runs while vehicles are in simulation
        traci.simulationStep()

        # Get all vehicles in the simulation
        vehicles = traci.vehicle.getIDList()

        # **🟨 Loop through all vehicles and filter cyclists 🟨**
        for vehicle_id in vehicles:
            # Check if the vehicle is a cyclist (assuming bicycles have a special vehicle type)
            if traci.vehicle.getVehicleClass(vehicle_id) == 'bicycle':

                # **🟨 Handle teleporting vehicles 🟨**
                if traci.vehicle.isStopped(vehicle_id):
                    continue  # Skip adjusting speed for teleported vehicles

                # **🟨 Get the slope for the vehicle 🟨**
                current_slope = traci.vehicle.getSlope(vehicle_id) # slope in degrees
                slope = math.tan(math.radians(current_slope)) * 100 # Convert to percentage
                slope_rad = math.radians(current_slope) # Convert to radians
                print(f"slope: {slope}")

                # Get current position (x, y, z)
                x,y,z = traci.vehicle.getPosition3D(vehicle_id)

                # Store current position for next step
                prev_position[vehicle_id] = (x, y, z)

                speed_current = traci.vehicle.getSpeed(vehicle_id)
                print(f"speed_current: {speed_current}")
                distance = traci.vehicle.getDistance(vehicle_id)
                time = traci.simulation.getTime()

                # Calculate acceleration and speed like in the Krauss PS model
                accel_flat = 1.2
                decel_max = 3.0
                max_speed_flat = 5.5
                g = 9.81

                # Calculate acceleration for slope
                a_max = max(
                    0,
                    accel_flat - g * np.sin(slope_rad)
                )
                print(f"a_max: {a_max}")

                # Calculate desired max speed
                v_max = max(
                    np.sqrt(a_max / accel_flat) * max_speed_flat,
                    speed_current - decel_max * step
                )
                print(f"v_max1: {np.sqrt(a_max / accel_flat) * max_speed_flat}, v_max2: {speed_current - decel_max * step}, v_max: {v_max}")

                # Calculate the next step speed
                v_next = max(
                    accel_flat / 2,
                    min(
                        speed_current + a_max * step,
                        v_max
                    )
                )
                print(f"v_next1: {accel_flat / 2}, v_next2: {speed_current + a_max * step}, v_next3: {v_max}, v_next: {v_next}")

                traci.vehicle.setSpeedMode(vehicle_id, 0)
                traci.vehicle.setSpeed(vehicle_id, v_next)

#                 if slope < -5:
#                     print(f"a_max: {a_max}, v_max: {v_max}, v_next: {v_next}, speed_current: {speed_current}, slope: {slope}")

                # Store
                data.append([vehicle_id, time, speed_current, distance, slope, x, y, z])

    # close simulation
    traci.close()

    df = pd.DataFrame(data, columns=['id', 'time', 'speed', 'distance', 'grade', 'x', 'y', 'z'])
    df.to_csv('Traci_kraussps_data1.csv', index=False)

    # runtime
    end = datetime.now()
    runtime = end - start
    print("runtime is: " + str(runtime))


if __name__ == '__main__':

    startG = datetime.now()

    run()

    endG = datetime.now()
    runtimeG = endG - startG
    print('Finished in: ' + str(runtimeG))