# Source code for mushroom_rl.policy.td_policy

import numpy as np
from scipy.optimize import brentq
from scipy.special import logsumexp
from .policy import Policy

from mushroom_rl.utils.parameters import Parameter

[docs]class TDPolicy(Policy):
[docs]    def __init__(self):
"""
Constructor.

"""
self._approximator = None

[docs]    def set_q(self, approximator):
"""
Args:
approximator (object): the approximator to use.

"""
self._approximator = approximator

[docs]    def get_q(self):
"""
Returns:
The approximator used by the policy.

"""
return self._approximator

[docs]class EpsGreedy(TDPolicy):
"""
Epsilon greedy policy.

"""
[docs]    def __init__(self, epsilon):
"""
Constructor.

Args:
epsilon (Parameter): the exploration coefficient. It indicates
the probability of performing a random actions in the current
step.

"""
super().__init__()

assert isinstance(epsilon, Parameter)
self._epsilon = epsilon

[docs]    def __call__(self, *args):
state = args[0]
q = self._approximator.predict(np.expand_dims(state, axis=0)).ravel()
max_a = np.argwhere(q == np.max(q)).ravel()

p = self._epsilon.get_value(state) / self._approximator.n_actions

if len(args) == 2:
action = args[1]
if action in max_a:
return p + (1. - self._epsilon.get_value(state)) / len(max_a)
else:
return p
else:
probs = np.ones(self._approximator.n_actions) * p
probs[max_a] += (1. - self._epsilon.get_value(state)) / len(max_a)

return probs

[docs]    def draw_action(self, state):
if not np.random.uniform() < self._epsilon(state):
q = self._approximator.predict(state)
max_a = np.argwhere(q == np.max(q)).ravel()

if len(max_a) > 1:
max_a = np.array([np.random.choice(max_a)])

return max_a

return np.array([np.random.choice(self._approximator.n_actions)])

[docs]    def set_epsilon(self, epsilon):
"""
Setter.

Args:
epsilon (Parameter): the exploration coefficient. It indicates the
probability of performing a random actions in the current step.

"""
assert isinstance(epsilon, Parameter)

self._epsilon = epsilon

[docs]    def update(self, *idx):
"""
Update the value of the epsilon parameter at the provided index (e.g. in
case of different values of epsilon for each visited state according to
the number of visits).

Args:
*idx (list): index of the parameter to be updated.

"""
self._epsilon.update(*idx)

[docs]class Boltzmann(TDPolicy):
"""
Boltzmann softmax policy.

"""
[docs]    def __init__(self, beta):
"""
Constructor.

Args:
beta (Parameter): the inverse of the temperature distribution. As
the temperature approaches infinity, the policy becomes more and
more random. As the temperature approaches 0.0, the policy becomes
more and more greedy.

"""
super().__init__()
self._beta = beta

[docs]    def __call__(self, *args):
state = args[0]
q_beta = self._approximator.predict(state) * self._beta(state)
q_beta -= q_beta.max()
qs = np.exp(q_beta)

if len(args) == 2:
action = args[1]

return qs[action] / np.sum(qs)
else:
return qs / np.sum(qs)

[docs]    def draw_action(self, state):
return np.array([np.random.choice(self._approximator.n_actions,
p=self(state))])

[docs]    def set_beta(self, beta):
"""
Setter.

Args:
beta (Parameter): the inverse of the temperature distribution.

"""
assert isinstance(beta, Parameter)

self._beta = beta

[docs]    def update(self, *idx):
"""
Update the value of the beta parameter at the provided index (e.g. in
case of different values of beta for each visited state according to
the number of visits).

Args:
*idx (list): index of the parameter to be updated.

"""
self._beta.update(*idx)

[docs]class Mellowmax(Boltzmann):
"""
Mellowmax policy.
"An Alternative Softmax Operator for Reinforcement Learning". Asadi K. and
Littman M.L.. 2017.

"""

class MellowmaxParameter:
def __init__(self, outer, omega, beta_min, beta_max):
self._omega = omega
self._outer = outer
self._beta_min = beta_min
self._beta_max = beta_max

def __call__(self, state):
q = self._outer._approximator.predict(state)
mm = (logsumexp(q * self._omega(state)) - np.log(
q.size)) / self._omega(state)

def f(beta):
v = q - mm
beta_v = beta * v
beta_v -= beta_v.max()

return np.sum(np.exp(beta_v) * v)

try:
beta = brentq(f, a=self._beta_min, b=self._beta_max)
assert not (np.isnan(beta) or np.isinf(beta))

return beta
except ValueError:
return 0.

[docs]    def __init__(self, omega, beta_min=-10., beta_max=10.):
"""
Constructor.

Args:
omega (Parameter): the omega parameter of the policy from which beta
of the Boltzmann policy is computed;
beta_min (float, -10.): one end of the bracketing interval for
minimization with Brent's method;
beta_max (float, 10.): the other end of the bracketing interval for
minimization with Brent's method.

"""
beta_mellow = self.MellowmaxParameter(self, omega, beta_min, beta_max)

super().__init__(beta_mellow)

[docs]    def set_beta(self, beta):
raise RuntimeError('Cannot change the beta parameter of Mellowmax policy')

[docs]    def update(self, *idx):
raise RuntimeError('Cannot update the beta parameter of Mellowmax policy')