Source code for imod.mf6.model

import abc
import collections
import inspect
import pathlib
from copy import deepcopy

import cftime
import jinja2
import numpy as np
import tomli
import tomli_w
import xugrid as xu
from jinja2 import Template

import imod
from imod.mf6 import qgs_util
from imod.mf6.pkgbase import Package
from imod.mf6.statusinfo import NestedStatusInfo, StatusInfo, StatusInfoBase
from imod.mf6.validation import pkg_errors_to_status_info
from imod.schemata import ValidationError


def initialize_template(name: str) -> Template:
    loader = jinja2.PackageLoader("imod", "templates/mf6")
    env = jinja2.Environment(loader=loader, keep_trailing_newline=True)
    return env.get_template(name)


class Modflow6Model(collections.UserDict, abc.ABC):
    _mandatory_packages = None
    _model_id = None

    def __init__(self, **kwargs):
        collections.UserDict.__init__(self)
        for k, v in kwargs.items():
            self[k] = v

        self._options = {}
        self._template = None

    def __setitem__(self, key, value):
        if len(key) > 16:
            raise KeyError(
                f"Received key with more than 16 characters: '{key}'"
                "Modflow 6 has a character limit of 16."
            )

        super().__setitem__(key, value)

    def update(self, *args, **kwargs):
        for k, v in dict(*args, **kwargs).items():
            self[k] = v

    def __get_diskey(self):
        dis_pkg_ids = ["dis", "disv", "disu"]

        diskeys = [
            self._get_pkgkey(pkg_id)
            for pkg_id in dis_pkg_ids
            if self._get_pkgkey(pkg_id) is not None
        ]

        if len(diskeys) > 1:
            raise ValueError(f"Found multiple discretizations {diskeys}")
        elif len(diskeys) == 0:
            raise ValueError("No model discretization found")
        else:
            return diskeys[0]

    def _get_pkgkey(self, pkg_id):
        """
        Get package key that belongs to a certain pkg_id, since the keys are
        user specified.
        """
        key = [pkgname for pkgname, pkg in self.items() if pkg._pkg_id == pkg_id]
        nkey = len(key)
        if nkey > 1:
            raise ValueError(f"Multiple instances of {key} detected")
        elif nkey == 1:
            return key[0]
        else:
            return None

    def _check_for_required_packages(self, modelkey: str) -> None:
        # Check for mandatory packages
        pkg_ids = set([pkg._pkg_id for pkg in self.values()])
        dispresent = "dis" in pkg_ids or "disv" in pkg_ids or "disu" in pkg_ids
        if not dispresent:
            raise ValueError(f"No dis/disv/disu package found in model {modelkey}")
        for required in self._mandatory_packages:
            if required not in pkg_ids:
                raise ValueError(f"No {required} package found in model {modelkey}")
        return

    def _use_cftime(self):
        """
        Also checks if datetime types are homogeneous across packages.
        """
        types = [
            type(pkg.dataset["time"].values[0])
            for pkg in self.values()
            if "time" in pkg.dataset.coords
        ]
        set_of_types = set(types)
        # Types will be empty if there's no time dependent input
        if len(set_of_types) == 0:
            return False
        else:  # there is time dependent input
            if not len(set_of_types) == 1:
                raise ValueError(
                    f"Multiple datetime types detected: {set_of_types}"
                    "Use either cftime or numpy.datetime64[ns]."
                )
            # Since we compare types and not instances, we use issubclass
            if issubclass(types[0], cftime.datetime):
                return True
            elif issubclass(types[0], np.datetime64):
                return False
            else:
                raise ValueError("Use either cftime or numpy.datetime64[ns].")

    def _yield_times(self):
        modeltimes = []
        for pkg in self.values():
            if "time" in pkg.dataset.coords:
                modeltimes.append(pkg.dataset["time"].values)
            repeat_stress = pkg.dataset.get("repeat_stress")
            if repeat_stress is not None and repeat_stress.values[()] is not None:
                modeltimes.append(repeat_stress.isel(repeat_items=0).values)
        return modeltimes

    def render(self, modelname: str):
        dir_for_render = pathlib.Path(modelname)
        d = {k: v for k, v in self._options.items() if not (v is None or v is False)}
        packages = []
        for pkgname, pkg in self.items():
            # Add the six to the package id
            pkg_id = pkg._pkg_id
            key = f"{pkg_id}6"
            path = dir_for_render / f"{pkgname}.{pkg_id}"
            packages.append((key, path.as_posix(), pkgname))
        d["packages"] = packages
        return self._template.render(d)

    def _model_checks(self, modelkey: str):
        """
        Check model integrity (called before writing)
        """

        self._check_for_required_packages(modelkey)

    def _validate(self, model_name: str = "") -> StatusInfoBase:
        try:
            diskey = self.__get_diskey()
        except Exception as e:
            status_info = StatusInfo(f"{model_name} model")
            status_info.add_error(str(e))
            return status_info

        dis = self[diskey]
        # We'll use the idomain for checking dims, shape, nodata.
        idomain = dis["idomain"]
        bottom = dis["bottom"]

        model_status_info = NestedStatusInfo(f"{model_name} model")
        for pkg_name, pkg in self.items():
            # Check for all schemata when writing. Types and dimensions
            # may have been changed after initialization...

            if pkg_name in ["adv"]:
                continue  # some packages can be skipped

            # Concatenate write and init schemata.
            schemata = deepcopy(pkg._init_schemata)
            for key, value in pkg._write_schemata.items():
                if key not in schemata.keys():
                    schemata[key] = value
                else:
                    schemata[key] += value

            pkg_errors = pkg._validate(
                schemata=schemata,
                idomain=idomain,
                bottom=bottom,
            )
            if len(pkg_errors) > 0:
                model_status_info.add(pkg_errors_to_status_info(pkg_name, pkg_errors))

        return model_status_info

    def write(
        self, directory, modelname, globaltimes, binary=True, validate: bool = True
    ) -> StatusInfoBase:
        """
        Write model namefile
        Write packages
        """

        workdir = pathlib.Path(directory)
        modeldirectory = workdir / modelname
        modeldirectory.mkdir(exist_ok=True, parents=True)
        if validate:
            model_status_info = self._validate(modelname)
            if model_status_info.has_errors():
                return model_status_info

        # write model namefile
        namefile_content = self.render(modelname)
        namefile_path = modeldirectory / f"{modelname}.nam"
        with open(namefile_path, "w") as f:
            f.write(namefile_content)

        # write package contents
        for pkg_name, pkg in self.items():
            try:
                pkg.write(
                    directory=modeldirectory,
                    pkgname=pkg_name,
                    globaltimes=globaltimes,
                    binary=binary,
                )
            except Exception as e:
                raise type(e)(f"{e}\nError occured while writing {pkg_name}")

        return NestedStatusInfo(modelname)

    def dump(
        self, directory, modelname, validate: bool = True, mdal_compliant: bool = False
    ):
        modeldirectory = pathlib.Path(directory) / modelname
        modeldirectory.mkdir(exist_ok=True, parents=True)
        if validate:
            statusinfo = self._validate()
            if statusinfo.has_errors():
                raise ValidationError(statusinfo.to_string())

        toml_content = collections.defaultdict(dict)
        for pkgname, pkg in self.items():
            pkg_path = f"{pkgname}.nc"
            toml_content[type(pkg).__name__][pkgname] = pkg_path
            dataset = pkg.dataset
            if isinstance(dataset, xu.UgridDataset):
                if mdal_compliant:
                    dataset = pkg.dataset.ugrid.to_dataset()
                    mdal_dataset = imod.util.mdal_compliant_ugrid2d(dataset)
                    mdal_dataset.to_netcdf(modeldirectory / pkg_path)
                else:
                    pkg.dataset.ugrid.to_netcdf(modeldirectory / pkg_path)
            else:
                pkg.dataset.to_netcdf(modeldirectory / pkg_path)

        toml_path = modeldirectory / f"{modelname}.toml"
        with open(toml_path, "wb") as f:
            tomli_w.dump(toml_content, f)

        return toml_path

    @classmethod
    def from_file(cls, toml_path):
        pkg_classes = {
            name: pkg_cls
            for name, pkg_cls in inspect.getmembers(imod.mf6, inspect.isclass)
            if issubclass(pkg_cls, Package)
        }

        toml_path = pathlib.Path(toml_path)
        with open(toml_path, "rb") as f:
            toml_content = tomli.load(f)

        parentdir = toml_path.parent
        instance = cls()
        for key, entry in toml_content.items():
            for pkgname, path in entry.items():
                pkg_cls = pkg_classes[key]
                instance[pkgname] = pkg_cls.from_file(parentdir / path)

        return instance

    def clip_box(
        self,
        time_min=None,
        time_max=None,
        layer_min=None,
        layer_max=None,
        x_min=None,
        x_max=None,
        y_min=None,
        y_max=None,
    ):
        """
        Clip a model by a bounding box (time, layer, y, x).

        Slicing intervals may be half-bounded, by providing None:

        * To select 500.0 <= x <= 1000.0:
          ``clip_box(x_min=500.0, x_max=1000.0)``.
        * To select x <= 1000.0: ``clip_box(x_min=None, x_max=1000.0)``
          or ``clip_box(x_max=1000.0)``.
        * To select x >= 500.0: ``clip_box(x_min = 500.0, x_max=None.0)``
          or ``clip_box(x_min=1000.0)``.

        Parameters
        ----------
        time_min: optional
        time_max: optional
        layer_min: optional, int
        layer_max: optional, int
        x_min: optional, float
        x_min: optional, float
        y_max: optional, float
        y_max: optional, float

        Returns
        -------
        clipped : Modflow6Model
        """
        clipped = type(self)(**self._options)
        for key, pkg in self.items():
            clipped[key] = pkg.clip_box(
                time_min=time_min,
                time_max=time_max,
                layer_min=layer_min,
                layer_max=layer_max,
                x_min=x_min,
                x_max=x_max,
                y_min=y_min,
                y_max=y_max,
            )
        return clipped


