import copy
from functools import partial
import statistics
from concurrent.futures import ThreadPoolExecutor

import numpy as np

import synchronization_util

import copyreg

import lxml.etree as etree


copyreg_registered = False

parser = etree.XMLParser(remove_blank_text=True)


def get_xml(file):

    xml = etree.parse(file, parser)

    return xml


def etree_unpickler(data):
    return etree.fromstring(data)


def etree_pickler(tree):
    data = etree.tostring(tree)
    return etree_unpickler, (data,)


def register_copyreg():
    global copyreg_registered

    if not copyreg_registered:
        copyreg.pickle(etree._ElementTree, etree_pickler, etree_unpickler)
        copyreg.pickle(etree._Element, etree_pickler, etree_unpickler)
        copyreg_registered = True


class PlanningOnlyAgent:
    
    def __init__(self):

        self.env = None

        self.phases = ['0S_2S_0L_2L', '1S_3S_1L_3L']
        self.planning_iterations = 2

        register_copyreg()

    def set_simulation_environment(self, env):
        self.env = env

    def choose_action(self, step, state, *args, **kwargs):

        intersection_index = kwargs.get('intersection_index', None)

        if intersection_index is None:
            raise ValueError('intersection_index must be declared')

        rng = np.random.Generator(np.random.MT19937(23423))

        action, _ = self._choose_action(step, state, step, intersection_index, rng,
                                        self.planning_iterations)

        return action

    def _choose_action(self, initial_step, one_state, original_step, intersection_index, rng, planning_iterations,
                       possible_actions=None, env=None, *args, **kwargs):

        if possible_actions is None:
            possible_actions = range(0, len(self.phases))

        if env is None:
            env = self.env

        save_state_filepath = env.save_state()

        # mutable objects need deep copy in the target function
        simulation_possibility_kwargs = {
            'initial_step': initial_step,
            'one_state': one_state,  # deep copy needed
            'original_step': original_step,
            'intersection_index': intersection_index,
            'save_state_filepath': save_state_filepath,
            'rng_state': rng.bit_generator.state,  # deep copy needed
            'planning_iterations': planning_iterations,
            'possible_actions': possible_actions,
        }

        simulation_possibility_kwargs.update(
            **kwargs
        )

        with ThreadPoolExecutor(max_workers=len(possible_actions)) as executor:
            possible_future_rewards = executor.map(
                partial(self._run_simulation_possibility, **simulation_possibility_kwargs),
                possible_actions
            )

        possible_future_rewards = list(possible_future_rewards)

        mean_rewards = [statistics.mean(future_rewards) for future_rewards in possible_future_rewards]

        best_actions = np.flatnonzero(mean_rewards == np.max(mean_rewards))

        if len(best_actions) > 1:
            action = rng.choice(best_actions)
        else:
            action = best_actions[0]

        rewards = possible_future_rewards[action]

        return action, rewards

    def _run_simulation_possibility(
            self,
            action,
            initial_step,
            one_state,
            original_step,
            intersection_index,
            save_state_filepath,
            rng_state,
            planning_iterations,
            possible_actions,
            **kwargs):
            
        try:

            one_state = copy.deepcopy(one_state)
            rng_state = copy.deepcopy(rng_state)

            from sumo_env import SumoEnv

            env = copy.deepcopy(self.env)

            env = SumoEnv(
                net_file=env.net_file,
                route_file=env.route_file,
                path_to_log=env.path_to_log)

            env.external_configurations['SUMOCFG_PARAMETERS'].update(
                {
                    '--begin': initial_step,
                    '--load-state': save_state_filepath
                }
            )

            execution_name = 'planning_for_step' + '_' + str(original_step) + '__' + \
                             'initial_step' + '_' + str(initial_step) + '__' + \
                             'phase' + '_' + str(action)

            _, next_action = env.reset(execution_name)
            rewards = []

            action_list = ['no_op']*len([0])
            action_list[intersection_index] = action

            next_state, reward, done, steps_iterated, _ = env.step(action_list)

            one_state = next_state[intersection_index]
            rewards.append(reward[0])

            planning_iterations -= 1
            if planning_iterations > 0 or done:

                rng = np.random.Generator(np.random.MT19937(23423))
                rng.bit_generator.state = rng_state

                _, future_rewards = self._choose_action(
                    initial_step + steps_iterated,
                    one_state,
                    original_step,
                    intersection_index,
                    rng,
                    planning_iterations,
                    possible_actions,
                    env,
                    **kwargs
                )

                rewards.extend(future_rewards)

            env.end_sumo()

        except Exception as e:
            print(e)
            raise e

        return rewards
