This summer, I took a free machine learning course for engineers offered by the University of Denver to understand machine learning. Since the Build-A-Bot team is getting closer to publishing our game so that the public can play, we are one step closer to working on the machine learning portion of the research project. The class I have been taking is going to prepare me for when the Build-A-Bot team needs to work on the machine learning algorithm.
This week in the course, we went over the concept of linear regression. But before I get into the complexity of linear regression, what exactly is machine learning? In 1959, the machine learning pioneer Arthur Samuel explains that machine learning is a "field of study that gives computers the ability to learn without being explicitly programmed." Normally in computer programs, the programmer has to explicitly tell the computer what to do and when. With machine learning, however, the computer executes a couple of functions and gives you a prediction or result based on the inputs that you have given it.
Machine learning can be used in a variety of ways. Do you know those YouTube and Netflix recommendations? Those were generated by a machine learning algorithm. Do you know the learning application Duolingo? It uses machine learning to determine how difficult it can make your lessons. The way they work is by using a lot of data to train. Those YouTube and Netflix recommendations are based off of what you have already watched, and what other users watch after watching what you just watched. Without sufficient data, it's impossible for the algorithm to have good recommendations.
For the sake of our example, we will be working with a company that has given us data about the population of a city in the thousands compared to the profit of the city in the thousands. The company wants us to create a machine learning algorithm that can predict very closely how much profit that city makes. If we were to plot the data that the company gave us, it would look like this:
Looking at the graph, it is obvious to us as humans that the bigger the population, the bigger the profit. But we are not capable of predicting with any sort of accuracy right off the bat. Since we are only looking at one input value (population of the city) and one output value (profit of the city), we can use a line to predict what this trend is going to look like. But how will the computer know which line to draw? We will tell the computer to find the line using linear regression. Linear regression is commonly used for predictive analysis in statistics. The way we will code it is as follows:
Generate a random m and b for the equation of the line (y = mx + b)
For each data point that we are given, determine the cost of it
Cost is simply how far away is it from the line. How far off is the line from the data point?
Find the average cost of all points
Modify m and b to lower the average cost until we converge at a point
This is the basic gist of how linear regression works. However, one issue you may fall into is trying to figure out how to modify m and b to lower the average cost. Do you increase or decrease the slope (m)? Do you increase or decrease the y-intercept (b)? By how much? Although it might seem like there's nothing we can do to figure this out, doesn't this sound like an optimization problem? All we are trying to find is the minimum cost of a line. So let's use a gradient descent function to find the lowest cost. This graph displays the costs of the variations of the line.
So how does the gradient descent function work? Well, let's imagine a marble is randomly placed on a hill. The marble is going to roll down the side along the slope of the hill. In the same way, our algorithm places the y-intercept and slope at a random point on this graph. The computer then calculates the slope of that point and uses that slope to adjust the y-intercept and slope values for the next iteration. By doing this, the points (or marble) will be approaching a minimum of the costs. If we were to graph out the cost for every iteration we will be getting a graph that looks like this:
Notice how the cost starts off extremely high, but as you go through more iterations of the gradient descent function, the cost plateaus around a cost of 5. This is how we know that the machine learning algorithm has found accurate y-intercept (b) and slope (m) values. Now if we plot the line with those values and overlay it with the points we have been given, we get a graph that looks like this:
The machine learning algorithm has found that this is the line of best fit. Now we can use this line to predict future values. For example, let's say our company wanted to know what the profit of a city of population 17.5 thousand is. Looking at the line of best fit, the algorithm predicts that the profit of that city will be about 15,000. Overall, it's crazy to see how powerful a simple line can be in machine learning. Can't wait to learn some more!
Want more content like this?
Yes!
Nah, I'm good.
Comments