[docs]class GroundwaterFlowModel(Modflow6Model): _mandatory_packages = ("npf", "ic", "oc", "sto") _model_id = "gwf6"
[docs] def __init__( self, listing_file: str = None, print_input: bool = False, print_flows: bool = False, save_flows: bool = False, newton: bool = False, under_relaxation: bool = False, ): super().__init__() self._options = { "listing_file": listing_file, "print_input": print_input, "print_flows": print_flows, "save_flows": save_flows, "newton": newton, "under_relaxation": under_relaxation, } self._template = initialize_template("gwf-nam.j2")
def write_qgis_project(self, directory, crs, aggregate_layers=False): """ Write qgis projectfile and accompanying netcdf files that can be read in qgis. Parameters ---------- directory : Path directory of qgis project crs : str, int, anything that can be converted to a pyproj.crs.CRS aggregate_layers : Optional, bool If True, aggregate layers by taking the mean, i.e. ds.mean(dim="layer") """ ext = ".qgs" directory = pathlib.Path(directory) directory.mkdir(exist_ok=True, parents=True) pkgnames = [ pkgname for pkgname, pkg in self.items() if all(i in pkg.dataset.dims for i in ["x", "y"]) ] data_paths = [] data_vars_ls = [] for pkgname in pkgnames: pkg = self[pkgname].rio.write_crs(crs) data_path = pkg._netcdf_path(directory, pkgname) data_path = "./" + data_path.relative_to(directory).as_posix() data_paths.append(data_path) # FUTURE: MDAL has matured enough that we do not necessarily # have to write seperate netcdfs anymore data_vars_ls.append( pkg.write_netcdf(directory, pkgname, aggregate_layers=aggregate_layers) ) qgs_tree = qgs_util._create_qgis_tree( self, pkgnames, data_paths, data_vars_ls, crs ) qgs_util._write_qgis_projectfile(qgs_tree, directory / ("qgis_proj" + ext))
class GroundwaterTransportModel(Modflow6Model): """ The GroundwaterTransportModel (GWT) simulates transport of a single solute species flowing in groundwater. """ _mandatory_packages = ("mst", "dsp", "oc", "ic") _model_id = "gwt6" def __init__( self, listing_file: str = None, print_input: bool = False, print_flows: bool = False, save_flows: bool = False, ): super().__init__() self._options = { "listing_file": listing_file, "print_input": print_input, "print_flows": print_flows, "save_flows": save_flows, } self._template = initialize_template("gwt-nam.j2")