Sindbad.MachineLearning Module
MachineLearningThe MachineLearning module provides the core functionality for integrating machine learning (ML) and hybrid modeling capabilities into the SINDBAD framework.
Purpose
This module brings together all components required for hybrid (process-based + ML) modeling in SINDBAD, including data preparation, model construction, training routines, gradient computation, and optimizer management. It supports flexible configuration, cross-validation, and seamless integration with SINDBAD's process-based modeling workflows.
Dependencies
Related (SINDBAD ecosystem)
OmniTools: Shared helpers used indirectly via internal modules.
External (third-party)
Base.Iterators: Iterators for batching and repetition (repeated,partition).Distributed: Parallel and distributed computing utilities (nworkers,pmap,workers,nprocs,CachingPool).JLD2: Saving/loading model artifacts (e.g. fold indices, checkpoints).ProgressMeter: Progress utilities (imported symbols only).
Internal (within Sindbad)
Sindbad.DataLoadersSindbad.SetupSindbad.SimulationSindbad.TypesSindbadTEM
Optional dependencies (weakdeps / experimental)
Some ML training/AD backends are listed as weak dependencies in the root Project.toml (e.g. Zygote, ForwardDiff, Optimisers, PreallocationTools, etc.) and are enabled via extensions. Flux is a hard dependency of this module.
Included Files
utilsMachineLearning.jl: Utility functions for machine-learning workflows.diffCaches.jl: Caching utilities for differentiation.activationFunctions.jl: Implements various activation functions, including custom and Flux-provided activations.mlModels.jl: Constructors and utilities for hybrid/ML model components.mlOptimizers.jl: Functions for creating and configuring optimizers for ML training (backend-dependent).loss.jl: Loss functions and utilities for evaluating model performance and computing gradients.prepHybrid.jl: Prepares data structures and loss definitions for hybrid modeling (including data splits and feature extraction).mlGradient.jl: Routines for computing gradients using different libraries and methods, supporting both automatic and finite difference differentiation.mlTrain.jl: Training routines for ML and hybrid models, including batching, checkpointing, and evaluation.neuralNetwork.jl: Neural network utilities and architectures.siteLosses.jl: Site-specific loss calculation utilities.oneHots.jl: One-hot encoding utilities.loadCovariates.jl: Functions for loading and handling covariate data.
Notes
The module is modular and extensible, allowing users to add new ML models, optimizers, activation functions, and training methods.
It is tightly integrated with the SINDBAD ecosystem, ensuring consistent data handling and reproducibility across hybrid and process-based modeling workflows.
Functions
JoinDenseNN
Sindbad.MachineLearning.JoinDenseNN Function
JoinDenseNN(models::Tuple)Arguments:
- models :: a tuple of models, i.e. (m1, m2)
Returns:
- all parameters as a vector or matrix (multiple samples)
Example
using Sindbad.MachineLearning
using Flux
using Random
Random.seed!(123)
m_big = Chain(Dense(4 => 5, relu), Dense(5 => 3), Flux.sigmoid)
m_eta = Dense(1=>1, Flux.sigmoid)
x_big_a = rand(Float32, 4, 10)
x_small_a1 = rand(Float32, 1, 10)
x_small_a2 = rand(Float32, 1, 10)
model = JoinDenseNN((m_big, m_eta))
model((x_big_a, x_small_a2))Code
function JoinDenseNN(models::Tuple)
return Chain(Join(vcat, models...))
endactivationFunction
Sindbad.MachineLearning.activationFunction Function
activationFunction(model_options, act::AbstractActivation)Return the activation function corresponding to the specified activation type and model options.
This function dispatches on the activation type to provide the appropriate activation function for use in neural network layers. For custom activation types, relevant parameters can be passed via model_options.
Arguments
model_options: A struct or NamedTuple containing model options, including parameters for custom activation functions (e.g.,k_σforCustomSigmoid).act: An activation type specifying the desired activation function. Supported types include:FluxRelu: Rectified Linear Unit (ReLU) activation.FluxTanh: Hyperbolic Tangent (tanh) activation.FluxSigmoid: Sigmoid activation.CustomSigmoid: Custom sigmoid activation with steepness parameterk_σ.
Returns
- A callable activation function suitable for use in neural network layers.
Example
act_fn = activationFunction(model_options, FluxRelu())
y = act_fn(x)Code
function activationFunction end
function activationFunction(_, ::FluxRelu)
return Flux.relu
end
function activationFunction(_, ::FluxRelu)
return Flux.relu
end
function activationFunction(_, ::FluxTanh)
return Flux.tanh
end
function activationFunction(_, ::FluxSigmoid)
return Flux.sigmoid
end
function activationFunction(model_options, ::CustomSigmoid)
sigmoid_k(x, K) = one(x) / (one(x) + exp(-K * x))
custom_sigmoid = x -> sigmoid_k(x, model_options.k_σ)
return custom_sigmoid
enddenseNN
Sindbad.MachineLearning.denseNN Function
denseNN(in_dim::Int, n_neurons::Int, out_dim::Int; extra_hlayers=0, activation_hidden=Flux.relu, activation_out= Flux.sigmoid, seed=1618)Arguments
in_dim: input dimensionn_neurons: number of neurons in each hidden layerout_dim: output dimensionextra_hlayers=0: controls the number of extra hidden layers, default iszeroactivation_hidden=Flux.relu: activation function within hidden layers, default is Reluactivation_out= Flux.sigmoid: activation of output layer, default is sigmoidseed=1618: Random seed, default is ~ (1+√5)/2
Returns a Flux.Chain neural network.
Code
function denseNN(in_dim::Int, n_neurons::Int, out_dim::Int;
extra_hlayers=0,
activation_hidden=Flux.relu,
activation_out=Flux.sigmoid,
seed=1618)
Random.seed!(seed)
return Flux.Chain(Flux.Dense(in_dim => n_neurons, activation_hidden),
[Flux.Dense(n_neurons, n_neurons, activation_hidden) for _ in 0:(extra_hlayers-1)]...,
Flux.Dense(n_neurons => out_dim, activation_out))
enddestructureNN
Sindbad.MachineLearning.destructureNN Function
destructureNN(model; nn_opt=Optimisers.Adam())Given a model returns a flat vector with all weights, a re structure of the neural network and the current state.
Arguments
model: a Flux.Chain neural network.nn_opt: Optimiser, the default isOptimisers.Adam().
Returns:
flat :: a flat vector with all network weights
re :: an object containing the model structure, used later to
reconstruct the neural networkopt_state :: the state of the optimiser
Code
function destructureNN(model; nn_opt=Optimisers.Adam())
flat, re = Optimisers.destructure(model)
opt_state = Optimisers.setup(nn_opt, flat)
return flat, re, opt_state
endepochLossComponents
Sindbad.MachineLearning.epochLossComponents Function
epochLossComponents(loss_functions::F, loss_array_sites, loss_array_components, epoch_number, scaled_params, sites_list) where {F}Compute and store the loss metrics and loss components for each site in parallel for a given training epoch.
This function evaluates the provided loss functions for each site using the current scaled parameters, and stores the resulting scalar loss metrics and loss component vectors in the corresponding arrays for the specified epoch. Parallel execution is used to accelerate computation across sites.
Arguments
loss_functions::F: An array or KeyedArray of loss functions, one per site (whereFis a subtype ofAbstractArray{<:Function}).loss_array_sites: A matrix to store the scalar loss metric for each site and epoch (dimensions: site × epoch).loss_array_components: A 3D tensor to store the loss components for each site, component, and epoch (dimensions: site × component × epoch).epoch_number: The current epoch number (integer).scaled_params: A callable or array providing the scaled parameters for each site (e.g.,scaled_params(site=site_name)).sites_list: List or array of site identifiers to process.
Notes
The function uses Julia's threading (
Threads.@spawn) to compute losses for multiple sites in parallel.Each site's loss metric and components are stored at the corresponding index for the current epoch.
Designed for use within training loops to track loss evolution over epochs.
Example
epochLossComponents(loss_functions, loss_array_sites, loss_array_components, epoch, scaled_params, sites)Code
function epochLossComponents(loss_functions::F, loss_array_sites, loss_array_components, epoch_number, scaled_params, sites_list) where {F}
@sync begin
for idx ∈ eachindex(sites_list)
Threads.@spawn begin
site_name = sites_list[idx]
loc_params = scaled_params(site=site_name)
loss_f = loss_functions(site=site_name)
loss_metric, loss_components, loss_indices = loss_f(loc_params)
loss_array_sites[idx, epoch_number] = loss_metric
loss_array_components[idx, loss_indices, epoch_number] = loss_components
end
end
end
endgetCacheFromOutput
Sindbad.MachineLearning.getCacheFromOutput Function
getCacheFromOutput(loc_output, ::MachineLearningGradType)
getCacheFromOutput(loc_output, ::ForwardDiffGrad)
getCacheFromOutput(loc_output, ::PolyesterForwardDiffGrad)Returns the appropriate Cache type based on the automatic differentiation or finite differences package being used.
Arguments
loc_output: The local outputSecond argument specifies the differentiation method:
ForwardDiffGrad: Uses ForwardDiff.jl for automatic differentiationMachineLearningGradType: All other libraries, e.g., FiniteDiff.jl,FiniteDifferences.jl, etc. for gradient calculationsPolyesterForwardDiffGrad: Uses PolyesterForwardDiff.jl for automatic differentiation
Code
function getCacheFromOutput(loc_output, ::MachineLearningGradType)
return loc_output
end
function getCacheFromOutput(loc_output, ::PolyesterForwardDiffGrad)
return getCacheFromOutput(loc_output, ForwardDiffGrad())
end
function getCacheFromOutput end
function getOutputFromCache(loc_output, _, ::MachineLearningGradType)
return loc_output
endgetIndicesSplit
Sindbad.MachineLearning.getIndicesSplit Function
getIndicesSplit(info, sites, fold_type)Determine the indices for training, validation, and testing site splits for hybrid (ML) modeling in SINDBAD.
This function dispatches on the fold_type argument to either load precomputed folds from file or to compute the splits on-the-fly based on the provided split ratios and number of folds.
Arguments
info: The SINDBAD experiment info structure, containing hybrid modeling configuration.sites: Array of site identifiers (e.g., site names or indices).fold_type: Determines the splitting strategy. UseLoadFoldFromFile()to load folds from file, orCalcFoldFromSplit()to compute splits dynamically.
Returns
indices_training: Indices of sites assigned to the training set.indices_validation: Indices of sites assigned to the validation set.indices_testing: Indices of sites assigned to the testing set.
Notes
When using
LoadFoldFromFile, the function loads fold indices from the file specified ininfo.hybrid.fold.fold_path.When using
CalcFoldFromSplit, the function splits the sites according to the ratios and number of folds specified ininfo.hybrid.ml_training.options.Ensures reproducibility by using the random seed from
info.hybrid.random_seedwhen shuffling sites.
Example
indices_train, indices_val, indices_test = getIndicesSplit(info, sites, info.hybrid.fold.fold_type)Code
function getIndicesSplit end
function getIndicesSplit(info, _, ::LoadFoldFromFile)
# load the folds from file
path_data_folds = info.hybrid.fold.fold_path
n_fold = info.hybrid.fold.which_fold
data_folds = load(path_data_folds)
indices_training = data_folds["unfold_training"][n_fold]
indices_validation = data_folds["unfold_validation"][n_fold]
indices_testing = data_folds["unfold_tests"][n_fold]
return indices_training, indices_validation, indices_testing
end
function getIndicesSplit(info, _, ::LoadFoldFromFile)
# load the folds from file
path_data_folds = info.hybrid.fold.fold_path
n_fold = info.hybrid.fold.which_fold
data_folds = load(path_data_folds)
indices_training = data_folds["unfold_training"][n_fold]
indices_validation = data_folds["unfold_validation"][n_fold]
indices_testing = data_folds["unfold_tests"][n_fold]
return indices_training, indices_validation, indices_testing
end
function getIndicesSplit(info, site_indices, ::CalcFoldFromSplit)
site_indices = collect(eachindex(site_indices)) # Ensure site_indices is an array of indices
# split the sites into training, validation and testing
n_fold = info.hybrid.ml_training.options.n_folds
split_ratio = info.hybrid.ml_training.options.split_ratio
test_ratio = split_ratio[2]
val_ratio = split_ratio[1]
train_ratio = 1 - test_ratio - val_ratio
@assert train_ratio + val_ratio + test_ratio ≈ 1.0 "Ratios must sum to 1.0"
return getNFolds(site_indices, train_ratio, val_ratio, test_ratio, n_fold, info.hybrid.ml_training.options.batch_size; seed=info.hybrid.random_seed)
endgetInnerArgs
Sindbad.MachineLearning.getInnerArgs Function
getInnerArgs(idx, grads_lib, scaled_params_batch, parameter_scaling_type, selected_models, space_forcing, space_spinup_forcing, loc_forcing_t, space_output, loc_land, tem_info, parameter_to_index, parameter_scaling_type, space_observations, cost_options, constraint_method, indices_batch, sites_batch)Function to get inner arguments for the loss function.
Arguments
idx: index batch valuegrads_lib: gradient libraryscaled_params_batch: scaled parameters batchselected_models: selected modelsspace_forcing: forcing data locationspace_spinup_forcing: spinup forcing data locationloc_forcing_t: forcing data time for one time step.space_output: output data locationloc_land: initial land statetem_info: model informationparameter_to_index: parameter to indexparameter_scaling_type: type determining parameter scalingloc_observations: observation data locationcost_options: cost optionsconstraint_method: constraint methodindices_batch: indices batchsites_batch: sites batch
Code
function getInnerArgs(idx, grads_lib,
scaled_params_batch, # ? input_args
selected_models,
space_forcing,
space_spinup_forcing,
loc_forcing_t,
space_output,
loc_land,
tem_info,
parameter_to_index,
parameter_scaling_type,
space_observations,
cost_options,
constraint_method,
indices_batch,
sites_batch)
site_location = indices_batch[idx]
site_name = sites_batch[idx]
# get site information
x_vals = scaled_params_batch(site=site_name).data.data
loc_forcing = space_forcing[site_location]
loc_obs = space_observations[site_location]
loc_output = space_output[site_location]
loc_spinup_forcing = space_spinup_forcing[site_location]
loc_cost_option = cost_options[site_location]
return (;
loc_params = x_vals,
inner_args = (
selected_models,
loc_forcing,
loc_spinup_forcing,
loc_forcing_t,
getCacheFromOutput(loc_output, grads_lib),
deepcopy(loc_land),
tem_info,
parameter_to_index,
parameter_scaling_type,
loc_obs,
loc_cost_option,
constraint_method)
)
endgetLossForSites
Sindbad.MachineLearning.getLossForSites Function
getLossForSites(gradient_lib, loss_function::F, loss_array_sites, loss_array_split, epoch_number, scaled_params, sites_list, indices_sites, models, space_forcing, space_spinup_forcing, loc_forcing_t, space_output, loc_land, tem_info, parameter_to_index, parameter_scaling_type, space_observations, cost_options, constraint_method) where {F}Calculates the loss for all sites. The loss is calculated using the loss_function function. The loss_array_sites and loss_array_split arrays are updated with the loss values. The loss_array_sites array stores the loss values for each site and epoch, while the loss_array_split array stores the loss values for each model output and epoch.
Arguments
gradient_lib: gradient libraryloss_function: loss functionloss_array_sites: array to store the loss values for each site and epochloss_array_split: array to store the loss values for each model output and epochepoch_number: epoch numberscaled_params: scaled parameterssites_list: list of sitesindices_sites: indices of sitesmodels: list of modelsspace_forcing: forcing data locationspace_spinup_forcing: spinup forcing data locationloc_forcing_t: forcing data time for one time step.space_output: output data locationloc_land: initial land statetem_info: model informationparameter_to_index: parameter to indexspace_observations: observation data locationcost_options: cost optionsconstraint_method: constraint method
Code
function getLossForSites(gradient_lib, loss_function::F, loss_array_sites, loss_array_split, epoch_number,
scaled_params, sites_list, indices_sites, models, space_forcing, space_spinup_forcing,
loc_forcing_t, space_output, loc_land, tem_info, parameter_to_index, parameter_scaling_type, space_observations,
cost_options, constraint_method) where {F}
@sync begin
for idx ∈ eachindex(indices_sites)
Threads.@spawn begin
site_location = indices_sites[idx]
site_name = sites_list[idx]
loc_params = scaled_params(site=site_name)
loc_forcing = space_forcing[site_location]
loc_obs = space_observations[site_location]
loc_output = space_output[site_location]
loc_spinup_forcing = space_spinup_forcing[site_location]
loc_cost_option = cost_options[site_location]
gg, gg_split, loss_indices = loss_function(loc_params, gradient_lib, models, loc_forcing, loc_spinup_forcing,
loc_forcing_t, loc_output, deepcopy(loc_land), tem_info, parameter_to_index, parameter_scaling_type, loc_obs, loc_cost_option, constraint_method;
optim_mode=false)
loss_array_sites[idx, epoch_number] = gg
# @show gg_split, idx, loss_indices, epoch_number
loss_array_split[idx, loss_indices, epoch_number] = gg_split
end
end
end
endgetLossFunctionHandles
Sindbad.MachineLearning.getLossFunctionHandles Function
getLossFunctionHandles(info, run_helpers, sites)Construct loss function handles for each site for use in hybrid (ML) modeling in SINDBAD.
This function generates callable loss functions and loss component functions for each site, encapsulating all necessary arguments and configuration from the experiment info and runtime helpers. These handles are used during training and evaluation to compute the loss and its components for each site efficiently.
Arguments
info: The SINDBAD experiment info structure, containing model, optimization, and hybrid configuration.run_helpers: Helper object returned byprepTEM, containing prepared model, forcing, observation, and output structures.sites: Array of site indices or identifiers for which to build loss functions.
Returns
loss_functions: AKeyedArrayof callable loss functions, one per site. Each function takes model parameters as input and returns the scalar loss for that site.loss_component_functions: AKeyedArrayof callable functions, one per site, that return the vector of loss components (e.g., for multi-objective or constraint-based loss).
Notes
Each loss function is closed over all required data and options for its site, including model structure, parameter indices, scaling, forcing, observations, output cache, cost options, and hybrid/optimization settings.
The returned arrays are keyed by site for convenient lookup and iteration.
Example
loss_functions, loss_component_functions = getLossFunctionHandles(info, run_helpers, sites)
site_loss = loss_functions[site_index](params)
site_loss_components = loss_component_functions[site_index](params)Code
function getLossFunctionHandles(info, run_helpers, sites)
loss_functions = []
loss_component_functions = []
for site_location in eachindex(sites)
parameter_to_index = getParameterIndices(info.models.forward, info.optimization.parameter_table)
loc_forcing = run_helpers.space_forcing[site_location]
loc_obs = run_helpers.space_observation[site_location]
loc_output = getCacheFromOutput(run_helpers.space_output[site_location], info.hybrid.ml_gradient.method)
loc_spinup_forcing = run_helpers.space_spinup_forcing[site_location]
loc_cost_option = prepCostOptions(loc_obs, info.optimization.cost_options)
loss_tmp(x) = loss(x, info.models.forward, parameter_to_index, info.optimization.run_options.parameter_scaling, loc_forcing, loc_spinup_forcing, run_helpers.loc_forcing_t, loc_output, deepcopy(run_helpers.loc_land), run_helpers.tem_info, loc_obs, loc_cost_option, info.optimization.run_options.multi_constraint_method, info.hybrid.ml_gradient.method, info.hybrid.ml_training.options.loss_function)
loss_vector_tmp(x) = lossComponents(x, info.models.forward, parameter_to_index, info.optimization.run_options.parameter_scaling, loc_forcing, loc_spinup_forcing, run_helpers.loc_forcing_t, loc_output, deepcopy(run_helpers.loc_land), run_helpers.tem_info, loc_obs, loc_cost_option, info.optimization.run_options.multi_constraint_method, info.hybrid.ml_gradient.method, info.hybrid.ml_training.options.loss_function)
push!(loss_functions, loss_tmp)
push!(loss_component_functions, loss_vector_tmp)
end
loss_functions = MachineLearning.KeyedArray(loss_functions; site=sites)
loss_component_functions = MachineLearning.KeyedArray(loss_component_functions; site=sites)
return loss_functions, loss_component_functions
endgetOutputFromCache
Sindbad.MachineLearning.getOutputFromCache Function
getOutputFromCache(loc_output, _, ::MachineLearningGradType)
getOutputFromCache(loc_output, new_params, ::ForwardDiffGrad)
getOutputFromCache(loc_output, new_params, ::PolyesterForwardDiffGrad)Retrieves output values from Cache based on the differentiation method being used.
Arguments
loc_output: The cached output values_ornew_params: Additional parameters (only used with ForwardDiff)Third argument specifies the differentiation method:
MachineLearningGradType: Returns cached output directly when using other libraries, e.g., FiniteDiff.jl, FiniteDifferences.jl, etc.ForwardDiffGrad: Processes cached output with new parameters when using ForwardDiff.jl, returnsget_tmp.(loc_output, (new_params,))PolyesterForwardDiffGrad: Calls cached output with new parameters using ForwardDiff.jl
Code
function getOutputFromCache(loc_output, _, ::MachineLearningGradType)
return loc_output
end
function getOutputFromCache(loc_output, new_params, ::PolyesterForwardDiffGrad)
return getOutputFromCache(loc_output, new_params, ForwardDiffGrad())
endgetParamsAct
Sindbad.MachineLearning.getParamsAct Function
getParamsAct(x, parameter_table)Scales x values in the [0,1] interval to some given lower lo_b and upper up_b bounds.
Arguments
x: vector arrayparameter_table: a Table with input fieldsdefault,lowerandupperthat match thexvector.
Returns a vector array with new values scaled into the new interval [lower, upper].
Code
function getParamsAct(x, parameter_table)
lo_b = oftype(parameter_table.initial, parameter_table.lower)
up_b = oftype(parameter_table.initial, parameter_table.upper)
return scaleToBounds.(x, lo_b, up_b)
endgetPullback
Sindbad.MachineLearning.getPullback Function
getPullback(flat, re, features::AbstractArray)
getPullback(flat, re, features::Tuple)Arguments:
flat :: weight parameters.
re :: model structure (vanilla Chain Dense Layers).
features ::
npredictors andssamples.A vector of predictors
A matrix of predictors:
(p_n x s)A tuple vector of predictors:
(p1, p2)A tuple of matrices of predictors:
[(p1_n x s), (p2_n x s)]
Returns:
- new parameters and pullback function
Example
Here we do one input features vector or matrix.
using Sindbad.MachineLearning
using Flux
# model
m = Chain(Dense(4 => 5, relu), Dense(5 => 3), Flux.sigmoid)
# features
_feat = rand(Float32, 4)
# apply
flat, re = destructureNN(m)
# Zygote
new_params, pullback_func = getPullback(flat, re, _feat)
# ? or
_feat_ns = rand(Float32, 4, 3) # `n` predictors and `s` samples.
new_params, pullback_func = getPullback(flat, re, _feat_ns)Example
Here we do one multiple input features vector or matrix.
using Sindbad.MachineLearning
using Flux
# model
m1 = Chain(Dense(4 => 5, relu), Dense(5 => 3), Flux.sigmoid)
m2 = Dense(2=>1, Flux.sigmoid)
combo_ms = JoinDenseNN((m1, m2))
# features
_feat1 = rand(Float32, 4)
_feat2 = rand(Float32, 2)
# apply
flat, re = destructureNN(combo_ms)
# Zygote
new_params, pullback_func = getPullback(flat, re, (_feat1, _feat2))
# ? or with multiple samples
_feat1_ns = rand(Float32, 4, 3) # `n` predictors and `s` samples.
_feat2_ns = rand(Float32, 2, 3) # `n` predictors and `s` samples.
new_params, pullback_func = getPullback(flat, re, (_feat1_ns, _feat2_ns))Code
function getPullback end
function getPullback(flat, re, features::AbstractArray)
new_params, pullback_func = Zygote.pullback(p -> re(p)(features), flat)
return new_params, pullback_func
end
function getPullback(flat, re, features::AbstractArray)
new_params, pullback_func = Zygote.pullback(p -> re(p)(features), flat)
return new_params, pullback_func
end
function getPullback(flat, re, features::Tuple)
new_params, pullback_func = Zygote.pullback(p -> re(p)(features), flat)
return new_params, pullback_func
endgradientBatch!
Sindbad.MachineLearning.gradientBatch! Function
gradientBatch!(grads_lib, grads_batch, chunk_size::Int, loss_f::Function, get_inner_args::Function, input_args...; showprog=false)
gradientBatch!(grads_lib, grads_batch, gradient_options::NamedTuple, loss_functions, scaled_params_batch, sites_batch; showprog=false)Compute gradients for a batch of samples in hybrid (ML) modeling in SINDBAD.
This function computes the gradients of the loss function with respect to model parameters for a batch of sites or samples, using the specified gradient library. It supports both distributed and multi-threaded execution, and can handle different gradient computation backends (e.g., PolyesterForwardDiff, ForwardDiff, FiniteDiff, etc.).
Arguments
grads_lib: Gradient computation library or method. Supported types include:PolyesterForwardDiffGrad: UsesPolyesterForwardDiff.jlfor multi-threaded chunked gradients.Other
MachineLearningGradTypesubtypes: Use their respective backend.
grads_batch: Pre-allocated array for storing batched gradients (size: n_parameters × n_samples).chunk_size: (Optional) Chunk size for threaded gradient computation (used byPolyesterForwardDiffGrad).gradient_options: (Optional) NamedTuple of gradient options (e.g., chunk size).loss_f: Loss function to be applied (for all samples).get_inner_args: Function to obtain inner arguments for the loss function.input_args: Global input arguments for the batch.loss_functions: Array or KeyedArray of loss functions, one per site.scaled_params_batch: Callable or array providing scaled parameters for each site.sites_batch: List or array of site identifiers for the batch.showprog: (Optional) Iftrue, display a progress bar during computation (default:false).
Returns
- Updates
grads_batchin-place with computed gradients for each sample in the batch.
Notes
The function automatically selects between distributed (
pmap) and multi-threaded (Threads.@spawn) execution depending on the backend and arguments.Designed for use within training loops for efficient batch gradient computation.
Example
gradientBatch!(grads_lib, grads_batch, (chunk_size=4,), loss_functions, scaled_params_batch, sites_batch; showprog=true)Code
function gradientBatch! end
function gradientBatch!(grads_lib::PolyesterForwardDiffGrad, dx_batch, chunk_size::Int,
loss_f::Function, get_inner_args::Function, input_args...; showprog=false)
mapfun = showprog ? progress_pmap : pmap
result = mapfun(CachingPool(workers()), axes(dx_batch, 2)) do idx
x_vals, inner_args = get_inner_args(idx, grads_lib, input_args...)
gradientSite(grads_lib, x_vals, chunk_size, loss_f, inner_args...)
end
for idx in axes(dx_batch, 2)
dx_batch[:, idx] = result[idx]
end
end
function gradientBatch!(grads_lib::PolyesterForwardDiffGrad, dx_batch, chunk_size::Int,
loss_f::Function, get_inner_args::Function, input_args...; showprog=false)
mapfun = showprog ? progress_pmap : pmap
result = mapfun(CachingPool(workers()), axes(dx_batch, 2)) do idx
x_vals, inner_args = get_inner_args(idx, grads_lib, input_args...)
gradientSite(grads_lib, x_vals, chunk_size, loss_f, inner_args...)
end
for idx in axes(dx_batch, 2)
dx_batch[:, idx] = result[idx]
end
end
function gradientBatch!(grads_lib::PolyesterForwardDiffGrad, dx_batch, gradient_options::NamedTuple, loss_functions, scaled_params_batch, sites_batch; showprog=false)
mapfun = showprog ? progress_pmap : pmap
result = mapfun(CachingPool(workers()), axes(dx_batch, 2)) do idx
site_name = sites_batch[idx]
loss_f = loss_functions(site=site_name)
x_vals = scaled_params_batch(site=site_name).data.data
gradientSite(grads_lib, x_vals, gradient_options, loss_f)
end
for idx in axes(dx_batch, 2)
dx_batch[:, idx] = result[idx]
end
end
function gradientBatch!(grads_lib::MachineLearningGradType, grads_batch, gradient_options::NamedTuple, loss_functions, scaled_params_batch, sites_batch; showprog=false)
# Threads.@spawn allows dynamic scheduling instead of static scheduling
# of Threads.@threads macro.
# See <https://github.com/JuliaLang/julia/issues/21017>
p = Progress(length(axes(grads_batch,2)); desc="Computing batch grads...", color=:cyan, enabled=showprog)
@sync begin
for idx ∈ axes(grads_batch, 2)
Threads.@spawn begin
site_name = sites_batch[idx]
loss_f = loss_functions(site=site_name)
x_vals = scaled_params_batch(site=site_name).data.data
gg = gradientSite(grads_lib, x_vals, gradient_options, loss_f)
grads_batch[:, idx] = gg
next!(p)
end
end
end
endgradientSite
Sindbad.MachineLearning.gradientSite Function
gradientSite(grads_lib, x_vals, chunk_size::Int, loss_f::Function, args...)
gradientSite(grads_lib, x_vals, gradient_options::NamedTuple, loss_f::Function)
gradientSite(grads_lib, x_vals::AbstractArray, gradient_options::NamedTuple, loss_f::Function)Compute gradients of the loss function with respect to model parameters for a single site using the specified gradient library.
This function dispatches on the type of grads_lib to select the appropriate differentiation backend (e.g., PolyesterForwardDiff, ForwardDiff, FiniteDiff, FiniteDifferences, Zygote, or Enzyme). It supports both threaded and single-threaded computation, as well as chunked evaluation for memory and speed trade-offs.
Arguments
grads_lib: Gradient computation library or method. Supported types include:PolyesterForwardDiffGrad: UsesPolyesterForwardDiff.jlfor multi-threaded chunked gradients.ForwardDiffGrad: UsesForwardDiff.jlfor automatic differentiation.FiniteDiffGrad: UsesFiniteDiff.jlfor finite difference gradients.FiniteDifferencesGrad: UsesFiniteDifferences.jlfor finite difference gradients.ZygoteGrad: UsesZygote.jlfor reverse-mode automatic differentiation.EnzymeGrad: UsesEnzyme.jlfor AD (experimental).
x_vals: Parameter values for which to compute gradients.chunk_size: (Optional) Chunk size for threaded gradient computation (used byPolyesterForwardDiffGrad).gradient_options: (Optional) NamedTuple of gradient options (e.g., chunk size).loss_f: Loss function to be differentiated.args...: Additional arguments to be passed to the loss function.
Returns
∇x: Array of gradients of the loss function with respect tox_vals.
Notes
On Apple M1 systems,
PolyesterForwardDiffGradfalls back to single-threadedForwardDiffdue to closure issues.The function is used internally for both site-level and batch-level gradient computation in hybridMachine Learningtraining.
Example
grads = gradientSite(ForwardDiffGrad(), x_vals, (chunk_size=4,), loss_f)Code
function gradientSite end
function gradientSite(grads_lib::MachineLearningGradType, ::Any, ::Any, ::Any)
@warn "
Gradient library `$(nameof(typeof(grads_lib)))` not implemented.
To implement a new gradient library:
- First add a new type as a subtype of `MachineLearningGradType` in `src/Types/MachineLearningTypes.jl`.
- Then, add a corresponding method.
- if it can be implemented as an internal Sindbad method without additional dependencies, implement the method in `src/MachineLearning/mlGradient.jl`.
- if it requires additional dependencies, implement the method in `ext/<extension_name>/MachineLearningGradientSite.jl` extension.
As a fallback, this function will return 10.0f0.
"
return 10.0f0
end
function gradientSite(grads_lib::MachineLearningGradType, ::Any, ::Any, ::Any)
@warn "
Gradient library `$(nameof(typeof(grads_lib)))` not implemented.
To implement a new gradient library:
- First add a new type as a subtype of `MachineLearningGradType` in `src/Types/MachineLearningTypes.jl`.
- Then, add a corresponding method.
- if it can be implemented as an internal Sindbad method without additional dependencies, implement the method in `src/MachineLearning/mlGradient.jl`.
- if it requires additional dependencies, implement the method in `ext/<extension_name>/MachineLearningGradientSite.jl` extension.
As a fallback, this function will return 10.0f0.
"
return 10.0f0
end
function gradientSite(grads_lib::PolyesterForwardDiffGrad, x_vals, chunk_size::Int, loss_f::F, args...) where {F}
loss_tmp(x) = loss_f(x, grads_lib, args...)
∇x = similar(x_vals) # pre-allocate
if occursin("arm64-apple-darwin", Sys.MACHINE) # fallback due to closure issues on M1 systems
# cfg = ForwardDiff.GradientConfig(loss_tmp, x_vals, Chunk{chunk_size}());
ForwardDiff.gradient!(∇x, loss_tmp, x_vals) # ?, add `cfg` at the end if further control is needed.
else
PolyesterForwardDiff.threaded_gradient!(loss_tmp, ∇x, x_vals, ForwardDiff.Chunk(chunk_size));
end
return ∇x
end
function gradientSite(::PolyesterForwardDiffGrad, x_vals, gradient_options::NamedTuple, loss_f::F) where {F}
∇x = similar(x_vals) # pre-allocate
if occursin("arm64-apple-darwin", Sys.MACHINE) # fallback due to closure issues on M1 systems
# cfg = ForwardDiff.GradientConfig(loss_tmp, x_vals, Chunk{chunk_size}());
ForwardDiff.gradient!(∇x, loss_f, x_vals) # ?, add `cfg` at the end if further control is needed.
else
PolyesterForwardDiff.threaded_gradient!(loss_f, ∇x, x_vals, ForwardDiff.Chunk(chunk_size));
end
return ∇x
endgradsNaNCheck!
Sindbad.MachineLearning.gradsNaNCheck! Function
gradsNaNCheck!(grads_batch, _params_batch, sites_batch, parameter_table; show_params_for_nan=false)Utility function to check if some calculated gradients were NaN (if found please double check your approach). This function will replace those NaNs with 0.0f0.
Arguments
grads_batch: gradients array._params_batch: parameters values.sites_batch: sites names.parameter_table: parameters table.show_params_for_nan=false: if true, it will show the parameters that caused the NaNs.
Code
function gradsNaNCheck!(grads_batch, _params_batch, sites_batch, parameter_table; replace_value = 0.0, show_params_for_nan=false)
if sum(isnan.(grads_batch))>0
if show_params_for_nan
foreach(findall(x->isnan(x), grads_batch)) do ci
p_index_tmp, si = Tuple(ci)
site_name_tmp = sites_batch[si]
p_vec_tmp = _params_batch(site=site_name_tmp)
parameter_values = Pair(parameter_table.name[p_index_tmp], (p_vec_tmp[p_index_tmp], parameter_table.lower[p_index_tmp], parameter_table.upper[p_index_tmp]))
@info "site: $site_name_tmp, parameter: $parameter_values"
end
end
@warn "NaNs in grads, replacing all by 0.0f0"
replace!(grads_batch, NaN => eltype(grads_batch)(replace_value))
end
endlcKAoneHotbatch
Sindbad.MachineLearning.lcKAoneHotbatch Function
lcKAoneHotbatch(lc_data, up_bound, lc_name, ka_labels)Arguments
lc_data: Vector arrayup_bound: last index class, the range goes from1:up_bound, and any case not in that range uses theup_boundvalue. ForPFTuse17and forKG32.lc_name: land cover approach, eitherKGorPFT.ka_labels: KeyedArray labels, i.e. site names
Code
function lcKAoneHotbatch(lc_data, up_bound, lc_name, ka_labels)
oneHot_lc = Flux.onehotbatch(lc_data, 1:up_bound, up_bound)
feat_labels = "$(lc_name)_".*string.(1:up_bound)
if lowercase(lc_name)=="kg"
feat_labels = KGlabels
elseif lowercase(lc_name)=="pft"
feat_labels = PFTlabels
end
return KeyedArray(Array(oneHot_lc); features=feat_labels, site=ka_labels)
endloadCovariates
Sindbad.MachineLearning.loadCovariates Function
loadCovariates(sites_forcing; kind="all")use the kind argument to select different sets of covariates
Arguments
sites_forcing: names of forcing sites
kind: defaults to "all"
Other options
PFTKGKG_PFTPFT_ABCNOPSWBKG_ABCNOPSWBABCNOPSWBveg_allvegKG_vegveg_ABCNOPSWB
Code
function loadCovariates(sites_forcing; kind="all", cube_path = "/Net/Groups/BGI/work_5/scratch/lalonso/CovariatesFLUXNET_3.zarr")
c_read = Cube(cube_path)
# select features, do only nor
only_nor = occursin.(r"nor", c_read.features)
nor_sel = c_read.features[only_nor].val
nor_sel = [string.(s) for s in nor_sel] |> sort
# select only normalized continuous variables
ds_nor = c_read[features = At(nor_sel)]
xfeat_nor = yaxCubeToKeyedArray(ds_nor)
# apply PCA to xfeat_nor if needed
# ? where is age?
kg_data = c_read[features=At("KG")][:].data
oneHot_KG = lcKAoneHotbatch(kg_data, 32, "KG", string.(c_read.site))
pft_data = c_read[features=At("PFT")][:].data
oneHot_pft = lcKAoneHotbatch(pft_data, 17, "PFT", string.(c_read.site))
oneHot_veg = vegKAoneHotbatch(pft_data, string.(c_read.site))
stackedFeatures = if kind=="all"
reduce(vcat, [oneHot_KG, oneHot_pft, xfeat_nor])
elseif kind=="PFT"
reduce(vcat, [oneHot_pft])
elseif kind=="KG"
reduce(vcat, [oneHot_KG])
elseif kind=="KG_PFT"
reduce(vcat, [oneHot_KG, oneHot_pft])
elseif kind=="PFT_ABCNOPSWB"
reduce(vcat, [oneHot_pft, xfeat_nor])
elseif kind=="KG_ABCNOPSWB"
reduce(vcat, [oneHot_KG, xfeat_nor])
elseif kind=="ABCNOPSWB"
reduce(vcat, [xfeat_nor])
elseif kind =="veg_all"
reduce(vcat, [oneHot_KG, oneHot_veg, xfeat_nor])
elseif kind=="veg"
reduce(vcat, [oneHot_veg])
elseif kind=="KG_veg"
reduce(vcat, [oneHot_KG, oneHot_veg])
elseif kind=="veg_ABCNOPSWB"
reduce(vcat, [oneHot_veg, xfeat_nor])
end
# remove sites (with NaNs and duplicates)
to_remove = [
"CA-NS3",
# "CA-NS4",
"IT-CA1",
# "IT-CA2",
"IT-SR2",
# "IT-SRo",
"US-ARb",
# "US-ARc",
"US-GBT",
# "US-GLE",
"US-Tw1",
# "US-Tw2"
]
not_these = ["RU-Tks", "US-Atq", "US-UMd"] # NaNs
not_these = vcat(not_these, to_remove)
new_sites = setdiff(c_read.site, not_these)
stackedFeatures = stackedFeatures(; site=new_sites)
# get common sites between names in forcing and covariates
sites_feature_all = [s for s in stackedFeatures.site]
sites_common = intersect(sites_feature_all, sites_forcing)
xfeatures = Float32.(stackedFeatures(; site=sites_common))
return xfeatures
endloadTrainedNN
Sindbad.MachineLearning.loadTrainedNN Function
loadTrainedNN(path_model)Arguments
path_model: path to the model.
Code
function loadTrainedNN(path_model)
model_props = JLD2.load(path_model)
return (;
trainedNN=model_props["re"](model_props["flat"]), # ? model structure and trained weights
lower_bound=model_props["lower_bound"], # ? parameters' attributes
upper_bound=model_props["upper_bound"],
ps_names=model_props["ps_names"],
metadata_global=model_props["metadata_global"])
endloss
Sindbad.MachineLearning.loss Function
loss(params, models, parameter_to_index, parameter_scaling_type, loc_forcing, loc_spinup_forcing, loc_forcing_t, loc_output, land_init, tem_info, loc_obs, cost_options, constraint_method, gradient_lib, ::LossModelObsMachineLearning)Calculates the scalar loss for a given site in hybrid (ML) modeling in SINDBAD.
This function computes the loss value for a given site by first calling lossVector to obtain the vector of loss components, and then combining them into a scalar loss using the combineMetric function and the specified constraint method.
Arguments
params: Model parameters (typically output from anMachine Learningmodel).models: List of process-based models.parameter_to_index: Mapping from parameter names to indices.parameter_scaling_type: Parameter scaling configuration.loc_forcing: Forcing data for the site.loc_spinup_forcing: Spinup forcing data for the site.loc_forcing_t: Forcing data for a single time step.loc_output: Output data structure for the site.land_init: Initial land state.tem_info: Model information and configuration.loc_obs: Observation data for the site.cost_options: Cost function and metric configuration.constraint_method: Constraint method for combining metrics.gradient_lib: Gradient computation library or method.::LossModelObsMachineLearning: Type dispatch for loss model with observations and machine learning.
Returns
t_loss: Scalar loss value for the site.
Notes
This function is used internally by higher-level training and evaluation routines.
The loss is computed by aggregating the loss vector using the specified constraint method.
Example
t_loss = loss(params, models, parameter_to_index, parameter_scaling_type, loc_forcing, loc_spinup_forcing, loc_forcing_t, loc_output, land_init, tem_info, loc_obs, cost_options, constraint_method, gradient_lib, LossModelObsMachineLearning())Code
function lossVector(params, models, parameter_to_index, parameter_scaling_type, loc_forcing, loc_spinup_forcing, loc_forcing_t, loc_output, land_init, tem_info, loc_obs, cost_options, constraint_method, gradient_lib,::LossModelObsMachineLearning)
loc_output_from_cache = getOutputFromCache(loc_output, params, gradient_lib)
models = updateModels(params, parameter_to_index, parameter_scaling_type, models)
coreTEM!(
models,
loc_forcing,
loc_spinup_forcing,
loc_forcing_t,
loc_output_from_cache,
land_init,
tem_info)
loss_vector = metricVector(loc_output_from_cache, loc_obs, cost_options)
loss_indices = cost_options.obs_sn
return loss_vector, loss_indices
end
function loss(params, models, parameter_to_index, parameter_scaling_type, loc_forcing, loc_spinup_forcing, loc_forcing_t, loc_output, land_init, tem_info, loc_obs, cost_options, constraint_method, gradient_lib,loss_type::LossModelObsMachineLearning)
loss_vector, _ = lossVector(params, models,parameter_to_index, parameter_scaling_type, loc_forcing, loc_spinup_forcing, loc_forcing_t, loc_output, land_init, tem_info, loc_obs, cost_options, constraint_method, gradient_lib, loss_type)
t_loss = combineMetric(loss_vector, constraint_method)
return t_loss
end
function lossComponents(params, models, parameter_to_index, parameter_scaling_type, loc_forcing, loc_spinup_forcing, loc_forcing_t, loc_output, land_init, tem_info, loc_obs, cost_options, constraint_method, gradient_lib,loss_type::LossModelObsMachineLearning)
loss_vector, loss_indices = lossVector(params, models,parameter_to_index, parameter_scaling_type, loc_forcing, loc_spinup_forcing, loc_forcing_t, loc_output, land_init, tem_info, loc_obs, cost_options, constraint_method, gradient_lib, loss_type)
t_loss = combineMetric(loss_vector, constraint_method)
return t_loss, loss_vector, loss_indices
endlossComponents
Missing docstring.
Missing docstring for lossComponents. Check Documenter's build log for details.
Code
function lossComponents(params, models, parameter_to_index, parameter_scaling_type, loc_forcing, loc_spinup_forcing, loc_forcing_t, loc_output, land_init, tem_info, loc_obs, cost_options, constraint_method, gradient_lib,loss_type::LossModelObsMachineLearning)
loss_vector, loss_indices = lossVector(params, models,parameter_to_index, parameter_scaling_type, loc_forcing, loc_spinup_forcing, loc_forcing_t, loc_output, land_init, tem_info, loc_obs, cost_options, constraint_method, gradient_lib, loss_type)
t_loss = combineMetric(loss_vector, constraint_method)
return t_loss, loss_vector, loss_indices
endlossSite
Sindbad.MachineLearning.lossSite Function
lossSite(new_params, gradient_lib, models, loc_forcing, loc_spinup_forcing, loc_forcing_t, loc_output, land_init, tem_info, parameter_to_index, parameter_scaling_type, loc_obs, cost_options, constraint_method; optim_mode=true)Function to calculate the loss for a given site. This is used for optimization, hence the optim_mode argument is set to true by default. Also, a gradient library should be set as well as new parameters to update the models. See all input arguments in the function:
Arguments
new_params: new parametersgradient_lib: gradient librarymodels: list of modelsloc_forcing: forcing data locationloc_spinup_forcing: spinup forcing data locationloc_forcing_t: forcing data time for one time step.loc_output: output data locationland_init: initial land statetem_info: model informationparameter_to_index: parameter to indexloc_obs: observation data locationcost_options: cost optionsconstraint_method: constraint method
Code
function lossSite(new_params, gradient_lib, models, loc_forcing, loc_spinup_forcing,
loc_forcing_t, loc_output, land_init, tem_info, parameter_to_index, parameter_scaling_type,
loc_obs, cost_options, constraint_method; optim_mode=true)
out_data = getOutputFromCache(loc_output, new_params, gradient_lib)
new_models = updateModels(new_params, parameter_to_index, parameter_scaling_type, models)
return getLoss(new_models, loc_forcing, loc_spinup_forcing, loc_forcing_t, out_data, land_init, tem_info, loc_obs, cost_options, constraint_method; optim_mode)
endlossVector
Sindbad.MachineLearning.lossVector Function
lossVector(params, models, parameter_to_index, parameter_scaling_type, loc_forcing, loc_spinup_forcing, loc_forcing_t, loc_output, land_init, tem_info, loc_obs, cost_options, constraint_method, gradient_lib, ::LossModelObsMachineLearning)Calculate the loss vector for a given site in hybrid (ML) modeling in SINDBAD.
This function runs the core TEM model with the provided parameters, forcing data, initial land state, and model information, then computes the loss vector using the specified cost options and metrics. It is typically used for site-level loss evaluation during training and validation.
Arguments
params: Model parameters (in this case, output from anMachine Learningmodel).models: List of process-based models.parameter_to_index: Mapping from parameter names to indices.parameter_scaling_type: Parameter scaling configuration.loc_forcing: Forcing data for the site.loc_spinup_forcing: Spinup forcing data for the site.loc_forcing_t: Forcing data for a single time step.loc_output: Output data structure for the site.land_init: Initial land state.tem_info: Model information and configuration.loc_obs: Observation data for the site.cost_options: Cost function and metric configuration.constraint_method: Constraint method for combining metrics.gradient_lib: Gradient computation library or method.::LossModelObsMachineLearning: Type dispatch for loss model with observations and machine learning.
Returns
loss_vector: Vector of loss components for the site.loss_indices: Indices corresponding to each loss component.
Notes
This function is used internally by higher-level loss and training routines.
The loss vector is typically combined into a scalar loss using
combineMetric.
Example
loss_vec, loss_idx = lossVector(params, models, parameter_to_index, parameter_scaling_type, loc_forcing, loc_spinup_forcing, loc_forcing_t, loc_output, land_init, tem_info, loc_obs, cost_options, constraint_method, gradient_lib, LossModelObsMachineLearning())Code
function lossVector(params, models, parameter_to_index, parameter_scaling_type, loc_forcing, loc_spinup_forcing, loc_forcing_t, loc_output, land_init, tem_info, loc_obs, cost_options, constraint_method, gradient_lib,::LossModelObsMachineLearning)
loc_output_from_cache = getOutputFromCache(loc_output, params, gradient_lib)
models = updateModels(params, parameter_to_index, parameter_scaling_type, models)
coreTEM!(
models,
loc_forcing,
loc_spinup_forcing,
loc_forcing_t,
loc_output_from_cache,
land_init,
tem_info)
loss_vector = metricVector(loc_output_from_cache, loc_obs, cost_options)
loss_indices = cost_options.obs_sn
return loss_vector, loss_indices
endmixedGradientTraining
Sindbad.MachineLearning.mixedGradientTraining Function
mixedGradientTraining(grads_lib, nn_model, train_refs, test_val_refs, loss_fargs, forward_args; n_epochs=3, optimizer=Optimisers.Adam(), path_experiment="/")Training function that computes model parameters using a neural network, which are then used by process-based models (PBMs) to estimate parameter gradients. Neural network weights are updated using the product of these gradients with the neural network's Jacobian.
Arguments
grads_lib: Library to compute PBMs parameter gradients.nn_model: AFlux.Chainneural network.train_refs: training data features.test_val_refs: test and validation data features.loss_fargs: functions used to calculate the loss.forward_args: arguments to evaluate the PBMs.path_experiment="/": save model to path.
Code
function mixedGradientTraining(grads_lib, nn_model, train_refs, test_val_refs, total_constraints, loss_fargs, forward_args;
n_epochs=3, optimizer=Optimisers.Adam(), path_experiment="/")
sites_training, indices_sites_training, xfeatures, parameter_table, batch_size, chunk_size, metadata_global = train_refs
sites_validation, indices_sites_validation, sites_testing, indices_sites_testing = test_val_refs
lossSite, getInnerArgs = loss_fargs
flat, re, opt_state = destructureNN(nn_model; nn_opt=optimizer)
n_params = length(nn_model[end].bias)
loss_training = fill(zero(Float32), length(sites_training), n_epochs)
loss_validation = fill(zero(Float32), length(sites_validation), n_epochs)
loss_testing = fill(zero(Float32), length(sites_testing), n_epochs)
# ? save also the individual losses
loss_split_training = fill(NaN32, length(sites_training), total_constraints, n_epochs)
loss_split_validation = fill(NaN32, length(sites_validation), total_constraints, n_epochs)
loss_split_testing = fill(NaN32, length(sites_testing), total_constraints, n_epochs)
path_checkpoint = joinpath(path_experiment, "checkpoint")
f_path = mkpath(path_checkpoint)
@showprogress desc="training..." for epoch ∈ 1:n_epochs
x_batches, idx_xbatches = batchShuffler(sites_training, indices_sites_training, batch_size; bs_seed=epoch)
for (sites_batch, indices_sites_batch) in zip(x_batches, idx_xbatches)
grads_batch = zeros(Float32, n_params, length(sites_batch))
x_feat_batch = xfeatures(; site=sites_batch)
new_params, pullback_func = getPullback(flat, re, x_feat_batch)
_params_batch = getParamsAct(new_params, parameter_table)
input_args = (_params_batch, forward_args..., indices_sites_batch, sites_batch)
gradientBatch!(grads_lib, grads_batch, chunk_size, lossSite, getInnerArgs, input_args...)
gradsNaNCheck!(grads_batch, _params_batch, sites_batch, parameter_table) #? checks for NaNs and if any replace them with 0.0f0
# Jacobian-vector product
∇params = pullback_func(grads_batch)[1]
opt_state, flat = Optimisers.update(opt_state, flat, ∇params)
end
# calculate losses for all sites!
_params_epoch = re(flat)(xfeatures)
params_epoch = getParamsAct(_params_epoch, parameter_table)
getLossForSites(grads_lib, lossSite, loss_training, loss_split_training, epoch, params_epoch, sites_training, indices_sites_training, forward_args...)
# ? validation
getLossForSites(grads_lib, lossSite, loss_validation, loss_split_validation, epoch, params_epoch, sites_validation, indices_sites_validation, forward_args...)
# ? test
getLossForSites(grads_lib, lossSite, loss_testing, loss_split_testing, epoch, params_epoch, sites_testing, indices_sites_testing, forward_args...)
jldsave(joinpath(f_path, "checkpoint_epoch_$(epoch).jld2");
lower_bound=parameter_table.lower, upper_bound=parameter_table.upper, ps_names=parameter_table.name,
parameter_table=parameter_table,
metadata_global=metadata_global,
loss_training=loss_training[:, epoch],
loss_validation=loss_validation[:, epoch],
loss_testing=loss_testing[:, epoch],
loss_split_training=loss_split_training[:,:, epoch],
loss_split_validation=loss_split_validation[:,:, epoch],
loss_split_testing=loss_split_testing[:,:, epoch],
re=re,
flat=flat)
end
return nothing
endmlModel
Sindbad.MachineLearning.mlModel Function
mlModel(info, n_features, ::MachineLearningModelType)Builds a Flux dense neural network model. This function initializes a neural network model based on the provided info and n_features.
Arguments
info: The experiment information containing model options and parameters.n_features: The number of features in the input data.::MachineLearningModelType: Type dispatch for the machine learning model type.
Supported MachineLearningModelType:
::FluxDenseNN: A simple dense neural network model implemented in Flux.jl.
Returns
The initialized machine learning model.
Code
function mlModel end
function mlModel(info, n_features, ::FluxDenseNN)
n_params = sum(info.optimization.parameter_table.is_ml);
n_layers = info.hybrid.ml_model.options.n_layers
n_neurons = info.hybrid.ml_model.options.n_neurons
ml_seed = info.hybrid.random_seed;
print_info(mlModel, @__FILE__, @__LINE__, "Flux Dense NN with $n_features features, $n_params parameters, $n_layers hidden/inner layers and $n_neurons neurons.", n_f=2)
print_info(nothing, @__FILE__, @__LINE__, "Seed: $ml_seed", n_f=4)
print_info(nothing, @__FILE__, @__LINE__, "Hidden Layers: $(n_layers)", n_f=4)
print_info(nothing, @__FILE__, @__LINE__, "Total number of parameters: $(sum(info.optimization.parameter_table.is_ml))", n_f=4)
print_info(nothing, @__FILE__, @__LINE__, "Number of neurons per layer: $(n_neurons)", n_f=4)
print_info(nothing, @__FILE__, @__LINE__, "Number of parameters per layer: $(n_params / n_layers)", n_f=4)
activation_hidden = activationFunction(info.hybrid.ml_model.options, info.hybrid.ml_model.options.activation_hidden)
activation_out = activationFunction(info.hybrid.ml_model.options, info.hybrid.ml_model.options.activation_out)
print_info(nothing, @__FILE__, @__LINE__, "Activation function for hidden layers: $(info.hybrid.ml_model.options.activation_hidden)", n_f=4)
print_info(nothing, @__FILE__, @__LINE__, "Activation function for output layer: $(info.hybrid.ml_model.options.activation_out)", n_f=4)
Random.seed!(ml_seed)
flux_model = Flux.Chain(
Flux.Dense(n_features => n_neurons, activation_hidden),
[Flux.Dense(n_neurons, n_neurons, activation_hidden) for _ in 1:n_layers]...,
Flux.Dense(n_neurons => n_params, activation_out)
)
return flux_model
end
function mlModel(info, n_features, ::FluxDenseNN)
n_params = sum(info.optimization.parameter_table.is_ml);
n_layers = info.hybrid.ml_model.options.n_layers
n_neurons = info.hybrid.ml_model.options.n_neurons
ml_seed = info.hybrid.random_seed;
print_info(mlModel, @__FILE__, @__LINE__, "Flux Dense NN with $n_features features, $n_params parameters, $n_layers hidden/inner layers and $n_neurons neurons.", n_f=2)
print_info(nothing, @__FILE__, @__LINE__, "Seed: $ml_seed", n_f=4)
print_info(nothing, @__FILE__, @__LINE__, "Hidden Layers: $(n_layers)", n_f=4)
print_info(nothing, @__FILE__, @__LINE__, "Total number of parameters: $(sum(info.optimization.parameter_table.is_ml))", n_f=4)
print_info(nothing, @__FILE__, @__LINE__, "Number of neurons per layer: $(n_neurons)", n_f=4)
print_info(nothing, @__FILE__, @__LINE__, "Number of parameters per layer: $(n_params / n_layers)", n_f=4)
activation_hidden = activationFunction(info.hybrid.ml_model.options, info.hybrid.ml_model.options.activation_hidden)
activation_out = activationFunction(info.hybrid.ml_model.options, info.hybrid.ml_model.options.activation_out)
print_info(nothing, @__FILE__, @__LINE__, "Activation function for hidden layers: $(info.hybrid.ml_model.options.activation_hidden)", n_f=4)
print_info(nothing, @__FILE__, @__LINE__, "Activation function for output layer: $(info.hybrid.ml_model.options.activation_out)", n_f=4)
Random.seed!(ml_seed)
flux_model = Flux.Chain(
Flux.Dense(n_features => n_neurons, activation_hidden),
[Flux.Dense(n_neurons, n_neurons, activation_hidden) for _ in 1:n_layers]...,
Flux.Dense(n_neurons => n_params, activation_out)
)
return flux_model
endmlOptimizer
Sindbad.MachineLearning.mlOptimizer Function
mlOptimizer(optimizer_options, ::MachineLearningOptimizerType)Create aMachine Learningoptimizer from the given options and type. The optimizer is created using the given options and type. The options are passed to the constructor of the optimizer.
Arguments:
optimizer_options: A dictionary or NamedTuple containing options for the optimizer.::MachineLearningOptimizerType: The type used to determine which optimizer to create. Supported types include:OptimisersAdam: For Adam optimizer.OptimisersDescent: For Descent optimizer.
.
Returns:
- AMachine Learningoptimizer object that can be used to optimize machine learning models.
Code
function mlOptimizer end
function mlOptimizer(optimizer_options, ::OptimisersAdam)
return Optimisers.Adam(optimizer_options...)
end
function mlOptimizer(optimizer_options, ::OptimisersAdam)
return Optimisers.Adam(optimizer_options...)
end
function mlOptimizer(optimizer_options, ::OptimisersDescent)
return Optimisers.Descent(optimizer_options...)
endoneHotPFT
Sindbad.MachineLearning.oneHotPFT Function
oneHotPFT(pft, up_bound, veg_class)Arguments
pft: (Plant Functional Type). Any entry not in 1:17 would be set to the last index, this includes NaN! Last index is water/NaNup_bound: last index class, the range goes from1:up_bound, and any case not in that range uses theup_boundvalue. ForPFTuse17.veg_class:trueorfalse.
Returns a vector.
Code
function oneHotPFT(pft, up_bound, veg_class)
if !veg_class
return Flux.onehot(pft, 1:up_bound, up_bound)
else
_pft = pft
if length(pft)==1
_pft = pft[1]
end
return vegOneHot(toClass(_pft))
end
endpartitionBatches
Sindbad.MachineLearning.partitionBatches Function
partitionBatches(n; batch_size=32)Return an Iterator partitioning a dataset into batches.
Arguments
n: number of samplesbatch_size: batch size
Code
function partitionBatches(n; batch_size=32)
return partition(1:n, batch_size)
endprepHybrid
Sindbad.MachineLearning.prepHybrid Function
prepHybrid(forcing, observations, info, ::MachineLearningTrainingType)Prepare all data structures, loss functions, and machine learning components required for hybrid (process-based + machine learning) modeling in SINDBAD.
This function orchestrates the setup for hybrid modeling by:
Initializing model helpers and runtime structures.
Building loss function handles for each site.
Splitting sites into training, validation, and testing sets according to the hybrid configuration.
Loading covariate features for all sites.
Building the machine learning model as specified in the configuration.
Preparing arrays for storing losses and loss components during training and evaluation.
Initializing the optimizer forMachine Learningtraining.
Collecting all relevant metadata and configuration into a single
hybrid_helpersNamedTuple for downstream training routines.
Arguments
forcing: Forcing data structure as required by the process-based model.observations: Observational data structure.info: The SINDBAD experiment info structure, containing all configuration and runtime options.::MachineLearningTrainingType: Type specifying theMachine Learningtraining method to use (e.g.,MixedGradient).
Returns
hybrid_helpers: A NamedTuple containing all prepared data, models, loss functions, indices, features, optimizers, and arrays needed for hybridMachine Learningtraining and evaluation.
Fields of hybrid_helpers
run_helpers: Output ofprepTEM, containing prepared model, forcing, observation, and output structures.sites: NamedTuple withtraining,validation, andtestingsite arrays.indices: NamedTuple with indices fortraining,validation, andtestingsites.features: NamedTuple withn_featuresanddata(covariate features for all sites).ml_model: The machine learning model instance (e.g., a Flux neural network).options: Theinfo.hybridconfiguration NamedTuple.checkpoint_path: Path for saving checkpoints during training.parameter_table: Parameter table frominfo.optimization.loss_functions: KeyedArray of callable loss functions, one per site.loss_component_functions: KeyedArray of callable loss component functions, one per site.training_optimizer: The optimizer object forMachine Learningtraining.loss_array: NamedTuple of arrays to store scalar losses for training, validation, and testing.loss_array_components: NamedTuple of arrays to store loss components for training, validation, and testing.metadata_global: Global metadata from the output configuration.
Notes
This function is typically called once at the start of a hybrid modeling experiment to set up all necessary components.
The returned
hybrid_helpersis designed to be passed directly to training routines such astrainML.
Example
hybrid_helpers = prepHybrid(forcing, observations, info, MixedGradient())
trainML(hybrid_helpers, MixedGradient())Code
function prepHybrid(forcing, observations, info, ::MachineLearningTrainingType)
run_helpers = prepTEM(info.models.forward, forcing, observations, info)
sites_forcing = forcing.data[1].site;
print_info(prepHybrid, @__FILE__, @__LINE__, "preparing hybridMachine Learninghelpers for $(length(sites_forcing)) sites", n_f=2)
print_info(nothing, @__FILE__, @__LINE__, "Building loss function handles for every site", n_m=4)
loss_functions, loss_component_functions = getLossFunctionHandles(info, run_helpers, sites_forcing)
## split the sites
print_info(prepHybrid, @__FILE__, @__LINE__, "Getting indices and sites for training, validation and testing", n_f=2)
indices_training, indices_validation, indices_testing = getIndicesSplit(info, sites_forcing, info.hybrid.fold.fold_type)
sites_training = sites_forcing[indices_training]
sites_validation = sites_forcing[indices_validation]
sites_testing = sites_forcing[indices_testing]
sites = (; training = sites_training, validation = sites_validation, testing = sites_testing)
indices = (; training = indices_training, validation = indices_validation, testing = indices_testing)
print_info(nothing, @__FILE__, @__LINE__, "Total sites: $(length(sites_forcing))", n_m=4)
print_info(nothing, @__FILE__, @__LINE__, "Training sites: $(length(sites.training))", n_m=4)
print_info(nothing, @__FILE__, @__LINE__, "Validation sites: $(length(sites.validation))", n_m=4)
print_info(nothing, @__FILE__, @__LINE__, "Testing sites: $(length(sites.testing))", n_m=4)
## get covariates
print_info(prepHybrid, @__FILE__, @__LINE__, "Loading covariates for hybridMachine Learningmodel", n_f=2)
print_info(nothing, @__FILE__, @__LINE__, "variables: $(info.hybrid.covariates.variables)", n_m=4)
print_info(nothing, @__FILE__, @__LINE__, "path: $(info.hybrid.covariates.path)", n_m=4)
xfeatures = loadCovariates(sites_forcing; kind=info.hybrid.covariates.variables, cube_path=info.hybrid.covariates.path)
print_info(nothing, @__FILE__, @__LINE__, "Min/Max of features: [$(minimum(xfeatures)), $(maximum(xfeatures))]", n_m=4)
n_features = length(xfeatures.features)
features = (; n_features=n_features, data=xfeatures)
## buildMachine Learningmodel and get init predictions
print_info(prepHybrid, @__FILE__, @__LINE__, "Preparing machine learning model", n_f=2)
ml_model = mlModel(info, n_features, info.hybrid.ml_model.method)
print_info(prepHybrid, @__FILE__, @__LINE__, "Preparing loss arrays", n_f=2)
n_epochs = info.hybrid.ml_training.options.n_epochs
loss_array_training = fill(zero(Float32), length(sites.training), n_epochs)
loss_array_validation = fill(zero(Float32), length(sites.validation), n_epochs)
loss_array_testing = fill(zero(Float32), length(sites.testing), n_epochs)
# ? save also the individual losses
num_constraints = length(info.optimization.cost_options.variable)
loss_array_components_training = fill(NaN32, length(sites.training), num_constraints, n_epochs)
loss_array_components_validation = fill(NaN32, length(sites.validation), num_constraints, n_epochs)
loss_array_components_testing = fill(NaN32, length(sites.testing), num_constraints, n_epochs)
loss_array_components = (;
training=loss_array_components_training,
validation=loss_array_components_validation,
testing=loss_array_components_testing
)
loss_array = (;
training=loss_array_training,
validation=loss_array_validation,
testing=loss_array_testing
)
print_info(nothing, @__FILE__, @__LINE__, "Number of sites: $(length(sites_forcing))", n_m=4)
print_info(nothing, @__FILE__, @__LINE__, "Loss array shape (training | validation | testing): $(size(loss_array.training)) | $(size(loss_array.validation)) | $(size(loss_array.testing))", n_m=4)
print_info(nothing, @__FILE__, @__LINE__, "Loss array components shape (training | validation | testing): $(size(loss_array_components.training)) | $(size(loss_array_components.validation)) | $(size(loss_array_components.testing))", n_m=4)
print_info(nothing, @__FILE__, @__LINE__, "Number of constraints: $num_constraints", n_m=4)
print_info(prepHybrid, @__FILE__, @__LINE__, "Preparing training optimizer", n_f=2)
print_info(nothing, @__FILE__, @__LINE__, "Method: $(nameof(typeof(info.hybrid.ml_optimizer.method)))", n_m=4)
training_optimizer = mlOptimizer(info.hybrid.ml_optimizer.options, info.hybrid.ml_optimizer.method)
metadata_global = info.output.file_info.global_metadata
options = info.hybrid
hybrid_helpers = (;
run_helpers=run_helpers,
sites=sites,
indices=indices,
features=features,
ml_model=ml_model,
options=options,
checkpoint_path=info.output.dirs.hybrid.checkpoint,
parameter_table=info.optimization.parameter_table,
loss_functions=loss_functions,
loss_component_functions=loss_component_functions,
training_optimizer=training_optimizer,
loss_array=loss_array,
loss_array_components=loss_array_components,
metadata_global=metadata_global
)
return hybrid_helpers
endshuffleBatches
Sindbad.MachineLearning.shuffleBatches Function
shuffleBatches(list, bs; seed=1)Arguments
bs: Batch sizelist: an array of samplesseed: Int
Returns shuffled partitioned batches.
Code
function shuffleBatches(list, bs; seed=1)
bs_idxs = partitionBatches(length(list); batch_size = bs)
s_list = shuffleList(list; seed=seed)
xbatches = [s_list[p] for p in bs_idxs if length(p) == bs]
return xbatches
endshuffleList
Sindbad.MachineLearning.shuffleList Function
shuffleList(list; seed=123)Arguments
list: an array of samplesseed: Int
Code
function shuffleList(list; seed=123)
rand_indxs = randperm(MersenneTwister(seed), length(list))
return list[rand_indxs]
endsiteNameToID
Sindbad.MachineLearning.siteNameToID Function
siteNameToID(site_name, sites_list)Returns the index of site_name in the sites_list
Arguments
site_name: site namesites_list: list of site names
Code
function siteNameToID(site_name, sites_list)
return findfirst(s -> s == site_name, sites_list)
endtoClass
Sindbad.MachineLearning.toClass Function
toClass(x::Number; vegetation_rules)Arguments
x: a key(Number)fromvegetation_rulesvegetation_rules
Code
function toClass(x::Number; vegetation_rules=vegetation_rules)
if ismissing(x)
return vegetation_rules[missing]
elseif x isa AbstractFloat && isnan(x)
return vegetation_rules[NaN]
end
new_key = Int(x)
return get(vegetation_rules, new_key, "Unknown key")
endtrainML
Sindbad.MachineLearning.trainML Function
trainML(hybrid_helpers, ::MachineLearningTrainingType)Train a machine learning (ML) or hybrid model in SINDBAD using the specified training method.
This function performs the training loop for theMachine Learningmodel, handling batching, gradient computation, optimizer updates, loss calculation, and checkpointing. It supports hybrid modeling workflows where ML-derived parameters are used in process-based models, and is designed to work with the data structures prepared by prepHybrid.
Arguments
hybrid_helpers: NamedTuple containing all prepared data, models, loss functions, indices, features, optimizers, and arrays needed forMachine Learningtraining and evaluation (as returned byprepHybrid).::MachineLearningTrainingType: Type specifying theMachine Learningtraining method to use (e.g.,MixedGradient).
Workflow
Iterates over epochs and batches of training sites.
For each batch:
Extracts features and computes model parameters.
Computes gradients using the specified gradient method.
Checks for NaNs in gradients and replaces them if needed.
Updates model parameters using the optimizer.
After each epoch:
Computes and stores losses and loss components for training, validation, and testing sets.
Saves model checkpoints and loss arrays to disk if a checkpoint path is specified.
Notes
The function is extensible to support different training strategies via dispatch on
MachineLearningTrainingType.Designed for use with hybrid modeling, whereMachine Learningmodels provide parameters to process-based models.
Checkpointing enables resuming or analyzing training progress.
Example
hybrid_helpers = prepHybrid(forcing, observations, info, MixedGradient())
trainML(hybrid_helpers, MixedGradient())Code
function trainML(hybrid_helpers, ::MixedGradient)
ml_model = hybrid_helpers.ml_model
all_sites = hybrid_helpers.sites
sites_training = all_sites.training
xfeatures = hybrid_helpers.features.data
parameter_table = hybrid_helpers.parameter_table
metadata_global = hybrid_helpers.metadata_global
loss_functions = hybrid_helpers.loss_functions
loss_array = hybrid_helpers.loss_array
loss_array_components = hybrid_helpers.loss_array_components
loss_component_functions = hybrid_helpers.loss_component_functions
ml_optimizer = hybrid_helpers.training_optimizer
flat, re, opt_state = destructureNN(ml_model; nn_opt=ml_optimizer)
n_params = length(parameter_table.name)
options = hybrid_helpers.options
batch_size = options.ml_training.options.batch_size
gradient_options = options.ml_gradient
n_epochs = options.ml_training.options.n_epochs
checkpoint_path = hybrid_helpers.checkpoint_path
@showprogress desc="training..." for epoch ∈ 1:n_epochs
x_batches = shuffleBatches(sites_training, batch_size; seed=epoch)
for sites_batch in x_batches
grads_batch = zeros(Float32, n_params, length(sites_batch))
x_feat_batch = xfeatures(; site=sites_batch)
new_params, pullback_func = getPullback(flat, re, x_feat_batch)
scaled_params_batch = getParamsAct(new_params, parameter_table)
@debug " Epoch $(epoch): training on batch with $(length(sites_batch)) sites, scaled_params: minimum=$(minimum(scaled_params_batch)), maximum=$(maximum(scaled_params_batch))"
gradientBatch!(gradient_options.method, grads_batch, gradient_options.options, loss_functions, scaled_params_batch, sites_batch; showprog=false)
gradsNaNCheck!(grads_batch, scaled_params_batch, sites_batch, parameter_table, replace_value=options.replace_value_for_gradient) #? checks for NaNs and if any replace them with replace_value_for_gradient
# Jacobian-vector product
∇params = pullback_func(grads_batch)[1]
opt_state, flat = Optimisers.update(opt_state, flat, ∇params)
end
# calculate losses for all sites!
if !isempty(checkpoint_path)
f_path = joinpath(checkpoint_path, "epoch_$(epoch).jld2")
_params_epoch = re(flat)(xfeatures)
scaled_params_epoch = getParamsAct(_params_epoch, parameter_table)
for comps in (:training, :validation, :testing)
sites_comp = getproperty(all_sites, comps)
loss_array_epoch = getproperty(loss_array, comps)
loss_array_components_epoch = getproperty(loss_array_components, comps)
epochLossComponents(loss_component_functions, loss_array_epoch, loss_array_components_epoch, epoch, scaled_params_epoch, sites_comp)
end
jldsave(f_path;
lower_bound=parameter_table.lower, upper_bound=parameter_table.upper, parameter_names=parameter_table.name,
parameter_table=parameter_table,
metadata_global=metadata_global,
loss_array_training=loss_array.training[:, epoch],
loss_array_validation=loss_array.validation[:, epoch],
loss_array_testing=loss_array.testing[:, epoch],
loss_array_components_training=loss_array_components.training[:,:, epoch],
loss_array_components_validation=loss_array_components.validation[:,:, epoch],
loss_array_components_testing=loss_array_components.testing[:,:, epoch],
re=re,
flat=flat)
end
end
endvegKAoneHotbatch
Sindbad.MachineLearning.vegKAoneHotbatch Function
vegKAoneHotbatch(pft_data, ka_labels)Arguments
pft_data: Vector arrayka_labels: KeyedArray labels, i.e. site names
Code
function vegKAoneHotbatch(pft_data, ka_labels)
oneHot_veg = vegOneHotbatch(toClass.(pft_data))
return KeyedArray(Array(oneHot_veg); features=vegetation_labels, site=ka_labels)
endvegOneHot
Sindbad.MachineLearning.vegOneHot Function
vegOneHot(v_class; vegetation_labels)Arguments
v_class: get it by doingtoClass(x; vegetation_rules).vegetation_labels: see them by typingvegetation_labels.
Code
function vegOneHotbatch(veg_classes; vegetation_labels=vegetation_labels)
return Flux.onehotbatch(veg_classes, vegetation_labels)
end
function vegOneHot(v_class; vegetation_labels=vegetation_labels)
return Flux.onehot(v_class, vegetation_labels)
endvegOneHotbatch
Sindbad.MachineLearning.vegOneHotbatch Function
vegOneHotbatch(veg_classes; vegetation_labels)Arguments
veg_classes: get these from
toClass.([x1, x2,...])vegetation_labels: see them by typing
vegetation_labels
Code
function vegOneHotbatch(veg_classes; vegetation_labels=vegetation_labels)
return Flux.onehotbatch(veg_classes, vegetation_labels)
end