bnns

Overview

The bnns package provides an efficient and user-friendly implementation of Bayesian Neural Networks (BNNs) for regression, binary classification, and multiclass classification problems. By integrating Bayesian inference, bnns allows for uncertainty quantification in predictions and robust parameter estimation.

This vignette covers: 1. Installing and loading the package 2. Preparing data 3. Fitting a BNN model 4. Summarizing the model 5. Making predictions 6. Model evaluation 7. Customizing prior

1. Installation

To install the package, use the following commands:

# Install from CRAN (if available)
#  install.packages("bnns")

# Or install the development version from GitHub
# devtools::install_github("swarnendu-stat/bnns")

Load the package in your R session:

library(bnns)

2. Preparing the Data

The bnns package expects data in the form of matrices for predictors and a vector for responses.

Here’s an example of generating synthetic data:

# Generate training data
set.seed(123)
df <- data.frame(x1 = runif(10), x2 = runif(10), y = rnorm(10))

For binary or multiclass classification:

# Binary classification response
df$y_bin <- sample(0:1, 10, replace = TRUE)

# Multiclass classification response
df$y_cat <- factor(sample(letters[1:3], 10, replace = TRUE)) # 3 classes

3. Fitting a Bayesian Neural Network Model

Fit a Bayesian Neural Network using the bnns() function. Specify the network architecture using arguments like the number of layers (L), nodes per layer (nodes), and activation functions (act_fn).

Regression Example

model_reg <- bnns(
  y ~ -1 + x1 + x2,
  data = df,
  L = 1, # Number of hidden layers
  nodes = 2, # Nodes per layer
  act_fn = 3, # Activation functions: 3 = ReLU
  out_act_fn = 1, # Output activation function: 1 = Identity (for regression)
  iter = 1e1,  # Very low number of iteration is shown, increase to at least 1e3 for meaningful inference
  warmup = 5,  # Very low number of warmup is shown, increase to at least 2e2 for meaningful inference
  chains = 1
)

Binary Classification Example

model_bin <- bnns(
  y_bin ~ -1 + x1 + x2,
  data = df,
  L = 1,
  nodes = c(16),
  act_fn = c(2),
  out_act_fn = 2, # Output activation: 2 = Logistic sigmoid
  iter = 2e2,
  warmup = 1e2,
  chains = 1
)

Multiclass Classification Example

model_cat <- bnns(
  y_cat ~ -1 + x1 + x2,
  data = df,
  L = 3,
  nodes = c(32, 16, 8),
  act_fn = c(3, 2, 2),
  out_act_fn = 3, # Output activation: 3 = Softmax
  iter = 2e2,
  warmup = 1e2,
  chains = 1
)

4. Summarizing the Model

Use the summary() function to view details of the fitted model, including the network architecture, posterior distributions, and predictive performance.

summary(model_reg)
#> Call:
#> bnns.default(formula = y ~ -1 + x1 + x2, data = df, L = 1, nodes = 2, 
#>     act_fn = 3, out_act_fn = 1, iter = 10, warmup = 5, chains = 1)
#> 
#> Data Summary:
#> Number of observations: 10 
#> Number of features: 2 
#> 
#> Network Architecture:
#> Number of hidden layers: 1 
#> Nodes per layer: 2 
#> Activation functions: 3 
#> Output activation function: 1 
#> 
#> Posterior Summary (Key Parameters):
#>                mean    se_mean        sd       2.5%        25%        50%
#> w_out[1] -0.2682487 0.19315006 0.3610847 -0.6939969 -0.5144810 -0.2992260
#> w_out[2] -0.4426392 0.10830231 0.2024660 -0.7055733 -0.5981513 -0.3399174
#> b_out     0.5266503 0.08137059 0.1521184  0.3630533  0.3630533  0.6022671
#> sigma     1.1249599 0.26761385 0.5002912  0.8186562  0.8849170  0.9576394
#>                  75%       97.5%   n_eff      Rhat
#> w_out[1]  0.09320342  0.09320342 3.49485 0.9075369
#> w_out[2] -0.27880903 -0.27880903 3.49485 3.7425590
#> b_out     0.62355410  0.67554666 3.49485 0.8134989
#> sigma     0.95763938  1.90774292 3.49485 0.9124295
#> 
#> Model Fit Information:
#> Iterations: 10 
#> Warmup: 5 
#> Thinning: 1 
#> Chains: 1 
#> 
#> Predictive Performance:
#> RMSE (training): 0.8624338 
#> MAE (training): 0.6994271 
#> 
#> Notes:
#> Check convergence diagnostics for parameters with high R-hat values.
summary(model_bin)
summary(model_cat)

5. Making Predictions

The predict() function generates predictions for new data. The format of predictions depends on the output activation function.

# New data
test_x <- matrix(runif(10), nrow = 5, ncol = 2) |>
  data.frame() |>
  `colnames<-`(c("x1", "x2"))

# Regression predictions
pred_reg <- predict(model_reg, test_x)
# Binary classification predictions
pred_bin <- predict(model_bin, test_x)

# Multiclass classification predictions
pred_cat <- predict(model_cat, test_x)

6. Evaluating the Model

The bnns package includes utility functions like measure_cont, measure_bin, and measure_cat for evaluating model performance.

Regression Evaluation

# True responses
test_y <- rnorm(5)

# Evaluate predictions
metrics_reg <- measure_cont(obs = test_y, pred = pred_reg)
print(metrics_reg)
#> $rmse
#> [1] 0.8597093
#> 
#> $mae
#> [1] 0.7284594

Binary Classification Evaluation

# True responses
test_y_bin <- sample(c(rep(0, 2), rep(1, 3)), 5)

# Evaluate predictions
metrics_bin <- measure_bin(obs = test_y_bin, pred = pred_bin)

Multiclass Classification Evaluation

# True responses
test_y_cat <- factor(sample(letters[1:3], 5, replace = TRUE))

# Evaluate predictions
metrics_cat <- measure_cat(obs = test_y_cat, pred = pred_cat)

7. Customized Prior

Customized priors can be used for weights as well as the sigma parameter (for regression). Here we show an example use of a Cauchy prior for weights in multi-classification case.

model_cat_cauchy <- bnns(
  y_cat ~ -1 + x1 + x2,
  data = df,
  L = 3,
  nodes = c(32, 16, 8),
  act_fn = c(3, 2, 2),
  out_act_fn = 3, # Output activation: 3 = Softmax
  iter = 2e2,
  warmup = 1e2,
  chains = 1,
  prior_weights = list(dist = "cauchy", params = list(mu = 0, sigma = 2.5))
)
# Evaluate predictions
metrics_cat_cauchy <- measure_cat(obs = test_y_cat, pred = predict(model_cat_cauchy, test_x))

8. Notes on Bayesian Neural Networks

  • Bayesian inference allows for use of prior knowledge about the weights.
  • It allows for uncertainty quantification in predictions.
  • Always check convergence diagnostics such as R-hat values.
  • Use informative priors when possible to stabilize the model.

References

For more details, consult the source code on GitHub.