Source code for sustaingym.envs.evcharging.train_gmm_model

"""
GMM training script.

The GMMs are fitted to 4 feature dimensions. The 4 features are, in order,

- ``'arrival_time'``: minute of day, normalized to [0, 1)
- ``'departure_time'``: minute of day, normalized to [0, 1)
- ``'estimated_departure_time'``: minute of day, normalized to [0, 1)
- ``'requested_energy'``: energy requested; multiply by 100 to get kWh

Example command line usage

.. code:: bash

    python -m sustaingym.envs.evcharging.train_gmm_model --site caltech --gmm_n_components 30 --date_range 2019-05-01 2019-08-31 2019-09-01 2019-12-31 2020-02-01 2020-05-31 2021-05-01 2021-08-31
    python -m sustaingym.envs.evcharging.train_gmm_model --site jpl --gmm_n_components 30 --date_range 2019-05-01 2019-08-31 2019-09-01 2019-12-31 2020-02-01 2020-05-31 2021-05-01 2021-08-31

Usage

.. code:: none

    usage: train_gmm_model.py [-h] [--site SITE] [--gmm_n_components GMM_N_COMPONENTS]
                            [--date_ranges DATE_RANGES [DATE_RANGES ...]]

    optional arguments:
    -h, --help            show this help message and exit
    --site SITE           Name of site: 'caltech' or 'jpl'
    --gmm_n_components GMM_N_COMPONENTS
    --date_ranges DATE_RANGES [DATE_RANGES ...]
                          Date ranges for GMM models to be trained on. Number
                          of dates must be divisible by 2, with the second
                          later than the first. Dates should be formatted as
                          YYYY-MM-DD. Supported ranges in between 2018-11-01
                          and 2021-08-31.
"""
from __future__ import annotations

import argparse
from collections.abc import Sequence
from datetime import datetime

import numpy as np
import pandas as pd
from sklearn.mixture import GaussianMixture

from .utils import (AM_LA, DEFAULT_DATE_RANGES, DATE_FORMAT, MINS_IN_DAY,
                    REQ_ENERGY_SCALE, START_DATE, END_DATE, get_real_events,
                    save_gmm_model, site_str_to_site, SiteStr)


