Source code for sustaingym.algorithms.base

from __future__ import annotations

from copy import deepcopy
from collections import defaultdict
from collections.abc import Sequence
from typing import Any

import gymnasium as gym
import numpy as np
import pandas as pd
from pettingzoo import ParallelEnv
from ray.rllib.algorithms.algorithm import Algorithm
from tqdm import tqdm


[docs] class BaseAlgorithm: """Base abstract class for running an agent in an environment. Subclasses are expected to implement the `get_action()` method. Args: env: environment to run algorithm on multiagent: whether the environment is multiagent """ def __init__(self, env: gym.Env | ParallelEnv, multiagent: bool = False): self.env = env self.multiagent = multiagent
[docs] def get_action(self, observation: dict[str, Any] ) -> np.ndarray | dict[str, np.ndarray]: """Returns an action based on gym observations.""" raise NotImplementedError
[docs] def reset(self) -> None: """Resets the algorithm at the end of each episode.""" pass
[docs] def run(self, seeds: Sequence[int] | int) -> pd.DataFrame: """Runs the scheduling algorithm and returns the resulting rewards. Runs the scheduling algorithm for the date period of event generation and returns the resulting reward. Args: seeds: if a list, on each episode run, ``self.env`` is reset using the seed. If an integer, a list is created using ``range(seeds)`` and used to reset the env instead. Returns: results: DataFrame of length len(seeds) or seeds containing reward info .. code:: none column dtype seed int return float64 """ if isinstance(seeds, int): seeds = list(range(seeds)) results = defaultdict[str, list](list) for seed in tqdm(seeds): results['seed'].append(seed) ep_return = 0.0 # Reset environment obs, _ = self.env.reset(seed=seed) # Reset algorithm self.reset() # Run episode until finished done = False while not done: action = self.get_action(obs) obs, reward, terminated, truncated, info = self.env.step(action) assert (type(reward) == dict) == self.multiagent if self.multiagent: assert isinstance(reward, dict) assert isinstance(terminated, dict) assert isinstance(truncated, dict) reward = sum(reward.values()) done = any(terminated.values()) or any(truncated.values()) else: done = terminated or truncated ep_return += reward results['return'].append(ep_return) # in multiagent setting, assume that all agents get same info if self.multiagent: agent = list(info.keys())[0] info = info[agent] for key, value in info.items(): results[key].append(deepcopy(value)) return pd.DataFrame(results)
[docs] class RLLibAlgorithm(BaseAlgorithm): """Wrapper for RLLib RL agent.""" def __init__(self, env: gym.Env | ParallelEnv, algo: Algorithm, multiagent: bool = False): super().__init__(env, multiagent=multiagent) self.algo = algo
[docs] def get_action(self, observation: dict[str, Any] ) -> np.ndarray | dict[str, np.ndarray]: """Returns output of RL model.""" if self.multiagent: multiagent_config = self.algo.config['multiagent'] if len(multiagent_config['policies']) == 1: action = { agent: self.algo.compute_single_action(observation[agent], explore=False) for agent in observation } else: action = {} for agent_id, agent_obs in observation.items(): policy_id = multiagent_config['policy_mapping_fn'](agent_id) action[agent_id] = self.algo.compute_single_action( agent_obs, policy_id=policy_id, explore=False) else: action = self.algo.compute_single_action(observation, explore=False) return action
[docs] class RandomAlgorithm(BaseAlgorithm): """Random action."""
[docs] def get_action(self, observation: dict[str, Any]) -> Any: """Returns random action.""" if self.multiagent: assert isinstance(self.env, ParallelEnv) action = { agent: self.env.action_spaces[agent].sample() for agent in observation } return action else: return self.env.action_space.sample()