Source code for dpest.pst

import yaml
import pyemu
import tempfile
from dpest.functions import *

[docs] def pst( cultivar_parameters=None, ecotype_parameters=None, dataframe_observations=None, output_path=None, model_comand_line=None, noptmax=1000, pst_filename='PEST_CONTROL.pst', input_output_file_pairs=None ): """ Creates a ``PEST control file (.PST)`` for DSSAT crop models calibration. This file guides the model calibration process by specifying input and output files, parameter bounds, and directions for PEST to extract and compare model-generated observations with experimental data. The module takes model parameters (with their values, groupings, and bounds) and observation DataFrames as inputs. **Conditionally Required Arguments:** ======= To properly create the ``PEST control file (.PST)``, the user must specify at least one of the following arguments: * **cultivar_parameters** (*dict*, *optional, but required if ``ecotype_parameters`` is not specified*): Dictionary containing cultivar model parameters with their values, bounds, and groupings. It is obtained from the ``cul`` module (see ``dpest.wheat.ceres.cul``). * **ecotype_parameters** (*dict*, *optional, but required if ``cultivar_parameters`` is not specified*): Dictionary containing ecotype model parameters with their values, bounds, and groupings. This dictionary is obtained from the ``eco`` module (see ``dpest.wheat.ceres.eco``). **Required Arguments:** ======= * **dataframe_observations** (``pd.DataFrame`` or ``list``): DataFrame or list of DataFrames containing observations to be used during model calibration and included in the ``PEST control file (.PST)``. It can be a single dataframe as ``dataframe_observations = dataframe``, or a list of dataframes as ``dataframe_observations = [dataframe1, dataframe2]``. These DataFrames are created by the ``dpest.wheat.overview`` and ``dpest.wheat.plantgro`` modules, and each DataFrame *must* contain columns named ``'variable_name'``, ``'value_measured'``, and ``'group'``. * **model_comand_line** (*str*): Command line used to run the DSSAT model executable. * **input_output_file_pairs** (``list``): List of tuples where each tuple contains an input and output file pair. The required tuples depend on the other arguments passed to this module: * If ``cultivar_parameters`` is specified, this list *must* contain a tuple with the ``PEST template file (.TPL)`` for the cultivar and the corresponding ``DSSAT cultivar file (.CUL)``. * If ``ecotype_parameters`` is specified, this list *must* contain a tuple with the ``PEST template file (.TPL)`` for the ecotype and the corresponding ``DSSAT ecotype file (.ECO)``. * For *each* DataFrame specified in ``dataframe_observations``, this list *must* contain a tuple with the ``PEST instruction file (.INS)`` created by the ``overview`` or ``plantgro`` module and the corresponding ``OVERVIEW.OUT`` or ``PlantGro.OUT`` file. Each element on the list follows this structure: ``[(input_file1, output_file1), (input_file2, output_file2)]``. The first element of each tuple is the path to either a ``PEST template file (.TPL)`` or a ``PEST instruction file (.INS)``, and the second element is the path to the corresponding DSSAT input or output file. **Optional Arguments:** ======= * **output_path** (*str*, *default: current working directory*): Directory to save the ``PEST control file (.PST)``. By default, the file is created in the same directory where the script is located. * **noptmax** (*int*, *default: 1000*): Maximum number of iterations for the optimization process. * **pst_filename** (*str*, *default: "PEST_CONTROL.pst"*): File name for the ``PEST control file (.PST)`` to be created. **Returns:** ======= * ``None``: This module creates the ``PEST control file (.PST)`` at the specified ``output_path`` (or in the script's directory by default) with the provided ``pst_filename``. It validates inputs, processes observation data, sets up parameters, and writes the resulting ``PEST control file (.PST)``. **Examples:** ======= 1. **Creating a PEST Control File with Cultivar and Ecotype Parameters, End-of-Season Crop Performance Metrics, and Plant Growth Dynamics:** .. code-block:: python from dpest import pst pst( cultivar_parameters = cultivar_parameters, ecotype_parameters = ecotype_parameters, dataframe_observations = [overview_observations, plantgro_observations], model_comand_line = r'py "C:/pest18/run-dssat.py"', input_output_file_pairs = [ (cultivar_tpl_path, 'C://DSSAT48/Genotype/WHCER048.CUL'), (ecotype_tpl_path, 'C://DSSAT48/Genotype/WHCER048.ECO'), (overview_ins_path, 'C://DSSAT48/Wheat/OVERVIEW.OUT'), (plantgro_ins_path, 'C://DSSAT48/Wheat/PlantGro.OUT') ] ) This example shows how to create a ``PEST control file (.PST)`` using both cultivar and ecotype parameters. The ``dataframe_observations`` argument is assigned a list of two DataFrames: (1) end-of-season crop performance metrics created using the ``dpest.wheat.overview`` module, and (2) plant growth dynamics data created using the ``dpest.wheat.plantgro`` module. The example specifies the model command line and lists the required input and output file pairs. 2. **Creating a PEST Control File with Only Cultivar Parameters, Model Performance Metrics, and Plant Growth Data:** .. code-block:: python from dpest import pst pst( cultivar_parameters = cultivar_parameters, dataframe_observations = [overview_observations, plantgro_observations], model_comand_line = r'py "C:/pest18/run-dssat.py"', input_output_file_pairs = [ (cultivar_tpl_path, 'C://DSSAT48/Genotype/WHCER048.CUL'), (overview_ins_path, 'C://DSSAT48/Wheat/OVERVIEW.OUT'), (plantgro_ins_path, 'C://DSSAT48/Wheat/PlantGro.OUT') ] ) This example demonstrates how to create a ``PEST control file (.PST)`` using only cultivar parameters. The ``dataframe_observations`` argument uses a list of two DataFrames: one representing model performance data created by the ``dpest.wheat.overview`` module, and another containing plant growth data created by the ``dpest.wheat.plantgro`` module. 3. **Creating a PEST Control File with Cultivar Parameters and Just Plant Growth Data:** .. code-block:: python from dpest import pst pst( cultivar_parameters = cultivar_parameters, dataframe_observations = plantgro_observations, model_comand_line=r'py "C:/pest18/run-dssat.py"', input_output_file_pairs = [ (cultivar_tpl_path, 'C://DSSAT48/Genotype/WHCER048.CUL'), (plantgro_ins_path, 'C://DSSAT48/Wheat/PlantGro.OUT') ] ) This example shows the use of a single ``dataframe_observations`` argument containing plant growth dynamics metrics created with the ``dpest.wheat.plantgro`` module, along with the cultivar parameters and the appropriate input and output file pairs. """ # Define default variables yml_pst_file_block = 'PST_FILE' yml_file_observation_groups = 'OBSERVATION_GROUPS_SPECIFICATIONS' try: ## Get the yaml_data # Get the directory of the current script current_dir = os.path.dirname(os.path.abspath(__file__)) # Construct the path to arguments.yml arguments_file = os.path.join(current_dir, 'arguments.yml') # Ensure the YAML file exists if not os.path.isfile(arguments_file): raise FileNotFoundError(f"YAML file not found: {arguments_file}") # Load YAML configuration with open(arguments_file, 'r') as yml_file: yaml_data = yaml.safe_load(yml_file) # Validate inputs if not (cultivar_parameters or ecotype_parameters): raise ValueError( "At least one of `cultivar_parameters` or `ecotype_parameters` must be provided and non-empty.") if cultivar_parameters and not isinstance(cultivar_parameters, dict): raise ValueError("`cultivar_parameters`, if provided, must be a dictionary.") if ecotype_parameters and not isinstance(ecotype_parameters, dict): raise ValueError("`ecotype_parameters`, if provided, must be a dictionary.") # Additional validation for file extensions based on parameters if cultivar_parameters: if not any(pair[1].lower().endswith('.cul') for pair in input_output_file_pairs): raise ValueError( "If `cultivar_parameters` is provided, at least one file in `input_output_file_pairs` must have a '.CUL' extension.") if ecotype_parameters: if not any(pair[1].lower().endswith('.eco') for pair in input_output_file_pairs): raise ValueError( "If `ecotype_parameters` is provided, at least one file in `input_output_file_pairs` must have a '.ECO' extension.") # Validate that at least one file has a '.OUT' extension if not any(pair[1].lower().endswith('.out') for pair in input_output_file_pairs): raise ValueError("At least one file in `input_output_file_pairs` must have a '.OUT' extension.") if dataframe_observations is None: raise ValueError("`dataframe_observations` must be provided.") # Convert single dataframe to list for consistent processing if isinstance(dataframe_observations, pd.DataFrame): dataframe_observations = [dataframe_observations] if not isinstance(dataframe_observations, list) or not all( isinstance(df, pd.DataFrame) for df in dataframe_observations): raise ValueError("`dataframe_observations` must be a DataFrame or a list of DataFrames.") required_columns = {'variable_name', 'value_measured', 'group'} for df in dataframe_observations: if not required_columns.issubset(df.columns): raise ValueError( "Each DataFrame in `dataframe_observations` must contain 'variable_name', 'value_measured', and 'group' columns.") # Get Parameter Group Variables observation_groups = yaml_data[yml_pst_file_block][yml_file_observation_groups] # Merge dictionaries if both are provided, or use the one that exists parameters = { 'parameters': {**(cultivar_parameters.get('parameters', {}) if cultivar_parameters else {}), **(ecotype_parameters.get('parameters', {}) if ecotype_parameters else {})}, 'minima_parameters': {**(cultivar_parameters.get('minima_parameters', {}) if cultivar_parameters else {}), **(ecotype_parameters.get('minima_parameters', {}) if ecotype_parameters else {})}, 'maxima_parameters': {**(cultivar_parameters.get('maxima_parameters', {}) if cultivar_parameters else {}), **(ecotype_parameters.get('maxima_parameters', {}) if ecotype_parameters else {})}, 'parameters_grouped': {**(cultivar_parameters.get('parameters_grouped', {}) if cultivar_parameters else {}), **(ecotype_parameters.get('parameters_grouped', {}) if ecotype_parameters else {})} } # Extract cultivar_parameters all_params = [ param for group in parameters['parameters_grouped'].values() for param in group.replace(' ', '').split(',') ] # Create a minimal PST object pst = pyemu.pst_utils.generic_pst(all_params) # Populate parameters for param in all_params: pst.parameter_data.loc[param, 'parval1'] = float(parameters['parameters'][param]) pst.parameter_data.loc[param, "parlbnd"] = float(parameters['minima_parameters'][param]) pst.parameter_data.loc[param, "parubnd"] = float(parameters['maxima_parameters'][param]) pst.parameter_data.loc[param, "pargp"] = next( (group for group, params in parameters['parameters_grouped'].items() if param in params.split(', ')), None) # Add PARTRANS and PARCHGLIM pst.parameter_data.loc[param, "partrans"] = "none" # Set PARTRANS to none pst.parameter_data.loc[param, "parchglim"] = "relative" # Set PARCHGLIM to relative # Create parameter groups using values from observation_groups pargp_data = [] for group in parameters['parameters_grouped'].keys(): pargp_entry = {"pargpnme": group} # Start with the group name pargp_entry.update(observation_groups) # Update with values from observation_groups pargp_data.append(pargp_entry) # Convert parameter groups list to DataFrame pst.parameter_groups = pd.DataFrame(pargp_data) # Clear existing observation data pst.observation_data = pst.observation_data.iloc[0:0] # Process all dataframes for df in dataframe_observations: # Validate and clean observation data df['value_measured'] = pd.to_numeric(df['value_measured'], errors='coerce') df = df.dropna(subset=['value_measured']) for index, row in df.iterrows(): obsnme = row['variable_name'] obsval = row['value_measured'] obgnme = row['group'] pst.observation_data.loc[obsnme, 'obsnme'] = obsnme pst.observation_data.loc[obsnme, 'obsval'] = obsval pst.observation_data.loc[obsnme, 'obgnme'] = obgnme pst.observation_data.loc[obsnme, 'weight'] = 1.0 # Default weight # ~~~~~~~~ Handle input and output files if input_output_file_pairs: # Validate file pairs if not all(len(pair) == 2 for pair in input_output_file_pairs): raise ValueError("Each input_output_file_pair must contain exactly two elements") if not all(pair[0].lower().endswith(('.tpl', '.ins')) for pair in input_output_file_pairs): raise ValueError("The first element of each pair must be a .tpl or .ins file") # Validate file existence for pair in input_output_file_pairs: validate_file_path(pair[0]) # Validate PEST file (TPL or INS) validate_file_path(pair[1]) # Validate model file # Function to count TPL and INS files def count_file_types(file_pairs): tpl_count = sum(1 for pair in file_pairs if pair[0].lower().endswith('.tpl')) ins_count = sum(1 for pair in file_pairs if pair[0].lower().endswith('.ins')) return tpl_count, ins_count # Add quotes to escape spaces def escape_spaces(file_pairs): return [ (f'"{pair[0]}"' if ' ' in pair[0] else pair[0], f'"{pair[1]}"' if ' ' in pair[1] else pair[1]) for pair in file_pairs ] # Escape spaces in paths input_output_file_pairs = escape_spaces(input_output_file_pairs) # Count TPL and INS files tpl_count, ins_count = count_file_types(input_output_file_pairs) # Set input files (TPL files) pst.model_input_data = pd.DataFrame({ 'pest_file': [pair[0] for pair in input_output_file_pairs if pair[0].strip('"').lower().endswith('.tpl')], 'model_file': [pair[1] for pair in input_output_file_pairs if pair[0].strip('"').lower().endswith('.tpl')] }) # Set output files (INS files) pst.model_output_data = pd.DataFrame({ 'pest_file': [pair[0] for pair in input_output_file_pairs if pair[0].strip('"').lower().endswith('.ins')], 'model_file': [pair[1] for pair in input_output_file_pairs if pair[0].strip('"').lower().endswith('.ins')] }) # Set NTPLFLE and NINSFLE pst.control_data.ntplfle = tpl_count pst.control_data.ninsfle = ins_count # ~~~~~~~~/ Handle input and output files # Set NUMCOM, JACFILE, and MESSFILE pst.control_data.numcom = 1 pst.control_data.jacfile = 0 pst.control_data.messfile = 0 # Set mode of operation to use pst.pestmode = "estimation" # ~~~~~~~~ Customize SVD section as a custom attribute # Store the original write method original_write = pst.write # Define a new write method that updates the SVD section def custom_write(self, filename): # First, write to a temporary file with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp_file: original_write(temp_file.name) temp_filename = temp_file.name # Read the content of the temporary file with open(temp_filename, 'r') as f: content = f.read() # Compute SVD defaults based on number of parameters npar = self.npar svdmode = 1 # enable SVD maxsing = npar # allow up to number of parameters eigthresh = 5e-7 # recommended in PEST manual for most cases eigwrite = 0 # only singular values, smaller .svd file # Build SVD section text (matches PEST format) svd_section = ( "* singular value decomposition\n" f" {svdmode}\n" f" {maxsing} {eigthresh:.6E}\n" f" {eigwrite}\n" ) # Replace existing SVD section, or insert a new one after * control data if re.search(r'\* singular value decomposition.*?(?=\*|$)', content, flags=re.DOTALL | re.IGNORECASE): # Replace existing SVD block content = re.sub( r'\* singular value decomposition.*?(?=\*|$)', svd_section, content, flags=re.DOTALL | re.IGNORECASE ) else: # Insert SVD block immediately after the * control data section content = re.sub( r'(\* control data.*?(?=\*|$))', r'\1\n' + svd_section, content, flags=re.DOTALL | re.IGNORECASE ) # Write modified content to the final file with open(filename, 'w') as f: f.write(content) # Remove the temporary file os.unlink(temp_filename) # Replace the write method pst.write = custom_write.__get__(pst) # ~~~~~~~~/ Customize SVD section as a custom attribute # # ~~~~~~~~ Add LSQR section as a custom attribute # # pst.lsqr_data = { # "lsqrmode": 1, # "lsqr_atol": 1e-4, # "lsqr_btol": 1e-4, # "lsqr_conlim": 28.0, # "lsqr_itnlim": 28, # "lsqrwrite": 0 # } # # # Store the original write method # original_write = pst.write # # # Define a new write method that replaces SVD with LSQR # def custom_write(self, filename): # # First, write to a temporary file # with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp_file: # original_write(temp_file.name) # temp_filename = temp_file.name # # # Read the content of the temporary file # with open(temp_filename, 'r') as f: # content = f.read() # # # Replace SVD section with LSQR # lsqr_section = f"* lsqr\n {self.lsqr_data['lsqrmode']}\n {self.lsqr_data['lsqr_atol']} {self.lsqr_data['lsqr_btol']} {self.lsqr_data['lsqr_conlim']} {self.lsqr_data['lsqr_itnlim']}\n {self.lsqr_data['lsqrwrite']}\n" # content = re.sub(r'\* singular value decomposition.*?(?=\*|$)', lsqr_section, content, flags=re.DOTALL) # # # Write modified content to the final file # with open(filename, 'w') as f: # f.write(content) # # # Remove the temporary file # os.unlink(temp_filename) # # # Replace the write method # pst.write = custom_write.__get__(pst) # # # ~~~~~~~~/ Add LSQR section as a custom attribute # Set additional control data parameters pst.control_data.rlambda1 = 10.0 pst.control_data.numlam = 10 pst.control_data.icov = 1 pst.control_data.icor = 1 pst.control_data.ieig = 1 # Add the the command used to run the model executable pst.model_command = [model_comand_line] # Add number of iteractions pst.control_data.noptmax = noptmax # Validate output_path output_path = validate_output_path(output_path) # Create the path and name for the file ouput pst_file_path = os.path.join(output_path, pst_filename) # Write the PST file pst.write(pst_file_path) print(f"PST file successfully created: {pst_file_path}") except ValueError as ve: print(f"ValueError: {ve}") except FileNotFoundError as fe: print(f"FileNotFoundError: {fe}") except Exception as e: print(f"An unexpected error occurred: {e}")