""" Module for interfacing with the PostGIS database
.. module:: dbio
:synopsis: Definition of the DBIO module
.. moduleauthor:: Kostas Andreadis <kandread@jpl.nasa.gov>
"""
import numpy as np
import tempfile
from osgeo import gdal, osr
import subprocess
import random
import psycopg2 as pg
import string
import rpath
import sys
import logging
[docs]def connect(dbname):
"""Connect to database *dbname*."""
log = logging.getLogger(__name__)
try:
db = pg.connect(database=dbname)
except pg.OperationalError:
db = None
try:
db = pg.connect(database=dbname, host="/tmp/")
except:
log.error("Cannot connect to database {0}. Please restart it by running \n {1}/pg_ctl -D {2}/postgres restart".format(
dbname, rpath.bins, rpath.data))
sys.exit()
return db
[docs]def columnExists(dbname, schemaname, tablename, colname):
"""Tests whether a column exists in a table."""
db = connect(dbname)
cur = db.cursor()
sql = "select column_name from information_schema.columns where table_schema='{0}' and table_name='{1}' and column_name='{2}'".format(schemaname, tablename, colname)
cur.execute(sql)
column_exists = bool(cur.rowcount)
cur.close()
db.close()
return column_exists
[docs]def tableExists(dbname, schemaname, tablename):
"""Check if table exists in the database."""
db = connect(dbname)
cur = db.cursor()
cur.execute("select * from information_schema.tables where table_schema='{0}' and table_name='{1}'".format(schemaname, tablename))
table_exists = bool(cur.rowcount)
cur.close()
db.close()
return table_exists
[docs]def schemaExists(dbname, schemaname):
"""Check if schema exists in database."""
db = connect(dbname)
cur = db.cursor()
cur.execute("select * from information_schema.schemata where schema_name='{0}'".format(schemaname))
schema_exists = bool(cur.rowcount)
cur.close()
db.close()
return schema_exists
[docs]def writeGeotif(lat, lon, res, data, filename=None):
"""Writes Geotif in temporary directory so it can be imported into the PostGIS database."""
if isinstance(data, np.ma.masked_array):
nodata = np.double(data.fill_value)
data = data.data
else:
nodata = -9999.
if len(data.shape) < 2:
nrows = int((max(lat) - min(lat)) / res) + 1
ncols = int((max(lon) - min(lon)) / res) + 1
out = np.zeros((nrows, ncols)) + nodata
for c in range(len(lat)):
i = int((max(lat) - lat[c]) / res)
j = int((lon[c] - min(lon)) / res)
out[i, j] = data[c]
else:
nrows, ncols = data.shape
out = data
if filename is None:
f = tempfile.NamedTemporaryFile(suffix=".tif", delete=False)
filename = f.name
f.close()
driver = gdal.GetDriverByName("GTiff")
ods = driver.Create(filename, ncols, nrows, 1, gdal.GDT_Float32)
ods.SetGeoTransform([min(lon) - res / 2.0, res, 0,
max(lat) + res / 2.0, 0, -res])
srs = osr.SpatialReference()
srs.SetWellKnownGeogCS("WGS84")
ods.SetProjection(srs.ExportToWkt())
ods.GetRasterBand(1).WriteArray(out)
ods.GetRasterBand(1).SetNoDataValue(nodata)
ods = None
return filename
[docs]def deleteRasters(dbname, tablename, dt, squery=""):
"""If date already exists delete associated rasters before
ingesting, and optionally constrain with subquery."""
log = logging.getLogger(__name__)
db = connect(dbname)
cur = db.cursor()
sql = "select * from {0} where fdate='{1}'".format(tablename, dt.strftime("%Y-%m-%d"))
cur.execute(sql)
if bool(cur.rowcount):
log.warning("Overwriting raster in {0} table for {1}".format(tablename, dt.strftime("%Y-%m-%d")))
cur.execute("delete from {0} where fdate='{1}' {2}".format(tablename, dt.strftime("%Y-%m-%d"), squery))
db.commit()
cur.close()
db.close()
def _getResamplingMethod(dbname, tablename, res):
"""Return a raster resampling method based on the resolution of the model and the requested datasets."""
db = connect(dbname)
cur = db.cursor()
cur.execute(
"select st_pixelheight(rast) from {0} limit 1".format(tablename))
data_res = cur.fetchone()[0]
if res == data_res:
resample_method = "near"
elif res < data_res:
resample_method = "bilinear"
else:
resample_method = "average"
cur.close()
db.close()
return resample_method
[docs]def getResampledTables(dbname, options, res):
"""Find names of resampled raster tables."""
rtables = {}
db = connect(dbname)
cur = db.cursor()
for v in ['precip', 'tmax', 'tmin', 'wind']:
tname = options['vic'][v]
cur.execute(
"select * from raster_resampled where sname='{0}' and tname like '{1}%' and resolution={2}".format(v, tname, res))
rtables[v] = cur.fetchone()[1]
cur.close()
db.close()
return rtables
def _createRasterTable(dbname, stname):
"""Create table *stname* holding rasters in database *dbname*."""
db = connect(dbname)
cur = db.cursor()
cur.execute(
"create table {0} (rid serial primary key, rast raster, fdate date not null)".format(stname))
db.commit()
cur.close()
db.close()
def _createDateIndex(dbname, schemaname, tablename):
"""Create table index based on date column."""
db = connect(dbname)
cur = db.cursor()
cur.execute("create index {1}_t on {0}.{1}(fdate)".format(schemaname, tablename))
db.commit()
cur.close()
db.close()
[docs]def createResampledCatalog(dbname):
"""Create catalog that holds information on resampled rasters."""
db = connect(dbname)
cur = db.cursor()
cur.execute("select * from information_schema.tables where table_name='raster_resampled'")
if not bool(cur.rowcount):
sql = """create or replace function resampled(_s text, _t text, out result double precision) as
$func$
begin
execute format('select st_scalex(rast) from %s.%s limit 1',quote_ident(_s),quote_ident(_t)) into result;
end
$func$ language plpgsql;"""
cur.execute(sql)
cur.execute("create or replace view raster_resampled as (select r_table_schema as sname,r_table_name as tname,resampled(r_table_schema,r_table_name) as resolution from raster_columns)")
db.commit()
cur.close()
db.close()
[docs]def resampleRaster(dbname, sname, tname, dt, res, method, tilesize, overwrite, squery=""):
"""Resample raster to target resolution."""
db = connect(dbname)
cur = db.cursor()
# check if resampled table exists
cur.execute("select * from pg_catalog.pg_class c inner join pg_catalog.pg_namespace n on c.relnamespace=n.oid where n.nspname='{0}' and c.relname='{1}_{2}'".format(sname, tname, int(1.0 / res)))
# if it exists insert data, if not create it
if bool(cur.rowcount):
# check if date already exists and delete it before ingesting
if overwrite:
deleteRasters(dbname, "{0}.{1}_{2}".format(sname, tname, int(1.0 / res)), dt, squery)
sql = "insert into {0}.{1}_{2} (with dt as (select max(fdate) as maxdate from {0}.{1}_{2}), f as (select fdate,st_tile(st_rescale(rast,{3},'{4}'),{5},{6}) as rast from {0}.{1} where fdate=date'{7}' {8}) select fdate,rast,dense_rank() over (order by st_upperleftx(rast),st_upperlefty(rast)) as rid from f)".format(sname, tname, int(1.0 / res), res, method, tilesize[0], tilesize[1], dt.strftime("%Y-%m-%d"), squery)
cur.execute(sql)
else:
sql = "create table {0}.{1}_{2} as (with f as (select fdate,st_tile(st_rescale(rast,{3},'{4}'),{5},{6}) as rast from {0}.{1} where fdate=date'{7}' {8}) select fdate,rast,dense_rank() over (order by st_upperleftx(rast),st_upperlefty(rast)) as rid from f)".format(
sname, tname, int(1.0 / res), res, method, tilesize[0], tilesize[1], dt.strftime("%Y-%m-%d"), squery)
cur.execute(sql)
cur.execute("create index {1}_{2}_t on {0}.{1}_{2}(fdate)".format(
sname, tname, int(1.0 / res)))
cur.execute("create index {1}_{2}_r on {0}.{1}_{2}(rid)".format(
sname, tname, int(1.0 / res)))
db.commit()
[docs]def createResampledTables(dbname, sname, tname, dt, tilesize, overwrite, squery=""):
"""Cache resampled tables by using materialized views."""
db = connect(dbname)
cur = db.cursor()
# create catalog that holds information on resampled rasters
createResampledCatalog(dbname)
# create or update materialized view for each resolution available to VIC
cur.execute("select distinct(resolution) from vic.soils")
if bool(cur.rowcount):
resolutions = [r[0] for r in cur.fetchall()]
for res in resolutions:
# get appropriate resampling method
method = _getResamplingMethod(dbname, "{0}.{1}".format(sname, tname), res)
resampleRaster(dbname, sname, tname, dt, res, method, tilesize, overwrite, squery)
cur.close()
db.close()
[docs]def ingest(dbname, filename, dt, stname, resample=True, overwrite=True):
"""Imports Geotif *filename* into database *db*."""
log = logging.getLogger(__name__)
tilesize = (10, 10)
db = connect(dbname)
cur = db.cursor()
# import temporary table
temptable = ''.join(random.SystemRandom().choice(
string.ascii_letters) for _ in range(8))
cmd = "{3}/raster2pgsql -d -s 4326 {0} {2} | {3}/psql -d {1}".format(filename, dbname, temptable, rpath.bins)
proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
out, err = proc.communicate()
log.debug(out)
cur.execute("alter table {0} add column fdate date".format(temptable))
cur.execute(
"update {3} set fdate = date '{0}-{1}-{2}'".format(dt.year, dt.month, dt.day, temptable))
# check if table exists
schemaname, tablename = stname.split(".")
if not schemaExists(dbname, schemaname):
cur.execute("create schema {0}".format(schemaname))
db.commit()
if not tableExists(dbname, schemaname, tablename):
_createRasterTable(dbname, stname)
_createDateIndex(dbname, schemaname, tablename)
# check if date already exists and delete it before ingesting
if overwrite:
deleteRasters(dbname, "{0}.{1}".format(schemaname, tablename), dt)
# create tiles from imported raster and insert into table
cur.execute("insert into {0}.{1} (fdate,rast) select fdate,rast from {2}".format(
schemaname, tablename, temptable))
db.commit()
# create materialized views for resampled rasters
if resample:
log.info("Creating resampled table for {0}.{1}".format(schemaname, tablename))
createResampledTables(dbname, schemaname, tablename, dt, tilesize, overwrite)
# delete temporary table
cur.execute("drop table {0}".format(temptable))
db.commit()
log.info("Imported {0} in {1}".format(dt.strftime("%Y-%m-%d"), stname))
cur.close()
db.close()