Topic Models: Cross Validation with Loglikelihood or Perplexity

Topic models: cross validation with loglikelihood or perplexity

The accepted answer to this question is good as far as it goes, but it doesn't actually address how to estimate perplexity on a validation dataset and how to use cross-validation.

Using perplexity for simple validation

Perplexity is a measure of how well a probability model fits a new set of data. In the topicmodels R package it is simple to fit with the perplexity function, which takes as arguments a previously fit topic model and a new set of data, and returns a single number. The lower the better.

For example, splitting the AssociatedPress data into a training set (75% of the rows) and a validation set (25% of the rows):

# load up some R packages including a few we'll need later
library(topicmodels)
library(doParallel)
library(ggplot2)
library(scales)

data("AssociatedPress", package = "topicmodels")

burnin = 1000
iter = 1000
keep = 50

full_data <- AssociatedPress
n <- nrow(full_data)
#-----------validation--------
k <- 5

splitter <- sample(1:n, round(n * 0.75))
train_set <- full_data[splitter, ]
valid_set <- full_data[-splitter, ]

fitted <- LDA(train_set, k = k, method = "Gibbs",
control = list(burnin = burnin, iter = iter, keep = keep) )
perplexity(fitted, newdata = train_set) # about 2700
perplexity(fitted, newdata = valid_set) # about 4300

The perplexity is higher for the validation set than the training set, because the topics have been optimised based on the training set.

Using perplexity and cross-validation to determine a good number of topics

The extension of this idea to cross-validation is straightforward. Divide the data into different subsets (say 5), and each subset gets one turn as the validation set and four turns as part of the training set. However, it's really computationally intensive, particularly when trying out the larger numbers of topics.

You might be able to use caret to do this, but I suspect it doesn't handle topic modelling yet. In any case, it's the sort of thing I prefer to do myself to be sure I understand what's going on.

The code below, even with parallel processing on 7 logical CPUs, took 3.5 hours to run on my laptop:

#----------------5-fold cross-validation, different numbers of topics----------------
# set up a cluster for parallel processing
cluster <- makeCluster(detectCores(logical = TRUE) - 1) # leave one CPU spare...
registerDoParallel(cluster)

# load up the needed R package on all the parallel sessions
clusterEvalQ(cluster, {
library(topicmodels)
})

folds <- 5
splitfolds <- sample(1:folds, n, replace = TRUE)
candidate_k <- c(2, 3, 4, 5, 10, 20, 30, 40, 50, 75, 100, 200, 300) # candidates for how many topics

# export all the needed R objects to the parallel sessions
clusterExport(cluster, c("full_data", "burnin", "iter", "keep", "splitfolds", "folds", "candidate_k"))

# we parallelize by the different number of topics. A processor is allocated a value
# of k, and does the cross-validation serially. This is because it is assumed there
# are more candidate values of k than there are cross-validation folds, hence it
# will be more efficient to parallelise
system.time({
results <- foreach(j = 1:length(candidate_k), .combine = rbind) %dopar%{
k <- candidate_k[j]
results_1k <- matrix(0, nrow = folds, ncol = 2)
colnames(results_1k) <- c("k", "perplexity")
for(i in 1:folds){
train_set <- full_data[splitfolds != i , ]
valid_set <- full_data[splitfolds == i, ]

fitted <- LDA(train_set, k = k, method = "Gibbs",
control = list(burnin = burnin, iter = iter, keep = keep) )
results_1k[i,] <- c(k, perplexity(fitted, newdata = valid_set))
}
return(results_1k)
}
})
stopCluster(cluster)

results_df <- as.data.frame(results)

ggplot(results_df, aes(x = k, y = perplexity)) +
geom_point() +
geom_smooth(se = FALSE) +
ggtitle("5-fold cross-validation of topic modelling with the 'Associated Press' dataset",
"(ie five different models fit for each candidate number of topics)") +
labs(x = "Candidate number of topics", y = "Perplexity when fitting the trained model to the hold-out set")

We see in the results that 200 topics is too many and has some over-fitting, and 50 is too few. Of the numbers of topics tried, 100 is the best, with the lowest average perplexity on the five different hold-out sets.

Sample Image

Use Log Likelihood to compare different mallet topic models?

The intention of the log likelihood calculation is to provide a metric that is comparable across different models. That said, I wouldn't recommend using it in that way.

First, if you actually care about language model predictive likelihood, you should use one of many more recent deep neural models.

Second, likelihood is very sensitive to smoothing parameters, so the fact that you get consistent differences may be just an artifact of your own settings. Preprocessing decisions like tokenization and multi-word terms can also have a bigger impact than choice of model.

Third, if you are actually interested in topic model output, you should be clear about what you want from the model, and what characteristics of a model make it useful for your specific needs. I like to suggest that people think of a topic model as more like making a map than fitting a regression. The best resolution of the map depends on where you want to go.

Finally, you are almost certainly better off with the simplest model.

How to calculate perplexity for LDA with Gibbs sampling

It needs one more parameter "estimate_theta",

use below code:

perplexity(ldaOut, newdata = dtm,estimate_theta=FALSE)


Related Topics



Leave a reply



Submit