I am working on implementing from scratch a linear regression model means without using Sklearn package.
all was working just fine , until i tried ploting the result.
i looked at a bunch of solution but neither of them was for myy problem
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
data = pd.read_csv(r'C:\Salary.csv')
x=data['Salary']
y=data['YearsExperience']
#y= mx b
m = 0
b = 0
Learning_Rate = .01
epochs = 5000
n = np.float(x.shape[0])
error = []
for i in range(epochs):
Y_hat = m*x b
#error
mse= (1/n)*np.sum((y-Y_hat)**2)
error.append(mse)
#gradient descend
db = (-2/n) * np.sum(x*(y-Y_hat))
dm = (-2/n) * np.sum((y-Y_hat))
m = m - Learning_Rate * dm
b = b - Learning_Rate * db
#tracing x and y line
x_line = np.linspace(0, 15, 100)
y_line = (m*x_line) b
#ploting result
plt.figure(figsize=(8,6))
plt.title('LR result')
**plt.plot(x_line,y_line) #the problem is apparently here
# i just don't know what to do**
plt.scatter(x,y)
plt.show()
appart from that, there is no problem with the code .
CodePudding user response:
Your code has multiple problems:
you are plotting the line from
0and15, while data range from about40000to140000. Even if you are correctly computing the line, you are going to plot it in a region far away from your datain the loop there is a mistake in the computation of
dmanddb, they are swapped. The corrected expressions are:dm = (-2/n)*np.sum(x*(y - Y_hat)) db = (-2/n)*np.sum((y - Y_hat))your
xandydata are on very different scales:xis~10⁴magnitude, whileyis~10¹. For this reason, alsomandbwill likely be very different from each other (different orders of magnitude). This is the reason why you should use two different learning rate for the different quantities you are optimizing:Learning_Rate_mformandLearning_Rate_bforb-
CodePudding user response:
The problem is not happening while plotting, the problem is with the parameters in plt.plot(x_line,y_line), I tested your code and found that y_line is all NaN values, double check the calculations (y_line, m, dm).


