Policy Evaluation
src.pcgym.policy_evaluation.policy_eval
Policy Evaluation Class for pc-gym.
This class provides methods for evaluating policies in a given environment, including rollouts, oracle comparisons, and data visualization.
Attributes:
Name | Type | Description |
---|---|---|
make_env |
Callable Function to create the environment. |
|
env_params |
dict Parameters for the environment. |
|
env |
Environment The environment instance. |
|
policies |
dict Dictionary of policies to evaluate. |
|
n_pi |
int Number of policies. |
|
reps |
int Number of repetitions for evaluation. |
|
oracle |
bool Whether to use oracle comparisons. |
|
cons_viol |
bool Whether to plot constraint violations. |
|
save_fig |
bool Whether to save generated figures. |
|
MPC_params |
dict or bool Parameters for MPC, if applicable. |
__init__(make_env, policies, reps, env_params, oracle=False, MPC_params=False, cons_viol=False, save_fig=False)
Initialize the policy_eval class.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
make_env
|
callable
|
Function to create the environment. |
required |
policies
|
dict
|
Dictionary of policies to evaluate. |
required |
reps
|
int
|
Number of repetitions for evaluation. |
required |
env_params
|
dict
|
Parameters for the environment. |
required |
oracle
|
bool
|
Whether to use oracle comparisons. Defaults to False. |
False
|
MPC_params
|
dict
|
Parameters for MPC, if applicable. Defaults to False. |
False
|
cons_viol
|
bool
|
Whether to plot constraint violations. Defaults to False. |
False
|
save_fig
|
bool
|
Whether to save generated figures. Defaults to False. |
False
|
rollout(policy_i)
Rollout the policy for N steps and return the total reward, states and actions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
policy_i
|
Policy to be rolled out. |
required |
Returns:
Name | Type | Description |
---|---|---|
tuple |
Containing: - total_reward (list): Total reward obtained. - s_rollout (np.ndarray): States obtained from rollout. - actions (np.ndarray): Actions obtained from rollout. - cons_info (np.ndarray): Constraint information. |
oracle_reward_fn(x, u)
Calculate the oracle reward for given states and actions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
ndarray
|
State trajectory. |
required |
u
|
ndarray
|
Action trajectory. |
required |
Returns:
Name | Type | Description |
---|---|---|
list |
list
|
Oracle rewards for each time step. |
get_rollouts()
Perform rollouts for all policies and collect data.
Returns:
Name | Type | Description |
---|---|---|
dict |
dict
|
Dictionary containing rollout data for each policy and oracle (if applicable). |
plot_data(data, reward_dist=False)
Plot the rollout data for all policies.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data
|
dict
|
Dictionary containing rollout data. |
required |
reward_dist
|
bool
|
Whether to plot reward distribution. Defaults to False. |
False
|