#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Plotting functions for pyCATHY (2D and 3D)
Plotting:
- output files
- inputs files
- petro/pedophysical relationships
- Data assimilation results
"""
import glob
import os
import pyvista as pv
pv.set_plot_theme("document")
pv.set_jupyter_backend('static')
import datetime
import math
import imageio
# import panel as pn
import ipywidgets as widgets
import matplotlib as mpl
import matplotlib.dates as mdates
import matplotlib.font_manager
import matplotlib.pyplot as plt
import matplotlib.style
import natsort
import numpy as np
import pandas as pd
from matplotlib import cm
from matplotlib.colors import LogNorm
from matplotlib.ticker import FormatStrFormatter
from scipy.interpolate import griddata
# from scipy.spatial import Delaunay
from shapely.geometry import Polygon, MultiPoint, Point
from pyCATHY.cathy_utils import (
change_x2date,
convert_time_units,
label_units,
transform2_time_delta,
)
from pyCATHY.importers import cathy_inputs as in_CT
from pyCATHY.importers import cathy_outputs as out_CT
import geopandas as gpd
import rioxarray as rio
# mpl.style.use('default')
mpl.rcParams["grid.color"] = "k"
mpl.rcParams["grid.linestyle"] = ":"
mpl.rcParams["grid.linewidth"] = 0.25
# mpl.rcParams['font.family'] = 'Avenir'
plt.rcParams["font.size"] = 12
plt.rcParams["axes.linewidth"] = 0.75
# nice_fonts = {
# # "font.family": "serif",
# "font.serif": "Times New Roman",
# }
# matplotlib.rcParams.update(nice_fonts)
#%% ---------------------------------------------------------------------------
# -----------------Plot OUTPUTS CATHY -----------------------------------------
# -----------------------------------------------------------------------------
def create_gridded_mask(x,y,**kwargs):
buffer = False
if 'buffer' in kwargs:
buffer = kwargs['buffer']
# points = MultiPoint(np.column_stack((x, y)))
points = np.column_stack((x, y))
multi_point = MultiPoint(points)
boundary = Polygon(points)
# interval =100
if buffer:
interval = int(x[1]- x[0])*4
boundary = multi_point.buffer(interval/2).buffer(-interval/2)
# boundary_coords
# Extract the boundary coordinates and create a boolean mask for the region inside the boundary
boundary_coords = np.array(boundary.exterior.coords)
xmin, ymin = boundary_coords.min(axis=0)
xmax, ymax = boundary_coords.max(axis=0)
xi, yi = np.meshgrid(np.linspace(xmin, xmax, 50), np.linspace(ymin, ymax, 50))
xx, yy = xi.flatten(), yi.flatten()
mask = [Point(coord).within(boundary) for coord in zip(xx, yy)]
mask = np.array(mask).reshape(xi.shape)
return xi, yi, mask
def plot_WTD(XYZsurface,WT,**kwargs):
# FLAG, flag for what has happened
# - 1 watertable calculated correctly
# - 2 unstaturated zone above saturated zone above unsaturated zone
# - 3 watertable not encountered, fully saturated vertical profile
# - 4 watertable not encountered, unsaturated profile
if "ax" not in kwargs:
fig, ax = plt.subplots()
else:
ax = kwargs["ax"]
colorbar = True
if 'colorbar' in kwargs:
colorbar = kwargs['colorbar']
ti = 0
if 'ti' in kwargs:
ti = kwargs['ti']
scatter=False
if 'scatter' in kwargs:
scatter = kwargs['scatter']
if scatter:
valid_kwargs = ['vmin', 'vmax']
relevant_kwargs = {key: value for key, value in kwargs.items() if key in valid_kwargs}
cmap = ax.scatter(XYZsurface[:,0], XYZsurface[:,1], c=WT[ti],
**relevant_kwargs)
ax.grid()
# cbar = plt.colorbar(cmap,ax=ax)
# cbar.set_label('GWD (m)')
else:
x = XYZsurface[:,0]
y = XYZsurface[:,1]
cmap = 'viridis'
if type(ti) is list:
z = WT[ti[1]]-WT[ti[0]]
cmap = 'bwr'
else:
z = WT[ti]
xi, yi, mask = create_gridded_mask(x,y)
xx, yy = xi.flatten(), yi.flatten()
zi = griddata((x, y), z, (xx, yy), method='nearest')
# Apply the mask to the gridded data
zi_masked = np.ma.masked_where(~mask, zi.reshape(xi.shape))
# Plot the masked data
cmap = ax.contourf(xi, yi, zi_masked, cmap=cmap)
ax.set_xlabel('x (m)')
ax.set_ylabel('y (m)')
ax.axis('equal')
ax.grid(True, linestyle='-.')
if colorbar:
cbar = plt.colorbar(cmap,ax=ax)
if type(ti) is list:
cbar.set_label(r'$\Delta$ GWD (m)')
else:
cbar.set_label('GWD (m)')
return cmap
def show_spatialET(df_fort777,**kwargs):
unit = 'ms-1'
unit_str = '$ms^{-1}$'
if 'unit' in kwargs:
unit = kwargs.pop('unit')
cmap = 'Blues'
if 'cmap' in kwargs:
cmap = kwargs.pop('cmap')
crs = None
if 'crs' in kwargs:
crs = kwargs.pop('crs')
ti = 0
if 'ti' in kwargs:
ti = kwargs.pop('ti')
scatter=False
if 'scatter' in kwargs:
scatter = kwargs.pop('scatter')
mask = None
if 'mask_gpd' in kwargs:
mask = kwargs.pop('mask_gpd')
clim = None
if 'clim' in kwargs:
clim = kwargs.pop('clim')
colorbar = True
if 'colorbar' in kwargs:
colorbar = kwargs.pop('colorbar')
df_fort777_indexes = df_fort777.set_index('time_sec').index.unique().to_numpy()
df_fort777_select_t = df_fort777.set_index('time_sec').loc[df_fort777_indexes[ti]]
if "ax" not in kwargs:
fig, ax = plt.subplots()
else:
ax = kwargs.pop("ax")
# if scatter:
df_fort777_select_t = df_fort777_select_t.drop_duplicates()
if unit == 'mmday-1':
df_fort777_select_t['ACT. ETRA'] = df_fort777_select_t['ACT. ETRA']*1e3*86000
unit_str = '$mm.day^{-1}$'
if mask is not None:
polygon = mask.geometry.iloc[0] # Assuming a single polygon in the shapefile
filtered_data = []
for i in range(len(df_fort777_select_t)):
point = gpd.points_from_xy([df_fort777_select_t['X'].iloc[i]],
[df_fort777_select_t['Y'].iloc[i]])[0]
if polygon.contains(point):
filtered_data.append([df_fort777_select_t['X'].iloc[i],
df_fort777_select_t['Y'].iloc[i],
df_fort777_select_t['ACT. ETRA'].iloc[i]])
mask.crs
filtered_data = np.vstack(filtered_data)
df_fort777_select_t_xr = df_fort777_select_t.set_index(['X','Y']).to_xarray()
df_fort777_select_t_xr = df_fort777_select_t_xr.rio.set_spatial_dims('X','Y')
df_fort777_select_t_xr.rio.write_crs(mask.crs, inplace=True)
df_fort777_select_t_xr_clipped = df_fort777_select_t_xr.rio.clip(mask.geometry)
df_fort777_select_t_xr_clipped = df_fort777_select_t_xr_clipped.transpose('Y', 'X')
data_array = df_fort777_select_t_xr_clipped['ACT. ETRA'].values
df_fort777_select_t_xr_clipped.rio.bounds()
# Plot using plt.imshow
cmap = ax.imshow(data_array,
cmap=cmap,
origin='lower',
extent=[min(filtered_data[:,0]),max(filtered_data[:,0]),
min(filtered_data[:,1]),max(filtered_data[:,1])],
clim = clim,
)
else:
df_fort777_select_t_xr = df_fort777_select_t.set_index(['X','Y']).to_xarray()
df_fort777_select_t_xr = df_fort777_select_t_xr.rio.set_spatial_dims('X','Y')
if crs is not None:
df_fort777_select_t_xr.rio.write_crs(crs, inplace=True)
df_fort777_select_t_xr = df_fort777_select_t_xr.transpose('Y', 'X')
data_array = df_fort777_select_t_xr['ACT. ETRA'].values
# Plot using plt.imshow
cmap = ax.imshow(data_array,
cmap=cmap,
origin='lower',
extent=[min(df_fort777_select_t['X']),max(df_fort777_select_t['X']),
min(df_fort777_select_t['Y']),max(df_fort777_select_t['Y'])],
clim = clim,
)
title = 'ETa'
if 'title' in kwargs:
title = kwargs['title']
ax.set_title(title)
ax.set_xlabel('x (m)')
ax.set_ylabel('y (m)')
# ax.axis('equal')
ax.grid(True, linestyle='-.')
if colorbar:
cbar = plt.colorbar(cmap,ax=ax)
cbar.set_label('ETa ' + unit_str)
return cmap
[docs]
def show_wtdepth(df_wtdepth=[], workdir=[], project_name=[], **kwargs):
"""
plot NET SEEPFACE VOL and NET SEEPFACE FLX over the time t
"""
# read hgraph file if df_hgraph not existing
# ------------------------------------------------------------------------
if len(df_wtdepth) == 0:
df_wtdepth = out_CT.read_wtdepth(filename="wtdepth")
# fig, ax = plt.subplots(2,1)
# if 'delta_t' in kwargs:
# df_wtdepth['time'] = pd.to_timedelta(df_wtdepth['time'],unit='s')
if "ax" not in kwargs:
fig, ax = plt.subplots()
else:
ax = kwargs["ax"]
ax.plot(df_wtdepth["time (s)"], df_wtdepth["water table depth (m)"])
ax.set_title("Water table depth")
ax.set(xlabel="Time (s)", ylabel="Water table depth (m)")
# ax.legend(["Total flow volume", "nansfdir flow volume"])
pass
[docs]
def show_hgsfdet(df_hgsfdeth=[], workdir=[], project_name=[], **kwargs):
"""
plot NET SEEPFACE VOL and NET SEEPFACE FLX over the time t.
"""
# read hgraph file if df_hgraph not existing
# ------------------------------------------------------------------------
if len(df_hgsfdeth) == 0:
df_hgsfdeth = out_CT.read_hgsfdet(filename="hgsfdeth")
fig, ax = plt.subplots(2, 1)
if "delta_t" in kwargs:
df_hgsfdeth["time"] = pd.to_timedelta(df_hgsfdeth["time"], unit="s")
df_hgsfdeth.pivot_table(values="NET SEEPFACE VOL", index="time").plot(
ax=ax[0], ylabel="NET SEEPFACE VOL", xlabel="time (s)"
)
df_hgsfdeth.pivot_table(values="NET SEEPFACE FLX", index="time").plot(
ax=ax[1], ylabel="NET SEEPFACE FLX", xlabel="time (s)"
)
return fig, ax
[docs]
def show_dtcoupling(
df_dtcoupling=[], workdir=[], project_name=[], x="time", yprop="Atmact-vf", **kwargs
):
"""
plot dtcoupling
"""
if "yprop" in kwargs:
yprop = kwargs["yprop"]
# read hgraph file if df_hgraph not existing
# ------------------------------------------------------------------------
if "ax" not in kwargs:
fig, ax = plt.subplots()
else:
ax = kwargs["ax"]
if len(df_dtcoupling) == 0:
df_dtcoupling = out_CT.read_dtcoupling(filename="dtcoupling")
nstep = len(df_dtcoupling["Atmpot-d"])
jmax = max(df_dtcoupling["Atmact-vf"])
timeatm = [df_dtcoupling["Deltat"][0]]
for i in np.arange(1, nstep):
timeatm.append(timeatm[i - 1] + df_dtcoupling["Deltat"][i - 1])
len(timeatm)
ax.step(timeatm, df_dtcoupling[yprop],
color="green", where="post",
label=yprop,
# ax=ax
)
# plt.step(timeatm, df_dtcoupling[yprop[1]], color="blue", where="post", label=yprop[1])
ax.set_xlabel("time (s)")
ax.set_ylabel(label_units(yprop))
plt.legend()
[docs]
def show_hgraph(df_hgraph=[], workdir=[], project_name=[], x="time", y="SW", **kwargs):
"""
plot hgraph
"""
# read hgraph file if df_hgraph not existing
# ------------------------------------------------------------------------
if len(df_hgraph) == 0:
df_hgraph = out_CT.read_hgraph(filename="hgraph")
if "ax" not in kwargs:
fig, ax = plt.subplots()
else:
ax = kwargs["ax"]
if "delta_t" in kwargs:
df_hgraph["time"] = pd.to_timedelta(df_hgraph["time"], unit="s")
df_hgraph.pivot_table(values="streamflow", index="time").plot(
ax=ax, ylabel="streamflow ($m^3/s$)", xlabel="time (s)", marker="."
)
[docs]
def show_COCumflowvol(df_cumflowvol=[], workdir=None, project_name=None, **kwargs):
"""
plot COCumflowvol
"""
# read file if df_cumflowvol not existing
# ------------------------------------------------------------------------
if len(df_cumflowvol) == 0:
df_cumflowvol = out_CT.read_cumflowvol(
os.path.join(workdir, project_name, "output", "cumflowvol")
)
# plot Net Flow Volume (m^3) = f(time)
# ------------------------------------------------------------------------
if "ax" not in kwargs:
fig, ax = plt.subplots()
else:
ax = kwargs["ax"]
ax.plot(df_cumflowvol[:, 2], -df_cumflowvol[:, 7], "b-.")
ax.plot(df_cumflowvol[:, 2] / 3600, df_cumflowvol[:, 7])
ax.set_title("Cumulative flow volume")
ax.set(xlabel="Time (s)", ylabel="Net Flow Volume (m^3)")
ax.legend(["Total flow volume", "nansfdir flow volume"])
[docs]
def show_vp_DEPRECATED(
df_vp=[], workdir=[], project_name=[], index="time", x="time", y="SW", **kwargs
):
"""
plot vp
DEPRECATED use psi and sw reader instead and plot using pandas after find_nearest_node search.
"""
# read vp file if df_atmbc not existing
# ------------------------------------------------------------------------
if len(df_vp) == 0:
df_vp = out_CT.read_vp(filename="vp")
node_uni = df_vp["node"].unique()
fig, ax = plt.subplots(1, len(node_uni))
if len(node_uni) < 2:
ax = [ax]
colors = cm.Reds(np.linspace(0.2, 0.8, len(df_vp["str_nb"].unique())))
for i, n in enumerate(node_uni):
df_vp_pivot = pd.pivot_table(df_vp, values=y, index=["time", "node", "str_nb"])
df_vp_pivot_nodei = df_vp_pivot.xs(n, level="node")
df_vp_pivot_nodei.unstack("str_nb").plot(
ax=ax[i], title="node: " + str(n), color=colors
)
ax[i].legend([])
label = label_units(y)
if i == 0:
ax[i].set_ylabel(label)
return fig, ax
[docs]
def show_vtk(
filename=None,
unit="pressure",
timeStep=0,
notebook=False,
path=None,
savefig=False,
ax=None,
**kwargs,
):
"""
Plot vtk file using pyvista
Parameters
----------
filename : str, optional
Name of the file. The default is None.
unit : str, optional
['pressure', 'saturation', 'ER', 'permeability', 'velocity'] . The default is pressure.
timeStep : int, optional
Time step to plot. The default is 0.
notebook : bool, optional
Option to plot in a notebook. The default is False.
path : str, optional
Path of the vtk file. The default is None.
savefig : bool, optional
Save figure. The default is False.
**kwargs :
mainly plotting adjustements to pass into Pyvista.
"""
my_colormap = "viridis"
if path is None:
path = os.getcwd()
legend = True
if 'legend' in kwargs:
legend = kwargs.pop('legend')
# Parse physical attribute + cmap from unit
# -------------------------------------------------------------------------
if filename is None:
# for surface/subsurface hydrology
# --------------------------------------------------------------------
if unit == "pressure":
my_colormap = "autumn"
elif unit == "saturation":
my_colormap = "Blues"
if timeStep < 10:
filename = "10" + str(timeStep) + ".vtk"
elif timeStep < 100:
filename = "1" + str(timeStep) + ".vtk"
elif timeStep < 200:
newnb = [int(x) for x in str(timeStep)]
filename = "2" + str(newnb[1]) + str(newnb[2]) + ".vtk"
elif timeStep < 300:
newnb = [int(x) for x in str(timeStep)]
filename = "3" + str(newnb[1]) + str(newnb[2]) + ".vtk"
# for transport
# --------------------------------------------------------------------
elif unit == "celerity":
raise NotImplementedError("Transport vtk file output not yet implemented")
# for ERT
# --------------------------------------------------------------------
elif "ER" in unit:
filename = "ER_converted" + str(timeStep) + ".vtk"
my_colormap = "viridis"
unit = "ER_converted" + str(timeStep)
elif "ER_int" in unit:
filename = "ER_converted" + str(timeStep) + "_nearIntrp2_pg_msh.vtk"
my_colormap = "viridis"
unit = "ER_converted" + str(timeStep) + "_nearIntrp2_pg_msh"
if unit == 'PERMX':
my_colormap = None
print('No colormap')
print(my_colormap)
if 'cmap' in kwargs:
my_colormap = kwargs.pop('cmap')
show_edges = True
if "show_edges" in kwargs:
show_edges = kwargs['show_edges']
mesh = pv.read(os.path.join(path, filename))
if unit in list(mesh.array_names):
print("plot " + str(unit))
else:
print("physcial property not existing")
# notebook = activate widgets
# -------------------------------------------------------------------------
if notebook:
if ax is None:
ax = pv.Plotter(notebook=notebook)
# pn.extension('vtk') # this needs to be at the top of each cell for some reason
out = widgets.Output()
def on_value_change(change):
with out:
PhysScalars = "pressure"
time_step = timeStep
# print(change['new'])
if hasattr(change.owner, "options"): # =='Phys. prop:':
PhysScalars = change.owner.options[change["new"] - 1]
else:
time_step = change["new"]
out.clear_output()
mesh = pv.read("./my_cathy_prj/vtk/10" + str(time_step) + ".vtk")
_ = ax.add_mesh(
mesh, scalars=PhysScalars[0], cmap=my_colormap, **kwargs
)
if unit == "saturation":
ax.update_scalar_bar_range([0, 1])
if "clim" in kwargs:
ax.update_scalar_bar_range([kwargs["clim"][0], 1])
legend_entries = []
legend_entries.append(["Time=" + str(mesh["time"]), "w"])
_ = ax.add_legend(legend_entries)
ax.show_grid()
cpos = ax.show(True)
slider = widgets.IntSlider(
min=0,
max=10,
step=1,
continuous_update=True,
description="Time step #:",
)
# play = widgets.Play(min=1, interval=2000)
choice = widgets.Dropdown(
options=[("pressure", 1), ("saturation", 2)],
value=1,
description="Phys. prop:",
)
slider.observe(on_value_change, "value")
choice.observe(on_value_change, "value")
plotvtk = widgets.VBox([choice, slider, out])
plotvtk
# No notebook interaction
# -------------------------------------------------------------------------
else:
if ax is None:
ax = pv.Plotter(notebook=False)
_ = ax.add_mesh(mesh, scalars=unit, cmap=my_colormap, **kwargs)
if unit == "saturation":
ax.update_scalar_bar_range([0, 1])
if "clim" in kwargs:
ax.update_scalar_bar_range([kwargs["clim"][0], kwargs["clim"][1]])
# add time stamp as legend
if legend:
legend_entries = []
time_delta = transform2_time_delta(mesh["TIME"], "s")
legend_entries.append(["Time=" + str(time_delta[0]), "w"])
_ = ax.add_legend(legend_entries)
_ = ax.show_bounds(minor_ticks=True, font_size=1)
# add scatter points to the plot
# ---------------------------------------------------------------------
for key, value in kwargs.items():
# add electrodes positions
# -----------------------------------------------------------------
if key == "elecs":
poly_elecs = pv.PolyData(value)
poly_elecs["My Labels"] = [
f"Label {i}" for i in range(poly_elecs.n_points)
]
ax.add_point_labels(
poly_elecs, "My Labels", point_size=20, font_size=36
)
# add tensiometers positions
# -----------------------------------------------------------------
# add TDR probe positions
# -----------------------------------------------------------------
if savefig is True:
ax.view_xz()
cpos = ax.show(screenshot= os.path.join(path, filename + unit + ".png"))
print("figure saved" + os.path.join(path, filename + unit + ".png"))
ax.close()
pass
[docs]
def show_vtk_TL(
filename=None,
unit="pressure",
timeStep="all",
notebook=False,
path=None,
savefig=True,
show=True,
**kwargs,
):
"""
Time lapse animation of selected time steps
Parameters
----------
filename : str, optional
DESCRIPTION. The default is None.
unit : TYPE, optional
DESCRIPTION. The default is None.
timeStep : TYPE, optional
DESCRIPTION. The default is "all".
notebook : bool, optional
DESCRIPTION. The default is False.
path : TYPE, optional
DESCRIPTION. The default is None.
savefig : TYPE, optional
DESCRIPTION. The default is False.
show : TYPE, optional
DESCRIPTION. The default is True.
"""
x_units = None
xlabel = "s"
# for key, value in kwargs.items():
if "x_units" in kwargs:
x_units = kwargs.pop('x_units')
print(x_units)
if path is None:
path = os.getcwd()
if filename is None:
if unit == "pressure":
filename = "1*.vtk"
filename0 = "100.vtk"
my_colormap = "autumn"
elif unit == "saturation":
my_colormap = "Blues"
filename = "1*.vtk"
filename0 = "100.vtk"
elif "ER" in unit:
filename = "ER" + str(timeStep) + ".vtk"
my_colormap = "viridis"
mesh = pv.read(os.path.join(path, filename0))
if unit in list(mesh.array_names):
print("plot " + str(unit))
else:
print("physcial property not existing")
offscreen = True
if show == False:
offscreen = True
plotter = pv.Plotter(
notebook=notebook,
off_screen=offscreen,
)
# print('*'*10)
# print(unit)
plotter.add_mesh(mesh,
scalars=unit,
cmap=my_colormap,
# opacity=0.3,
**kwargs,
)
if savefig:
plotter.open_gif(os.path.join(path + unit + ".gif")
)
plotter.add_scalar_bar(title=unit)
# options to colorbar
# ---------------------------------------------------------------------
if "clim" in kwargs:
plotter.update_scalar_bar_range([kwargs["clim"][0],
kwargs["clim"][1]]
)
legend_entry = "Time= " + str(mesh["TIME"])
print(legend_entry)
if x_units is not None:
xlabel, t_lgd = convert_time_units(mesh["TIME"], x_units)
legend_entry = "Time=" + str(t_lgd) + xlabel
plotter.show_grid()
cpos = plotter.show(auto_close=False)
plotter.add_text(legend_entry, name="time-label")
files = []
for file in glob.glob(os.path.join(path, filename)):
files.append(file)
for ff in natsort.natsorted(files, reverse=False):
print(ff)
mesh = pv.read(ff)
array_new = mesh.get_array(unit)
print(mesh["TIME"])
if x_units is not None:
xlabel, t_lgd = convert_time_units(mesh["TIME"], x_units)
legend_entry = "Time=" + str(t_lgd) + xlabel
# print(array_new)
plotter.update_scalars(array_new, render=True)
plotter.add_text(legend_entry, name="time-label")
plotter.render()
# plotter.write_frame(False)
if unit == "saturation":
plotter.update_scalar_bar_range([0, 1])
if "clim" in kwargs:
plotter.update_scalar_bar_range([kwargs["clim"][0], kwargs["clim"][1]])
plotter.write_frame()
plotter.close()
if savefig:
gif_original = os.path.join(path + unit + ".gif")
gif_speed_down = os.path.join(path + unit + "_slow.gif")
gif = imageio.mimread(gif_original)
imageio.mimsave(gif_speed_down, gif, duration=1800)
print("gif saved" + os.path.join(path, gif_original))
return
#%% ---------------------------------------------------------------------------
# -----------------Plot INPUTS CATHY ------------------------------------------
# -----------------------------------------------------------------------------
[docs]
def show_atmbc(t_atmbc, v_atmbc, ax=None, **kwargs):
"""
Plot atmbc=f(time)
Parameters
----------
t_atmbc : np.array
time where atmbc change.
v_atmbc : list of 1 or 2 arrays (when available)
v_atmbc[0] is the array of Rain/Irrigation change over the time;
v_atmbc[1] is the array of ET demand over the time;
"""
# NOT YET IMPLEMETED
# read atmbc file if df_atmbc not existing
# ------------------------------------------------------------------------
# df_atmbc = []
# if len(df_atmbc)==0:
# df_atmbc = out_CT.read_atmbc(filename='vp')
xlabel = "time (s)"
for key, value in kwargs.items():
if key == "x_units":
x_units = value
if x_units == "days":
xlabel = "time (days)"
t_atmbc = [x / (24 * 60 * 60) for x in t_atmbc]
if x_units == "hours":
xlabel = "time (h)"
t_atmbc = [x / (60 * 60) for x in t_atmbc]
if "datetime" in kwargs:
t_atmbc = kwargs["datetime"]
xlabel = "date"
# https://matplotlib.org/stable/gallery/lines_bars_and_markers/stairs_demo.html#sphx-glr-gallery-lines-bars-and-markers-stairs-demo-py
v_atmbc_n = []
if isinstance(v_atmbc, list):
try:
np.shape(v_atmbc)[1]
v_atmbc_p = v_atmbc[1] # positif
v_atmbc_n = v_atmbc[0] # negatif
v_atmbc = v_atmbc[0] - v_atmbc[1]
except:
pass
# if np.shape(t_atmbc) != np.shape(v_atmbc):
if len(v_atmbc_n) > 0:
if ax is None:
fig, ax = plt.subplots(2, 1, sharex=True)
(ax1, ax2) = (ax[0], ax[1])
color = "tab:blue"
ax1.set_xlabel("time (h)")
ax1.set_ylabel("Rain/Irr", color=color)
ax1.step(t_atmbc, v_atmbc_p, color="blue", where="post", label="Rain/Irr")
color = "tab:red"
ax2.set_ylabel("ET", color=color) # we already handled the x-label with ax1
ax2.step(t_atmbc, -v_atmbc_n, color=color, where="post", label="ET")
ax2.tick_params(axis="y", labelcolor=color)
fig.tight_layout() # otherwise the right y-label is slightly clipped
# plt.show(block=False)
else:
if ax is None:
fig, ax = plt.subplots(figsize=(6, 3))
ax.plot(t_atmbc, v_atmbc, "k.")
ax.set(xlabel=xlabel, ylabel="net Q (m/s)", title="atmbc inputs")
ax.grid()
# plt.show(block=False)
if "IETO" in kwargs:
if kwargs["IETO"] != 0:
plt.step(t_atmbc, v_atmbc, color="k", where="post")
elif kwargs["IETO"] == 0: # case of linear interpolation between points
ax.plot(t_atmbc, v_atmbc, "k.")
pass
[docs]
def show_atmbc_3d(df_atmbc):
"""
Temporary (must exist throught show_vtk only)
"""
# df_atmbc.set_index('time',inplace=True)
# df_atmbc.loc[0]
# filename = './vtk/100.vtk'
# mesh = pv.read(filename)
# mesh.add_field_data(df_atmbc.loc[0]['value'], 'atmbc')
# plotter = pv.Plotter(notebook=True)
# _ = plotter.add_mesh(mesh, show_edges=True, scalars='atmbc')
# # _ = plotter.add_legend(legend_entries)
# plotter.show_grid()
# cpos = plotter.show()
show_vtk(new_field="atmbc")
pass
[docs]
def show_soil(soil_map, ax=None, **kwargs):
"""
View from top of the soil prop
Parameters
----------
soil_map : DataFrame()
Dataframe containing soil properties for the DEM
"""
yprop = "PERMX"
if "yprop" in kwargs:
yprop = kwargs.pop("yprop")
layer_nb = "1"
if "layer_nb" in kwargs:
layer_nb = kwargs.pop("layer_nb")
cmap = "tab10"
nb_of_zones = len(np.unique(soil_map))
cmap = mpl.colors.ListedColormap(plt.cm.tab10.colors[:nb_of_zones])
if "cmap" in kwargs:
cmap = kwargs.pop("cmap")
if ax is None:
fig, ax = plt.subplots()
# linewidth = 1
# if 'linewidth' in kwargs:
# linewidth = kwargs['linewidth']
cf = ax.pcolormesh(
soil_map,
edgecolors="black",
cmap=cmap,
**kwargs,
)
if "clim" in kwargs:
clim = kwargs["clim"]
else:
clim = [min(soil_map.flatten()), max(soil_map.flatten())]
cf.set_clim(clim[0], clim[1])
cax = plt.colorbar(
cf,
ticks=np.linspace(
clim[0], clim[1], nb_of_zones
),
ax=ax,
label=yprop,
)
# try:
# cax.set_ticklabels(range(int(min(soil_map.flatten())), nb_of_zones))
cax.ax.set_yticklabels(
[
"{:.1e}".format(x)
for x in np.linspace(
clim[0], clim[1], nb_of_zones
)
]
)
# ,
# fontsize=16,
# )
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_title("view from top (before extruding), layer nb" + str(layer_nb))
# plt.show(block=False)
# plt.close()
return cf
[docs]
def show_raster(
raster_map,
str_hd_raster={},
prop="",
hapin={},
ax=None,
cmap="gist_earth",
**kwargs,
):
x = np.zeros(raster_map.shape[0]) # + hapin['xllcorner']
y = np.zeros(raster_map.shape[1]) # + hapin['yllcorner']
if len(str_hd_raster) < 1:
str_hd_raster = {}
str_hd_raster["west"] = 0
str_hd_raster["south"] = 0
if len(hapin) < 1:
hapin = {}
hapin["delta_x"] = 1
hapin["delta_y"] = 1
for a in range(raster_map.shape[0]):
x[a] = float(str_hd_raster["west"]) + hapin["delta_x"] * a
for a in range(raster_map.shape[1]):
y[a] = float(str_hd_raster["south"]) + hapin["delta_y"] * a
if ax is None:
# fig = plt.figure()
# ax = plt.axes(projection="3d")
fig, ax = plt.subplots()
# X, Y = np.meshgrid(x, y)
# surf = ax.plot_surface(X, Y, raster_map.T, cmap="viridis")
pmesh = ax.pcolormesh(x, y, raster_map.T, **kwargs) # , cmap=cmap)
# cbar = plt.colorbar(pmesh, shrink=0.25, orientation='horizontal',
# label='Elevation (m)')
# ax.set(xlabel="Easting (m)", ylabel="Northing (m)", zlabel="Elevation (m)")
ax.set(xlabel="Easting (m)", ylabel="Northing (m)")
# plt.show(block=False)
# plt.close()
ax.set_title(prop)
return pmesh
[docs]
def show_zone(zone_map, **kwargs):
"""
View from top of the vegetation type (equivalent somehow to root map)
Parameters
----------
veg_map : np.array([])
Indice of vegetation. The dimension of the vegetation map must match
the dimension of the DEM.
"""
# cmap='tab10'
nb_of_zones = len(np.unique(zone_map))
cmap = mpl.colors.ListedColormap(plt.cm.tab10.colors[:nb_of_zones])
if "cmap" in kwargs:
cmap = kwargs["cmap"]
fig, ax = plt.subplots()
cf = ax.pcolormesh(zone_map, edgecolors="black", cmap=cmap)
# cbar = fig.colorbar(cf, ax=ax, label='indice of zones')
cax = plt.colorbar(
cf,
ticks=np.linspace(
int(min(zone_map.flatten())), int(max(zone_map.flatten())), nb_of_zones + 1
),
ax=ax,
label="indice of zones",
)
cax.set_ticklabels(range(int(min(zone_map.flatten())), nb_of_zones + 2))
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_title("view from top (before extruding)")
plt.show(block=False)
plt.close()
return fig, ax
[docs]
def show_indice_veg(veg_map, ax=None, **kwargs):
"""
View from top of the vegetation type (equivalent somehow to root map)
Parameters
----------
veg_map : np.array([])
Indice of vegetation. The dimension of the vegetation map must match
the dimension of the DEM.
"""
# cmap='tab10'
nb_of_zones = len(np.unique(veg_map))
cmap = mpl.colors.ListedColormap(plt.cm.tab10.colors[:nb_of_zones])
if "cmap" in kwargs:
cmap = kwargs.pop("cmap")
if ax is None:
fig, ax = plt.subplots()
cf = ax.pcolormesh(veg_map, edgecolors="black", cmap=cmap,
**kwargs)
# fig.colorbar(cf, ax=ax, label='indice of vegetation')
cax = plt.colorbar(
cf,
ticks=np.arange(
int(min(veg_map.flatten())), int(max(veg_map.flatten()))+1
),
ax=ax,
label="indice of vegetation",
)
# cax.set_ticklabels(range(int(min(veg_map.flatten())), nb_of_zones + 2))
cax.set_ticklabels(np.arange(
int(min(veg_map.flatten())),
int(max(veg_map.flatten())) +1
# nb_of_zones + 1
)
)
cax.ax.set_yticklabels(
[
"{:d}".format(int(x))
for x in np.arange(
min(veg_map.flatten())-1, max(veg_map.flatten())
)
]
)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_title("view from top (before extruding)")
# plt.show(block=False)
# plt.close()
return ax
[docs]
def dem_plot_2d_top(parameter, label="", **kwargs):
"""
View of the DEM from top of the a given parameter
(Not yet implemented)
Possible parameters are:
- vegetation
- altitude
"""
if label == "all":
fig, axs = plt.subplots(
int(len(parameter.keys()) / 2),
int(len(parameter.keys()) / 3),
sharex=True,
sharey=True,
)
for ax, p in zip(axs.reshape(-1), parameter.keys()):
cf = ax.pcolormesh(parameter[p]) # edgecolors="black"
fig.colorbar(cf, ax=ax, label=p, fraction=0.046, pad=0.04, shrink=0.8)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_aspect("auto", "box")
plt.tight_layout()
plt.show(block=False)
else:
fig, ax = plt.subplots()
cf = ax.pcolormesh(parameter, edgecolors="black")
fig.colorbar(cf, ax=ax, label=label)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_title("view from top (before extruding)")
plt.show(block=False)
return fig, ax
[docs]
def get_dem_coords(dem_mat=[], hapin={}):
# if len(dem_mat) == 0:
# # Read the Header
# dem_mat, str_hd_dem = in_CT.read_dem(
# os.path.join(workdir, project_name, "prepro/dem"),
# os.path.join(workdir, project_name, "prepro/dtm_13.val"),
# )
# transpose because values in dtm_13 files from where the DEM raster is extracted are transposed ...
# dem_mat = dem_mat.T
# dem_mat = np.flipud(dem_mat)
# print(dem_mat)
x = np.zeros(dem_mat.shape[1]) # + hapin["xllcorner"]
y = np.zeros(dem_mat.shape[0]) # + hapin["yllcorner"]
# print(x)
# print(str_hd_dem)
for a in range(0, dem_mat.shape[1]):
x[a] = float(hapin["xllcorner"]) + hapin["delta_x"] * (a + 1) - hapin["delta_x"]/2
for a in range(0, dem_mat.shape[0]):
y[a] = float(hapin["yllcorner"]) + hapin["delta_y"] * (a + 1) - hapin["delta_y"]/2
return x, np.flipud(y) # np.flipud(dem_mat)
[docs]
def show_dem(
dem_mat=[],
hapin={},
ax=None,
**kwargs,
):
"""
Creates a 3D representation from a Grass DEM file
"""
# np.shape(dem_mat)
x, y = get_dem_coords(dem_mat, hapin)
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(projection="3d")
X, Y = np.meshgrid(x, y)
dem_mat[dem_mat==-9999] = np.nan
surf = ax.plot_surface(X, Y, dem_mat, cmap="viridis", **kwargs)
# surf = ax.plot_surface(Y,X, dem_mat, cmap="viridis",**kwargs)
cbar = plt.colorbar(
surf, shrink=0.25, orientation="horizontal", label="Elevation (m)"
)
ax.set(xlabel="Easting (m)",
ylabel="Northing (m)",
zlabel="Elevation (m)")
# plt.show(block=False)
# plt.close()
ax.yaxis.set_major_formatter(FormatStrFormatter("%3.4e"))
ax.xaxis.set_major_formatter(FormatStrFormatter("%3.4e"))
pass
[docs]
def plot_mesh_bounds(BCtypName, mesh_bound_cond_df, time, ax=None):
m = np.array(["o", "+"])
mvalue = []
alpha = []
mesh_bound_cond_df_selec = mesh_bound_cond_df[mesh_bound_cond_df['time']==time]
for bound_val in mesh_bound_cond_df_selec[BCtypName]:
if bound_val == 0:
mvalue.append(1)
alpha.append(1)
else:
mvalue.append(0)
alpha.append(0.1)
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(projection="3d")
ax.scatter(
mesh_bound_cond_df_selec["x"],
mesh_bound_cond_df_selec["y"],
mesh_bound_cond_df_selec["z"],
c=mvalue,
)
ax.set_xlabel("X Label")
ax.set_ylabel("Y Label")
ax.set_zlabel("Z Label")
ax.set_title(BCtypName + " Time " + str(time))
# ax.set_title(BCtypName)
# plt.show(block=False)
# return fig, ax
pass
#%% ---------------------------------------------------------------------------
# -----------------Plot PETRO-PEDOPHYSICS relationships -----------------------
# -----------------------------------------------------------------------------
[docs]
def plot_VGP(df_VGP, savefig=False, **kwargs):
if "ax" in kwargs:
ax = kwargs["ax"]
else:
fig = plt.figure(figsize=(6, 3), dpi=350)
ax = fig.add_subplot()
label = []
if "label" in kwargs:
label = kwargs["label"]
ax.plot(df_VGP["psi"], df_VGP["theta"], label=label, marker=".")
ax.set_xlabel("Soil Water Potential (cm)")
ax.set_ylabel(r"Water content ($m^3/m^3$)")
return ax
[docs]
def DA_plot_Archie(df_Archie, savefig=False, **kwargs):
if "ax" in kwargs:
ax = kwargs["ax"]
else:
fig = plt.figure(figsize=(6, 3), dpi=350)
ax = fig.add_subplot()
ax.scatter(df_Archie["sw"], df_Archie["ER_converted"])
ax.set_xlabel("Saturation (-)")
ax.set_ylabel("ER converted")
ax.set_title("[All ensemble]")
if "porosity" in kwargs:
ax.scatter(df_Archie["sw"] * kwargs["porosity"], df_Archie["ER_converted"])
ax.set_xlabel("swc")
ax.set_ylabel("ER_converted")
if savefig == True:
plt.savefig("Archie.png", dpi=300)
return ax
#%% ----------------------------------------------------------------------------
# -----------------Plot results DATA ASSIMILATION------------------------------
# -----------------------------------------------------------------------------
[docs]
def plot_hist_perturbated_parm(parm, var_per, type_parm, parm_per_array, **kwargs):
fig = plt.figure(figsize=(6, 3), dpi=150)
if var_per[type_parm]["transf_type"] is not None:
if "log".casefold() in var_per[type_parm]["transf_type"].casefold():
w = 0.2
nbins = math.ceil((parm_per_array.max() - parm_per_array.min()) / w)
# plt.hist(np.log10(parm_sampling), ensemble_size, alpha=0.5, label='sampling')
plt.hist(
np.log10(parm_per_array), nbins, alpha=0.5, label="ini_perturbation"
)
plt.axvline(
x=np.log10(parm[type_parm + "_nominal"]), linestyle="--", color="red"
)
else:
# plt.hist(parm_sampling, ensemble_size/2, alpha=0.5, label='sampling')
plt.hist(parm_per_array, alpha=0.5, label="ini_perturbation")
plt.axvline(x=parm["nominal"], linestyle="--", color="red")
plt.legend(loc="upper right")
plt.xlabel(parm["units"])
plt.ylabel("Probability")
plt.title("Histogram of " + type_parm)
plt.tight_layout()
fig.savefig(os.path.join(os.getcwd(), kwargs["savefig"]), dpi=350)
[docs]
def show_DA_process_ens(
EnsembleX, Data, DataCov, dD, dAS, B, Analysis, savefig=False, **kwargs
):
"""Plot result of Data Assimilation Analysis"""
label_sensor = "raw data"
if "label_sensor" in kwargs:
label_sensor = kwargs["label_sensor"]
fig = plt.figure(figsize=(12, 6), dpi=300)
ax1 = fig.add_subplot(2, 5, 1)
cax = ax1.matshow(EnsembleX, aspect="auto") # ,
# cmap=cm.rainbow, norm=colors.LogNorm())
ax1.set_title("Prior")
ax1.set_ylabel("$\psi$ params #")
ax1.set_xlabel("Members #")
cbar = fig.colorbar(cax, location="bottom")
ax = fig.add_subplot(2, 5, 6)
cax = ax.matshow(np.cov(EnsembleX), aspect="auto", cmap="gray", norm=LogNorm())
ax.set_title("cov(Prior)")
ax.set_xlabel("$\psi$ params #")
ax.set_ylabel("$\psi$ params #")
# cbar = fig.colorbar(cax, location='bottom')
cbar = fig.colorbar(cax, format="$%.1f$", location="bottom")
ax.set_yticks([])
ax = fig.add_subplot(2, 5, 2)
cax = ax.matshow(np.tile(Data, (np.shape(EnsembleX)[1], 1)).T, aspect="auto")
ax.set_title(label_sensor)
ax.set_ylabel("Meas")
ax.set_xlabel("Members #")
cbar = fig.colorbar(cax, location="bottom")
ax.set_yticks([])
ax = fig.add_subplot(2, 5, 7)
cax = ax.matshow(DataCov, aspect="auto", cmap="gray_r")
# cmap=cm.rainbow, norm=colors.LogNorm())
# vmin=0, vmax=1e-29)
ax.set_title("cov(meas)")
ax.set_ylabel("Meas")
ax.set_xlabel("Meas")
cbar = fig.colorbar(cax, location="bottom")
ax.set_yticks([])
ax = fig.add_subplot(2, 5, 3)
cax = ax.matshow(dD, aspect="auto", cmap="jet")
ax.set_title("Meas - Sim")
ax.set_ylabel("Meas")
ax.set_xlabel("Members #")
cbar = fig.colorbar(cax, location="bottom")
ax.set_yticks([])
ax = fig.add_subplot(2, 5, 4, sharey=ax1)
cax = ax.matshow(np.dot(dAS, B), aspect="auto", cmap="jet")
ax.set_title("Correction")
ax.set_ylabel("$\psi$ params #")
ax.set_xlabel("Members #")
cbar = fig.colorbar(cax, location="bottom")
ax.set_yticks([])
ax = fig.add_subplot(2, 5, 5, sharey=ax1)
cax = ax.matshow(Analysis, aspect="auto")
ax.set_title("Posterior")
ax.set_ylabel("$\psi$ params #")
ax.set_xlabel("Members #")
cbar = fig.colorbar(cax, location="bottom")
ax.set_yticks([])
plt.tight_layout()
savename = "showDA_process_ens"
if "savename" in kwargs:
savename = kwargs["savename"]
if savefig == True:
fig.savefig(savename + ".png", dpi=300)
plt.close()
# fig, ax = plt.subplots()
# cax = ax.matshow(np.dot(dAS, B), aspect="auto", cmap="jet")
# ax.set_title("Correction")
# ax.set_ylabel("$\psi$ params #")
# ax.set_xlabel("Members #")
# cbar = fig.colorbar(cax, location="bottom")
# ax.set_yticks([])
# fig, ax = plt.subplots()
# cax = ax.matshow(DataCov, aspect="auto", cmap="gray_r")
# ax.set_title("cov(meas)")
# ax.set_ylabel("Meas")
# ax.set_xlabel("Meas")
# cbar = fig.colorbar(cax, location="bottom")
# ax.set_yticks([])
return fig, ax
[docs]
def DA_RMS(df_performance, sensorName, **kwargs):
"""Plot result of Data Assimilation: RMS evolution over the time"""
if "ax" in kwargs:
ax = kwargs["ax"]
# else:
# if "start_date" in kwargs:
# fig, ax = plt.subplots(3, 1, sharex=True)
# else:
# fig, ax = plt.subplots(2, 1)
keytime = "time"
xlabel = "assimilation #"
start_date = None
keytime = "time"
xlabel = "assimilation #"
if kwargs.get("start_date") is not None:
keytime = "time_date"
xlabel = "date"
atmbc_times = None
if "atmbc_times" in kwargs:
atmbc_times = kwargs["atmbc_times"]
header = ["time", "ObsType", "RMSE" + sensorName, "NMRMSE" + sensorName, "OL"]
df_perf_plot = df_performance[header]
df_perf_plot["RMSE" + sensorName] = df_perf_plot["RMSE" + sensorName].astype(
"float64"
)
df_perf_plot["NMRMSE" + sensorName] = df_perf_plot["NMRMSE" + sensorName].astype(
"float64"
)
df_perf_plot.OL = df_perf_plot.OL.astype("str")
df_perf_plot.dropna(inplace=True)
# len(dates)
# len(df_perf_plot["time"])
# len(atmbc_times)
if start_date is not None:
dates = change_x2date(atmbc_times, start_date)
df_perf_plot["time_date"] = df_perf_plot["time"].replace(
list(df_perf_plot["time"].unique()), dates)
p0 = df_perf_plot.pivot(index=keytime, columns="OL", values="RMSE" + sensorName)
p1 = df_perf_plot.pivot(index=keytime, columns="OL", values="NMRMSE" + sensorName)
if "start_date" in kwargs:
p0.plot(xlabel=xlabel, ylabel="RMSE" + sensorName, ax=ax[1], style=[".-"])
p1.plot(xlabel=xlabel, ylabel="NMRMSE" + sensorName, ax=ax[2], style=[".-"])
else:
p0.plot(xlabel=xlabel, ylabel="RMSE" + sensorName, ax=ax[0], style=[".-"])
p1.plot(xlabel=xlabel, ylabel="NMRMSE" + sensorName, ax=ax[1], style=[".-"])
return ax, plt
[docs]
def DA_plot_parm_dynamic(
parm="ks", dict_parm_pert={}, list_assimilation_times=[], savefig=False, **kwargs
):
"""Plot result of Data Assimilation: parameter estimation evolution over the time"""
ensemble_size = len(dict_parm_pert[parm]["ini_perturbation"])
# nb_times = len(df_DA['time'].unique())
# fig = plt.figure(figsize=(6, 3), dpi=150)
if 'ax' in kwargs:
ax = kwargs['ax']
else:
fig, ax = plt.subplots()
ax.hist(
dict_parm_pert[parm]["ini_perturbation"],
ensemble_size,
alpha=0.5,
label="ini_perturbation",
)
ax.legend(loc="upper right")
ax.set_ylabel("Probability")
for nt in list_assimilation_times:
try:
ax.hist(
dict_parm_pert[parm]["update_nb" + str(int(nt + 1))],
ensemble_size,
alpha=0.5,
label="update nb" + str(int(nt + 1)),
)
except:
pass
ax.legend(loc="upper right")
ax.set_ylabel("Probability")
if "log" in kwargs:
if kwargs["log"]:
ax.set_xscale("log")
# plt.show()
return ax
[docs]
def DA_plot_parm_dynamic_scatter(
parm="ks", dict_parm_pert={}, list_assimilation_times=[], savefig=False, **kwargs
):
"""Plot result of Data Assimilation: parameter estimation evolution over the time"""
if "ax" in kwargs:
ax = kwargs["ax"]
else:
fig = plt.figure(figsize=(6, 3), dpi=350)
ax = fig.add_subplot()
ensemble_size = len(dict_parm_pert[parm]["ini_perturbation"])
mean_t = [np.mean(dict_parm_pert[parm]["ini_perturbation"])]
mean_t_yaxis = np.mean(dict_parm_pert[parm]["ini_perturbation"])
cloud_t = np.zeros([ensemble_size, len(list_assimilation_times)])
cloud_t[:, 0] = np.array(dict_parm_pert[parm]["ini_perturbation"])
np.shape(cloud_t)
try:
for nt in list_assimilation_times[1:]:
cloud_t[:, int(nt + 1)] = np.hstack(
dict_parm_pert[parm]["update_nb" + str(int(nt))]
)
except:
pass
dict_parm_new = {}
name = []
i = 0
for k in dict_parm_pert[parm].keys():
if "upd" in k:
dict_parm_new[parm + "_" + k] = np.hstack(dict_parm_pert[parm][k])
name.append(str(i + 1))
i = i + 1
df = pd.DataFrame()
df = pd.DataFrame(data=dict_parm_new)
df.index.name = "Ensemble_nb"
color = 'k'
if "color" in kwargs:
color = kwargs["color"]
if len(df.columns)>15:
# nii = [int(ni) for ni in np.arange(0,len(df.columns),6)]
name = [str(ni+1) for ni in list_assimilation_times]
df = df.iloc[:,list_assimilation_times]
boxplot = df.boxplot(
color=color,
ax=ax,
grid=False,
flierprops=dict(marker='o',
color='black',
markersize=5
)
)
boxplot.set_xticklabels(name,
# rotation=90
)
ax.set_ylabel(parm)
ax.set_xlabel("assimilation #")
if "log" in kwargs:
if kwargs["log"]:
boxplot.set_yscale("log")
return ax
[docs]
def prepare_DA_plot_time_dynamic(DA, state="psi", nodes_of_interest=[], **kwargs):
"""Select data from DA_df dataframe"""
keytime = "time"
start_date = None
xlabel = "time (h)"
if kwargs.get("start_date") is not None:
start_date = kwargs["start_date"]
keytime = "time_date"
if "keytime" in kwargs:
keytime = kwargs['keytime']
atmbc_times = None
if "atmbc_times" in kwargs:
atmbc_times = kwargs["atmbc_times"]
if start_date is not None:
dates = change_x2date(atmbc_times, start_date)
if len(dates) < len(list(DA["time"].unique())):
DA = DA.drop(DA[DA.time + 1 >= len(list(DA["time"].unique()))].index)
if len(dates) > len(list(DA["time"].unique())):
# check if contains difference in time otherwise add articifially 1s to keep
# the formatting. (Else plot crashes)
if all(dates[: len(list(DA["time"].unique()))].hour == 0):
time_delta_change = datetime.timedelta(seconds=1)
dates.values[0] = dates[0] + time_delta_change
dates = dates[: len(list(DA["time"].unique()))]
unique_times = DA["time"].unique()
DA["time_date"] = DA["time"].map(dict(zip(unique_times, dates)))
isOL = DA.loc[DA["OL"] == True]
isENS = DA.loc[DA["OL"] == False]
isENS_time_Ens = isENS.set_index(["time", "Ensemble_nb"])
isOL_time_Ens = isOL.set_index(["time", "Ensemble_nb"])
if type(nodes_of_interest) != list:
nodes_of_interest = [nodes_of_interest]
NENS = int(max(DA["Ensemble_nb"].unique()))
key2plot = "psi_bef_update"
if "sw" in state:
key2plot = "sw_bef_update_"
if "key2plot" in kwargs:
key2plot = kwargs["key2plot"]
prep_DA = {
"isENS": isENS,
"isOL": isOL,
}
# -----------------------------#
if len(nodes_of_interest) > 0:
if len(isOL) > 0:
nb_of_times = len(np.unique(isOL["time"]))
nb_of_ens = len(np.unique(isOL["Ensemble_nb"]))
nb_of_nodes = len(isOL[isOL["time"] == 0]) / nb_of_ens
test = len(isOL) / nb_of_nodes
try:
int(test)
except:
ValueError("inconsistent nb of OL data - no plot")
pass
else:
isOL.insert(
2,
"idnode",
np.tile(np.arange(nb_of_nodes), nb_of_times * nb_of_ens),
True,
)
select_isOL = isOL[isOL["idnode"].isin(nodes_of_interest)]
select_isOL = select_isOL.set_index([keytime, "idnode"])
select_isOL = select_isOL.reset_index()
# mean, min and max of Open Loop
# --------------------------------------------------------------------------
ens_mean_isOL_time = (
select_isOL.set_index([keytime, "Ensemble_nb", "idnode"])
.groupby(level=[keytime, "idnode"])[key2plot]
.mean()
)
ens_mean_isOL_time = ens_mean_isOL_time.reset_index(name="mean(ENS)_OL")
ens_min_isOL_time = (
select_isOL.set_index([keytime, "Ensemble_nb", "idnode"])
.groupby(level=[keytime, "idnode"])[key2plot]
.min()
)
ens_min_isOL_time = ens_min_isOL_time.reset_index(name="min(ENS)_OL")
ens_max_isOL_time = (
select_isOL.set_index([keytime, "Ensemble_nb", "idnode"])
.groupby(level=[keytime, "idnode"])[key2plot]
.max()
)
ens_max_isOL_time = ens_max_isOL_time.reset_index(name="max(ENS)_OL")
prep_DA.update(
{
"ens_mean_isOL_time": ens_mean_isOL_time,
"ens_max_isOL_time": ens_max_isOL_time,
"ens_min_isOL_time": ens_min_isOL_time,
}
)
if len(isENS) > 0:
nb_of_times = len(np.unique(isENS["time"]))
nb_of_ens = len(np.unique(isENS["Ensemble_nb"]))
nb_of_nodes = len(isENS[isENS["time"] == isENS["time"][0]]) / nb_of_ens
if len(isOL) > 0:
isENS.insert(
2,
"idnode",
np.tile(
np.arange(int(len(isENS) / (max(isENS["time"]) * (NENS)))),
int(max(isENS["time"])) * (NENS),
),
True,
)
else:
isENS.insert(
2,"idnode",
np.tile(np.arange(nb_of_nodes), nb_of_times*nb_of_ens),
True,
)
check_notrejected = isENS['rejected']==0
isENS = isENS[check_notrejected]
select_isENS = isENS[isENS["idnode"].isin(nodes_of_interest)]
select_isENS = select_isENS.set_index([keytime, "idnode"])
select_isENS = select_isENS.reset_index()
# mean, min and max of Open Loop
# --------------------------------------------------------------------------
ens_mean_isENS_time = (
select_isENS.set_index([keytime, "Ensemble_nb", "idnode"])
.groupby(level=[keytime, "idnode"])[key2plot]
.mean()
)
ens_mean_isENS_time = ens_mean_isENS_time.reset_index(name="mean(ENS)")
ens_min_isENS_time = (
select_isENS.set_index([keytime, "Ensemble_nb", "idnode"])
.groupby(level=[keytime, "idnode"])[key2plot]
.min()
)
ens_min_isENS_time = ens_min_isENS_time.reset_index(name="min(ENS)")
ens_max_isENS_time = (
select_isENS.set_index([keytime, "Ensemble_nb", "idnode"])
.groupby(level=[keytime, "idnode"])[key2plot]
.max()
)
ens_max_isENS_time = ens_max_isENS_time.reset_index(name="max(ENS)")
prep_DA.update(
{
"ens_mean_isENS_time": ens_mean_isENS_time,
"ens_max_isENS_time": ens_max_isENS_time,
"ens_min_isENS_time": ens_min_isENS_time,
}
)
else:
# take the spatial average mean
# -----------------------------#
spatial_mean_isOL_time_ens = isOL_time_Ens.groupby(
level=[keytime, "Ensemble_nb"]
)["analysis"].mean()
spatial_mean_isOL_time_ens = spatial_mean_isOL_time_ens.reset_index()
spatial_mean_isENS_time_ens = isENS_time_Ens.groupby(
level=[keytime, "Ensemble_nb"]
)["analysis"].mean()
spatial_mean_isENS_time_ens = spatial_mean_isENS_time_ens.reset_index()
select_isOL = spatial_mean_isOL_time_ens
select_isENS = spatial_mean_isENS_time_ens
return prep_DA
[docs]
def DA_plot_time_dynamic(
DA, state="psi", nodes_of_interest=[], savefig=False, **kwargs
):
"""Plot result of Data Assimilation: state estimation evolution over the time"""
keytime = "time"
xlabel = "time (h)"
if kwargs.get("start_date") is not None:
start_date = kwargs["start_date"]
xlabel = "date"
keytime = "time_date"
prep_DA = prepare_DA_plot_time_dynamic(
DA,
state=state,
nodes_of_interest=nodes_of_interest,
**kwargs
)
if "ax" in kwargs:
ax = kwargs["ax"]
else:
fig = plt.figure(figsize=(6, 3), dpi=350)
ax = fig.add_subplot()
alpha = 0.5
colors_minmax = 'darkblue'
if "colors_minmax" in kwargs:
colors_minmax = kwargs["colors_minmax"]
if "keytime" in kwargs:
keytime = kwargs['keytime']
xlabel = "assimilation_times"
ylabel = r"pressure head $\psi$ (m)"
if "sw" in state:
ylabel = "water saturation (-)"
# --------------------------------------------------------------------------
if len(prep_DA["isENS"]) > 0:
prep_DA["ens_mean_isENS_time"].pivot(
index=keytime,
columns=["idnode"],
# columns=['idnode'],
values=["mean(ENS)"],
).plot(ax=ax, style=[".-"], color=colors_minmax,
)
prep_DA["ens_max_isENS_time"].pivot(
index=keytime,
columns=["idnode"],
# columns=['idnode'],
values=["max(ENS)"],
).plot(ax=ax, style=[".-"], color=colors_minmax,
)
prep_DA["ens_min_isENS_time"].pivot(
index=keytime,
columns=["idnode"],
# columns=['idnode'],
values=["min(ENS)"],
).plot(
ax=ax,
style=[".-"],
color=colors_minmax,
xlabel="(assimilation) time - (h)",
ylabel="pressure head $\psi$ (m)",
)
lgd = ax.fill_between(
prep_DA["ens_max_isENS_time"][keytime],
prep_DA["ens_min_isENS_time"]["min(ENS)"],
prep_DA["ens_max_isENS_time"]["max(ENS)"],
alpha=0.2,
color=colors_minmax,
label="minmax DA",
)
if "ens_mean_isOL_time" in prep_DA.keys():
prep_DA["ens_mean_isOL_time"].pivot(
index=keytime,
# columns=["Ensemble_nb",'idnode'],
columns=["idnode"],
values=["mean(ENS)_OL"],
).plot(
ax=ax,
style=["-"],
color="grey",
label=False,
ylabel="pressure head $\psi$ (m)",
) # water saturation (-)
prep_DA["ens_min_isOL_time"].pivot(
index=keytime,
columns=["idnode"],
# columns=['idnode'],
values=["min(ENS)_OL"],
).plot(ax=ax, style=["--"], color="grey", label=False)
prep_DA["ens_max_isOL_time"].pivot(
index=keytime,
columns=["idnode"],
# columns=['idnode'],
values=["max(ENS)_OL"],
).plot(ax=ax, style=["--"], color="grey", label=False)
ax.fill_between(
prep_DA["ens_mean_isOL_time"][keytime],
prep_DA["ens_min_isOL_time"]["min(ENS)_OL"],
prep_DA["ens_max_isOL_time"]["max(ENS)_OL"],
alpha=0.2,
color="grey",
label="minmax OL",
)
ax.set_xlabel(xlabel)
# ax.set_ylabel('pressure head (m)')
ax.set_ylabel(ylabel)
savename = "showDA_dynamic"
if "savename" in kwargs:
savename = kwargs["savename"]
if savefig == True:
plt.savefig(savename + ".png", dpi=300)
pass