Source code for resilientplotterclass.interactive

import branca
import folium
import geopandas as gpd
import ipyleaflet
import ipywidgets
import matplotlib.pyplot as plt
import numpy as np
import xyzservices
from folium.plugins import Draw
from matplotlib.colors import to_hex


def _explore_image(da, m, cmap="Spectral_r", vmin=None, vmax=None, legend=None, legend_kwds={}, **kwargs):
    """ "Plot data interactively using folium.

    :param da:          Data to plot.
    :type da:           xarray.DataArray
    :param m:           Folium map to plot on.
    :type m:            folium.Map, optional
    :param cmap:        Colormap to use.
    :type cmap:         str, optional
    :param vmin:        Minimum value for colormap.
    :type vmin:         float, optional
    :param vmax:        Maximum value for colormap.
    :type vmax:         float, optional
    :param legend:      Whether to show legend.
    :type legend:       bool, optional
    :param legend_kwds: Keyword arguments for legend.
    :type legend_kwds:  dict, optional
    :param kwargs:      Keyword arguments for :func:`folium.raster_layers.ImageOverlay`.
    :type kwargs:       dict, optional
    :return:            Plot.
    :rtype:             folium.Map

    :See also: `folium.raster_layers.ImageOverlay <https://python-visualization.github.io/folium/latest/reference.html#folium.raster_layers.ImageOverlay>`_
    :See also: `folium.LinearColormap <https://python-visualization.github.io/folium/latest/advanced_guide/colormaps.html>`_
    """

    # Convert da to number
    if not np.issubdtype(da.dtype, np.number):
        da = da.astype(float)

    # Get cmap
    cmap = plt.get_cmap(cmap)

    # Get vmin and vmax
    if vmin is None:
        vmin = da.min().values
    if vmax is None:
        vmax = da.max().values

    # Get legend
    if legend is None:
        legend = True if da.ndim == 2 else False

    # Default kwargs
    kwargs.setdefault("mercator_project", True)

    # Normalise values
    da = ((da - vmin) / (vmax - vmin)).clip(0, 1)

    # Get colors
    if da.ndim == 2 or (da.ndim == 3 and da.shape[0] == 1):
        rgba = cmap(da.values).reshape((*da.shape, 4))
    elif da.ndim == 3 and da.shape[0] == 3:
        rgb = da.values.transpose(1, 2, 0)
        rgba = np.concatenate([rgb, np.ones((*rgb.shape[:2], 1))], axis=2)
    else:
        raise ValueError("DataArray must be 2D or 3D with shape (3, y, x)")

    # Fix color range for mercator projection
    if "mercator_project" in kwargs and kwargs["mercator_project"]:
        # Find first two non-nan values
        nan_indices = np.where(np.isnan(da.values))

        # If there are less than two non-nan values, set first two values to 0 and 1
        if len(nan_indices[0]) < 2:
            nan_indices = (np.array([0, 0]), np.array([0, 1]))

        # Set first two values to 0 and 1
        rgba[nan_indices[0][0], nan_indices[1][0], :3] = 0
        rgba[nan_indices[0][1], nan_indices[1][1], :3] = 1

    # Convert range to 0-255
    rgba = (rgba * 255).astype(np.uint8)

    # Get bounds
    bounds = da.rio.bounds()
    bounds = [[bounds[1], bounds[0]], [bounds[3], bounds[2]]]

    # Create image
    img = folium.raster_layers.ImageOverlay(image=rgba, bounds=bounds, **kwargs)

    # Add to map
    m.add_child(img)

    # Add legend
    if legend:
        # Get colors
        colors = cmap(np.linspace(0, 1, 256))

        # Convert colors to hex
        colors = [to_hex(color) for color in colors]

        # Create colormap
        cmap = folium.LinearColormap(colors, vmin=vmin, vmax=vmax, **legend_kwds)

        # Add to map
        m.add_child(cmap)

    # Return map
    return m


def _set_map_bounds(m, bounds):
    # Convert bounds to list
    bounds = [[bounds[1], bounds[0]], [bounds[3], bounds[2]]]

    # Get map bounds
    map_bounds = m.get_bounds()

    # Combine bounds
    if np.all(np.array(map_bounds) != None):
        bounds = [
            [min(bounds[0][0], map_bounds[0][0]), min(bounds[0][1], map_bounds[0][1])],
            [max(bounds[1][0], map_bounds[1][0]), max(bounds[1][1], map_bounds[1][1])],
        ]

    # Set bounds
    m.fit_bounds(bounds)

    # Return map
    return m


