Source code for ensemble

""" Class definition for the ensemble interface

.. module:: ensemble
   :synopsis: Definition of the ensemble class

.. moduleauthor:: Kostas Andreadis <kandread@jpl.nasa.gov>

"""

import vic
from vic import state
import tempfile
import sys
import random
from datetime import date, timedelta
from multiprocessing import Process
import numpy as np
import shutil
import os
from dateutil.relativedelta import relativedelta
import rpath
import dbio
import logging


[docs]class Ensemble: def __init__(self, nens, dbname, resolution, startyear, startmonth, startday, endyear, endmonth, endday, name=""): """Create an ensemble of models with size *nens*.""" self.nens = nens self.models = [] self.name = name self.statefiles = [] self.res = resolution self.startyear, self.startmonth, self.startday = startyear, startmonth, startday self.endyear, self.endmonth, self.endday = endyear, endmonth, endday self.dbname = dbname for e in range(nens): modelpath = tempfile.mkdtemp(dir=".") model = vic.VIC(modelpath, dbname, resolution, startyear, startmonth, startday, endyear, endmonth, endday, name=name) self.models.append(model) def _ensembleTable(self, write, e): def write_wrapper(data, dates, tablename, initialize, skipsave): return write(data, dates, tablename, initialize, e, skipsave=skipsave) return write_wrapper
[docs] def setStateFiles(self, statefiles): """Set initial state files for each ensemble member.""" for e in range(len(statefiles)): try: shutil.copy(statefiles[e], self.models[e].model_path) except: pass filename = statefiles[e].split("/")[-1] statefiles[ e] = "{0}/{1}".format(self.models[e].model_path, filename) self.statefiles = statefiles
[docs] def readStateFiles(self): """Read initial state files for each ensemble member.""" cells = [] _, vegfile, snowbandfile = self.models[0].paramFromDB() veg = state.readVegetation("{0}/{1}".format(rpath.data, vegfile)) bands, _ = state.readSnowbands( "{0}/{1}".format(rpath.data, snowbandfile)) for filename in self.statefiles: c, _, _, _ = state.readStateFile(filename) cells.append(c) return cells, veg, bands
[docs] def updateStateFiles(self, data, alat, alon, agid): """Update initial state files with *data*.""" _, vegparam, snowbands = self.models[0].paramFromDB() veg = state.readVegetation("{0}/{1}".format(rpath.data, vegparam)) bands, _ = state.readSnowbands("{0}/{1}".format(rpath.data, snowbands)) for e, statefile in enumerate(self.statefiles): states, nlayer, nnodes, dateline = state.readStateFile(statefile) for var in data: x = state.readVariable(self.models[e], states, alat[var], alon[ var], veg, bands, nlayer, var) states = state.updateVariable(self.models[e], states, x, data[var][:, e], alat[ var], alon[var], agid, veg, bands, nlayer, var) state.writeStateFile(statefile, states, "{0}\n{1} {2}".format( dateline.strip(), nlayer, nnodes))
[docs] def setDates(self, startyear, startmonth, startday, endyear, endmonth, endday): """Set simulation dates for entire ensemble.""" self.startyear, self.startmonth, self.startday = startyear, startmonth, startday self.endyear, self.endmonth, self.endday = endyear, endmonth, endday for m in self.models: m.startyear = startyear m.startmonth = startmonth m.startday = startday m.endyear = endyear m.endmonth = endmonth m.endday = endday
def __getitem__(self, m): """Return a model instance.""" return self.models[m] def __len__(self): """Return ensemble size.""" return len(self.models) def __iter__(self): """Return an iterator to model ensemble members.""" return iter(self.models)
[docs] def writeParamFiles(self, savestate=""): """Write model parameter file for each ensemble member.""" for e, model in enumerate(self.models): if len(self.statefiles) > 0: model.writeParamFile(state_file=self.statefiles[ e], save_state=savestate) else: model.writeParamFile(save_state=savestate)
[docs] def writeSoilFiles(self, shapefile): """Write soil parameter files based on domain shapefile.""" self.models[0].writeSoilFile(shapefile) for model in self.models[1:]: shutil.copy( "{0}/soil.txt".format(self.models[0].model_path), "{0}/".format(model.model_path)) model.lat = self.models[0].lat model.lon = self.models[0].lon model.gid = self.models[0].gid model.lgid = self.models[0].lgid model.depths = self.models[0].depths model.elev = self.models[0].elev
[docs] def writeForcings(self, method, options): """Write forcings for the ensemble based on method (ESP, BCSD).""" log = logging.getLogger(__name__) if method.lower() == "esp": self._ESP(options) elif method.lower() == "bcsd": pass elif method.lower() == "iri": self.__fromDataset("iri", options) elif method.lower() == "nmme": self.__fromDataset("nmme", options) else: log.error("No appropriate method for generating meteorological forecast ensemble, exiting!") sys.exit()
def __fromDataset(self, dataset, options): """Generate and write forcings by using a dataset-specific function.""" dsmod = __import__("datasets." + dataset, fromlist=[dataset]) dsmod.generate(options, self)
[docs] def perturb(self, prec, tmax, tmin, wind, nens=None, perr=0.25, terr=2.0): """Perturb meteorological forcings.""" if nens is None: nens = self.nens ensprec = [] enstmax = [] enstmin = [] enswind = [] for e in range(nens): p = [] tx = [] tn = [] w = [] for i in range(len(prec)): p.append(list(prec[i])) if prec[i][2] > 0.0: p[-1][2] = np.log(np.random.lognormal(prec[i] [2], abs(perr * prec[i][2]))) tx.append(list(tmax[i])) tn.append(list(tmin[i])) tavgp = 0.5 * (tmax[i][2] + tmin[i][2]) + \ np.random.normal(0., terr) tx[-1][2] = (tavgp - 0.5 * tmin[i][2]) / 0.5 tn[-1][2] = (tavgp - 0.5 * tmax[i][2]) / 0.5 w.append(list(wind[i])) ensprec.append(p) enstmax.append(tx) enstmin.append(tn) enswind.append(w) return ensprec, enstmax, enstmin, enswind
def _ESP(self, options): """Generate meteorological forcings using the Ensemble Streamflow Prediction method.""" ndays = (date(self.endyear, self.endmonth, self.endday) - date(self.startyear, self.startmonth, self.startday)).days db = dbio.connect(self.models[0].dbname) cur = db.cursor() if self.startmonth < self.endmonth: sql = "select distinct (date_part('year', fdate)) as year from precip.{0} where date_part('month', fdate) >= {1} and date_part('month', fdate) <= {2}".format(options['vic']['precip'], self.startmonth, self.endmonth) else: sql = "select distinct (date_part('year', fdate)) as year from precip.{0} where date_part('month', fdate) >= {1} or date_part('month', fdate) <= {2}".format(options['vic']['precip'], self.startmonth, self.endmonth) cur.execute(sql) years = map(lambda y: int(y[0]), cur.fetchall()) random.shuffle(years) while len(years) < self.nens: years += years for e in range(self.nens): model = self.models[e] model.startyear = years[e] t = date(model.startyear, model.startmonth, model.startday) + timedelta(ndays) model.endyear, model.endmonth, model.endday = t.year, t.month, t.day prec, tmax, tmin, wind = model.getForcings(options['vic']) model.writeForcings(prec, tmax, tmin, wind) cur.close() db.close()
[docs] def run(self, vicexe): """Run ensemble of VIC models using multi-threading.""" procs = [Process(target=self.models[e].run, args=(vicexe,)) for e in range(self.nens)] for p in procs: p.start() for p in procs: p.join()
def _initializeDeterm(self, basin, forcings, vicexe): """Initialize ensemble of VIC models deterministically.""" db = dbio.connect(self.dbname) cur = db.cursor() dt = "{0}-{1}-{2}".format(self.startyear, self.startmonth, self.startday) cur.execute( "select * from information_schema.tables where table_schema='{0}' and table_name='state'".format(self.name)) if bool(cur.rowcount): sql = "select filename, fdate from {0}.state order by abs(date '{1}' - fdate)".format( self.name, dt) cur.execute(sql) if bool(cur.rowcount): statefile, t = cur.fetchone() else: statefile = "" else: statefile = "" if statefile == "": t = date(self.startyear - 1, self.startmonth, self.startday) # checks if statefile corresponds to requested forecast start date if (t - date(self.startyear, self.startmonth, self.startday)).days < 0: if (t - date(self.startyear - 1, self.startmonth, self.startday)).days < 0: # if statefile is older than a year, start the model # uninitialized for 1 year t = date(self.startyear - 1, self.startmonth, self.startday) modelpath = tempfile.mkdtemp(dir=".") model = vic.VIC(modelpath, self.dbname, self.res, t.year, t.month, t.day, self.startyear, self.startmonth, self.startday, self.name) model.writeParamFile(save_state=modelpath, init_state=bool(statefile)) model.writeSoilFile(basin) prec, tmax, tmin, wind = model.getForcings(forcings) model.writeForcings(prec, tmax, tmin, wind) model.run(vicexe) statefile = model.model_path + \ "/vic.state_{0:04d}{1:02d}{2:02d}".format( self.startyear, self.startmonth, self.startday) for emodel in self.models: shutil.copy(statefile, emodel.model_path) shutil.rmtree(model.model_path) statefiles = [statefile] * self.nens self.setStateFiles(statefiles) cur.close() db.close() return statefiles def _initializeRandom(self, basin, forcings, vicexe, initdays=90, saveindb=False, saveto="db", saveargs=[], overwrite=True, skipsave=0): """Initialize ensemble of VIC models by sampling the meterological forcings and running them *initmonths* prior to simulation start date.""" db = dbio.connect(self.dbname) cur = db.cursor() sql = "select distinct (date_part('year', fdate)) as year from precip.{0}".format( forcings['precip']) cur.execute(sql) years = map(lambda y: int(y[0]), cur.fetchall()) years.remove(max(years)) cur.close() db.close() statefiles = [] t = date(self.startyear, self.startmonth, self.startday) - \ relativedelta(days=initdays) ndays = (date(self.startyear, self.startmonth, self.startday) - t).days modelpath = tempfile.mkdtemp() # (dir=".") model = vic.VIC(modelpath, self.dbname, self.res, t.year, t.month, t.day, self.startyear, self.startmonth, self.startday, self.name) years = np.random.choice(years, self.nens) pmodels = [] for e in range(self.nens): modelpath = tempfile.mkdtemp() # (dir=".") model = vic.VIC(modelpath, self.dbname, self.res, t.year, t.month, t.day, self.startyear, self.startmonth, self.startday, self.name) model.writeParamFile(save_state=modelpath, init_state=False) model.writeSoilFile(basin) model.startyear = years[e] model.endyear = years[e] + (model.endyear - t.year) ddays = (date(model.endyear, model.endmonth, model.endday) - date(model.startyear, model.startmonth, model.startday)).days - ndays t1 = date(model.endyear, model.endmonth, model.endday) - \ relativedelta(days=ddays) model.endyear, model.endmonth, model.endday = t1.year, t1.month, t1.day prec, tmax, tmin, wind = model.getForcings(forcings) model.writeForcings(prec, tmax, tmin, wind) model.startyear, model.startmonth, model.startday = t.year, t.month, t.day model.endyear, model.endmonth, model.endday = self.startyear, self.startmonth, self.startday pmodels.append(model) procs = [Process(target=pmodels[e].run, args=(vicexe,)) for e in range(self.nens)] for p in procs: p.start() for p in procs: p.join() if saveindb: if skipsave < 0: skipdays = (date(self.startyear, self.startmonth, self.startday) - t).days + 1 + skipsave else: skipdays = skipsave init = overwrite for e in range(len(pmodels)): model = pmodels[e] if getattr(model.writeToDB, 'func_name') is not 'write_wrapper': model.writeToDB = self._ensembleTable( model.writeToDB, e + 1) if e > 0: init = False model.save(saveto, saveargs, init, skipsave=skipdays) for e in range(self.nens): statefile = pmodels[e].model_path + "/vic.state_{0:04d}{1:02d}{2:02d}".format( self.startyear, self.startmonth, self.startday) statefiles.append(statefile) return statefiles def _initializePerturb(self, basin, forcings, vicexe, initdays=90, saveindb=False, saveto="db", saveargs=[], overwrite=True, skipsave=0): """Initialize ensemble of VIC models by perturbing the meterological forcings and running them *initmonths* prior to simulation start date.""" statefiles = [] t = date(self.startyear, self.startmonth, self.startday) - \ relativedelta(days=initdays) modelpath = tempfile.mkdtemp() model = vic.VIC(modelpath, self.dbname, self.res, t.year, t.month, t.day, self.startyear, self.startmonth, self.startday, self.name) prec, tmax, tmin, wind = model.getForcings(forcings) eprec, etmax, etmin, ewind = self.perturb(prec, tmax, tmin, wind) pmodels = [] for e in range(self.nens): modelpath = tempfile.mkdtemp() model = vic.VIC(modelpath, self.dbname, self.res, t.year, t.month, t.day, self.startyear, self.startmonth, self.startday, self.name) model.writeParamFile(save_state=modelpath, init_state=False) model.writeSoilFile(basin) model.writeForcings(eprec[e], etmax[e], etmin[e], ewind[e]) pmodels.append(model) procs = [Process(target=pmodels[e].run, args=(vicexe,)) for e in range(self.nens)] for p in procs: p.start() for p in procs: p.join() if saveindb: if skipsave < 0: skipdays = (date(self.startyear, self.startmonth, self.startday) - t).days + 1 + skipsave else: skipdays = skipsave init = overwrite for e in range(len(pmodels)): model = pmodels[e] if getattr(model.writeToDB, 'func_name') is not 'write_wrapper': model.writeToDB = self._ensembleTable( model.writeToDB, e + 1) if e > 0: init = False model.save(saveto, saveargs, init, skipsave=skipdays) for e in range(self.nens): statefile = pmodels[e].model_path + "/vic.state_{0:04d}{1:02d}{2:02d}".format( self.startyear, self.startmonth, self.startday) statefiles.append(statefile) return statefiles
[docs] def initialize(self, options, basin, method, vicexe, saveindb=False, saveto="db", saveargs=[], overwrite=True, skipsave=0, initdays=90): """Initialize ensemble of VIC models using one of three methods: 1) deterministic (default): each ensemble member has an identical state 2) random: each ensemble member gets a random day from climatology 3) perturb: perturb precipitation and temperature""" log = logging.getLogger(__name__) forcings = {'temperature': options['vic'][ 'temperature'], 'wind': options['vic']['wind']} if 'lai' in options['vic']: forcings['lai'] = options['vic']['lai'] forcings['precip'] = options['vic']['precip'].split(",")[0] self.writeParamFiles() # write soil file for each ensemble member and populate # latitude/longitude arrays self.writeSoilFiles(basin) if method.find("determ") == 0: statefiles = self._initializeDeterm(basin, forcings, vicexe) elif method.find("states") == 0: db = dbio.connect(self.dbname) cur = db.cursor() cur.execute( "select * from information_schema.tables where table_name='state' and table_schema=%s", (self.name,)) if bool(cur.rowcount): cur.execute("select filename from {0}.state where date_part('month', fdate) = {1}".format( self.name, self.startmonth)) statefiles = map(lambda q: q[0], cur.fetchall()) statefiles = list(np.random.choice(statefiles, self.nens)) else: log.warning("No statefiles found in the database. Not initializing ensemble!") statefiles = [] cur.close() db.close() elif method.find("random") == 0: statefiles = self._initializeRandom( basin, forcings, vicexe, initdays=initdays, saveindb=saveindb, saveto=saveto, saveargs=saveargs, skipsave=skipsave, overwrite=overwrite) elif method.find("perturb") == 0: statefiles = self._initializePerturb( basin, forcings, vicexe, initdays=initdays, saveindb=saveindb, saveto=saveto, saveargs=saveargs, skipsave=skipsave, overwrite=overwrite) else: log.error("No appropriate method to initialize the ensemble found!") sys.exit() self.setStateFiles(statefiles)
[docs] def save(self, saveto, args, initialize=True): """Reads and saves selected output data variables from the ensemble into the database or a user-defined directory.""" def ensembleTable(write, e): def write_wrapper(data, dates, tablename, initialize, skipsave): return write(data, dates, tablename, initialize, e) return write_wrapper for e in range(self.nens): model = self.models[e] if getattr(model.writeToDB, 'func_name') is not 'write_wrapper': # decorate function to add ensemble information model.writeToDB = ensembleTable(model.writeToDB, e + 1) if saveto == "db": if e > 0: initialize = False model.save(saveto, args, initialize) else: if e < 1: if os.path.isdir(saveto): shutil.rmtree(saveto) elif os.path.isfile(saveto): os.remove(saveto) os.makedirs(saveto) model.save(saveto + "/{0}".format(e + 1), args, False)