Skip to content

Policy Evaluation

src.pcgym.policy_evaluation.policy_eval

Policy Evaluation Class for pc-gym.

This class provides methods for evaluating policies in a given environment, including rollouts, oracle comparisons, and data visualization.

Attributes:

Name Type Description
make_env

Callable Function to create the environment.

env_params

dict Parameters for the environment.

env

Environment The environment instance.

policies

dict Dictionary of policies to evaluate.

n_pi

int Number of policies.

reps

int Number of repetitions for evaluation.

oracle

bool Whether to use oracle comparisons.

cons_viol

bool Whether to plot constraint violations.

save_fig

bool Whether to save generated figures.

MPC_params

dict or bool Parameters for MPC, if applicable.

__init__(make_env, policies, reps, env_params, oracle=False, MPC_params=False, cons_viol=False, save_fig=False)

Initialize the policy_eval class.

Parameters:

Name Type Description Default
make_env callable

Function to create the environment.

required
policies dict

Dictionary of policies to evaluate.

required
reps int

Number of repetitions for evaluation.

required
env_params dict

Parameters for the environment.

required
oracle bool

Whether to use oracle comparisons. Defaults to False.

False
MPC_params dict

Parameters for MPC, if applicable. Defaults to False.

False
cons_viol bool

Whether to plot constraint violations. Defaults to False.

False
save_fig bool

Whether to save generated figures. Defaults to False.

False

rollout(policy_i)

Rollout the policy for N steps and return the total reward, states and actions.

Parameters:

Name Type Description Default
policy_i

Policy to be rolled out.

required

Returns:

Name Type Description
tuple

Containing: - total_reward (list): Total reward obtained. - s_rollout (np.ndarray): States obtained from rollout. - actions (np.ndarray): Actions obtained from rollout. - cons_info (np.ndarray): Constraint information.

oracle_reward_fn(x, u)

Calculate the oracle reward for given states and actions.

Parameters:

Name Type Description Default
x ndarray

State trajectory.

required
u ndarray

Action trajectory.

required

Returns:

Name Type Description
list list

Oracle rewards for each time step.

get_rollouts()

Perform rollouts for all policies and collect data.

Returns:

Name Type Description
dict dict

Dictionary containing rollout data for each policy and oracle (if applicable).

plot_data(data, reward_dist=False)

Plot the rollout data for all policies.

Parameters:

Name Type Description Default
data dict

Dictionary containing rollout data.

required
reward_dist bool

Whether to plot reward distribution. Defaults to False.

False