class _Draw(Draw):
    """Wrapper for folium.plugins.Draw to add a button to export the map.

    :param Draw: Folium Draw object.
    :type Draw: folium.plugins.Draw
    """

    def render(self, **kwargs):
        super().render(**kwargs)

        figure = self.get_root()
        assert isinstance(figure, branca.element.Figure), "You cannot render this Element if it is not in a Figure."

        export_style = """
            <style>
                #export {
                    position: absolute;
                    bottom: 12px;
                    left: 12px;
                    z-index: 1000;
                    background: white;
                    outline: 2px solid rgba(0, 0, 0, 0.2);
                    padding: 7px;
                    border-radius: 2px;
                    cursor: pointer;
                    font-size: 12px;
                    text-decoration: none;
                }

                #export:hover {
                    background: #f0f0f0;
                }
            </style>
        """
        export_button = """<a href='#' id='export'>💾</a>"""
        if self.export:
            # Add button to figure
            figure.header.add_child(branca.element.Element(export_style), name="export")
            figure.html.add_child(branca.element.Element(export_button), name="export_button")


[docs] def pcolormesh(da, m=None, skip=1, smooth=1, **kwargs): raise NotImplementedError("pcolormesh plot not implemented yet")
[docs] def imshow(da, m=None, skip=1, smooth=1, **kwargs): """Plot data interactively using imshow. :param da: Data to plot. :type da: xarray.DataArray :param m: Folium map to plot on. :type m: folium.Map, optional :param skip: Plot every nth value in x and y direction. :type skip: int, optional :param smooth: Smooth data array with rolling mean in x and y direction. :type smooth: int, optional :param xlim: x limits. :type xlim: list[float], optional :param ylim: y limits. :type ylim: list[float], optional :param kwargs: Keyword arguments for :func:`resilientplotterclass.interactive._explore_image`. :type kwargs: dict, optional :return: Plot. :rtype: folium.Map """ # Reproject DataArray if da.rio.crs != "EPSG:4326": print("\033[93mReprojecting DataArray to EPSG:4326.\033[0m") da = da.rio.reproject("EPSG:4326") # Get map if m is None: m = folium.Map() # Set map bounds m = _set_map_bounds(m, da.rio.bounds()) # Reproject DataArray if da.rio.crs != "EPSG:4326": da = da.rio.reproject("EPSG:4326") # Skip DataArray values if skip > 1: da = da.isel(x=slice(None, None, skip), y=slice(None, None, skip)) # Smooth DataArray if smooth > 1: da = da.rolling(x=smooth, center=True).mean().rolling(y=smooth, center=True).mean() # Plot DataArray m = _explore_image(da, m=m, **kwargs) # Return plot return m
[docs] def scatter(da, m=None, skip=1, smooth=1, **kwargs): raise NotImplementedError("scatter plot not implemented yet")
[docs] def contourf(da, m=None, skip=1, smooth=1, **kwargs): raise NotImplementedError("contourf plot not implemented yet")
[docs] def contour(da, m=None, skip=1, smooth=1, **kwargs): raise NotImplementedError("contour plot not implemented yet")
[docs] def quiver(da, m=None, skip=1, smooth=1, **kwargs): raise NotImplementedError("quiver plot not implemented yet")
[docs] def streamplot(da, m=None, skip=1, smooth=1, **kwargs): raise NotImplementedError("streamplot plot not implemented yet")
[docs] def plot_geometries(gdf, m=None, **kwargs): """Plot geometries interactively using folium. :param gdf: GeoDataFrame to plot. :type gdf: geopandas.GeoDataFrame :param m: Folium map to plot on. :type m: folium.Map, optional :param kwargs: Keyword arguments for :func:`geopandas.explore`. :type kwargs: dict, optional :return: Plot. :rtype: folium.Map :See also: `geopandas.explore <https://geopandas.org/en/stable/docs/reference/api/geopandas.GeoDataFrame.explore.html>`_ """ # Reproject GeoDataFrame if gdf.crs != "EPSG:4326": print("\033[93mReprojecting GeoDataFrame to EPSG:4326.\033[0m") gdf = gdf.to_crs("EPSG:4326") # Get map if m is None: m = folium.Map() # Set map bounds m = _set_map_bounds(m, gdf.total_bounds) # Plot GeoDataFrame gdf.explore(m=m, **kwargs) # Return plot return m
[docs] def plot_basemap(m=None, **kwargs): """Plot basemaps interactively using folium. :param m: Folium map to plot on. :type m: folium.Map, optional :param kwargs: Keyword arguments for :func:`folium.Map`. :type kwargs: dict, optional :return: Plot. :rtype: folium.Map :See also: `folium.TileLayer <https://python-visualization.github.io/folium/latest/reference.html#folium.raster_layers.TileLayer>`_ """ # Get map if m is None: m = folium.Map() # Plot basemap folium.TileLayer(**kwargs).add_to(m) # Return plot return m
[docs] class Draw_Map(ipyleaflet.Map): """ A class to create a map to draw geometries. """ STYLE = {"weight": 2} def __init__( self, center: tuple[float] = None, zoom: int = 8, basemap: xyzservices.TileProvider = ipyleaflet.basemaps.OpenStreetMap.Mapnik, file_path_gdf: str = None, gdf: gpd.GeoDataFrame = None, **kwargs, ): """Constructor for the draw map class. :param center: Center of the map. :type center: tuple[float], optional :param zoom: Zoom level of the map. :type zoom: int, optional :param basemap: Basemap layer for the map. :type basemap: xyzservices.TileProvider, optional :param file_path_gdf: File path to the GeoDataFrame file. :type file_path_gdf: str, optional :param gdf: GeoDataFrame to display on the map. :type gdf: gpd.GeoDataFrame, optional """ # Get geometries from file or list gdf = None if file_path_gdf is not None: gdf = gpd.read_file(file_path_gdf).to_crs("EPSG:4326") elif gdf is not None and not gdf.empty: gdf = gpd.GeoDataFrame(geometry=gdf, crs="EPSG:4326") # Get center of the map if center is None and gdf is not None and not gdf.empty: centroid = gdf.geometry.union_all().representative_point() center = (centroid.y, centroid.x) # Set default kwargs kwargs.setdefault("scroll_wheel_zoom", True) # Enable scroll wheel zoom kwargs.setdefault("attribution_control", False) # Disable attribution control kwargs.setdefault("layout", ipywidgets.Layout(height="600px", width="100%")) # Set layout height and width # Initialize map superclass super().__init__(center=center, zoom=zoom, basemap=basemap, **kwargs) # Initialise draw control self.draw_control = ipyleaflet.DrawControl( polygon={"shapeOptions": self.STYLE}, rectangle={"shapeOptions": self.STYLE}, circlemarker={"shapeOptions": self.STYLE}, polyline={"shapeOptions": self.STYLE}, ) # Add draw control to the map self.add_control(self.draw_control) # Set geometries to the map if gdf is not None and not gdf.empty: self.set_geometries(gdf) self.fit_bounds(gdf.total_bounds)
[docs] def set_geometries(self, gdf: gpd.GeoDataFrame) -> None: """Set geometries to the map. :param aoi: The GeoDataFrame containing the geometries. :type aoi: gpd.GeoDataFrame Returns: None """ # Update draw flag self.drawn = True # Add style to GeoDataFrame gdf["style"] = [self.STYLE] * len(gdf) # Add geometries to the draw control self.draw_control.data = self.draw_control.data + list(gdf.iterfeatures())
[docs] def get_geometries(self, crs: str = "EPSG:4326") -> gpd.GeoDataFrame: """ Get geometries from the map. Args: crs (str): The coordinate reference system to reproject the geometries to. Returns: gpd.GeoDataFrame: A GeoDataFrame containing the geometries from the map. """ # Get geometries from the map if not self.draw_control.data: gdf = gpd.GeoDataFrame(columns=["geometry"], crs="EPSG:4326") else: gdf = gpd.GeoDataFrame.from_features(self.draw_control.data, crs="EPSG:4326").drop(columns="style") # Reproject geometries gdf = gdf.to_crs(crs) return gdf