.............................................................................................................. ...........72508806321777...........77777..................................................................... ..........588863436888888888888888888888885................................................................... .........5885.........777112222221177770880................................................................... .........8887.........................58897................................................................... .........48897.......................1852..................................................................... ..........28881.......77.....2547...397....................................................................... ...........733947...708861.7688880277......................................................................... ...............743758888672888888888927....................................................................... .................70888827688888888888886...................................................................... ................388880728511111488888897...................................................................... ...............70888476817088882788884........................................................................ .................7271888.5888888.08971037..................................................................... ...................5888837308857184758888047.................................................................. .................708888880311140071088888888027............................................................... ................5888888888888885758888888888888027............................................................ ..............7088888888888880770888888888888888892........................................................... .............3888888888888885...1688888888888882715088880527.................................................. ...........70888888888888807.......168888888857388617117138867................................................ ..........4888888888888885............1688885798077688880275887.77777777777777777777777777.................... ........76888888888888881...............7208728871888888885.086.8888888888888888888888888883.................. .......28888888888888867...................7.48874888888880.980.088888888888888888888888888882................ .....76888888888888882.......................78857488888867288178888888888888888888888888888889............... ....18041115888888897.........................7980277117776881708888888888888888888888888888884............... ...4847388017088884.............................7588880888617488888888888888888888888888888807................ ...887488888748897...............................71111111126888888888888888888888888888888867................. ...5837490677084.................................38888888888888888888888888888888888888888377................. ....2085212685712................................76000000000000000000000088888888888888881.7165117............ ......71221771088577447.................................2222222222222222..588888888888807.....7988837......... ..............25776888857...............................0888888888888888..4888888888886.........7408867....... ...............798867..7667.............................0888888888888888..488888888884.............28883...... ...............1885......77.............................0888888888888888..75666666547................6885..... .................204....................................0888888888888888.............................78887.... ...................261..................................1111111111111111..............................8881.... ........................................................0888888888888888.............................78887.... ........................................................0888888888888888.............................6883..... ........................................................0888888888888888............................4880...... ........................................................0888888888888888...........................18887...... ........................................................0888888888888888..........................78887....... .......................................................711111111111111117.........................6884........ ........................................777245590888888888888888888888888888888096542177.........1880......... .....................................408888888888888888888888888888888888888888888888888804......0881......... ....................................78008888888888888888888888888888888888888888888888888887....78887......... .....................................777777777777777777777777777777777777777777777777777777......777.......... ..............................................................................................................
This post illustrates the basic concepts underlying Gradient Boosted Machine (GBM) models. There are many different variants of the basic GBM model (XG-Boost, ADABoost, Deep-Boost, Cat-Boost, Light-GBM etc.).
I illustrate in this post the basic GBM algorithm for 2 applications: first a regression problem, then a classification problem.
I got the idea for this post from this article: \(\quad\) https://medium.com/mlreview/gradient-boosting-from-scratch-1e317ae4587d
I also used this vignette extensively: \(\quad\) https://cran.r-project.org/web/packages/gbm/vignettes/gbm.pdf
A quick note on terminology: throughout this post, I use the term predicted values to refer to the fitted values of the model to the training data.
First, we load the libraries that we need:
library(tidyverse) # for data manipulation and plotting
library(rpart) # for fitting regression trees
library(rpart.plot) # for plotting regression trees
library(knitr) # for nice display of tables using knitr::kable()
# change formatting of the code output:
knitr::opts_chunk$set(
class.output = "bg-primary",
class.message = "bg-info text-info",
class.warning = "bg-warning text-warning",
class.error = "bg-danger text-danger"
)
We set global plot styling settings for the ggplot() function:
theme_set( theme_bw() +
theme( plot.background = element_rect(fill="black"),
panel.background = element_rect(fill="black"),
axis.text = element_text(colour="white"),
axis.title = element_text(colour="white"),
plot.title = element_text(colour="white"),
legend.text = element_text(colour="white"),
legend.title = element_text(colour="white"),
legend.background = element_rect(fill="black"),
legend.key = element_rect(fill="black"),
panel.grid.major = element_line(colour="grey20"),
panel.grid.minor = element_line(colour="grey20")
)
)
The GBM model takes in a vector of explanatory variable values (features) \(x_1,x_2,x_3...\) as input, and produces a prediction \(\hat{y}=f(x_1,x_2,x_3,...)\) for the true observed response \(y\) corresponding to these \(x_1,x_2,x_3...\).
For example, suppose that the 5th individual in our dataset has the following height, weight and running speed:
\[\begin{array}{lcl} x^{(5)}_{height} &=& 1.6 \\ x^{(5)}_{weight} &=& 63 \\ x^{(5)}_{speed} &=& 14 \\ \end{array}\]
The model will take in these values \(X_5=[1.6,63,14]\) (corresponding to individual/observation \(i=5\) in our dataset), and produce a prediction \(f(x^{(5)}_{height},x^{(5)}_{weight},x^{(5)}_{speed})\) for this individual’s age \(y_5\).
Suppose, for purposes of example, that the model predicts that this individual’s age is
\[\begin{array}{lcl} \hat{y}_5 &=& f(1.6,63,14) \\ &=& 29 \\ \end{array}\]
and that their true age is \(y_5=35\). The model is fairly close in it’s prediction, but not perfect.
We can use a loss function to measure how close this prediction \(\hat{y}_5\) is to the true response value \(y_5\).
For example, the squared-error loss function is
\[L\Big(y_i \space, \space \hat{y}_i \Big) \quad = \quad \Big(y_i- \hat{y}_i\Big)^2\]
So, our prediction of \(\hat{y}_5=29\) for a true response of \(y_5=35\) gives a (squared error) loss of:
\[L\Big(35 \space, \space 29 \Big) \quad = \quad \Big(35- 29\Big)^2 \quad=\quad 36\] Here are the squared error loss values that we’d see for some other predicted values:
We can see that the squared error loss function gives (relatively) very small loss values to predictions that are close to the true \(y_i\) value, with penalties increasing at an accelerating rate as predictions begin to stray further away from the truth.
The gradient (slope) of the squared error loss function (with respect to the prediction \(\hat y_i\)) is
\[\displaystyle\frac{\partial \space L(y_i,\hat{y_i})}{\partial \space \hat{y_i}} \quad=\quad -2(y_i-\hat{y_i})\]
The loss value for an individual observation can be used to measure the accuracy of the model prediction for that single observation. We can use this information to train a predictive model.
Suppose that we have the following model fit (3 individuals/observations):
i | predicted age | true age | squared error loss for this prediction | gradient of loss function for this prediction |
---|---|---|---|---|
1 | 8 | 5 | 9 | 6 |
2 | 12 | 21 | 81 | -18 |
3 | 50 | 49 | 1 | 2 |
The large squared error loss, and large gradient, tell us that by far the largest error (worst prediction) is the one the model is making on observation 2.
The gradient of -18 means that at the current prediction, the squared error loss for this prediction is decreasing at a rate of 18 per unit increase in the prediction. For example, changing our prediction from \(12\) to \(12.001\) for observation 2 results in an reduction in squared error loss of \((12-21)^2-(12.001-21)^2=0.018\).
We can use this gradient to inform the improvement of our predictions.
Suppose that we choose a learning rate (\(\lambda\)) of 0.2. If we subtract \(\lambda \times [\text{gradient } i]\) from each of our predictions then each prediction improves:
i | predicted age | true age | update | updated prediction |
---|---|---|---|---|
1 | 8 | 5 | -(6*0.2) = -1.2 | 6.8 |
2 | 12 | 21 | -(-18*0.2) = +3.6 | 15.6 |
3 | 50 | 49 | -(2*0.2) = -0.4 | 49.6 |
Using the gradient of the loss function to inform the updates ensures that the biggest changes/corrections/updates are made to the predictions which are furthest from the true values (biggest errors). The learning rate \(\lambda\) ensures that the learning is gradual (it helps to prevent us overshooting the correct prediction and creating an error in the opposite direction).
We could take the new updated predictions, calculate loss values and gradients for these new predictions, and update them again in the same way to get even better predictions:
i | predicted age | true age | squared error loss | gradient | update | updated prediction |
---|---|---|---|---|---|---|
1 | 6.8 | 5 | 3.24 | 3.6 | -(3.6*0.2) = -0.72 | 6.08 |
2 | 15.6 | 21 | 29.16 | -10.8 | -(-10.8*0.2) = +2.16 | 17.76 |
3 | 49.6 | 49 | 0.36 | 1.2 | -(1.2*0.2) = -0.24 | 49.36 |
The GBM model builds on this concept of using the gradient of the loss function to iteratively modify the model predictions.
The basic GBM algorithm (used to train the model) is as follows:
Choose a loss function appropriate to the problem and the data
Fit an initial model to the data: get a prediction \(\hat{y_i}\) for every observation \(y_i\)
For each observation, calculate the gradient \(g_i\) of the loss function.
Optional: Take a random sample of the data (sample of rows and/or columns).
Build a model on the data (subsample). This model aims to produce a prediction \(\hat g_i\) for each \(g_i\). The desired update to the model prediction is \(\rho_i=\lambda \times (-g_i)\) (the learning rate times the negative gradient).
Update each prediction as \(\hat{y_i}+\hat{\rho_i}=\hat y_i - \lambda\hat g_i\) (i.e. the previous prediction for observation \(i\) minus the predicted update required, which means subtracting the predicted gradient times the learning rate).
We now have an updated model prediction \(\hat{y_i}\) for every row/observation \(i\) in our data. We can now go back to step 3 with these updated predictions, calculating new gradients and updating the predictions again.
So, the final prediction (after repeating this step many times) for each observation \(i\) will be:
\[\underset{\text{model 3 prediction}}{\underbrace{\underset{\text{model 2 prediction}}{\underbrace{\underset{\text{model 1 prediction}}{\underbrace{\hat{y_i}}} + \overset{\text{predicted update to model 1 prediction}}{\hat{\rho_i}^{(1)}}}} + \overset{\text{predicted update to model 2 prediction}}{\hat{\rho_i}^{(2)}}}} \quad + \quad ...........\]
Each subsequent model fit tries to predict what the errors of the previous model were (errors measured by the gradients of the loss function). These are then used to to adjust the predictions of the previous model fit, with the speed at which the model learns being controlled by the learning rate \(\lambda\). This is the concept of Gradient Boosting.
The subsampling of rows and columns of the data in each iteration (step 4 above) helps to reduce overfitting (improves performance on non-training data) by not allowing the model to fixate on the same part of the data in each iteration.
…
Let’s try this on some data!
The regression problem is to predict a continuous variable \(y\) using one or more explanatory variables \(x_1,x_2,x_3,...\).
For example, suppose that we have the following data, consisting of a continuous response variable \(y\) and a single explanatory variable \(x\):
# create random data
dataset <-
tibble( x = runif(100, 0,100) ) %>%
rowwise() %>%
mutate( y = case_when( x < 20 ~ rnorm(1, 60, 5),
x < 40 ~ rnorm(1, 20, 5),
x < 60 ~ rnorm(1, 80, 5),
x < 80 ~ rnorm(1, 10, 5),
TRUE ~ rnorm(1, 70, 5)
)
) %>%
ungroup()
# print the first 6 rows of the data:
dataset %>% head(6)
# set plot styling:
par( bg="black", col="white", col.axis="white", col.lab="white", col.main="white" )
plot( dataset$y ~ dataset$x,
pch = 16,
cex = 0.5,
col = sample(2:20,replace=TRUE,size=nrow(dataset)),
xlab = "x", ylab="y"
)
Given any particular value \(x_i\) of \(x\) (e.g. \(x_5\) = 15.468), we would like our model to provide an estimate \(\hat{y_i}\) of \(y_i\).
This simple dataset problem could be handled using a much simpler model than a GBM (such as a regression tree of depth greater than 3), but solving it using a GBM is a nice illustratation of the GBM concept.
We initialise our model with a prediction of \(y_i = 40\) for every observation \(x_i\):
latest_predictions <- rep(40, nrow(dataset))
latest_gradients <- -2*(dataset$y-latest_predictions)
# (code for plots omitted)
The mean squared error (average value of \((y_i-\hat{y_i})^2\)) for this initial model is MSE=805.46. This is a measure of the overall model accuracy/quality.
Now, we iteratively fit regression tree models of depth 1 (stumps) to predict the loss function gradients \(g_i\) of each previous model prediction \(\hat y_i\), in each iteration using these predicted gradients to update our overall model prediction: (notice how the Mean Squared Error decreases with each update)
par( mfrow = c(1,3), # expand plot window to accomodate 3 plots per line
bg="black", col="white", col.axis="white", col.lab="white", col.main="white" # plot styling
)
learnrate <- 0.5 # set the learning rate at lambda=0.5
for( i in 2:15 ){ # for 14 iterations
# fit new model to gradient of previous model:
model_m <- rpart( grad ~ x, data=tibble(grad=latest_gradients, x=dataset$x), maxdepth=1, model=TRUE )
# plot the model fit:
prp( model_m ,
main = paste0("model ",i),
branch.col="red",
border.col="red",
under.col = "white",
split.border.col = "red",
split.col="white",
nn.col="white",
box.col="black"
)
# plot the model predictions of the previous model's gradients:
plot( x = dataset$x,
y = latest_gradients,
main = paste0("model ", i, " predictions on loss function \n gradients of model ", i-1, " predictions"),
col = "blue"
)
points( x = dataset$x, y = predict(model_m), col=2, pch=16 )
abline( h=0 )
# update our predictions by adding the predicted gradients of the previous model:
latest_predictions <- latest_predictions - (learnrate * predict(model_m))
# calculate gradients of new updated predictions:
latest_gradients <- -2*(dataset$y-latest_predictions)
# plot our new predictions over our data:
plot( x = dataset$x,
y = dataset$y,
main = paste0("latest prediction \n (model", i, ")"),
pch = 16
)
points( x = dataset$x, y = latest_predictions, pch=16, col=3 )
# print the MSE for our latest predictions:
print(paste0( "model ", i, " MSE: ", round(mean( (dataset$y-latest_predictions)^2 ), digits=2) ) )
}
## [1] "model 2 MSE: 633.92"
## [1] "model 3 MSE: 471.45"
## [1] "model 4 MSE: 347.74"
## [1] "model 5 MSE: 248.99"
## [1] "model 6 MSE: 220.56"
## [1] "model 7 MSE: 160.98"
## [1] "model 8 MSE: 109.29"
## [1] "model 9 MSE: 92.51"
## [1] "model 10 MSE: 64.14"
## [1] "model 11 MSE: 50.16"
## [1] "model 12 MSE: 43.42"
## [1] "model 13 MSE: 35.24"
## [1] "model 14 MSE: 31.33"
## [1] "model 15 MSE: 27.47"
We can see that after only 15 consecutive fits of a one-split regression tree stump model (arguably the simplest possible predictive model), the fit to the training data looks very good.
Further iterations would likely lead to overfitting to the training data.
In a binary classification problem, we are again trying to predict a response value \(y\) using explanatory variables \(x_1, x_2, x_3,...\), except that the variable \(y\) is known to take on a value of \(y=0\) or \(y=1\).
I use the same GBM algorithm as above to solve this problem, except with a different loss function.
For this classification problem, I choose the Bernoulli loss function:
\[\begin{array}{lcl} \mathcal{L}\Big(y_i \space, \space \hat y_i \Big) &=& y_i \space \hat y_i + log\Big(1+e^{\hat y _i}\Big)\\ \end{array}\]
The gradient of the Bernoulli loss function (rate of change per unit change in prediction \(\hat y_i\)) is
\[\displaystyle\frac{\partial \space \mathcal{L}}{\partial \space \hat y_i} \quad=\quad -y_i + \displaystyle\frac{e^{\hat y_i}}{1+e^{\hat y_i}}\]
For this problem, we have the following data:
We make a custom plotting function:
make_plot_ftn <-
function( prediction_vector, plot_title="" ){
plot( x = 1:nrow(classifydat),
y = classifydat$y*100,
pch = 1,
axes=FALSE,
xlab="", ylab="",
xlim=c(0,nrow(classifydat)),
main = plot_title
)
points( x = 1:nrow(classifydat), y = classifydat$x1, ylim=c(0,100), pch=2, col="red", cex=0.8 )
points( x = 1:nrow(classifydat), y = classifydat$x2, ylim=c(0,100), pch=4, col="blue", cex=0.8 )
abline( v=1:nrow(classifydat), lwd=0.1 )
points( x = 1:nrow(classifydat), y = prediction_vector*100, pch=16)
legend("topleft", pch=c(1,16,2,4), legend=c("true label","predicted label","x1 value","x2 value"),
col=c("black","black","red","blue") )
}
# # test the plot function: (not run)
# make_plot_ftn( prediction_vector = runif( n = nrow(classifydat), min=0, max=1) )
First, we initialise all predictions at \(\hat y=0.5\). Then, we iteratively fit the GBM model, much the same as in the previous example:
# initialise all predictions at y=0.5
current_estimates <- rep(0.5, nrow(classifydat) )
# specify the learning rate:
learnrate <- 0.5
# get the gradient of the loss function for each of the initial predictions
get_gradients <- -(classifydat$y) + exp(current_estimates) / ( 1 + exp(current_estimates) )
# model iterations:
for( i in 2:150 ){ # for 149 iterations
# print iteration count
print( paste0("iteration ", i) )
# fit regression tree to the gradients:
modeldata <- tibble( y = get_gradients,
x1 = classifydat$x1,
x2 = classifydat$x2
)
fit_rpart <- rpart( y ~ x1+x2, data=modeldata, maxdepth=3 )
# make plots of the model fit:
par( mfrow = c(1,2) )
prp( fit_rpart,
main = paste0( "model ", i)
)
plot( x = 1:length(get_gradients),
y = get_gradients,
xlab = "observation ID",
ylab = "negative gradient",
main = paste0("model ",i, " fit to model ", i-1, "\n negative gradients")
)
points( x = 1:length(get_gradients),
y = predict(fit_rpart),
pch = 16,
col = "red"
)
# update the global prediction using this iteration's model fit:
current_estimates <- current_estimates - learnrate*predict(fit_rpart)
# force estimates outside of the allowed range [0,1] of y into range [0,1]:
estimates_to_plot <- current_estimates
estimates_to_plot[estimates_to_plot>1] <- 1
estimates_to_plot[estimates_to_plot<0] <- 0
# print out loss and accuracy for the current model fit:
paste0(
"Model ", i,
": ",
"Bernoulli loss: ",
-2 * sum( classifydat$y*estimates_to_plot - log( 1 + exp(estimates_to_plot) ) ),
" ",
"accuracy (num correct/num predictions): ",
sum(round(estimates_to_plot) == classifydat$y) / nrow(classifydat)
) %>%
print()
# print model plots:
par( mfrow = c(1,1) )
make_plot_ftn( prediction_vector = estimates_to_plot,
plot_title = paste0("model ", i, " predictions")
)
print(
tibble( x1 = classifydat$x1,
x2 = classifydat$x2,
true_label = factor(classifydat$y),
prediction = estimates_to_plot
) %>%
ggplot( data = .,
aes( x = x1,
y = x2,
shape = true_label,
colour = prediction
)
) +
geom_point() +
ggtitle( paste0("model ", i, " predictions") )
)
# get gradients for each updated prediction:
get_gradients <- -(classifydat$y) + exp(current_estimates) / ( 1 + exp(current_estimates) )
# stop doing iterations if the model accuracy gets greater than 0.95:
accuracy <- sum(round(estimates_to_plot) == classifydat$y) / nrow(classifydat)
if( accuracy >= 0.99 ){ break }
}
## [1] "iteration 2"
## [1] "Model 2: Bernoulli loss: 147.457386525874 accuracy (num correct/num predictions): 0.71"
## [1] "iteration 3"
## [1] "Model 3: Bernoulli loss: 138.953675559695 accuracy (num correct/num predictions): 0.74"
## [1] "iteration 4"
## [1] "Model 4: Bernoulli loss: 136.053299359402 accuracy (num correct/num predictions): 0.74"
## [1] "iteration 5"
## [1] "Model 5: Bernoulli loss: 132.129939522351 accuracy (num correct/num predictions): 0.77"
## [1] "iteration 6"
## [1] "Model 6: Bernoulli loss: 130.678097812958 accuracy (num correct/num predictions): 0.77"
## [1] "iteration 7"
## [1] "Model 7: Bernoulli loss: 130.57452089907 accuracy (num correct/num predictions): 0.74"
## [1] "iteration 8"
## [1] "Model 8: Bernoulli loss: 130.741472361916 accuracy (num correct/num predictions): 0.77"
## [1] "iteration 9"
## [1] "Model 9: Bernoulli loss: 129.633260006247 accuracy (num correct/num predictions): 0.76"
## [1] "iteration 10"
## [1] "Model 10: Bernoulli loss: 128.552588044856 accuracy (num correct/num predictions): 0.77"
## [1] "iteration 11"
## [1] "Model 11: Bernoulli loss: 128.502431229817 accuracy (num correct/num predictions): 0.74"
## [1] "iteration 12"
## [1] "Model 12: Bernoulli loss: 128.686264357917 accuracy (num correct/num predictions): 0.75"
## [1] "iteration 13"
## [1] "Model 13: Bernoulli loss: 127.906732258984 accuracy (num correct/num predictions): 0.79"
## [1] "iteration 14"
## [1] "Model 14: Bernoulli loss: 127.036628742525 accuracy (num correct/num predictions): 0.82"
## [1] "iteration 15"
## [1] "Model 15: Bernoulli loss: 126.563204264433 accuracy (num correct/num predictions): 0.83"
## [1] "iteration 16"
## [1] "Model 16: Bernoulli loss: 125.741271928684 accuracy (num correct/num predictions): 0.83"
## [1] "iteration 17"
## [1] "Model 17: Bernoulli loss: 125.731441530405 accuracy (num correct/num predictions): 0.83"
## [1] "iteration 18"
## [1] "Model 18: Bernoulli loss: 125.225561822776 accuracy (num correct/num predictions): 0.83"
## [1] "iteration 19"
## [1] "Model 19: Bernoulli loss: 124.492543480206 accuracy (num correct/num predictions): 0.82"
## [1] "iteration 20"
## [1] "Model 20: Bernoulli loss: 124.272690017358 accuracy (num correct/num predictions): 0.83"
## [1] "iteration 21"
## [1] "Model 21: Bernoulli loss: 124.13981145081 accuracy (num correct/num predictions): 0.83"
## [1] "iteration 22"
## [1] "Model 22: Bernoulli loss: 123.982530435131 accuracy (num correct/num predictions): 0.83"
## [1] "iteration 23"
## [1] "Model 23: Bernoulli loss: 123.76861930998 accuracy (num correct/num predictions): 0.83"
## [1] "iteration 24"
## [1] "Model 24: Bernoulli loss: 123.492457744244 accuracy (num correct/num predictions): 0.83"
## [1] "iteration 25"
## [1] "Model 25: Bernoulli loss: 122.778689449196 accuracy (num correct/num predictions): 0.84"
## [1] "iteration 26"
## [1] "Model 26: Bernoulli loss: 122.231027779188 accuracy (num correct/num predictions): 0.85"
## [1] "iteration 27"
## [1] "Model 27: Bernoulli loss: 122.039689161153 accuracy (num correct/num predictions): 0.85"
## [1] "iteration 28"
## [1] "Model 28: Bernoulli loss: 122.003300209983 accuracy (num correct/num predictions): 0.85"
## [1] "iteration 29"
## [1] "Model 29: Bernoulli loss: 120.650323278289 accuracy (num correct/num predictions): 0.85"
## [1] "iteration 30"
## [1] "Model 30: Bernoulli loss: 120.293383930841 accuracy (num correct/num predictions): 0.85"
## [1] "iteration 31"
## [1] "Model 31: Bernoulli loss: 119.722623133076 accuracy (num correct/num predictions): 0.85"
## [1] "iteration 32"
## [1] "Model 32: Bernoulli loss: 118.884519927166 accuracy (num correct/num predictions): 0.85"
## [1] "iteration 33"
## [1] "Model 33: Bernoulli loss: 118.809405100696 accuracy (num correct/num predictions): 0.85"
## [1] "iteration 34"
## [1] "Model 34: Bernoulli loss: 118.831592522436 accuracy (num correct/num predictions): 0.85"
## [1] "iteration 35"
## [1] "Model 35: Bernoulli loss: 118.698268594699 accuracy (num correct/num predictions): 0.85"
## [1] "iteration 36"
## [1] "Model 36: Bernoulli loss: 118.691304638152 accuracy (num correct/num predictions): 0.85"
## [1] "iteration 37"
## [1] "Model 37: Bernoulli loss: 118.087431893827 accuracy (num correct/num predictions): 0.85"
## [1] "iteration 38"
## [1] "Model 38: Bernoulli loss: 118.022523218004 accuracy (num correct/num predictions): 0.85"
## [1] "iteration 39"
## [1] "Model 39: Bernoulli loss: 117.781522192635 accuracy (num correct/num predictions): 0.85"
## [1] "iteration 40"
## [1] "Model 40: Bernoulli loss: 117.741247943596 accuracy (num correct/num predictions): 0.85"
## [1] "iteration 41"
## [1] "Model 41: Bernoulli loss: 117.523017694007 accuracy (num correct/num predictions): 0.87"
## [1] "iteration 42"
## [1] "Model 42: Bernoulli loss: 117.500888914215 accuracy (num correct/num predictions): 0.87"
## [1] "iteration 43"
## [1] "Model 43: Bernoulli loss: 117.480767005083 accuracy (num correct/num predictions): 0.87"
## [1] "iteration 44"
## [1] "Model 44: Bernoulli loss: 117.223580021481 accuracy (num correct/num predictions): 0.91"
## [1] "iteration 45"
## [1] "Model 45: Bernoulli loss: 117.143298307407 accuracy (num correct/num predictions): 0.91"
## [1] "iteration 46"
## [1] "Model 46: Bernoulli loss: 116.979419900343 accuracy (num correct/num predictions): 0.91"
## [1] "iteration 47"
## [1] "Model 47: Bernoulli loss: 116.540621798573 accuracy (num correct/num predictions): 0.92"
## [1] "iteration 48"
## [1] "Model 48: Bernoulli loss: 116.041740289446 accuracy (num correct/num predictions): 0.91"
## [1] "iteration 49"
## [1] "Model 49: Bernoulli loss: 116.017923201755 accuracy (num correct/num predictions): 0.93"
## [1] "iteration 50"
## [1] "Model 50: Bernoulli loss: 115.8322966993 accuracy (num correct/num predictions): 0.93"
## [1] "iteration 51"
## [1] "Model 51: Bernoulli loss: 115.909506876591 accuracy (num correct/num predictions): 0.91"
## [1] "iteration 52"
## [1] "Model 52: Bernoulli loss: 115.550397485861 accuracy (num correct/num predictions): 0.93"
## [1] "iteration 53"
## [1] "Model 53: Bernoulli loss: 115.179122931567 accuracy (num correct/num predictions): 0.94"
## [1] "iteration 54"
## [1] "Model 54: Bernoulli loss: 115.214799306254 accuracy (num correct/num predictions): 0.94"
## [1] "iteration 55"
## [1] "Model 55: Bernoulli loss: 115.247924967059 accuracy (num correct/num predictions): 0.94"
## [1] "iteration 56"
## [1] "Model 56: Bernoulli loss: 115.278677880804 accuracy (num correct/num predictions): 0.94"
## [1] "iteration 57"
## [1] "Model 57: Bernoulli loss: 115.307224236807 accuracy (num correct/num predictions): 0.94"
## [1] "iteration 58"
## [1] "Model 58: Bernoulli loss: 115.333719130887 accuracy (num correct/num predictions): 0.94"
## [1] "iteration 59"
## [1] "Model 59: Bernoulli loss: 115.358307224674 accuracy (num correct/num predictions): 0.94"
## [1] "iteration 60"
## [1] "Model 60: Bernoulli loss: 115.381123378411 accuracy (num correct/num predictions): 0.94"
## [1] "iteration 61"
## [1] "Model 61: Bernoulli loss: 115.402293256095 accuracy (num correct/num predictions): 0.94"
## [1] "iteration 62"
## [1] "Model 62: Bernoulli loss: 115.421933902313 accuracy (num correct/num predictions): 0.94"
## [1] "iteration 63"
## [1] "Model 63: Bernoulli loss: 115.440154290578 accuracy (num correct/num predictions): 0.94"
## [1] "iteration 64"
## [1] "Model 64: Bernoulli loss: 115.457055843284 accuracy (num correct/num predictions): 0.94"
## [1] "iteration 65"
## [1] "Model 65: Bernoulli loss: 115.472732923698 accuracy (num correct/num predictions): 0.94"
## [1] "iteration 66"
## [1] "Model 66: Bernoulli loss: 115.487273300578 accuracy (num correct/num predictions): 0.94"
## [1] "iteration 67"
## [1] "Model 67: Bernoulli loss: 115.500758586211 accuracy (num correct/num predictions): 0.94"
## [1] "iteration 68"
## [1] "Model 68: Bernoulli loss: 115.513264648746 accuracy (num correct/num predictions): 0.94"
## [1] "iteration 69"
## [1] "Model 69: Bernoulli loss: 115.5248619998 accuracy (num correct/num predictions): 0.94"
## [1] "iteration 70"
## [1] "Model 70: Bernoulli loss: 115.535616158379 accuracy (num correct/num predictions): 0.94"
## [1] "iteration 71"
## [1] "Model 71: Bernoulli loss: 115.403345362177 accuracy (num correct/num predictions): 0.94"
## [1] "iteration 72"
## [1] "Model 72: Bernoulli loss: 115.391498314188 accuracy (num correct/num predictions): 0.94"
## [1] "iteration 73"
## [1] "Model 73: Bernoulli loss: 115.055597056662 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 74"
## [1] "Model 74: Bernoulli loss: 115.056134001193 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 75"
## [1] "Model 75: Bernoulli loss: 114.805300055058 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 76"
## [1] "Model 76: Bernoulli loss: 114.849290344209 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 77"
## [1] "Model 77: Bernoulli loss: 114.574410243661 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 78"
## [1] "Model 78: Bernoulli loss: 114.484278300389 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 79"
## [1] "Model 79: Bernoulli loss: 114.424654650151 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 80"
## [1] "Model 80: Bernoulli loss: 114.289010174047 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 81"
## [1] "Model 81: Bernoulli loss: 114.325266054307 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 82"
## [1] "Model 82: Bernoulli loss: 113.73060754938 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 83"
## [1] "Model 83: Bernoulli loss: 113.664224839561 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 84"
## [1] "Model 84: Bernoulli loss: 113.714855096403 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 85"
## [1] "Model 85: Bernoulli loss: 113.69403256864 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 86"
## [1] "Model 86: Bernoulli loss: 113.228893060657 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 87"
## [1] "Model 87: Bernoulli loss: 113.169906651779 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 88"
## [1] "Model 88: Bernoulli loss: 113.085770874467 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 89"
## [1] "Model 89: Bernoulli loss: 113.02835620746 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 90"
## [1] "Model 90: Bernoulli loss: 112.877134672836 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 91"
## [1] "Model 91: Bernoulli loss: 112.915917422199 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 92"
## [1] "Model 92: Bernoulli loss: 112.790734149071 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 93"
## [1] "Model 93: Bernoulli loss: 112.693706905733 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 94"
## [1] "Model 94: Bernoulli loss: 112.624854799186 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 95"
## [1] "Model 95: Bernoulli loss: 112.588358183126 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 96"
## [1] "Model 96: Bernoulli loss: 112.499564262808 accuracy (num correct/num predictions): 0.96"
## [1] "iteration 97"
## [1] "Model 97: Bernoulli loss: 112.450408591645 accuracy (num correct/num predictions): 0.96"
## [1] "iteration 98"
## [1] "Model 98: Bernoulli loss: 112.431959916432 accuracy (num correct/num predictions): 0.95"
## [1] "iteration 99"
## [1] "Model 99: Bernoulli loss: 112.316111087328 accuracy (num correct/num predictions): 0.96"
## [1] "iteration 100"
## [1] "Model 100: Bernoulli loss: 112.271777113975 accuracy (num correct/num predictions): 0.96"
## [1] "iteration 101"
## [1] "Model 101: Bernoulli loss: 112.228384178718 accuracy (num correct/num predictions): 0.96"
## [1] "iteration 102"
## [1] "Model 102: Bernoulli loss: 112.221824026575 accuracy (num correct/num predictions): 0.96"
## [1] "iteration 103"
## [1] "Model 103: Bernoulli loss: 112.146777415676 accuracy (num correct/num predictions): 0.96"
## [1] "iteration 104"
## [1] "Model 104: Bernoulli loss: 112.140644555765 accuracy (num correct/num predictions): 0.96"
## [1] "iteration 105"
## [1] "Model 105: Bernoulli loss: 112.12804747259 accuracy (num correct/num predictions): 0.96"
## [1] "iteration 106"
## [1] "Model 106: Bernoulli loss: 112.079572222538 accuracy (num correct/num predictions): 0.96"
## [1] "iteration 107"
## [1] "Model 107: Bernoulli loss: 112.06450940358 accuracy (num correct/num predictions): 0.96"
## [1] "iteration 108"
## [1] "Model 108: Bernoulli loss: 111.983819338788 accuracy (num correct/num predictions): 0.96"
## [1] "iteration 109"
## [1] "Model 109: Bernoulli loss: 111.918531244172 accuracy (num correct/num predictions): 0.96"
## [1] "iteration 110"
## [1] "Model 110: Bernoulli loss: 111.854020906991 accuracy (num correct/num predictions): 0.96"
## [1] "iteration 111"
## [1] "Model 111: Bernoulli loss: 111.699428027502 accuracy (num correct/num predictions): 0.98"
## [1] "iteration 112"
## [1] "Model 112: Bernoulli loss: 111.626268610393 accuracy (num correct/num predictions): 0.98"
## [1] "iteration 113"
## [1] "Model 113: Bernoulli loss: 111.564618659494 accuracy (num correct/num predictions): 0.98"
## [1] "iteration 114"
## [1] "Model 114: Bernoulli loss: 111.429219590107 accuracy (num correct/num predictions): 0.98"
## [1] "iteration 115"
## [1] "Model 115: Bernoulli loss: 111.372342255467 accuracy (num correct/num predictions): 0.98"
## [1] "iteration 116"
## [1] "Model 116: Bernoulli loss: 111.250153875383 accuracy (num correct/num predictions): 0.98"
## [1] "iteration 117"
## [1] "Model 117: Bernoulli loss: 111.22221200592 accuracy (num correct/num predictions): 0.98"
## [1] "iteration 118"
## [1] "Model 118: Bernoulli loss: 111.169744708211 accuracy (num correct/num predictions): 0.99"