[docs] def preprocess(df: pd.DataFrame, filter: bool = True) -> pd.DataFrame: """Preprocessing script for real event sessions before GMM modeling. Filters EVs with departures / estimated departures on a different date than arrival date. The arrival, departure, and estimated departure are normalized between 0 and 1 for the time during the day, and the requested energy is normalized by a scaling factor. Args: df: DataFrame of charging events, expected to be gotten from `utils.get_real_events()` filter: option to filter cars staying overnight Returns: df: filtered copy of DataFrame with normalized parameters. """ if filter: # Filter cars staying overnight max_depart = np.maximum(df['departure'], df['estimated_departure']) mask = (df['arrival'].dt.day == max_depart.dt.day) df = df[mask].copy() # Normalize arrival time, departure time, estimated departure time for col in ['arrival', 'departure', 'estimated_departure']: df[col + '_time'] = (df[col].dt.hour * 60 + df[col].dt.minute) / MINS_IN_DAY # Normalize requested energy df['requested_energy (kWh)'] /= REQ_ENERGY_SCALE return df[[ 'arrival_time', 'departure_time', 'estimated_departure_time', 'requested_energy (kWh)' ]].copy()
[docs] def station_id_cnts(df: pd.DataFrame, n2i: dict[str, int]) -> np.ndarray: """Returns the usage counts for a network's charging station ids. Args: df: DataFrame of session observations. n2i: dict mapping charging station id to position in numpy array. Returns: cnts: number of sessions associated with each station id. """ vc = df['station_id'].value_counts() cnts = [0 for _ in range(len(n2i))] for station_id in vc.index: if station_id not in n2i: continue cnts[n2i[station_id]] = vc[station_id] if sum(cnts) == 0: raise ValueError('No station ids in DataFrame found in site. ') cnts = np.array(cnts, dtype=np.int32) return cnts
[docs] def parse_string_date_list(date_range: Sequence[str] ) -> Sequence[tuple[datetime, datetime]]: """Converts a sequence of string date ranges to datetimes. Args: date_range: an even-length sequence of string dates in the format 'YYYY-MM-DD'. Each consecutive pair describes a date range, and should fall inside the range 2018-11-01 and 2021-08-31. Returns: A sequence of 2-tuples containing a begin and end datetime. Raises: ValueError: length of date_range is odd ValueError: begin date of pair is not before end date of pair ValueError: begin and end date not in data's range """ if len(date_range) % 2 != 0: raise ValueError( 'Number of dates must be divisible by 2, found length ' f'{len(date_range)} with the second later than the first.') date_range_dt = [ datetime.strptime(date_str, DATE_FORMAT).replace(tzinfo=AM_LA) for date_str in date_range] date_ranges = [] for i in range(len(date_range) // 2): begin, end = date_range_dt[2 * i], date_range_dt[2 * i + 1] if begin < START_DATE: raise ValueError( f'beginning of date range {date_range[2 * i]} before data ' f'start date {START_DATE.strftime(DATE_FORMAT)}') if end > END_DATE: raise ValueError( f'end of date range {date_range[2 * i + 1]} after data ' f'end date {END_DATE.strftime(DATE_FORMAT)}') if begin > end: raise ValueError( f'beginning of date range {date_range[2 * i]} later than ' f'end {date_range[2 * i + 1]}') date_ranges.append((begin, end)) return date_ranges
[docs] def create_gmm(site: SiteStr, n_components: int, date_range: tuple[datetime, datetime]) -> None: """Creates a custom GMM and saves in the ``gmms`` folder. Args: site: either 'caltech' or 'jpl' n_components: number of components of Gaussian mixture model date_range: a range of dates that falls inside 2018-11-01 and 2021-08-31. """ # Get stations cn = site_str_to_site(site) n2i = {station_id: i for i, station_id in enumerate(cn.station_ids)} # Retrieve events and filter only claimed sessions df = get_real_events(date_range[0], date_range[1], site=site) df = df[df['claimed']] if len(df) == 0: print('Empty dataframe, abort GMM training. ') return # Get counts and station ids data num_days_total = (date_range[1] - date_range[0]).days + 1 cnt = df.arrival.dt.date.value_counts().to_numpy() num_unseen_days = num_days_total - len(cnt) # account for days when there are no EVs cnt = np.concatenate((cnt, np.zeros(num_unseen_days))) sid = station_id_cnts(df, n2i) # Preprocess DataFrame for GMM training df = preprocess(df) # Train and save print(f'Fitting GMM ({n_components} components, {len(df.columns)} ' f'dimensions) on data from {site} site from ' f'{date_range[0].strftime(DATE_FORMAT)} to ' f'{date_range[1].strftime(DATE_FORMAT)}... ') gmm = GaussianMixture(n_components=n_components) gmm.fit(df) save_gmm_model(site, gmm, cnt, sid, *date_range, n_components)
[docs] def create_gmms(site: SiteStr, n_components: int, date_ranges: Sequence[tuple[str, str]] = DEFAULT_DATE_RANGES ) -> None: """Creates multiple gmms and saves them in ``gmms`` folder. Args: site: either 'caltech' or 'jpl' n_components: number of components of Gaussian mixture model date_range: a sequence of 2-tuples of string dates in the format 'YYYY-MM-DD'. Each tuple describes a date range, and should fall inside the range 2018-11-01 and 2021-08-31. """ print('\n--- Training GMMs ---\n') for date_range in date_ranges: date_range_dt = parse_string_date_list(date_range)[0] create_gmm(site, n_components, date_range=date_range_dt) print('--- Training complete. ---\n')
if __name__ == '__main__':
[docs] parser = argparse.ArgumentParser( description='GMM Training Script', formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument( '--site', default='caltech', help='Name of site: "caltech" or "jpl"') parser.add_argument( '--gmm_n_components', type=int, default=30) parser.add_argument( '--date_ranges', nargs='+', help='Date ranges for GMM models to be trained on. Number of dates ' 'must be a multiple of 2, with the second later than the first. ' 'Dates should be formatted as YYYY-MM-DD. Supported ranges ' f'should be between {START_DATE.strftime(DATE_FORMAT)} and ' f'{END_DATE.strftime(DATE_FORMAT)}.') args = parser.parse_args() if args.date_ranges is None: create_gmms(args.site, args.gmm_n_components) else: if len(args.date_ranges) % 2 != 0: raise ValueError('Number of dates given must be a multiple of 2.') date_ranges = [ (args.date_ranges[i], args.date_ranges[i+1]) for i in range(0, len(args.date_ranges), 2) ] create_gmms(args.site, args.gmm_n_components, date_ranges)