900字范文,内容丰富有趣,生活中的好帮手!
900字范文 > Python实验--线性回归+梯度下降预测波士顿房价

Python实验--线性回归+梯度下降预测波士顿房价

时间:2024-01-19 05:56:41

相关推荐

Python实验--线性回归+梯度下降预测波士顿房价

1. 数据集介绍

先介绍一下将用到的数据集:共506样本,每个样本包括13个属性以及真实房价数据预处理:1.从sklearn的数据库中提取boston的数据库2.输出每个属性和房价之间的关联3.选择关联较大的属性留下4.划分数据集

def preprocess():# get the dataset of bostonX = boston().datay = boston().targetname_data = boston().feature_names# draw the figure of relationship between feature and priceplt.figure()for i in range(len(X[0])):plt.subplot(4, 4, i + 1)plt.scatter(X[:, i], y, s=20)plt.title(name_data[i])plt.show()# delete less relevant featureX = np.delete(X, [0, 1, 3, 4, 6, 7, 8, 9, 11], axis=1)# normalizationfor i in range(len(X[0])):X[:, i] = (X[:, i] - X[:, i].min()) / (X[:, i].max() - X[:, i].min())# split into test and trainXtrain, Xtest, Ytrain, Ytest = train_test_split(X, y, test_size=0.3, random_state=10)return Xtrain, Xtest, Ytrain, Ytest, X

2. 调用sklearn的线性回归函数

这个直接贴代码:

def lr(Xtrain, Xtest, Ytrain, Ytest, if_figure):# use LinearRegressionreg = LR().fit(Xtrain, Ytrain)y_pred = reg.predict(Xtest)loss = mean_squared_error(Ytest, y_pred)print("*************LR*****************")print("w\t= {}".format(reg.coef_))print("b\t= {:.4f}".format(reg.intercept_))# draw the figure of predict resultsif if_figure:plt.figure()plt.plot(range(len(Ytest)), Ytest, c="blue", label="real")plt.plot(range(len(y_pred)), y_pred, c="red", linestyle=':', label="predict")plt.title("predict results from row LR")plt.legend()plt.show()return loss

3. 手写梯度下降法

梯度下降主要思想就是以梯度作为每次迭代优化的方向,以步长更新参数,直到最优

为了两种方法比对方便,这里也使用均分误差作为Loss函数

def gradDescnet(Xtrain, Xtest, Ytrain, Ytest, X, if_figure, rate):# grad descentdef grad(y, yp, X):grad_w = (y - yp) * (-X)grad_b = (y - yp) * (-1)return [grad_w, grad_b]# set training parametersepoch_train = 100learning_rate = ratew = np.random.normal(0.0, 1.0, (1, len(X[0])))b = 0.0loss_train = []loss_test = []for epoch in range(epoch_train + 1):loss1 = 0for i in range(len(Xtrain)):yp = w.dot(Xtrain[i]) + b# calculate the losserr = Ytrain[i] - yploss1 += err ** 2# iterate update w and bgw = grad(Ytrain[i], yp, Xtrain[i])[0]gb = grad(Ytrain[i], yp, Xtrain[i])[1]w = w - learning_rate * gwb = b - learning_rate * gb# record the lossloss_train.append(loss1 / len(Xtrain))loss11 = 0for i in range(len(Xtest)):yp2 = w.dot(Xtest[i]) + berr2 = Ytest[i] - yp2loss11 += err2 ** 2# record the lossloss_test.append(loss11 / len(Xtest))# shuffle the dataXtrain, Ytrain = shuffle(Xtrain, Ytrain)# draw the figure of lossif if_figure:plt.figure()plt.title("figure of loss")plt.plot(range(len(loss_train)), loss_train, c="blue", linestyle=":", label="train")plt.plot(range(len(loss_test)), loss_test, c="red", label="test")plt.legend()plt.show()# draw figure of predict resultsif if_figure:Predict_value = []for i in range(len(Xtest)):Predict_value.append(w.dot(Xtest[i]) + b)plt.figure()plt.title("predict results from gradScent")plt.plot(range(len(Xtest)), Ytest, c="blue", label="real")plt.plot(range(len(Xtest)), Predict_value, c="red", linestyle=':', label="predict")plt.legend()plt.show()return loss_test[-1], w, b

4. 两种方法比对

