- Python Machine Learning Cookbook
- Prateek Joshi
- 440字
- 2025-02-24 18:38:24
Building a simple classifier
Let's see how to build a simple classifier using some training data.
How to do it…
- We will use the
simple_classifier.py
file that is already provided to you as reference. Assuming that you imported thenumpy
andmatplotlib.pyplot
packages like we did in the last chapter, let's create some sample data:X = np.array([[3,1], [2,5], [1,8], [6,4], [5,2], [3,5], [4,7], [4,-1]])
- Let's assign some labels to these points:
y = [0, 1, 1, 0, 0, 1, 1, 0]
- As we have only two classes, the
y
list contains 0s and 1s. In general, if you have N classes, then the values iny
will range from 0 to N-1. Let's separate the data into classes based on the labels:class_0 = np.array([X[i] for i in range(len(X)) if y[i]==0]) class_1 = np.array([X[i] for i in range(len(X)) if y[i]==1])
- To get an idea about our data, let's plot it, as follows:
plt.figure() plt.scatter(class_0[:,0], class_0[:,1], color='black', marker='s') plt.scatter(class_1[:,0], class_1[:,1], color='black', marker='x')
This is a scatterplot, where we use squares and crosses to plot the points. In this context, the
marker
parameter specifies the shape you want to use. We use squares to denote points inclass_0
and crosses to denote points inclass_1
. If you run this code, you will see the following figure: - In the preceding two lines, we just use the mapping between
X
andy
to create two lists. If you were asked to inspect the datapoints visually and draw a separating line, what would you do? You will simply draw a line in between them. Let's go ahead and do this:line_x = range(10) line_y = line_x
- We just created a line with the mathematical equation y = x. Let's plot it, as follows:
plt.figure() plt.scatter(class_0[:,0], class_0[:,1], color='black', marker='s') plt.scatter(class_1[:,0], class_1[:,1], color='black', marker='x') plt.plot(line_x, line_y, color='black', linewidth=3) plt.show()
- If you run this code, you should see the following figure:
There's more…
We built a simple classifier using the following rule: the input point (a
, b
) belongs to class_0
if a
is greater than or equal to b
; otherwise, it belongs to class_1
. If you inspect the points one by one, you will see that this is, in fact, true. This is it! You just built a linear classifier that can classify unknown data. It's a linear classifier because the separating line is a straight line. If it's a curve, then it becomes a nonlinear classifier.
This formation worked fine because there were a limited number of points, and we could visually inspect them. What if there are thousands of points? How do we generalize this process? Let's discuss that in the next recipe.