Source code for wavebreaking.processing.plots

""""""
"""
This file is part of WaveBreaking.

WaveBreaking provides indices to detect, classify
and track Rossby Wave Breaking (RWB) in climate and weather data.
The tool was developed during my master thesis at the University of Bern.
Link to thesis: https://occrdata.unibe.ch/students/theses/msc/406.pdf

---

Plotting functions
"""

__author__ = "Severin Kaderli"
__license__ = "MIT"
__email__ = "severin.kaderli@unibe.ch"

# import modules
import xarray as xr
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import matplotlib
import cartopy.crs as ccrs
import cartopy.feature as cfeature

from wavebreaking.utils.data_utils import (
    check_argument_types,
    get_dimension_attributes,
    check_empty_dataframes,
)
from wavebreaking.utils import plot_utils
from wavebreaking.processing import spatial


[docs]@check_argument_types(["flag_data"], [xr.DataArray]) @get_dimension_attributes("flag_data") def plot_clim( flag_data, seasons=None, proj=None, size=None, smooth_passes=5, periodic=True, labels=True, levels=None, cmap=None, title="", *args, **kwargs ): """ Creates a simple climatological plot showing the occurrence frequency of the detected events. Dimension names ("time_name", "lon_name", "lat_name"), size ("ntime", "nlon", "nlat") and resolution ("dlon", "dlat") can be passed as key=value argument. Parameters ---------- flag_data : xarray.DataArray data containing the locations of the events flagged with the value 1 seasons : list or array, optional months of the seasons for occurrence frequency calculation (e.g. [12, 1, 2]) proj : cartopy.crs, optional cartopy projection object size : tuple of integers, optional size of the figure in the format (width, height) smooth_passes : int or float, optional number of smoothing passes of the 5-point smoothing of the occurrence frequencies periodic : bool, optional If True, the first longitude is added at the end to close the gap in a polar projection labels : bool, optional If False, no labels are added to the plot levels : list or array, optional Colorbar levels. If not provided, default levels are used. cmap : string, optional Name of a valid cmap. If not provided, a default cmap is used. title : string, optional Title of the plot Returns ------- plot : matplotlib.pyplot Climatological plot of the occurrence frequencies. """ # define data crs data_crs = ccrs.PlateCarree() # initialize figure proj = proj if proj is not None else data_crs size = size if size is not None else (12, 8) fig, ax = plt.subplots(1, 1, subplot_kw=dict(projection=proj), figsize=size) # calculate occurrence frequencies, if provided for seasons if seasons is None: freq = ( xr.where(flag_data > 0, 1, 0).sum(dim=kwargs["time_name"]) / kwargs["ntime"] * 100 ) else: ds_season = flag_data.sel( {kwargs["time_name"]: flag_data[kwargs["time_name"]].dt.month.isin(seasons)} ) freq = ( xr.where(ds_season > 0, 1, 0).sum(dim=kwargs["time_name"]) / len(ds_season[kwargs["time_name"]]) * 100 ) # perform smoothing freq = spatial.calculate_smoothed_field( freq.expand_dims("time"), smooth_passes ).isel(time=0) # add longitude to ensure that there is no gap in a periodic field if periodic is True: freq = plot_utils.calculate_periodic_field(freq, **kwargs) # define levels if levels is None: levels = plot_utils.get_levels(freq.min(), freq.max()) # define cmap if cmap is None: cmap = plot_utils.get_new_cmap("RdYlBu_r") # plot frequencies p = freq.where(freq > 0, -999).plot.contourf( ax=ax, cmap=cmap, levels=levels, transform=data_crs, add_colorbar=False, extend="max", ) # define colorbar cax = fig.add_axes( [ ax.get_position().x1 + 0.05, ax.get_position().y0, 0.015, ax.get_position().height, ] ) plot_utils.add_colorbar(p, cax, levels, label="Occurrence frequency in %") # add coast lines and grid lines ax.add_feature(cfeature.COASTLINE, color="dimgrey") ax.gridlines(draw_labels=False, color="black", linestyle="dotted", linewidth=1.1) # plot labels if labels is True: plot_utils.add_grid_lines(ax) # make a circular cut out for the NorthPolarStereo projection if proj == ccrs.NorthPolarStereo(): plot_utils.add_circular_boundary(ax) plot_utils.add_circular_patch(ax) # set title ax.set_title(title, fontweight="bold", fontsize=20)
[docs]@check_argument_types(["flag_data"], [xr.DataArray]) @get_dimension_attributes("flag_data") def plot_step( flag_data, step, data=None, contour_levels=None, proj=None, size=None, periodic=True, labels=True, levels=None, cmap="RdBu_r", color_events="gold", title="", *args, **kwargs ): """ Creates a plot showing the events at one time step. Dimension names ("time_name", "lon_name", "lat_name"), size ("ntime", "nlon", "nlat") and resolution ("dlon", "dlat") can be passed as key=value argument. Parameters ---------- flag_data : xarray.DataArray Data containing the locations of the events flagged with the value 1 step : int or string index or name of a time step in the xarray.Dataset data : xarray.DataArray Data that has been used to calculate the contours and the indices contour_level : array_like contour levels that are shown in the plot proj : cartopy.crs, optional cartopy projection object size : tuple of integers, optional size of the figure in the format (width, height) periodic : bool, optional If True, the first longitude is added at the end to close the gap in a polar projection labels : bool, optional If False, no labels are added to the plot levels : list or array, optional Colorbar levels. If not provided, default levels are used. cmap : string, optional Name of a valid cmap. If not provided, a default cmap is used. color_events : string, optional Color of the events title : string, optional Title of the plot Returns ------- plot : matplotlib.pyplot Plot of one time step. """ # define data crs data_crs = ccrs.PlateCarree() # initialize figure proj = proj if proj is not None else data_crs size = size if size is not None else (12, 8) fig, ax = plt.subplots(1, 1, subplot_kw=dict(projection=proj), figsize=size) # select data if type(step) is str or type(step) == np.dtype("datetime64[ns]"): try: flag = flag_data.sel({kwargs["time_name"]: step}) except KeyError: errmsg = "step {} not supported or out of range!".format(step) raise KeyError(errmsg) else: try: flag = flag_data.isel({kwargs["time_name"]: step}) except KeyError: errmsg = "step {} not supported or out of range!".format(step) raise KeyError(errmsg) # get date date = flag[kwargs["time_name"]].values if date.dtype == np.dtype("datetime64[ns]"): date = pd.Timestamp(date).strftime("%Y-%m-%dT%H") # plot field data if provided if data is not None: field = data.sel({kwargs["time_name"]: date}) if periodic is True: field = plot_utils.calculate_periodic_field(field, **kwargs) if levels is None: levels = plot_utils.get_levels(field.min(), field.max()) p = field.plot.contourf( ax=ax, cmap=cmap, levels=levels, transform=data_crs, add_colorbar=False, alpha=0.8, ) if contour_levels is not None: # check contour levels try: iter(contour_levels) except Exception: contour_levels = [contour_levels] field.plot.contour( ax=ax, transform=data_crs, levels=contour_levels, linestyles="-", linewidths=2, colors="#000000", ) # define colorbar if all(x in field.attrs for x in ["units", "long_name"]): cbar_label = field.long_name + " [" + field.units + "]" else: cbar_label = None cax = fig.add_axes( [ ax.get_position().x1 + 0.05, ax.get_position().y0, 0.015, ax.get_position().height, ] ) plot_utils.add_colorbar(p, cax, levels, label=cbar_label) # plot flag data if periodic is True: flag = plot_utils.calculate_periodic_field(flag, **kwargs) flag.where(flag > 0).plot.contourf( ax=ax, colors=["white", color_events], levels=[0, 0.5], transform=data_crs, add_colorbar=False, ) # add the date to the figure plt.text( 0.99, 0.98, "Date: " + str(date), fontsize=10, fontweight="bold", ha="right", va="top", transform=ax.transAxes, ) # add coast lines and grid lines ax.add_feature(cfeature.COASTLINE, color="dimgrey") ax.gridlines(draw_labels=False, color="black", linestyle="dotted", linewidth=1.1) # plot labels if labels is True: plot_utils.add_grid_lines(ax) # make a circular cut out for the NorthPolarStereo projection if proj == ccrs.NorthPolarStereo(): plot_utils.add_circular_boundary(ax) plot_utils.add_circular_patch(ax) # set title ax.set_title(title, fontweight="bold", fontsize=20)
[docs]@check_argument_types(["data", "events"], [xr.DataArray, gpd.GeoDataFrame]) @check_empty_dataframes @get_dimension_attributes("data") def plot_tracks( data, events, proj=None, size=None, min_path=0, plot_events=False, labels=True, title="", *args, **kwargs ): """ Creates a plot showing the tracks of the temporally coherent events. Dimension names ("time_name", "lon_name", "lat_name"), size ("ntime", "nlon", "nlat") and resolution ("dlon", "dlat") can be passed as key=value argument. Parameters ---------- data : xarray.DataArray Data that has been used to calculate the contours and the indices events: pd.DataFrame DataFrame with the date, coordinates and label of the identified events proj : cartopy.crs, optional cartopy projection object size : tuple of integers, optional size of the figure in the format (width, height) min_path: int, optional Minimal number of time steps an event has to be tracked plot_events: bool, optional If True, the events are also plotted by a shaded area labels: bool, optional If False, no labels are added to the plot title: string, optional Title of the plot Returns ------- plot : matplotlib.pyplot Plot of the tracks """ # define data crs data_crs = ccrs.PlateCarree() # initialize figure proj = proj if proj is not None else data_crs size = size if size is not None else (12, 8) fig, ax = plt.subplots(1, 1, subplot_kw=dict(projection=proj), figsize=size) # set background color ax.set_facecolor((0.1, 0.1, 0.1, 0.05)) # get colors for event plotting lab, count = np.unique(events.label, return_counts=True) lab_sel = lab[count > min_path] color_range = {} for r, name in enumerate(lab_sel): color_range[name] = matplotlib.cm.get_cmap("rainbow")(r / len(lab_sel)) # group event data by label and plot each path for name, group in events.groupby("label"): if len(group) > min_path: lons = np.asarray(group.com.tolist())[:, 0] lats = np.asarray(group.com.tolist())[:, 1] # plot start point of each path ax.scatter( lons[0], lats[0], s=14, zorder=10, facecolors="none", edgecolor="black", transform=data_crs, ) ax.plot(lons[0], lats[0], ".", color="red", transform=data_crs, alpha=0.7) # plot the coordinates of the events if plot_events is True: group.plot( ax=ax, color=color_range[name], transform=data_crs, alpha=0.5 ) max_lon = max(data[kwargs["lon_name"]].values) + kwargs["dlon"] min_lon = min(data[kwargs["lon_name"]].values) # check if the path needs to be split due to a crossing of the date border diffs = abs(np.diff(lons)) > (max_lon - min_lon / 2) split = [[i - 1, i] for i in np.where(diffs)[0] + 1] no_split = [[i - 1, i] for i in np.where(~diffs)[0] + 1] # plot paths that do not need to be split for item in no_split: ev_seg = np.asarray(group[item[0] : item[1] + 1].com.tolist()) ax.plot( ev_seg[:, 0], ev_seg[:, 1], "-", transform=data_crs, color="black" ) # plot paths that have to be split for item in split: ev_seg = np.asarray(group[item[0] : item[1] + 1].com.tolist()) # upstream split if np.diff(ev_seg[:, 0]) < 0: lons_plot = [(ev_seg[0, 0], max_lon), (min_lon, ev_seg[1, 0])] lon_diffs = np.diff(lons_plot) m = np.diff(ev_seg[:, 1])[0] / np.sum(lon_diffs) lats_plot = [ (ev_seg[0, 1], ev_seg[0, 1] + lon_diffs[0][0] * m), (ev_seg[1, 1] - lon_diffs[1][0] * m, ev_seg[1, 1]), ] # downstream split else: lons_plot = [(ev_seg[0, 0], min_lon), (max_lon, ev_seg[1, 0])] lon_diffs = np.diff(lons_plot) m = np.diff(ev_seg[:, 1])[0] / np.sum(lon_diffs) lats_plot = [ (ev_seg[0, 1], ev_seg[0, 1] + lon_diffs[0][0] * m), (ev_seg[1, 1] - lon_diffs[1][0] * m, ev_seg[1, 1]), ] # plot splitted segments for lon, lat in zip(lons_plot, lats_plot): ax.plot(lon, lat, "-", "-", transform=data_crs, color="black") # plot invisible data to get a plot of the full grid data.isel({kwargs["time_name"]: 0}).plot.contourf( ax=ax, transform=data_crs, add_colorbar=False, alpha=0 ) # add coast lines and grid lines ax.add_feature(cfeature.COASTLINE, color="dimgrey") ax.gridlines(draw_labels=False, color="black", linestyle="dotted", linewidth=1.1) # plot labels if labels is True: plot_utils.add_grid_lines(ax) # make a circular cut out for the NorthPolarStereo projection if proj == ccrs.NorthPolarStereo(): plot_utils.add_circular_boundary(ax) plot_utils.add_circular_patch(ax) # set title ax.set_title(title, fontweight="bold", fontsize=20)