#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Plotting functions for pyCATHY (2D and 3D)
Plotting:
- output files
- inputs files
- petro/pedophysical relationships
- Data assimilation results
"""
import glob
import os
import pyvista as pv
#pv.set_plot_theme("document")
#pv.set_jupyter_backend('static')
import datetime
import math
import imageio
# import panel as pn
import ipywidgets as widgets
import matplotlib as mpl
import matplotlib.dates as mdates
import matplotlib.font_manager
import matplotlib.pyplot as plt
import matplotlib.style
import natsort
import numpy as np
import pandas as pd
from matplotlib import cm
from matplotlib.colors import LogNorm
from matplotlib.ticker import FormatStrFormatter
from scipy.interpolate import griddata
# from scipy.spatial import Delaunay
from shapely.geometry import Polygon, MultiPoint, Point
from pyCATHY.cathy_utils import (
change_x2date,
convert_time_units,
label_units,
transform2_time_delta,
)
from pyCATHY.importers import cathy_inputs as in_CT
from pyCATHY.importers import cathy_outputs as out_CT
import geopandas as gpd
import rioxarray as rio
import scipy.stats as stats
import seaborn
# mpl.style.use('default')
mpl.rcParams["grid.color"] = "k"
mpl.rcParams["grid.linestyle"] = ":"
mpl.rcParams["grid.linewidth"] = 0.25
# mpl.rcParams['font.family'] = 'Avenir'
plt.rcParams["font.size"] = 12
plt.rcParams["axes.linewidth"] = 0.75
# nice_fonts = {
# # "font.family": "serif",
# "font.serif": "Times New Roman",
# }
# matplotlib.rcParams.update(nice_fonts)
#%% ---------------------------------------------------------------------------
# -----------------Plot OUTPUTS CATHY -----------------------------------------
# -----------------------------------------------------------------------------
def create_gridded_mask(x,y,**kwargs):
buffer = False
if 'buffer' in kwargs:
buffer = kwargs['buffer']
# points = MultiPoint(np.column_stack((x, y)))
points = np.column_stack((x, y))
multi_point = MultiPoint(points)
boundary = Polygon(points)
# interval =100
if buffer:
interval = int(x[1]- x[0])*4
boundary = multi_point.buffer(interval/2).buffer(-interval/2)
# boundary_coords
# Extract the boundary coordinates and create a boolean mask for the region inside the boundary
boundary_coords = np.array(boundary.exterior.coords)
xmin, ymin = boundary_coords.min(axis=0)
xmax, ymax = boundary_coords.max(axis=0)
xi, yi = np.meshgrid(np.linspace(xmin, xmax, 50), np.linspace(ymin, ymax, 50))
xx, yy = xi.flatten(), yi.flatten()
mask = [Point(coord).within(boundary) for coord in zip(xx, yy)]
mask = np.array(mask).reshape(xi.shape)
return xi, yi, mask
def plot_WTD(XYZsurface,WT,**kwargs):
# FLAG, flag for what has happened
# - 1 watertable calculated correctly
# - 2 unstaturated zone above saturated zone above unsaturated zone
# - 3 watertable not encountered, fully saturated vertical profile
# - 4 watertable not encountered, unsaturated profile
if "ax" not in kwargs:
fig, ax = plt.subplots()
else:
ax = kwargs["ax"]
colorbar = True
if 'colorbar' in kwargs:
colorbar = kwargs['colorbar']
ti = 0
if 'ti' in kwargs:
ti = kwargs['ti']
scatter=False
if 'scatter' in kwargs:
scatter = kwargs['scatter']
if scatter:
valid_kwargs = ['vmin', 'vmax']
relevant_kwargs = {key: value for key, value in kwargs.items() if key in valid_kwargs}
cmap = ax.scatter(XYZsurface[:,0], XYZsurface[:,1], c=WT[ti],
**relevant_kwargs)
ax.grid()
# cbar = plt.colorbar(cmap,ax=ax)
# cbar.set_label('GWD (m)')
else:
x = XYZsurface[:,0]
y = XYZsurface[:,1]
cmap = 'viridis'
if type(ti) is list:
z = WT[ti[1]]-WT[ti[0]]
cmap = 'bwr'
else:
z = WT[ti]
xi, yi, mask = create_gridded_mask(x,y)
xx, yy = xi.flatten(), yi.flatten()
zi = griddata((x, y), z, (xx, yy), method='nearest')
# Apply the mask to the gridded data
zi_masked = np.ma.masked_where(~mask, zi.reshape(xi.shape))
# Plot the masked data
cmap = ax.contourf(xi, yi, zi_masked, cmap=cmap)
ax.set_xlabel('x (m)')
ax.set_ylabel('y (m)')
ax.axis('equal')
ax.grid(True, linestyle='-.')
if colorbar:
cbar = plt.colorbar(cmap,ax=ax)
if type(ti) is list:
cbar.set_label(r'$\Delta$ GWD (m)')
else:
cbar.set_label('GWD (m)')
return cmap
def show_spatialET(df_fort777,**kwargs):
unit = 'ms-1'
unit_str = '$ms^{-1}$'
if 'unit' in kwargs:
unit = kwargs.pop('unit')
cmap = 'Blues'
if 'cmap' in kwargs:
cmap = kwargs.pop('cmap')
crs = None
if 'crs' in kwargs:
crs = kwargs.pop('crs')
ti = 0
if 'ti' in kwargs:
ti = kwargs.pop('ti')
scatter=False
if 'scatter' in kwargs:
scatter = kwargs.pop('scatter')
mask = None
if 'mask_gpd' in kwargs:
mask = kwargs.pop('mask_gpd')
clim = None
if 'clim' in kwargs:
clim = kwargs.pop('clim')
colorbar = True
if 'colorbar' in kwargs:
colorbar = kwargs.pop('colorbar')
df_fort777_indexes = df_fort777.set_index('time_sec').index.unique().to_numpy()
df_fort777_select_t = df_fort777.set_index('time_sec').loc[df_fort777_indexes[ti]]
if "ax" not in kwargs:
fig, ax = plt.subplots()
else:
ax = kwargs.pop("ax")
# if scatter:
df_fort777_select_t = df_fort777_select_t.drop_duplicates()
if unit == 'mmday-1':
df_fort777_select_t['ACT. ETRA'] = df_fort777_select_t['ACT. ETRA']*1e3*86000
unit_str = '$mm.day^{-1}$'
if mask is not None:
polygon = mask.geometry.iloc[0] # Assuming a single polygon in the shapefile
filtered_data = []
for i in range(len(df_fort777_select_t)):
point = gpd.points_from_xy([df_fort777_select_t['x'].iloc[i]],
[df_fort777_select_t['y'].iloc[i]])[0]
if polygon.contains(point):
filtered_data.append([df_fort777_select_t['x'].iloc[i],
df_fort777_select_t['y'].iloc[i],
df_fort777_select_t['ACT. ETRA'].iloc[i]])
mask.crs
filtered_data = np.vstack(filtered_data)
df_fort777_select_t_xr = df_fort777_select_t.set_index(['x','y']).to_xarray()
df_fort777_select_t_xr = df_fort777_select_t_xr.rio.set_spatial_dims('x','y')
df_fort777_select_t_xr.rio.write_crs(mask.crs, inplace=True)
df_fort777_select_t_xr_clipped = df_fort777_select_t_xr.rio.clip(mask.geometry)
df_fort777_select_t_xr_clipped = df_fort777_select_t_xr_clipped.transpose('y', 'x')
data_array = df_fort777_select_t_xr_clipped['ACT. ETRA'].values
df_fort777_select_t_xr_clipped.rio.bounds()
# Plot using plt.imshow
cmap = ax.imshow(data_array,
cmap=cmap,
origin='lower',
extent=[min(filtered_data[:,0]),max(filtered_data[:,0]),
min(filtered_data[:,1]),max(filtered_data[:,1])],
clim = clim,
)
else:
df_fort777_select_t_xr = df_fort777_select_t.set_index(['x','y']).to_xarray()
df_fort777_select_t_xr = df_fort777_select_t_xr.rio.set_spatial_dims('x','y')
if crs is not None:
df_fort777_select_t_xr.rio.write_crs(crs, inplace=True)
df_fort777_select_t_xr = df_fort777_select_t_xr.transpose('y', 'x')
data_array = df_fort777_select_t_xr['ACT. ETRA'].values
# Plot using plt.imshow
cmap = ax.imshow(data_array,
cmap=cmap,
origin='lower',
extent=[min(df_fort777_select_t['x']),max(df_fort777_select_t['x']),
min(df_fort777_select_t['y']),max(df_fort777_select_t['y'])],
clim = clim,
)
title = 'ETa'
if 'title' in kwargs:
title = kwargs['title']
ax.set_title(title)
ax.set_xlabel('x (m)')
ax.set_ylabel('y (m)')
# ax.axis('equal')
ax.grid(True, linestyle='-.')
if colorbar:
cbar = plt.colorbar(cmap,ax=ax)
cbar.set_label('ETa ' + unit_str)
return cmap
[docs]
def show_wtdepth(df_wtdepth=[], workdir=[], project_name=[], **kwargs):
"""
plot NET SEEPFACE VOL and NET SEEPFACE FLX over the time t
"""
# read hgraph file if df_hgraph not existing
# ------------------------------------------------------------------------
if len(df_wtdepth) == 0:
df_wtdepth = out_CT.read_wtdepth(filename="wtdepth")
# fig, ax = plt.subplots(2,1)
# if 'delta_t' in kwargs:
# df_wtdepth['time'] = pd.to_timedelta(df_wtdepth['time'],unit='s')
if "ax" not in kwargs:
fig, ax = plt.subplots()
else:
ax = kwargs["ax"]
ax.plot(df_wtdepth["time (s)"], df_wtdepth["water table depth (m)"])
ax.set_title("Water table depth")
ax.set(xlabel="Time (s)", ylabel="Water table depth (m)")
# ax.legend(["Total flow volume", "nansfdir flow volume"])
pass
[docs]
def show_hgsfdet(df_hgsfdeth=[], workdir=[], project_name=[], **kwargs):
"""
plot NET SEEPFACE VOL and NET SEEPFACE FLX over the time t.
"""
# read hgraph file if df_hgraph not existing
# ------------------------------------------------------------------------
if len(df_hgsfdeth) == 0:
df_hgsfdeth = out_CT.read_hgsfdet(filename="hgsfdeth")
fig, ax = plt.subplots(2, 1)
if "delta_t" in kwargs:
df_hgsfdeth["time"] = pd.to_timedelta(df_hgsfdeth["time"], unit="s")
df_hgsfdeth.pivot_table(values="NET SEEPFACE VOL", index="time").plot(
ax=ax[0], ylabel="NET SEEPFACE VOL", xlabel="time (s)"
)
df_hgsfdeth.pivot_table(values="NET SEEPFACE FLX", index="time").plot(
ax=ax[1], ylabel="NET SEEPFACE FLX", xlabel="time (s)"
)
return fig, ax
[docs]
def show_dtcoupling(
df_dtcoupling=[], workdir=[], project_name=[], x="time", yprop="Atmact-vf", **kwargs
):
"""
plot dtcoupling
"""
if "yprop" in kwargs:
yprop = kwargs["yprop"]
# read hgraph file if df_hgraph not existing
# ------------------------------------------------------------------------
if "ax" not in kwargs:
fig, ax = plt.subplots()
else:
ax = kwargs["ax"]
if len(df_dtcoupling) == 0:
df_dtcoupling = out_CT.read_dtcoupling(filename="dtcoupling")
nstep = len(df_dtcoupling["Atmpot-d"])
jmax = max(df_dtcoupling["Atmact-vf"])
timeatm = [df_dtcoupling["Deltat"][0]]
for i in np.arange(1, nstep):
timeatm.append(timeatm[i - 1] + df_dtcoupling["Deltat"][i - 1])
len(timeatm)
ax.step(timeatm, df_dtcoupling[yprop],
color="green", where="post",
label=yprop,
# ax=ax
)
# plt.step(timeatm, df_dtcoupling[yprop[1]], color="blue", where="post", label=yprop[1])
ax.set_xlabel("time (s)")
ax.set_ylabel(label_units(yprop))
plt.legend()
[docs]
def show_hgraph(df_hgraph=[], workdir=[], project_name=[], x="time", y="SW", **kwargs):
"""
plot hgraph
"""
# read hgraph file if df_hgraph not existing
# ------------------------------------------------------------------------
if len(df_hgraph) == 0:
df_hgraph = out_CT.read_hgraph(filename="hgraph")
if "ax" not in kwargs:
fig, ax = plt.subplots()
else:
ax = kwargs["ax"]
if "delta_t" in kwargs:
df_hgraph["time"] = pd.to_timedelta(df_hgraph["time"], unit="s")
df_hgraph.pivot_table(values="streamflow", index="time").plot(
ax=ax, ylabel="streamflow ($m^3/s$)", xlabel="time (s)", marker="."
)
[docs]
def show_COCumflowvol(df_cumflowvol=[], workdir=None, project_name=None, **kwargs):
"""
plot COCumflowvol
"""
# read file if df_cumflowvol not existing
# ------------------------------------------------------------------------
if len(df_cumflowvol) == 0:
df_cumflowvol = out_CT.read_cumflowvol(
os.path.join(workdir, project_name, "output", "cumflowvol")
)
# plot Net Flow Volume (m^3) = f(time)
# ------------------------------------------------------------------------
if "ax" not in kwargs:
fig, ax = plt.subplots()
else:
ax = kwargs.pop("ax")
# color = b
# if color in kwargs:
# color =
ax.plot(df_cumflowvol[:, 2], -df_cumflowvol[:, 7],
marker='.', **kwargs)
ax.plot(df_cumflowvol[:, 2] / 3600, df_cumflowvol[:, 7])
ax.set_title("Cumulative flow volume")
ax.set(xlabel="Time (s)", ylabel="Net Flow Volume (m^3)")
ax.legend(["Total flow volume", "nansfdir flow volume"])
[docs]
def show_vp_DEPRECATED(
df_vp=[], workdir=[], project_name=[], index="time", x="time", y="SW", **kwargs
):
"""
plot vp
DEPRECATED use psi and sw reader instead and plot using pandas after find_nearest_node search.
"""
# read vp file if df_atmbc not existing
# ------------------------------------------------------------------------
if len(df_vp) == 0:
df_vp = out_CT.read_vp(filename="vp")
node_uni = df_vp["node"].unique()
fig, ax = plt.subplots(1, len(node_uni))
if len(node_uni) < 2:
ax = [ax]
colors = cm.Reds(np.linspace(0.2, 0.8, len(df_vp["str_nb"].unique())))
for i, n in enumerate(node_uni):
df_vp_pivot = pd.pivot_table(df_vp, values=y, index=["time", "node", "str_nb"])
df_vp_pivot_nodei = df_vp_pivot.xs(n, level="node")
df_vp_pivot_nodei.unstack("str_nb").plot(
ax=ax[i], title="node: " + str(n), color=colors
)
ax[i].legend([])
label = label_units(y)
if i == 0:
ax[i].set_ylabel(label)
return fig, ax
[docs]
def show_vtk(
filename=None,
unit="pressure",
timeStep=0,
notebook=False,
path=None,
savefig=False,
ax=None,
**kwargs,
):
"""
Plot vtk file using pyvista
Parameters
----------
filename : str, optional
Name of the file. The default is None.
unit : str, optional
['pressure', 'saturation', 'ER', 'permeability', 'velocity'] . The default is pressure.
timeStep : int, optional
Time step to plot. The default is 0.
notebook : bool, optional
Option to plot in a notebook. The default is False.
path : str, optional
Path of the vtk file. The default is None.
savefig : bool, optional
Save figure. The default is False.
**kwargs :
mainly plotting adjustements to pass into Pyvista.
"""
my_colormap = "viridis"
if path is None:
path = os.getcwd()
legend = True
if 'legend' in kwargs:
legend = kwargs.pop('legend')
# Parse physical attribute + cmap from unit
# -------------------------------------------------------------------------
if filename is None:
# for surface/subsurface hydrology
# --------------------------------------------------------------------
if unit == "pressure":
my_colormap = "autumn"
elif unit == "saturation":
my_colormap = "Blues"
if timeStep < 10:
filename = "10" + str(timeStep) + ".vtk"
elif timeStep < 100:
filename = "1" + str(timeStep) + ".vtk"
elif timeStep < 200:
newnb = [int(x) for x in str(timeStep)]
filename = "2" + str(newnb[1]) + str(newnb[2]) + ".vtk"
elif timeStep < 300:
newnb = [int(x) for x in str(timeStep)]
filename = "3" + str(newnb[1]) + str(newnb[2]) + ".vtk"
# for transport
# --------------------------------------------------------------------
elif unit == "celerity":
raise NotImplementedError("Transport vtk file output not yet implemented")
# for ERT
# --------------------------------------------------------------------
elif "ER" in unit:
filename = "ER_converted" + str(timeStep) + ".vtk"
my_colormap = "viridis"
unit = "ER_converted" + str(timeStep)
elif "ER_int" in unit:
filename = "ER_converted" + str(timeStep) + "_nearIntrp2_pg_msh.vtk"
my_colormap = "viridis"
unit = "ER_converted" + str(timeStep) + "_nearIntrp2_pg_msh"
if unit == 'PERMX':
my_colormap = None
print('No colormap')
print(my_colormap)
if 'cmap' in kwargs:
my_colormap = kwargs.pop('cmap')
show_edges = True
if "show_edges" in kwargs:
show_edges = kwargs['show_edges']
mesh = pv.read(os.path.join(path, filename))
if unit in list(mesh.array_names):
print("plot " + str(unit))
else:
print("physcial property not existing")
# notebook = activate widgets
# -------------------------------------------------------------------------
if notebook:
if ax is None:
ax = pv.Plotter(notebook=notebook)
# pn.extension('vtk') # this needs to be at the top of each cell for some reason
out = widgets.Output()
def on_value_change(change):
with out:
PhysScalars = "pressure"
time_step = timeStep
# print(change['new'])
if hasattr(change.owner, "options"): # =='Phys. prop:':
PhysScalars = change.owner.options[change["new"] - 1]
else:
time_step = change["new"]
out.clear_output()
mesh = pv.read("./my_cathy_prj/vtk/10" + str(time_step) + ".vtk")
_ = ax.add_mesh(
mesh, scalars=PhysScalars[0], cmap=my_colormap, **kwargs
)
if unit == "saturation":
ax.update_scalar_bar_range([0, 1])
if "clim" in kwargs:
ax.update_scalar_bar_range([kwargs["clim"][0], 1])
legend_entries = []
legend_entries.append(["Time=" + str(mesh["time"]), "w"])
_ = ax.add_legend(legend_entries)
ax.show_grid()
cpos = ax.show(True)
slider = widgets.IntSlider(
min=0,
max=10,
step=1,
continuous_update=True,
description="Time step #:",
)
# play = widgets.Play(min=1, interval=2000)
choice = widgets.Dropdown(
options=[("pressure", 1), ("saturation", 2)],
value=1,
description="Phys. prop:",
)
slider.observe(on_value_change, "value")
choice.observe(on_value_change, "value")
plotvtk = widgets.VBox([choice, slider, out])
plotvtk
# No notebook interaction
# -------------------------------------------------------------------------
else:
if ax is None:
ax = pv.Plotter(notebook=False)
_ = ax.add_mesh(mesh, scalars=unit, cmap=my_colormap, **kwargs)
if unit == "saturation":
ax.update_scalar_bar_range([0, 1])
if "clim" in kwargs:
ax.update_scalar_bar_range([kwargs["clim"][0], kwargs["clim"][1]])
# add time stamp as legend
if legend:
legend_entries = []
time_delta = transform2_time_delta(mesh["TIME"], "s")
legend_entries.append(["Time=" + str(time_delta[0]), "w"])
_ = ax.add_legend(legend_entries)
_ = ax.show_bounds(minor_ticks=True, font_size=1)
# add scatter points to the plot
# ---------------------------------------------------------------------
for key, value in kwargs.items():
# add electrodes positions
# -----------------------------------------------------------------
if key == "elecs":
poly_elecs = pv.PolyData(value)
poly_elecs["My Labels"] = [
f"Label {i}" for i in range(poly_elecs.n_points)
]
ax.add_point_labels(
poly_elecs, "My Labels", point_size=20, font_size=36
)
# add tensiometers positions
# -----------------------------------------------------------------
# add TDR probe positions
# -----------------------------------------------------------------
if savefig is True:
ax.view_xz()
cpos = ax.show(screenshot= os.path.join(path, filename + unit + ".png"))
print("figure saved" + os.path.join(path, filename + unit + ".png"))
ax.close()
pass
[docs]
def show_vtk_TL(
filename=None,
unit="pressure",
timeStep="all",
notebook=False,
path=None,
savefig=True,
show=True,
**kwargs,
):
"""
Time lapse animation of selected time steps
Parameters
----------
filename : str, optional
DESCRIPTION. The default is None.
unit : TYPE, optional
DESCRIPTION. The default is None.
timeStep : TYPE, optional
DESCRIPTION. The default is "all".
notebook : bool, optional
DESCRIPTION. The default is False.
path : TYPE, optional
DESCRIPTION. The default is None.
savefig : TYPE, optional
DESCRIPTION. The default is False.
show : TYPE, optional
DESCRIPTION. The default is True.
"""
x_units = None
xlabel = "s"
# for key, value in kwargs.items():
if "x_units" in kwargs:
x_units = kwargs.pop('x_units')
print(x_units)
if path is None:
path = os.getcwd()
if filename is None:
if unit == "pressure":
filename = "1*.vtk"
filename0 = "100.vtk"
my_colormap = "autumn"
elif unit == "saturation":
my_colormap = "Blues"
filename = "1*.vtk"
filename0 = "100.vtk"
elif "ER" in unit:
filename = "ER" + str(timeStep) + ".vtk"
my_colormap = "viridis"
mesh = pv.read(os.path.join(path, filename0))
if unit in list(mesh.array_names):
print("plot " + str(unit))
else:
print("physcial property not existing")
offscreen = True
if show == False:
offscreen = True
if 'pl' in kwargs:
plotter= kwargs.pop('pl')
plotter = pv.Plotter(
notebook=notebook,
off_screen=offscreen,
)
# print('*'*10)
# print(unit)
plotter.add_mesh(mesh,
scalars=unit,
cmap=my_colormap,
# opacity=0.3,
**kwargs,
)
if savefig:
plotter.open_gif(os.path.join(path + unit + ".gif")
)
plotter.add_scalar_bar(title=unit)
# options to colorbar
# ---------------------------------------------------------------------
if "clim" in kwargs:
plotter.update_scalar_bar_range([kwargs["clim"][0],
kwargs["clim"][1]]
)
legend_entry = "Time= " + str(mesh["TIME"])
print(legend_entry)
if x_units is not None:
xlabel, t_lgd = convert_time_units(mesh["TIME"], x_units)
legend_entry = "Time=" + str(t_lgd) + xlabel
plotter.show_grid()
cpos = plotter.show(auto_close=False)
plotter.add_text(legend_entry, name="time-label")
files = []
for file in glob.glob(os.path.join(path, filename)):
files.append(file)
for ff in natsort.natsorted(files, reverse=False):
print(ff)
mesh = pv.read(ff)
array_new = mesh.get_array(unit)
print(mesh["TIME"])
if x_units is not None:
xlabel, t_lgd = convert_time_units(mesh["TIME"], x_units)
legend_entry = "Time=" + str(t_lgd) + xlabel
# print(array_new)
plotter.update_scalars(array_new, render=True)
plotter.add_text(legend_entry, name="time-label")
plotter.render()
# plotter.write_frame(False)
if unit == "saturation":
plotter.update_scalar_bar_range([0, 1])
if "clim" in kwargs:
plotter.update_scalar_bar_range([kwargs["clim"][0], kwargs["clim"][1]])
plotter.write_frame()
plotter.close()
if savefig:
gif_original = os.path.join(path + unit + ".gif")
gif_speed_down = os.path.join(path + unit + "_slow.gif")
gif = imageio.mimread(gif_original)
imageio.mimsave(gif_speed_down, gif, duration=1800)
print("gif saved" + os.path.join(path, gif_original))
return
#%% ---------------------------------------------------------------------------
# -----------------Plot INPUTS CATHY ------------------------------------------
# -----------------------------------------------------------------------------
[docs]
def show_atmbc(t_atmbc, v_atmbc, ax=None, **kwargs):
"""
Plot atmbc=f(time)
Parameters
----------
t_atmbc : np.array
time where atmbc change.
v_atmbc : list of 1 or 2 arrays (when available)
v_atmbc[0] is the array of Rain/Irrigation change over the time;
v_atmbc[1] is the array of ET demand over the time;
"""
# NOT YET IMPLEMETED
# read atmbc file if df_atmbc not existing
# ------------------------------------------------------------------------
# df_atmbc = []
# if len(df_atmbc)==0:
# df_atmbc = out_CT.read_atmbc(filename='vp')
xlabel = "time (s)"
for key, value in kwargs.items():
if key == "x_units":
x_units = value
if x_units == "days":
xlabel = "time (days)"
t_atmbc = [x / (24 * 60 * 60) for x in t_atmbc]
if x_units == "hours":
xlabel = "time (h)"
t_atmbc = [x / (60 * 60) for x in t_atmbc]
if "datetime" in kwargs:
t_atmbc = kwargs["datetime"]
xlabel = "date"
# https://matplotlib.org/stable/gallery/lines_bars_and_markers/stairs_demo.html#sphx-glr-gallery-lines-bars-and-markers-stairs-demo-py
v_atmbc_n = []
if isinstance(v_atmbc, list):
try:
np.shape(v_atmbc)[1]
v_atmbc_p = v_atmbc[1] # positif
v_atmbc_n = v_atmbc[0] # negatif
v_atmbc = v_atmbc[0] - v_atmbc[1]
except:
pass
# if np.shape(t_atmbc) != np.shape(v_atmbc):
if len(v_atmbc_n) > 0:
if ax is None:
fig, ax = plt.subplots(2, 1, sharex=True)
(ax1, ax2) = (ax[0], ax[1])
color = "tab:blue"
ax1.set_xlabel("time (h)")
ax1.set_ylabel("Rain/Irr", color=color)
ax1.step(t_atmbc, v_atmbc_p, color="blue", where="post", label="Rain/Irr")
color = "tab:red"
ax2.set_ylabel("ET", color=color) # we already handled the x-label with ax1
ax2.step(t_atmbc, -v_atmbc_n, color=color, where="post", label="ET")
ax2.tick_params(axis="y", labelcolor=color)
fig.tight_layout() # otherwise the right y-label is slightly clipped
# plt.show(block=False)
else:
if ax is None:
fig, ax = plt.subplots(figsize=(6, 3))
ax.plot(t_atmbc, v_atmbc, "k.")
ax.set(xlabel=xlabel, ylabel="net Q (m/s)", title="atmbc inputs")
ax.grid()
# plt.show(block=False)
if "IETO" in kwargs:
if kwargs["IETO"] != 0:
plt.step(t_atmbc, v_atmbc, color="k", where="post")
elif kwargs["IETO"] == 0: # case of linear interpolation between points
ax.plot(t_atmbc, v_atmbc, "k.")
pass
[docs]
def show_atmbc_3d(df_atmbc):
"""
Temporary (must exist throught show_vtk only)
"""
# df_atmbc.set_index('time',inplace=True)
# df_atmbc.loc[0]
# filename = './vtk/100.vtk'
# mesh = pv.read(filename)
# mesh.add_field_data(df_atmbc.loc[0]['value'], 'atmbc')
# plotter = pv.Plotter(notebook=True)
# _ = plotter.add_mesh(mesh, show_edges=True, scalars='atmbc')
# # _ = plotter.add_legend(legend_entries)
# plotter.show_grid()
# cpos = plotter.show()
show_vtk(new_field="atmbc")
pass
[docs]
def show_soil(soil_map, ax=None, **kwargs):
"""
View from top of the soil prop
Parameters
----------
soil_map : DataFrame()
Dataframe containing soil properties for the DEM
"""
yprop = "PERMX"
if "yprop" in kwargs:
yprop = kwargs.pop("yprop")
layer_nb = "1"
if "layer_nb" in kwargs:
layer_nb = kwargs.pop("layer_nb")
cmap = "tab10"
nb_of_zones = len(np.unique(soil_map))
cmap = mpl.colors.ListedColormap(plt.cm.tab10.colors[:nb_of_zones])
if "cmap" in kwargs:
cmap = kwargs.pop("cmap")
if ax is None:
fig, ax = plt.subplots()
# linewidth = 1
# if 'linewidth' in kwargs:
# linewidth = kwargs['linewidth']
cf = ax.pcolormesh(
soil_map,
edgecolors="black",
cmap=cmap,
**kwargs,
)
if "clim" in kwargs:
clim = kwargs["clim"]
else:
clim = [min(soil_map.flatten()), max(soil_map.flatten())]
cf.set_clim(clim[0], clim[1])
cax = plt.colorbar(
cf,
ticks=np.linspace(
clim[0], clim[1], nb_of_zones
),
ax=ax,
label=yprop,
)
# try:
# cax.set_ticklabels(range(int(min(soil_map.flatten())), nb_of_zones))
cax.ax.set_yticklabels(
[
"{:.1e}".format(x)
for x in np.linspace(
clim[0], clim[1], nb_of_zones
)
]
)
# ,
# fontsize=16,
# )
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_title("view from top (before extruding), layer nb" + str(layer_nb))
# plt.show(block=False)
# plt.close()
return cf
[docs]
def show_raster(
raster_map,
str_hd_raster={},
prop="",
hapin={},
ax=None,
cmap="gist_earth",
**kwargs,
):
x = np.zeros(raster_map.shape[0]) # + hapin['xllcorner']
y = np.zeros(raster_map.shape[1]) # + hapin['yllcorner']
if len(str_hd_raster) < 1:
str_hd_raster = {}
str_hd_raster["west"] = 0
str_hd_raster["south"] = 0
if len(hapin) < 1:
hapin = {}
hapin["delta_x"] = 1
hapin["delta_y"] = 1
for a in range(raster_map.shape[0]):
x[a] = float(str_hd_raster["west"]) + hapin["delta_x"] * a
for a in range(raster_map.shape[1]):
y[a] = float(str_hd_raster["south"]) + hapin["delta_y"] * a
if ax is None:
# fig = plt.figure()
# ax = plt.axes(projection="3d")
fig, ax = plt.subplots()
# X, Y = np.meshgrid(x, y)
# surf = ax.plot_surface(X, Y, raster_map.T, cmap="viridis")
pmesh = ax.pcolormesh(x, y, raster_map.T, **kwargs) # , cmap=cmap)
# cbar = plt.colorbar(pmesh, shrink=0.25, orientation='horizontal',
# label='Elevation (m)')
# ax.set(xlabel="Easting (m)", ylabel="Northing (m)", zlabel="Elevation (m)")
ax.set(xlabel="Easting (m)", ylabel="Northing (m)")
# plt.show(block=False)
# plt.close()
ax.set_title(prop)
return pmesh
[docs]
def show_zone(zone_map, **kwargs):
"""
View from top of the vegetation type (equivalent somehow to root map)
Parameters
----------
veg_map : np.array([])
Indice of vegetation. The dimension of the vegetation map must match
the dimension of the DEM.
"""
# cmap='tab10'
nb_of_zones = len(np.unique(zone_map))
cmap = mpl.colors.ListedColormap(plt.cm.tab10.colors[:nb_of_zones])
if "cmap" in kwargs:
cmap = kwargs["cmap"]
fig, ax = plt.subplots()
cf = ax.pcolormesh(zone_map, edgecolors="black", cmap=cmap)
# cbar = fig.colorbar(cf, ax=ax, label='indice of zones')
cax = plt.colorbar(
cf,
ticks=np.linspace(
int(min(zone_map.flatten())), int(max(zone_map.flatten())), nb_of_zones + 1
),
ax=ax,
label="indice of zones",
)
cax.set_ticklabels(range(int(min(zone_map.flatten())), nb_of_zones + 2))
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_title("view from top (before extruding)")
plt.show(block=False)
plt.close()
return fig, ax
[docs]
def show_indice_veg(veg_map, ax=None, **kwargs):
"""
View from top of the vegetation type (equivalent somehow to root map)
Parameters
----------
veg_map : np.array([])
Indice of vegetation. The dimension of the vegetation map must match
the dimension of the DEM.
"""
# cmap='tab10'
nb_of_zones = len(np.unique(veg_map))
cmap = mpl.colors.ListedColormap(plt.cm.tab10.colors[:nb_of_zones])
if "cmap" in kwargs:
cmap = kwargs.pop("cmap")
if ax is None:
fig, ax = plt.subplots()
# cf = ax.pcolormesh(
# veg_map,
# edgecolors="black",
# cmap=cmap,
# **kwargs
# )
cf = ax.imshow(veg_map,
# edgecolors="black",
cmap=cmap,
**kwargs
)
# fig.colorbar(cf, ax=ax, label='indice of vegetation')
cax = plt.colorbar(
cf,
ticks=np.arange(
int(min(veg_map.flatten())), int(max(veg_map.flatten()))+1
),
ax=ax,
label="indice of vegetation",
)
# cax.set_ticklabels(range(int(min(veg_map.flatten())), nb_of_zones + 2))
cax.set_ticklabels(np.arange(
int(min(veg_map.flatten())),
int(max(veg_map.flatten())) +1
# nb_of_zones + 1
)
)
cax.ax.set_yticklabels(
[
"{:d}".format(int(x))
for x in np.arange(
min(veg_map.flatten())-1, max(veg_map.flatten())
)
]
)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_title("view from top (before extruding)")
# plt.show(block=False)
# plt.close()
return ax
[docs]
def dem_plot_2d_top(parameter, label="", **kwargs):
"""
View of the DEM from top of the a given parameter
(Not yet implemented)
Possible parameters are:
- vegetation
- altitude
"""
if label == "all":
fig, axs = plt.subplots(
int(len(parameter.keys()) / 2),
int(len(parameter.keys()) / 3),
sharex=True,
sharey=True,
)
for ax, p in zip(axs.reshape(-1), parameter.keys()):
# print(p)
cf = ax.imshow(parameter[p]) # edgecolors="black"
fig.colorbar(cf, ax=ax, label=p, fraction=0.046, pad=0.04, shrink=0.8)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_aspect("auto", "box")
plt.tight_layout()
plt.show(block=False)
else:
fig, ax = plt.subplots()
cf = ax.imshow(parameter, edgecolors="black")
# cf = ax.pcolormesh(parameter, edgecolors="black")
fig.colorbar(cf, ax=ax, label=label)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_title("view from top (before extruding)")
plt.show(block=False)
return fig, ax
[docs]
def get_dem_coords(dem_mat=[], hapin={}):
x = np.zeros(dem_mat.shape[1]) # + hapin["xllcorner"]
y = np.zeros(dem_mat.shape[0]) # + hapin["yllcorner"]
for a in range(0, dem_mat.shape[1]):
x[a] = float(hapin["xllcorner"]) + hapin["delta_x"] * (a + 1) - hapin["delta_x"]/2
for a in range(0, dem_mat.shape[0]):
y[a] = float(hapin["yllcorner"]) + hapin["delta_y"] * (a + 1) - hapin["delta_y"]/2
return x, np.flipud(y)
[docs]
def show_dem(
dem_mat=[],
hapin={},
ax=None,
**kwargs,
):
"""
Creates a 3D representation from a Grass DEM file
"""
# np.shape(dem_mat)
x, y = get_dem_coords(dem_mat, hapin)
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(projection="3d")
X, Y = np.meshgrid(x, y)
dem_mat[dem_mat==-9999] = np.nan
surf = ax.plot_surface(X, Y, dem_mat, cmap="viridis", **kwargs)
# surf = ax.plot_surface(Y,X, dem_mat, cmap="viridis",**kwargs)
cbar = plt.colorbar(
surf, shrink=0.25, orientation="horizontal", label="Elevation (m)"
)
ax.set(xlabel="Easting (m)",
ylabel="Northing (m)",
zlabel="Elevation (m)")
# plt.show(block=False)
# plt.close()
ax.yaxis.set_major_formatter(FormatStrFormatter("%3.4e"))
ax.xaxis.set_major_formatter(FormatStrFormatter("%3.4e"))
pass
[docs]
def plot_mesh_bounds(BCtypName, mesh_bound_cond_df, time, ax=None):
mvalue = []
alpha = []
mesh_bound_cond_df_selec = mesh_bound_cond_df[mesh_bound_cond_df['time']==time]
for bound_val in mesh_bound_cond_df_selec[BCtypName]:
if bound_val == 0:
mvalue.append(1)
alpha.append(1)
else:
mvalue.append(0)
alpha.append(0.1)
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(projection="3d")
cmap = ax.scatter(
mesh_bound_cond_df_selec["x"],
mesh_bound_cond_df_selec["y"],
mesh_bound_cond_df_selec["z"],
c=mvalue,
)
ax.set_xlabel("X Label")
ax.set_ylabel("Y Label")
ax.set_zlabel("Z Label")
ax.set_title(BCtypName + " Time " + str(time))
return cmap
#%% ---------------------------------------------------------------------------
# -----------------Plot PETRO-PEDOPHYSICS relationships -----------------------
# -----------------------------------------------------------------------------
[docs]
def plot_VGP(df_VGP, savefig=False, **kwargs):
if "ax" in kwargs:
ax = kwargs["ax"]
else:
fig = plt.figure(figsize=(6, 3), dpi=350)
ax = fig.add_subplot()
label = []
if "label" in kwargs:
label = kwargs["label"]
ax.plot(df_VGP["psi"], df_VGP["theta"], label=label, marker=".")
ax.set_xlabel("Soil Water Potential (cm)")
ax.set_ylabel(r"Water content ($m^3/m^3$)")
return ax
[docs]
def DA_plot_Archie(df_Archie, savefig=False, **kwargs):
if "ax" in kwargs:
ax = kwargs["ax"]
else:
fig = plt.figure(figsize=(6, 3), dpi=350)
ax = fig.add_subplot()
ax.scatter(df_Archie["sw"], df_Archie["ER_converted"])
ax.set_xlabel("Saturation (-)")
ax.set_ylabel("ER converted")
ax.set_title("[All ensemble]")
if "porosity" in kwargs:
ax.scatter(df_Archie["sw"] * kwargs["porosity"], df_Archie["ER_converted"])
ax.set_xlabel("swc")
ax.set_ylabel("ER_converted")
if savefig == True:
plt.savefig("Archie.png", dpi=300)
return ax
#%% ----------------------------------------------------------------------------
# -----------------Plot results DATA ASSIMILATION------------------------------
# -----------------------------------------------------------------------------
[docs]
def plot_hist_perturbated_parm(parm, var_per, type_parm, parm_per_array, **kwargs):
fig = plt.figure(figsize=(6, 3), dpi=150)
if var_per[type_parm]["transf_type"] is not None:
if "log".casefold() in var_per[type_parm]["transf_type"].casefold():
w = 0.2
nbins = math.ceil((parm_per_array.max() - parm_per_array.min()) / w)
# plt.hist(np.log10(parm_sampling), ensemble_size, alpha=0.5, label='sampling')
plt.hist(
np.log10(parm_per_array), nbins, alpha=0.5, label="ini_perturbation"
)
plt.axvline(
x=np.log10(parm[type_parm + "_nominal"]), linestyle="--", color="red"
)
else:
# plt.hist(parm_sampling, ensemble_size/2, alpha=0.5, label='sampling')
plt.hist(parm_per_array,
bins=7, # Increase the number of bins
alpha=0.5,
label="ini_perturbation"
)
plt.axvline(x=parm["nominal"], linestyle="--", color="red")
plt.legend(loc="upper right")
plt.xlabel(parm["units"])
plt.ylabel("Probability")
plt.title("Histogram of " + type_parm)
plt.tight_layout()
fig.savefig(os.path.join(os.getcwd(), kwargs["savefig"]), dpi=350)
[docs]
def show_DA_process_ens(
EnsembleX, Data, DataCov, dD, dAS, B, Analysis, savefig=False, **kwargs
):
"""Plot result of Data Assimilation Analysis"""
label_sensor = "raw data"
if "label_sensor" in kwargs:
label_sensor = kwargs["label_sensor"]
fig = plt.figure(figsize=(12, 6), dpi=300)
ax1 = fig.add_subplot(2, 5, 1)
cax = ax1.matshow(EnsembleX, aspect="auto") # ,
# cmap=cm.rainbow, norm=colors.LogNorm())
ax1.set_title("Prior")
ax1.set_ylabel("$\psi$ params #")
ax1.set_xlabel("Members #")
cbar = fig.colorbar(cax, location="bottom")
ax = fig.add_subplot(2, 5, 6)
cax = ax.matshow(np.cov(EnsembleX), aspect="auto", cmap="gray", norm=LogNorm())
ax.set_title("cov(Prior)")
ax.set_xlabel("$\psi$ params #")
ax.set_ylabel("$\psi$ params #")
# cbar = fig.colorbar(cax, location='bottom')
cbar = fig.colorbar(cax, format="$%.1f$", location="bottom")
ax.set_yticks([])
ax = fig.add_subplot(2, 5, 2)
cax = ax.matshow(np.tile(Data, (np.shape(EnsembleX)[1], 1)).T, aspect="auto")
# ax.set_title(label_sensor)
ax.set_ylabel("Meas")
ax.set_xlabel("Members #")
cbar = fig.colorbar(cax, location="bottom")
ax.set_yticks([])
ax = fig.add_subplot(2, 5, 7)
cax = ax.matshow(DataCov, aspect="auto", cmap="gray_r")
# cmap=cm.rainbow, norm=colors.LogNorm())
# vmin=0, vmax=1e-29)
ax.set_title("cov(meas)")
ax.set_ylabel("Meas")
ax.set_xlabel("Meas")
cbar = fig.colorbar(cax, location="bottom")
ax.set_yticks([])
# vlim = np.percentile(np.abs(dD), 95) # symmetric color scale
ax = fig.add_subplot(2, 5, 3)
cax = ax.matshow(dD, aspect="auto", cmap="seismic") #, vmin=-vlim, vmax=vlim)
ax.set_title("Meas - Sim")
ax.set_ylabel("Meas")
ax.set_xlabel("Members #")
cbar = fig.colorbar(cax, location="bottom")
ax.set_yticks([])
ax = fig.add_subplot(2, 5, 4, sharey=ax1)
cax = ax.matshow(np.dot(dAS, B), aspect="auto", cmap="seismic")
ax.set_title("Correction")
ax.set_ylabel("$\psi$ params #")
ax.set_xlabel("Members #")
cbar = fig.colorbar(cax, location="bottom")
ax.set_yticks([])
ax = fig.add_subplot(2, 5, 5, sharey=ax1)
cax = ax.matshow(Analysis, aspect="auto")
ax.set_title("Posterior")
ax.set_ylabel("$\psi$ params #")
ax.set_xlabel("Members #")
cbar = fig.colorbar(cax, location="bottom")
ax.set_yticks([])
plt.tight_layout()
savename = "showDA_process_ens"
if "savename" in kwargs:
savename = kwargs["savename"]
if savefig == True:
fig.savefig(savename + ".png", dpi=300)
plt.close()
# fig, ax = plt.subplots()
# cax = ax.matshow(np.dot(dAS, B), aspect="auto", cmap="jet")
# ax.set_title("Correction")
# ax.set_ylabel("$\psi$ params #")
# ax.set_xlabel("Members #")
# cbar = fig.colorbar(cax, location="bottom")
# ax.set_yticks([])
# fig, ax = plt.subplots()
# cax = ax.matshow(DataCov, aspect="auto", cmap="gray_r")
# ax.set_title("cov(meas)")
# ax.set_ylabel("Meas")
# ax.set_xlabel("Meas")
# cbar = fig.colorbar(cax, location="bottom")
# ax.set_yticks([])
return fig, ax
[docs]
def DA_RMS(df_performance, sensorName, **kwargs):
"""Plot result of Data Assimilation: RMS evolution over time."""
# Get optional arguments with defaults
ax = kwargs.get("ax")
start_date = kwargs.get("start_date")
atmbc_times = kwargs.get("atmbc_times")
# Define keytime and xlabel based on start_date presence
keytime = "time_date" if start_date else "time"
xlabel = "date" if start_date else "assimilation #"
# Filter and cast necessary columns in one pass for efficiency
df_perf_plot = df_performance[["time", "ObsType", f"RMSE{sensorName}", f"NMRMSE{sensorName}", "OL"]].dropna()
df_perf_plot = df_perf_plot.astype({f"RMSE{sensorName}": "float64",
f"NMRMSE{sensorName}": "float64",
"OL": "str"})
# Replace 'time' with converted dates if start_date is given
if start_date:
dates = change_x2date(atmbc_times, start_date)
df_perf_plot["time_date"] = df_perf_plot["time"].map(dict(zip(df_perf_plot["time"].unique(), dates)))
# Pivot tables for RMSE and NMRMSE to prepare for plotting
p0 = df_perf_plot.pivot(index=keytime, columns="OL", values=f"RMSE{sensorName}")
p1 = df_perf_plot.pivot(index=keytime, columns="OL", values=f"NMRMSE{sensorName}")
# Plot layout setup if no axes are provided
if ax is None:
num_plots = 3 if start_date else 2
fig, ax = plt.subplots(num_plots, 1, sharex=True)
# Plotting
p0.plot(ax=ax[1 if start_date else 0], xlabel=xlabel, ylabel=f"RMSE{sensorName}", style=[".-"])
p1.plot(ax=ax[2 if start_date else 1], xlabel=xlabel, ylabel=f"NMRMSE{sensorName}", style=[".-"])
return ax, plt
# def DA_RMS(df_performance, sensorName, **kwargs):
# """Plot result of Data Assimilation: RMS evolution over the time"""
# if "ax" in kwargs:
# ax = kwargs["ax"]
# # else:
# # if "start_date" in kwargs:
# # fig, ax = plt.subplots(3, 1, sharex=True)
# # else:
# # fig, ax = plt.subplots(2, 1)
# keytime = "time"
# xlabel = "assimilation #"
# start_date = None
# keytime = "time"
# xlabel = "assimilation #"
# if kwargs.get("start_date") is not None:
# keytime = "time_date"
# xlabel = "date"
# atmbc_times = None
# if "atmbc_times" in kwargs:
# atmbc_times = kwargs["atmbc_times"]
# header = ["time", "ObsType", "RMSE" + sensorName, "NMRMSE" + sensorName, "OL"]
# df_perf_plot = df_performance[header]
# df_perf_plot["RMSE" + sensorName] = df_perf_plot["RMSE" + sensorName].astype(
# "float64"
# )
# df_perf_plot["NMRMSE" + sensorName] = df_perf_plot["NMRMSE" + sensorName].astype(
# "float64"
# )
# df_perf_plot.OL = df_perf_plot.OL.astype("str")
# df_perf_plot.dropna(inplace=True)
# # len(dates)
# # len(df_perf_plot["time"])
# # len(atmbc_times)
# if start_date is not None:
# dates = change_x2date(atmbc_times, start_date)
# df_perf_plot["time_date"] = df_perf_plot["time"].replace(
# list(df_perf_plot["time"].unique()), dates)
# p0 = df_perf_plot.pivot(index=keytime, columns="OL", values="RMSE" + sensorName)
# p1 = df_perf_plot.pivot(index=keytime, columns="OL", values="NMRMSE" + sensorName)
# if "start_date" in kwargs:
# p0.plot(xlabel=xlabel, ylabel="RMSE" + sensorName, ax=ax[1], style=[".-"])
# p1.plot(xlabel=xlabel, ylabel="NMRMSE" + sensorName, ax=ax[2], style=[".-"])
# else:
# p0.plot(xlabel=xlabel, ylabel="RMSE" + sensorName, ax=ax[0], style=[".-"])
# p1.plot(xlabel=xlabel, ylabel="NMRMSE" + sensorName, ax=ax[1], style=[".-"])
# return ax, plt
[docs]
def DA_plot_parm_dynamic(
parm="ks", dict_parm_pert={}, list_assimilation_times=[], savefig=False, **kwargs
):
"""Plot result of Data Assimilation: parameter estimation evolution over the time"""
ensemble_size = len(dict_parm_pert[parm]["ini_perturbation"])
# nb_times = len(df_DA['time'].unique())
# fig = plt.figure(figsize=(6, 3), dpi=150)
if 'ax' in kwargs:
ax = kwargs['ax']
else:
fig, ax = plt.subplots()
ax.hist(
dict_parm_pert[parm]["ini_perturbation"],
ensemble_size,
alpha=0.5,
label="ini_perturbation",
)
ax.legend(loc="upper right")
ax.set_ylabel("Probability")
for nt in list_assimilation_times:
try:
ax.hist(
dict_parm_pert[parm]["update_nb" + str(int(nt + 1))],
ensemble_size,
alpha=0.5,
label="update nb" + str(int(nt + 1)),
)
except:
pass
ax.legend(loc="upper right")
ax.set_ylabel("Probability")
if "log" in kwargs:
if kwargs["log"]:
ax.set_xscale("log")
# plt.show()
return ax
[docs]
def DA_plot_parm_dynamic_scatter(
parm="ks", dict_parm_pert={}, list_assimilation_times=[], savefig=False, **kwargs
):
"""Plot result of Data Assimilation: parameter estimation evolution over the time"""
if "ax" in kwargs:
ax = kwargs["ax"]
else:
fig = plt.figure(figsize=(6, 3), dpi=350)
ax = fig.add_subplot()
ensemble_size = len(dict_parm_pert[parm]["ini_perturbation"])
mean_t = [np.mean(dict_parm_pert[parm]["ini_perturbation"])]
mean_t_yaxis = np.mean(dict_parm_pert[parm]["ini_perturbation"])
cloud_t = np.zeros([ensemble_size, len(list_assimilation_times)])
cloud_t[:, 0] = np.array(dict_parm_pert[parm]["ini_perturbation"])
np.shape(cloud_t)
try:
for nt in list_assimilation_times[1:]:
cloud_t[:, int(nt + 1)] = np.hstack(
dict_parm_pert[parm]["update_nb" + str(int(nt))]
)
except:
pass
dict_parm_new = {}
name = []
i = 0
for k in dict_parm_pert[parm].keys():
if "upd" in k:
dict_parm_new[parm + "_" + k] = np.hstack(dict_parm_pert[parm][k])
name.append(str(i + 1))
i = i + 1
df = pd.DataFrame()
df = pd.DataFrame(data=dict_parm_new)
df.index.name = "Ensemble_nb"
color = 'k'
if "color" in kwargs:
color = kwargs["color"]
if len(df.columns)>30:
nii = [int(ni) for ni in np.arange(0,len(df.columns),15)]
name = [str(ni) for ni in nii]
df = df.iloc[:,nii]
# dates = pd.to_datetime(list_assimilation_times[nii])
list_assimilation_times = list_assimilation_times[nii]
df.columns = list_assimilation_times
boxplot = seaborn.boxplot(
df,
color='grey',
ax=ax
)
# ax.xaxis.set_major_formatter(mdates.DateFormatter('%d %b'))
if pd.api.types.is_datetime64_any_dtype(df.columns):
ax.set_xticklabels([pd.to_datetime(date).strftime('%d %b') for date in df.columns],
rotation=45)
ax.set_ylabel(parm)
ax.set_xlabel("assimilation #")
if "log" in kwargs:
if kwargs["log"]:
boxplot.set_yscale("log")
return ax, df
[docs]
def prepare_DA_plot_time_dynamic(DA, state="psi", nodes_of_interest=[], **kwargs):
"""Select data from DA_df dataframe"""
DAc = DA.copy()
keytime = "time"
start_date = None
xlabel = "time (h)"
if kwargs.get("start_date") is not None:
start_date = kwargs["start_date"]
keytime = "time_date"
if "keytime" in kwargs:
keytime = kwargs['keytime']
atmbc_times = None
if "atmbc_times" in kwargs:
atmbc_times = kwargs["atmbc_times"]
if start_date is not None:
dates = change_x2date(atmbc_times, start_date)
if len(dates) < len(list(DAc["time"].unique())):
DAc = DAc.drop(DAc[DAc.time + 1 >= len(list(DAc["time"].unique()))].index)
if len(dates) > len(list(DAc["time"].unique())):
# check if contains difference in time otherwise add articifially 1s to keep
# the formatting. (Else plot crashes)
if all(dates[: len(list(DA["time"].unique()))].hour == 0):
time_delta_change = datetime.timedelta(seconds=1)
dates.values[0] = dates[0] + time_delta_change
dates = dates[: len(list(DAc["time"].unique()))]
unique_times = DAc["time"].unique()
DAc["time_date"] = DAc["time"].map(dict(zip(unique_times, dates[1:])))
isOL = DAc.loc[DAc["OL"] == True]
isENS = DAc.loc[DAc["OL"] == False]
isENS_time_Ens = isENS.set_index(["time", "Ensemble_nb"])
isOL_time_Ens = isOL.set_index(["time", "Ensemble_nb"])
if type(nodes_of_interest) != list:
nodes_of_interest = [nodes_of_interest]
NENS = int(max(DAc["Ensemble_nb"].unique()))
key2plot = "psi_bef_update"
if "sw" in state:
key2plot = "sw_bef_update_"
if "key2plot" in kwargs:
key2plot = kwargs["key2plot"]
prep_DA = {
"isENS": isENS,
"isOL": isOL,
}
# -----------------------------#
if len(nodes_of_interest) > 0:
if len(isOL) > 0:
nb_of_times = len(np.unique(isOL["time"]))
nb_of_ens = len(np.unique(isOL["Ensemble_nb"]))
nb_of_nodes = len(isOL[isOL["time"] == 0]) / nb_of_ens
test = len(isOL) / nb_of_nodes
try:
int(test)
except:
ValueError("inconsistent nb of OL data - no plot")
pass
else:
isOL.insert(
2,
"idnode",
np.tile(np.arange(nb_of_nodes), nb_of_times * nb_of_ens),
True,
)
select_isOL = isOL[isOL["idnode"].isin(nodes_of_interest)]
select_isOL = select_isOL.set_index([keytime, "idnode"])
select_isOL = select_isOL.reset_index()
# mean, min and max of Open Loop
# --------------------------------------------------------------------------
ens_mean_isOL_time = (
select_isOL.set_index([keytime, "Ensemble_nb", "idnode"])
.groupby(level=[keytime, "idnode"])[key2plot]
.mean()
)
ens_mean_isOL_time = ens_mean_isOL_time.reset_index(name="mean(ENS)_OL")
ens_min_isOL_time = (
select_isOL.set_index([keytime, "Ensemble_nb", "idnode"])
.groupby(level=[keytime, "idnode"])[key2plot]
.min()
)
ens_min_isOL_time = ens_min_isOL_time.reset_index(name="min(ENS)_OL")
ens_max_isOL_time = (
select_isOL.set_index([keytime, "Ensemble_nb", "idnode"])
.groupby(level=[keytime, "idnode"])[key2plot]
.max()
)
ens_max_isOL_time = ens_max_isOL_time.reset_index(name="max(ENS)_OL")
prep_DA.update(
{
"ens_mean_isOL_time": ens_mean_isOL_time,
"ens_max_isOL_time": ens_max_isOL_time,
"ens_min_isOL_time": ens_min_isOL_time,
}
)
if len(isENS) > 0:
nb_of_times = len(np.unique(isENS["time"]))
nb_of_ens = len(np.unique(isENS["Ensemble_nb"]))
nb_of_nodes = len(isENS[isENS["time"] == isENS["time"][0]]) / nb_of_ens
if len(isOL) > 0:
isENS.insert(
2,
"idnode",
np.tile(
np.arange(int(len(isENS) / (max(isENS["time"]) * (NENS)))),
int(max(isENS["time"])) * (NENS),
),
True,
)
else:
isENS.insert(
2,"idnode",
np.tile(np.arange(nb_of_nodes), nb_of_times*nb_of_ens),
True,
)
check_notrejected = isENS['rejected']==0
isENS = isENS[check_notrejected]
select_isENS = isENS[isENS["idnode"].isin(nodes_of_interest)]
select_isENS = select_isENS.set_index([keytime, "idnode"])
select_isENS = select_isENS.reset_index()
# mean, min and max of Open Loop
# --------------------------------------------------------------------------
ens_mean_isENS_time = (
select_isENS.set_index([keytime, "Ensemble_nb", "idnode"])
.groupby(level=[keytime, "idnode"])[key2plot]
.mean()
)
ens_mean_isENS_time = ens_mean_isENS_time.reset_index(name="mean(ENS)")
ens_min_isENS_time = (
select_isENS.set_index([keytime, "Ensemble_nb", "idnode"])
.groupby(level=[keytime, "idnode"])[key2plot]
.min()
)
ens_min_isENS_time = ens_min_isENS_time.reset_index(name="min(ENS)")
ens_max_isENS_time = (
select_isENS.set_index([keytime, "Ensemble_nb", "idnode"])
.groupby(level=[keytime, "idnode"])[key2plot]
.max()
)
ens_max_isENS_time = ens_max_isENS_time.reset_index(name="max(ENS)")
prep_DA.update(
{
"ens_mean_isENS_time": ens_mean_isENS_time,
"ens_max_isENS_time": ens_max_isENS_time,
"ens_min_isENS_time": ens_min_isENS_time,
}
)
else:
# take the spatial average mean
# -----------------------------#
spatial_mean_isOL_time_ens = isOL_time_Ens.groupby(
level=[keytime, "Ensemble_nb"]
)["analysis"].mean()
spatial_mean_isOL_time_ens = spatial_mean_isOL_time_ens.reset_index()
spatial_mean_isENS_time_ens = isENS_time_Ens.groupby(
level=[keytime, "Ensemble_nb"]
)["analysis"].mean()
spatial_mean_isENS_time_ens = spatial_mean_isENS_time_ens.reset_index()
select_isOL = spatial_mean_isOL_time_ens
select_isENS = spatial_mean_isENS_time_ens
return prep_DA
[docs]
def DA_plot_time_dynamic(
DA, state="psi", nodes_of_interest=[], savefig=False, **kwargs
):
"""Plot result of Data Assimilation: state estimation evolution over the time"""
keytime = "time"
xlabel = "time (h)"
if kwargs.get("start_date") is not None:
start_date = kwargs["start_date"]
xlabel = "date"
keytime = "time_date"
# kwargs['keytime']
prep_DA = prepare_DA_plot_time_dynamic(
DA,
state=state,
nodes_of_interest=nodes_of_interest,
**kwargs
)
if "ax" in kwargs:
ax = kwargs["ax"]
else:
fig = plt.figure(figsize=(6, 3), dpi=350)
ax = fig.add_subplot()
alpha = 0.2
# colors_minmax = 'darkblue'
colors_minmax = 'grey'
if "colors_minmax" in kwargs:
colors_minmax = kwargs["colors_minmax"]
if "keytime" in kwargs:
keytime = kwargs['keytime']
xlabel = "assimilation_times"
ylabel = r"pressure head $\psi$ (m)"
if "sw" in state:
ylabel = "water saturation (-)"
# --------------------------------------------------------------------------
if len(prep_DA["isENS"]) > 0:
prep_DA["ens_mean_isENS_time"].pivot(
index=keytime,
columns=["idnode"],
# columns=['idnode'],
values=["mean(ENS)"],
).plot(ax=ax, style=["--"],
color=colors_minmax,
alpha=0.2
)
# print(prep_DA["ens_mean_isENS_time"].isna().sum())
# print(prep_DA["ens_mean_isENS_time"]['mean(ENS)'].dtype)
# print(prep_DA["ens_mean_isENS_time"]['time_date'].dtype)
prep_DA["ens_max_isENS_time"].pivot(
index=keytime,
columns=["idnode"],
# columns=['idnode'],
values=["max(ENS)"],
).plot(ax=ax,
style=["-"],
color=colors_minmax,
)
prep_DA["ens_min_isENS_time"].pivot(
index=keytime,
columns=["idnode"],
# columns=['idnode'],
values=["min(ENS)"],
).plot(
ax=ax,
# style=[".-"],
style=["-"],
color=colors_minmax,
xlabel="(assimilation) time - (h)",
ylabel="pressure head $\psi$ (m)",
)
lgd = ax.fill_between(
prep_DA["ens_max_isENS_time"][keytime],
prep_DA["ens_min_isENS_time"]["min(ENS)"],
prep_DA["ens_max_isENS_time"]["max(ENS)"],
alpha=alpha,
color=colors_minmax,
label="minmax DA",
)
if "ens_mean_isOL_time" in prep_DA.keys():
prep_DA["ens_mean_isOL_time"].pivot(
index=keytime,
# columns=["Ensemble_nb",'idnode'],
columns=["idnode"],
values=["mean(ENS)_OL"],
).plot(
ax=ax,
style=["-"],
color="grey",
label=False,
ylabel="pressure head $\psi$ (m)",
) # water saturation (-)
prep_DA["ens_min_isOL_time"].pivot(
index=keytime,
columns=["idnode"],
# columns=['idnode'],
values=["min(ENS)_OL"],
).plot(ax=ax, style=["--"], color="grey", label=False)
prep_DA["ens_max_isOL_time"].pivot(
index=keytime,
columns=["idnode"],
# columns=['idnode'],
values=["max(ENS)_OL"],
).plot(ax=ax, style=["--"], color="grey", label=False)
ax.fill_between(
prep_DA["ens_mean_isOL_time"][keytime],
prep_DA["ens_min_isOL_time"]["min(ENS)_OL"],
prep_DA["ens_max_isOL_time"]["max(ENS)_OL"],
alpha=0.2,
color="grey",
label="minmax OL",
)
ax.set_xlabel(xlabel)
# ax.set_ylabel('pressure head (m)')
ax.set_ylabel(ylabel)
savename = "showDA_dynamic"
if "savename" in kwargs:
savename = kwargs["savename"]
if savefig == True:
plt.savefig(savename + ".png", dpi=300)
pass
def DA_plot_ET_dynamic(ET_DA,
nodePos=None,
nodeIndice=None,
observations=None,
ax=None,
unit='m/s',
**kwargs
):
meanETacolor = 'red'
if 'color' in kwargs:
meanETacolor = kwargs.pop('color')
alphaENS = 0.1
if 'alphaENS' in kwargs:
alphaENS = kwargs.pop('alphaENS')
if nodePos is not None:
# Select data for the specific node
ET_DA_node = ET_DA.sel(x=nodePos[0],
y=nodePos[1],
method="nearest"
)
ET_DA_act_etra = ET_DA_node["ACT. ETRA"]
else:
ET_DA_act_etra = ET_DA.mean(dim=['x','y'])
if unit=='mm/day':
ET_DA_act_etra = ET_DA_act_etra*(1e3*86400)
# Plot each ensemble member in grey
ET_DA_act_etra.plot(
ax=ax,
x="assimilation",
hue="ensemble",
color="grey",
alpha=alphaENS,
add_legend=False
)
# Plot the mean across ensembles in red
ET_DA_mean = ET_DA_act_etra.mean(dim="ensemble")
ET_DA_mean.plot(ax=ax, x="assimilation",
color=meanETacolor,
alpha=0.5,
linewidth=1,
label="Mean pred."
)
if observations is not None:
if nodePos is not None:
obs2plot_selecnode = observations.xs(f'ETact{nodeIndice}')[['data','data_err','datetime']]
obs2plot = obs2plot_selecnode.iloc[:len(ET_DA_mean)][['data','datetime']]
else:
import copy
obs2plot = copy.copy(observations)
if unit=='mm/day':
obs2plot['data'] = obs2plot['data']*(1e3*86400)
ax.scatter(
obs2plot.datetime[:],
obs2plot.data[0:len(obs2plot.datetime)],
label="Observed",
color='darkgreen',
s=6
)
# ax.set_title('')
if nodePos is not None:
ax.set_title(f"Node at ({nodePos[0]}, {nodePos[1]})")
ax.set_ylabel(f'ETa - {unit}')
def DA_plot_ET_performance(ET_DA,
observations,
axi,
nodeposi=None,
nodei=None,
unit='m/s'
):
if nodeposi is not None:
# Select data for the specific node
ET_DA_node = ET_DA.sel(x=nodeposi[0], y=nodeposi[1], method="nearest")
obs2plot_selecnode = observations.xs(f'ETact{nodei}')[['data','data_err']]
# Extract the "ACT. ETRA" variable as a DataArray
ET_DA_act_etra = ET_DA_node["ACT. ETRA"]
else:
ET_DA_act_etra = ET_DA.mean(dim=['x','y'])["ACT. ETRA"]
obs2plot_selecnode = observations[['data','data_err']].groupby(level=1).mean()
ET_DA_mean = ET_DA_act_etra.mean(dim="ensemble")
if unit=='mm/day':
ET_DA_act_etra = ET_DA_act_etra*(1e3*86400)
ET_DA_mean = ET_DA_act_etra.mean(dim="ensemble")
# Plot data for each ensemble member
for ensi in range(len(ET_DA_act_etra.ensemble)):
ET_DA_act_etra_ensi = ET_DA_act_etra.isel(ensemble=ensi)
obs2plot = obs2plot_selecnode.iloc[0:len(ET_DA_act_etra_ensi)].data
if unit=='mm/day':
obs2plot = obs2plot*(1e3*86400)
print(len(obs2plot))
print(len(ET_DA_act_etra_ensi.values[:]))
axi.scatter(
ET_DA_act_etra_ensi.values[:],
obs2plot.values,
label=f"Ensemble {ensi}" if ensi == 0 else "", # Label only once
color='grey',
alpha=0.1,
)
axi.scatter(
ET_DA_mean.values[:],
obs2plot,
alpha=0.5,
color='red',
)
# Add 1:1 line
min_val = min(np.nanmin(ET_DA_act_etra.values), np.nanmin(obs2plot_selecnode.data))
max_val = max(np.nanmax(ET_DA_act_etra.values), np.nanmax(obs2plot_selecnode.data))
axi.plot([min_val, max_val], [min_val, max_val], "k--", label="1:1 Line")
axi.set_xlim([min_val, max_val])
axi.set_ylim([min_val, max_val])
# Calculate R² and p-value using the separate function
r2, p_value = calculate_r2_p_value(ET_DA_mean.values, obs2plot)
# Example usage
# rmse, nrmse = calculate_rmse_nrmse(ET_DA_mean.values,
# obs2plot,
# normalization="mean"
# )
# print(f'RMSE:{rmse}')
# print(f'nrmse:{nrmse}')
# Annotate the plot with R² and p-value using the separate function
annotate_r2_p_value(axi, r2, p_value)
# Customize subplot
if nodeposi is not None:
axi.set_title(f"Node at ({nodeposi[0]}, {nodeposi[1]})")
axi.set_xlabel("Modelled ETa")
axi.set_ylabel("Observed ETa")
axi.legend()
axi.set_aspect('equal')
def calculate_rmse_nrmse(y_pred, y_obs, normalization="range"):
"""Calculate RMSE and NRMSE between two time series.
Args:
y_pred (array-like): Predicted values.
y_obs (array-like): Observed values.
normalization (str): "mean", "range", or "std" for NRMSE normalization.
Returns:
tuple: (RMSE, NRMSE)
"""
rmse = np.sqrt(np.mean((y_obs - y_pred) ** 2))
if normalization == "mean":
nrmse = rmse / np.mean(y_obs)
elif normalization == "range":
nrmse = rmse / (np.max(y_obs) - np.min(y_obs))
elif normalization == "std":
nrmse = rmse / np.std(y_obs)
else:
raise ValueError("Invalid normalization method. Choose 'mean', 'range', or 'std'.")
return rmse, nrmse
# Function to calculate R² and p-value
def calculate_r2_p_value(modelled_data, observed_data):
corr_coeff, p_value = stats.pearsonr(modelled_data, observed_data)
r2 = corr_coeff ** 2 # R² value
return r2, p_value
# Function to annotate the plot with R² and p-value
def annotate_r2_p_value(axi, r2, p_value):
# annotation_text = f"R² = {r2:.2f}\np-value = {p_value:.2e}"
annotation_text = f"R² = {r2:.2f}"
axi.annotate(annotation_text,
xy=(0.05, 0.95),
xycoords='axes fraction',
fontsize=12,
ha='left',
va='top',
bbox=dict(facecolor='white', alpha=0.6, edgecolor='none', boxstyle="round,pad=0.5")
)