This blog features classification in Mahout and the underlying concepts. I will explain the basic classification process, training a Logistic Regression model with Stochastic Gradient Descent and a give walkthrough of classifying the Iris flower dataset with Mahout.
Clustering versus Classification
One of my previous blogs focused on text clustering in Mahout. Clustering is an example of unsupervised learning. The clustering algorithm finds groups within the data without being told what to look for upfront. This contrasts with classification, an example of supervised machine learning, which is the process of determining to which class an observation belongs. A common application of classification is spam filtering. With spam filtering we use labeled data to train the classifier: e-mails marked as spam or ham. We then can test the classifier to see whether it has done a good job of detecting spam from e-mail messages it hasn’t seen during the training phase.
The basic classification process
Classification is a deep and broad subject with many different algorithms and optimizations. The basic process however remains the same:
- Obtain a dataset
- Transform the dataset into a set of records with a field-oriented format containing the features the classifier trains on
- Label items in the training set
- Split a dataset into test set and training set
- Encode the training set and the test set into vectors
- Create a model by training the classifier with the training set, with multiple runs and passes if necessary
- Test the classifier with the test set
- Evaluate the classifier
- Improve the classifier and repeat the process
Logistic Regression & Stochastic Gradient Descent
Before we dive into Mahout let’s look at how Logistic Regression and Stochastic Gradient Descent work. This is very short and superficial introduction to this topic but I hope it gives enough of an idea how the algorithms work in order to follow the example later on. I have included links here and there to Wikipedia and videos of the Coursera Machine Learning course for more information.
The Logistic function
Before I discuss Logistic Regression and SGD let’s look it’s foundation, the logistic function. The logistic function is an S-shaped function whose range lies between 0 and 1, which makes it useful to model probabilities. When used in classification, an output close to 1 can indicate that an item belongs to a certain class. See the formula and graph below.
Logistic function
Logistic Regression model
Logistic Regression builds upon the logistic function. In contrast to the logistic function above which has a single x value as input, a Logistic Regression model allows many input variables: a vector of variables. Additionally, it consists of weights or coefficients for each input variable. The resulting Logistic Regression model looks like this:
Logistic regression model
The goal now is to find the values for $$\beta$$s, the regression coefficients, in such a way that the model can classify data with high accuracy. The classifier is accurate if the difference between observed and actual probabilities is low. This difference is also called the cost. By minimizing the cost function of the Logistic Regression model we can learn the values of the $$\beta$$ coefficients. See the following Coursera video on minimizing the cost function.
Stochastic Gradient Descent
The minimum of the cost function of Logistic Regression cannot be calculated directly, so we try to minimize it via Stochastic Gradient Descent, also known as Online Gradient Descent. In this process we descend along the cost function towards its minimum for each training observation we encounter. As a result, the $$\beta$$ coefficients are updated at every step and eventually as we keep taking steps closer to the minimum the cost is reduced and our model improves.
Classifying the Iris flower dataset
Now that you have a general idea about Logistic Regression and Stochastic Gradient Descent let’s look at an example. The Mahout source comes with a great example to demonstrate the classification process described above. The unit test OnlineLogisticRegressionTest
contains a test case for classifying the well-known Iris flower dataset. This is a small dataset from 1936 of 150 flowers with 3 different species: Setosa, Versicolor and Virginica and the width and length of their sepals and petals. The dataset is used as a benchmark for testing classification and clustering algorithms.
The code follows most of the steps of the classification process described above. To follow along make sure you have checked out the Mahout source code. Open OnlineLogisticRegressionTest
and look at the iris()
test case. In the unit test a classifier is trained which can classify the flowers’ species based on dimensions of the sepals and petals. In the next sections I show how Mahout can be used to train and test the classifier and finally assert its accuracy.
Setup and parsing the dataset
See the code snippet below. The first part of the code concerns with setting up a few data structures and load the test and training sets. We also create two separate List
s, data
and target
, for the features of the dataset and the classes we want to predict. Note that the type of the target
List
is Integer
because the classes of species will be encoded to Integer
s via the dictionary
based on the order they are processed in the dataset. More on that later. Another List
called order
is created to shuffle the contents of the dataset.
In the for
loop on line 24 we iterate over the lines from the CSV file. In lines 29 to 36 each line is splitted on comma’s into separate fields and all fields put into a vector, which is added to data
, the list of vectors. On line 30, the first position of the vector, which corresponds to $$\beta_{0}$$ is set to 1. $$\beta_{0}$$ is also known as the intercept term. Look at the regression graph on the following link to see why we need the intercept. The target variable is the 5th position in the CSV file, hence we use Iterables.get(values, 4)
on line 39 to obtain it and add it to target
. The other values of the data
vector are the remaining variables.
@Test public void iris() throws IOException { // Snip ... RandomUtils.useTestSeed(); Splitter onComma = Splitter.on(","); // read the data List raw = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8); // holds features List data = Lists.newArrayList(); // holds target variable List target = Lists.newArrayList(); // for decoding target values Dictionary dict = new Dictionary(); // for permuting data later List order = Lists.newArrayList(); for (String line : raw.subList(1, raw.size())) { // order gets a list of indexes order.add(order.size()); // parse the predictor variables Vector v = new DenseVector(5); v.set(0, 1); int i = 1; Iterable values = onComma.split(line); for (String value : Iterables.limit(values, 4)) { v.set(i++, Double.parseDouble(value)); } data.add(v); // and the target target.add(dict.intern(Iterables.get(values, 4))); } // randomize the order ... original data has each species all together // note that this randomization is deterministic Random random = RandomUtils.getRandom(); Collections.shuffle(order, random); // select training and test data List train = order.subList(0, 100); List test = order.subList(100, 150); logger.warn("Training set = {}", train); logger.warn("Test set = {}", test);
Training the Logistic Regression model
Now that all the proper data structures are in place let’s train the Logistic Regression model. The iris
test will perform 200 runs. This means that it creates 200 instances of the LR algorithm. Also it will do 30 passes through the training set for each run to improve accuracy of the classifier. Mahout’s Logistic Regression code is based on the pseudocode in the appendix of Bob Carpenter’s paper on Stochastic Gradient Descent.
The LR algorithm is created by instantiating the OnlineLogisticRegression
with the number of classes and the number of features. We pass the values 3 and 5 into the constructor of OnlineLogisticRegression
because we have 3 classes: Setosa, Versicolor and Virginica and 5 features: the intercept term, the petal length and width, the sepal length and width. The third constructor parameter is used for regularization. See the following Coursera video on regularization.
In the same loop, after 30 passes over the training set we test the classifier. We iterate through the test set and call the classifyFull
method which takes a single argument: an observation from the test set. Here it gets interesting: the method returns a Vector
with probabilities for each of the classes. This means that the sum of all the elements in it have to add to 1, see the testClassify()
method which checks this invariant. To find the class predicted by the classifier we use the method maxValueIndex
to find the class with the highest probability.
// now train many times and collect information on accuracy each time int[] correct = new int[test.size() + 1]; for (int run = 0; run < 200; run++) { OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1)); // 30 training passes should converge to > 95% accuracy nearly always but never to 100% for (int pass = 0; pass < 30; pass++) { Collections.shuffle(train, random); for (int k : train) { lr.train(target.get(k), data.get(k)); } } // check the accuracy on held out data int x = 0; int[] count = new int[3]; for (Integer k : test) { int r = lr.classifyFull(data.get(k)).maxValueIndex(); count[r]++; x += r == target.get(k) ? 1 : 0; } correct[x]++; }
Assert accuracy
After the we have performed 200 runs, each with 30 passes we will test for accuracy. The snippet below checks whether the List
correct
does not contain any entries with less than 95% accuracy. Also is checks whether there are no accuracies of 100%, too good to be true, because that would probably indicate a target leak. A target leak is information in the training set that unintentionally provides information about the target class, such as identifiers, timestamps but also very subtle pieces of information.
// verify we never saw worse than 95% correct, for (int i = 0; i < Math.floor(0.95 * test.size()); i++) { assertEquals(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i], 100.0 * i / test.size()), 0, correct[i]); } // nor perfect assertEquals(String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1]), 0, correct[test.size()]);
Next steps
This blog gave a short overview of Logistic Regression and Stochastic Gradient Descent in Mahout using the Iris dataset as an example. The next step would be to apply the algorithm on a more complex dataset that requires Mahout’s vector encoders. The vector encoders are used to classify text and word like variables instead of using double
s used in the Iris dataset. Other things that are not covered in this blog are ways to evaluate classifiers beyond accuracy. In closing, do you have questions or other feedback? Let me know by leaving a comment!
Thank you for writing this. Unfortunately I’m stuck on an easy step… opening the source. Where in the source is OnlineLogisticRegressionTest located?
Nevermind the above comment:
…core/src/main/java/org/apache/mahout/classifier/sgd/
Hi frank,
I tried to run this code but it is giving me very low accuracy on iris 2D dataset.Could you please tell me what i am doing wrong in this.My code is at
https://github.com/achala0309/mahout-sgd-classifier/tree/master
Hi Achala,
Could you please add a pom.xml file to your project so I can easily compile it, run it and have a look?
Cheers,
Frank
Hi Frank,
I added pom.xml in github repository
Thanks,
Achala
Hi Frank,
Did you get some time to check that
Thanks and Regards,
Achala
Hi Achala,
Thx, I will have a look this weekend.
Cheers,
Frank
Thank you so much for the insight. I was wondering if there is a way to store the classifier so that it can be used later to be tested using new data
yaa you could store your classifier by ModelSerializer.writeBinary
Achala,
You need to perform more passes over a shuffled training set.
The iris uses 30 passses.
Also, your code does not work outside of the box since you don’t have src/main/{java,resources} folders or a package.
Cheers,
Frank
Great post, but how can we deal with more complex dataset (e.g. categorical features) ?
@Roy Use Mahout’s vector encoders. See this example with the ‘Bank Marketing’ dataset from UCI: https://github.com/frankscholten/mahout-sgd-bank-marketing
Actually, this code was already committed in master: https://github.com/apache/mahout/tree/d850a091d3240f7863c92380fc01624c27f783c4/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing
I don’t understand the difference between features and predictors, does it train on all the features, or just the predictors?
@Sara – It trains on features sepal width and length in order to predict the value of the predictor, in this case the species of flower: Iris Setosa, Iris Versicolor,
Iris Virginica
Hi, to work with Mahout what coding language is required? Please suggest
@Sridharan – Any JVM based language such as Java, Groovy or Scala. Mahout itself is written in Java.
Thank you for this example. I’m using Mahout 0.9 jars, I’m having errors (using Eclipse) with “Iterables”, the Charsets within the Resources.readLines(), and the OnlineLogisticRegression class doesn’t seem to inherit the train() and classifyFull() methods … can you share your pom.xml? or your complete “import” list ? THX !
Hi Frank, i’ve just been taking a look at your post and it’s a great explanation. The one thing I don’t quite understand is the use of the list “Order” to select test and training data. The only thing that ever gets added to the list is its own size, which is surely 0, so how does this work when this data is later used in the training and test lists?
Thanks,
Simon 🙂
Can actually ignore the above comment… I was being stupid and figured it out ^^