
import traci

class Intersection:

    def __init__(self, execution_name):

        self.traffic_light_id = 'gneJ0'
        self.execution_name = execution_name

        self.phases = ['0S_2S_0L_2L', '1S_3S_1L_3L']
        self.min_action_time = 10

        self.phase_traffic_lights = self.get_phase_traffic_lights()

        # -1: all yellow, -2: all red, -3: none
        self.current_phase_index = 0
        self.previous_phase_index = 0
        self.next_phase_to_set_index = None
        self.current_phase_duration = -1
        self.current_min_action_duration = -1

        self.current_step_lane_subscription = None
        self.current_step_vehicle_subscription = None
        self.current_step_lane_area_detector_vehicle_ids = None
        self.previous_step_vehicles = []
        self.current_step_vehicles = []

        self.feature_dict = {}  # this second

        self.state_feature_list = ['current_phase', 'movement_number_of_vehicles']

        self.feature_dict_function = {
            'current_phase': lambda: [self.current_phase_index],
            'movement_number_of_vehicles': lambda: [0, 0, 0, 0, 0, 0, 0, 0]
        }

        self.reward_dict_function = {
            'sum_number_of_vehicles_been_stopped_threshold_1':
                lambda: 0,
        }

    def update_previous_measurements(self):

        self.previous_phase_index = self.current_phase_index
        self.previous_step_vehicles = self.current_step_vehicles

    def update_current_measurements(self):
        # need change, debug in seeing format

        if self.current_phase_index == self.previous_phase_index:
            self.current_phase_duration += 1
        else:
            self.current_phase_duration = 1

        self.current_min_action_duration += 1

        # update feature
        self._update_feature()

    def set_signal(self, action):

        if self.next_phase_to_set_index is None or self.current_min_action_duration >= self.min_action_time:

            if action == 'no_op':
                return

            # determine phase
            phase = self.phases[action]
            phase_index = self.phases.index(phase)
            self.next_phase_to_set_index = phase_index

            # set phase
            if self.current_phase_index == self.next_phase_to_set_index:  # the light phase keeps unchanged
                self.current_min_action_duration = 0
            else:  # the light phase needs to change

                traci_connection = traci.getConnection(self.execution_name)
                traci_connection.trafficlight.setRedYellowGreenState(self.traffic_light_id, self.phase_traffic_lights[phase])

                self.current_min_action_duration = 0

    # ================= update current step measurements ======================

    def _update_feature(self):

        feature_dict = {}
        for f in self.state_feature_list:
            feature_dict[f] = self.feature_dict_function[f]()

        self.feature_dict = feature_dict

    # ================= calculate features from current observations ======================

    def get_current_time(self):
        traci_connection = traci.getConnection(self.execution_name)
        return traci_connection.simulation.getTime()

    def get_state(self, state_feature_list):
        state_dict = {state_feature_name: self.feature_dict[state_feature_name]
                      for state_feature_name in state_feature_list}
        return state_dict

    def get_phase_traffic_lights(self):

        phase_traffic_lights = {}
        phase_traffic_lights['0S_2S_0L_2L'] = 'gGgrrrgGgrrrrGrG'
        phase_traffic_lights['1S_3S_1L_3L'] = 'rrrgGgrrrgGgGrGr'

        return phase_traffic_lights

    def select_active_action_time_action(self):
        
        if self.current_min_action_duration < self.min_action_time:
            # next phase only changes with a new action choice
            if self.next_phase_to_set_index is not None:
                return self.next_phase_to_set_index

        return -1
