import os
import uuid

import traci
import traci.constants as tc
from sumolib import checkBinary

from intersection import Intersection
import synchronization_util


class SumoEnv:

    LANE_VARIABLES_TO_SUBSCRIBE = [
        tc.LAST_STEP_VEHICLE_NUMBER,
        tc.LAST_STEP_VEHICLE_ID_LIST,
        tc.LAST_STEP_VEHICLE_HALTING_NUMBER,
        tc.VAR_WAITING_TIME,
        tc.VAR_LENGTH,
        tc.LAST_STEP_MEAN_SPEED,
        tc.VAR_MAXSPEED
    ]

    VEHICLE_VARIABLES_TO_SUBSCRIBE = [
        tc.VAR_POSITION,
        tc.VAR_SPEED,
        tc.VAR_WAITING_TIME,
        tc.VAR_LANEPOSITION,
        tc.VAR_ALLOWED_SPEED,
        tc.VAR_MINGAP,
        tc.VAR_TAU,
        # tc.VAR_LEADER,  # Problems with subscription
        # tc.VAR_SECURE_GAP,  # Problems with subscription
        tc.VAR_LENGTH,
        tc.VAR_LANE_ID,
        tc.VAR_DECEL,
        tc.VAR_WIDTH,
        tc.VAR_ANGLE,
        tc.VAR_STOPSTATE,
    ]

    SIMULATION_VARIABLES_TO_SUBSCRIBE = [
        tc.VAR_DEPARTED_VEHICLES_NUMBER,
        tc.VAR_PENDING_VEHICLES,
    ]

    def __init__(self, net_file, route_file, path_to_log):

        self.path_to_log = path_to_log

        self.net_file = net_file
        self.route_file = route_file

        self.external_configurations = {}
        self.external_configurations['SUMOCFG_PARAMETERS'] = {
            '-n': self.net_file,
            '-r': self.route_file,
            '--time-to-teleport': -1,
            # '--collision.stoptime': 10,
            '--collision.mingap-factor': 0,
            '--collision.action': 'warn',
            '--collision.check-junctions': False,
            '--device.rerouting.threads': 4,
            '--save-state.rng': True,
            '--ignore-junction-blocker': 10  # working in Sumo 1.8.0
        }

        self.environment_state_path = os.path.join(self.path_to_log, 'environment', 'temp')

        self.lanes_list = None

        self.execution_name = None

        self.current_step_lane_subscription = None
        self.current_step_vehicle_subscription = None
        self.current_step_lane_vehicle_subscription = None
        self.current_step_vehicles = []
        self.previous_step_vehicles = []

    def reset(self, execution_name):

        self.execution_name = execution_name + '__' + str(uuid.uuid4())

        self.intersection = Intersection(execution_name=self.execution_name)

        self.lanes_list = ['-gneE0_0', '-gneE0_1', '-gneE0_2', '-gneE1_0', '-gneE1_1', '-gneE1_2', '-gneE2_0', '-gneE2_1', '-gneE2_2', '-gneE3_0', '-gneE3_1', '-gneE3_2', 'gneE0_0', 'gneE0_1', 'gneE0_2', 'gneE1_0', 'gneE1_1', 'gneE1_2', 'gneE2_0', 'gneE2_1', 'gneE2_2', 'gneE3_0', 'gneE3_1', 'gneE3_2', ':gneJ0_0_0', ':gneJ0_1_0', ':gneJ0_2_0', ':gneJ0_12_0', ':gneJ0_13_0', ':gneJ0_3_0', ':gneJ0_4_0', ':gneJ0_5_0', ':gneJ0_14_0', ':gneJ0_15_0', ':gneJ0_6_0', ':gneJ0_7_0', ':gneJ0_8_0', ':gneJ0_16_0', ':gneJ0_17_0', ':gneJ0_9_0', ':gneJ0_10_0', ':gneJ0_11_0', ':gneJ0_18_0', ':gneJ0_19_0']

        self.current_step_lane_subscription = None
        self.current_step_vehicle_subscription = None
        self.current_step_lane_vehicle_subscription = None
        self.current_step_vehicles = []

        sumo_cmd_str = self._get_sumo_cmd()

        synchronization_util.traci_start_lock.acquire()
        print("start sumo")
        trace_file_path = self.path_to_log + '/' + 'trace_file_log.txt'
        try:
            traci.start(sumo_cmd_str, label=self.execution_name, traceFile=trace_file_path, traceGetters=False)
        except Exception as e:
            traci.close()

        traci_connection = traci.getConnection(self.execution_name)
        print("succeed in start sumo")
        synchronization_util.traci_start_lock.release()

        print('SUMO VERSION', traci_connection.getVersion()[1])

        # start subscription
        for lane in self.lanes_list:
            traci_connection.lane.subscribe(lane, [var for var in self.LANE_VARIABLES_TO_SUBSCRIBE])

        traci_connection.simulation.subscribe([var for var in self.SIMULATION_VARIABLES_TO_SUBSCRIBE])
        
        vehicle_ids = traci_connection.simulation.getLoadedIDList()
        for vehicle_id in vehicle_ids:
            traci_connection.vehicle.subscribe(vehicle_id, [var for var in self.VEHICLE_VARIABLES_TO_SUBSCRIBE])

        # get new measurements
        self.update_current_measurements()
        self.intersection.update_current_measurements()

        state, done = self.get_state()

        next_action = [None]

        return state, next_action

    def end_sumo(self):
        traci_connection = traci.getConnection(self.execution_name)
        traci_connection.close()

    def update_previous_measurements(self):

        self.previous_step_vehicles = self.current_step_vehicles

    def update_current_measurements(self):

        traci_connection = traci.getConnection(self.execution_name)

        # ====== lane level observations =======

        self.current_step_lane_subscription = {lane_id: traci_connection.lane.getSubscriptionResults(lane_id)
                                               for lane_id in self.lanes_list}

        # ====== vehicle level observations =======

        # get vehicle list
        current_step_vehicles = []
        for lane_id, values in self.current_step_lane_subscription.items():
            lane_vehicles = self.current_step_lane_subscription[lane_id][tc.LAST_STEP_VEHICLE_ID_LIST]
            current_step_vehicles += lane_vehicles

        self.current_step_vehicles = current_step_vehicles
        recently_arrived_vehicles = list(set(self.current_step_vehicles) - set(self.previous_step_vehicles))

        # update subscriptions
        for vehicle_id in recently_arrived_vehicles:
            traci_connection.vehicle.subscribe(vehicle_id, [var for var in self.VEHICLE_VARIABLES_TO_SUBSCRIBE])

    def get_current_time(self):
        traci_connection = traci.getConnection(self.execution_name)
        return traci_connection.simulation.getTime()

    def get_state(self):

        state_list = [self.intersection.get_state(['current_phase', 'movement_number_of_vehicles'])]
        done = False

        return state_list, done

    def save_state(self, name=None):

        if not os.path.isdir(self.environment_state_path):
            os.makedirs(self.environment_state_path)

        if name is None:
            state_name = self.execution_name + '_' + 'save_state' + '_' + str(self.get_current_time()) + '.sbx'

        filepath = os.path.join(self.environment_state_path, state_name)

        traci_connection = traci.getConnection(self.execution_name)
        traci_connection.simulation.saveState(filepath)

        return filepath

    def check_for_active_action_time_actions(self, action):

        intersection_index = 0

        action_time_action = self.intersection.select_active_action_time_action()

        if action_time_action != -1:
            action[intersection_index] = action_time_action

        return action
    
    def step(self, action):

        action = self.check_for_active_action_time_actions(action)

        if None in action:
            raise ValueError('Action cannot be None')

        step = 0
        while None not in action:

            time = self.get_current_time()

            # _step
            self._inner_step(action)

            # get reward
            reward = [0]

            if step == 0:
                print("time: {0}, action: {1}, reward: {2}".
                      format(time,
                             action[0],
                             reward[0]))

            next_state, done = self.get_state()

            step += 1

            action = [None]
            action = self.check_for_active_action_time_actions(action)

        next_action = action

        return next_state, reward, done, step, next_action

    def _inner_step(self, action):

        self.update_previous_measurements()
        self.intersection.update_previous_measurements()

        intersection_index = 0

        self.intersection.set_signal(action[intersection_index])

        traci_connection = traci.getConnection(self.execution_name)
        traci_connection.simulationStep()

        self.update_current_measurements()
        self.intersection.update_current_measurements()

    def get_sumo_binary(self, gui=False):

        if gui:
            sumo_binary = checkBinary('sumo-gui')
        else:
            sumo_binary = checkBinary('sumo')

        return sumo_binary

    def _get_sumo_cmd(self):

        sumocfg_parameters_list = [str(item)
                                   for key_value_pair in self.external_configurations['SUMOCFG_PARAMETERS'].items()
                                   for item in key_value_pair]

        sumo_binary = self.get_sumo_binary(gui=False)
        sumo_cmd = [sumo_binary, *sumocfg_parameters_list]

        return sumo_cmd