为了最终代码整洁,这里也封装为一个函数

梯度下降的步长选择0.01,这个超参数在下一部分会进行优化选择

def test():if_figure = TrueXtrain, Xtest, Ytrain, Ytest, X = preprocess()loss_lr = lr(Xtrain, Xtest, Ytrain, Ytest, if_figure)loss_gd, w, b = gradDescnet(Xtrain, Xtest, Ytrain, Ytest, X, if_figure, 0.01)print("*************GD*****************")print("w\t: {}".format(w))print("b\t: {}".format(b))print("************loss****************")print("lr\t: %.4f" % loss_lr)print("gd\t: %.4f" % loss_gd)

输出结果:

*************LR*****************w= [ -0.3923 21.25173835 -8.18006811 -21.61002144]b= 23.0543*************GD*****************w: [[ -0.43534889 21.65996503 -8.1076 -21.3622824 ]]b: [22.83733711]************loss****************lr: 31.4272gd: 31.2842

5. 超参数调整

由于迭代步长过小容易造成更新速度慢,而过长容易导致错过最优点

这里选择从0.001到0.05之间,输出步长和loss值的关系

同样封装成一个函数

def searchRate():if_figure = FalseXtrain, Xtest, Ytrain, Ytest, X = preprocess()loss_grad = []w_grad = []b_grad = []rates = list(np.arange(0.001, 0.05, 0.001))epoch = 1for rate in rates:loss, w, b = gradDescnet(Xtrain, Xtest, Ytrain, Ytest, X, if_figure, rate)loss_grad.append(loss[0])w_grad.append(w)b_grad.append(b)print("epoch %d: %.4f" % (epoch, loss_grad[-1]))epoch += 1plt.figure()plt.plot(rates, loss_grad)plt.title("loss under different rate")plt.show()loss_grad_min = min(loss_grad)position = loss_grad.index(loss_grad_min)w = w_grad[position]b = b_grad[position]rate = rates[position]loss_lr = lr(Xtrain, Xtest, Ytrain, Ytest, if_figure)print("*************GD*****************")print("w\t: {}".format(w))print("b\t: {}".format(b))print("rate: %.3f" % rate)print("************loss****************")print("lr\t: %.4f" % loss_lr)print("gd\t: %.4f" % loss_grad_min)

输出结果:

epoch 1: 35.1047epoch 2: 31.9512epoch 3: 31.6400epoch 4: 31.8814epoch 5: 31.3429epoch 6: 31.7260epoch 7: 31.5825epoch 8: 31.5523epoch 9: 32.4876epoch 10: 31.4287epoch 11: 31.1475epoch 12: 32.0841epoch 13: 32.0033epoch 14: 31.5768epoch 15: 31.1828epoch 16: 31.6558epoch 17: 32.2582epoch 18: 32.4916epoch 19: 31.2118epoch 20: 32.2877epoch 21: 31.7237epoch 22: 32.1203epoch 23: 32.7307epoch 24: 32.7434epoch 25: 32.6421epoch 26: 31.8588epoch 27: 31.1762epoch 28: 33.0360epoch 29: 32.5580epoch 30: 32.4591epoch 31: 31.4191epoch 32: 31.1398epoch 33: 31.4291epoch 34: 31.3900epoch 35: 31.2239epoch 36: 31.4200epoch 37: 31.2967epoch 38: 32.5322epoch 39: 32.3174epoch 40: 34.3984epoch 41: 31.1794epoch 42: 31.8992epoch 43: 32.0060epoch 44: 34.0944epoch 45: 34.3244epoch 46: 31.1479epoch 47: 32.8374epoch 48: 31.7111epoch 49: 33.6676*************LR*****************w= [ -0.3923 21.25173835 -8.18006811 -21.61002144]b= 23.0543*************GD*****************w: [[ -0.29030409 21.60092767 -8.02647596 -21.79164094]]b: [23.35049725]rate: 0.032************loss****************lr: 31.4272gd: 31.1398

可见调整步长对于最终结果还是有较大影响的

6. 导入库汇总

from sklearn.linear_model import LinearRegression as LRfrom sklearn.model_selection import train_test_splitfrom sklearn.datasets import load_boston as boston import matplotlib.pyplot as pltfrom sklearn.utils import shuffleimport numpy as npfrom sklearn.metrics import mean_squared_error

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。