900字范文,内容丰富有趣,生活中的好帮手!
900字范文 > MXNet动手学深度学习笔记:线性回归

MXNet动手学深度学习笔记:线性回归

时间:2023-01-25 10:47:33

相关推荐

MXNet动手学深度学习笔记:线性回归

为什么80%的码农都做不了架构师?>>>

#coding:utf-8from mxnet import ndarray as ndfrom mxnet import autogradimport randomimport matplotlib.pyplot as pltnum_inputs = 2num_examples = 1000true_w = [2,-3.4]true_b = 4.2X = nd.random_normal(shape=(num_examples,num_inputs))y = true_w[0] * X[:,0] + true_w[1] * X[:,1] + true_b# 添加随机噪声数据y += .01 * nd.random_normal(shape=y.shape)# 数据读取def data_iter():batch_size = 10idx = list(range(num_examples))random.shuffle(idx)for i in range(0,num_examples,batch_size):j = nd.array(idx[i:min(i+batch_size,num_examples)])yield nd.take(X,j),nd.take(y,j)# 初始化模型参数w = nd.random_normal(shape=(num_inputs,1))b = nd.zeros((1,))params = [w,b]for param in params:param.attach_grad()# 定义模型def net(X):return nd.dot(X,w) + b# 定义损失函数def square_loss(yhat,y):# 转换成相同形状,避免自动转换return(yhat - y.reshape(yhat.shape)) ** 2# 优化器def SGD(params,lr):for param in params:param[:] = param - lr * param.grad# 训练# 模型函数def real_fun(X):return true_w[0] * X[:,0] + true_w[1] * X[:,1] + true_bdef plot(losses,sample_size=100):xs = list(range(len(losses)))f,(fg1,fg2) = plt.subplots(1,2)fg1.set_title('Loss during training')fg1.plot(xs,losses,'-r')fg2.set_title('Estimate vs real funtion')fg2.plot(X[:sample_size,1].asnumpy(),net(X[:sample_size,:]).asnumpy(),'or',label='Estimated')fg2.plot(X[:sample_size,1].asnumpy(),real_fun(X[:sample_size,:]).asnumpy(),'*g',label='Real')fg2.legend()plt.show()epochs = 5learning_rate = 0.01niter = 0losses = []moving_loss = 0smoothing_constant = 0.01for e in range(epochs):total_loss = 0for data,label in data_iter():with autograd.record():output = net(data)loss = square_loss(output,label)loss.backward()SGD(params,learning_rate)total_loss += nd.sum(loss).asscalar()# 记录每一次数据点后,损失的移动平均值变化niter += 1curr_loss = nd.mean(loss).asscalar()moving_loss = (1 - smoothing_constant) * moving_loss + \(smoothing_constant) * curr_lossest_loss = moving_loss / (1 - (1 - smoothing_constant) ** niter)if (niter + 1) % 100 == 0:losses.append(est_loss)print('Epochs %s,batch %s . Moving avg of loss: %s Average loss:%f' %(e,niter,est_loss,total_loss / num_examples))print(true_w,w)print(true_b,b)

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。