本文共 1767 字,大约阅读时间需要 5 分钟。
学习莫烦的pytorch视频,对一些函数进行搜索,记录功能以及测试
import torchimport numpy as npfrom torch.autograd import Variableimport torch.nn.functional as Fimport matplotlib.pyplot as pltimport pdbx=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)#增加dim个维数,类似 int a=0==>int a[1][1]={0}y=x.pow(2)+0.2*torch.rand(x.size())x, y =Variable(x), Variable(y)#tensor2Variable#plt.scatter(x.data.numpy(), y.data.numpy())#散点图#plt.show()class Net(torch.nn.Module):#继承torch的这个模块 def __init__(self, n_features, n_hidden, n_output):#层的信息 super(Net,self).__init__() self.hidden = torch.nn.Linear(n_features, n_hidden) #self.hidden2 = torch.nn.Linear(n_hidden,n_hidden) self.predict = torch.nn.Linear(n_hidden, n_output) def forward(self,x): x=F.relu(self.hidden(x))#效果最好 #x=torch.sigmoid(self.hidden(x)) #x=torch.tanh(self.hidden(x)) #x=F.relu(self.hidden2(x)) x=self.predict(x) return xnet = Net(1, 10, 1)#print(net)#打印网络结构plt.ion()#打开交互模式#optimizer = torch.optim.SGD(net.parameters(), lr=0.5)#优化方法loss_func = torch.nn.MSELoss()#均方差损失函数'''t=0for i in net.parameters(): print(t,'\n') t=t+1 print(i)'''for t in range(200): l=0.5 if(t>100&&l>=0.5): l=l/10 optimizer = torch.optim.SGD(net.parameters(), lr=l)#优化方法 prediction = net(x) #pdb.set_trace() loss = loss_func(prediction, y)#预测值,真实值 optimizer.zero_grad()#梯度置0,梯度随着loss的变化而变化,每个epoch都不一样 loss.backward()#反向传播,计算梯度 optimizer.step()#根据梯度,学习率更新权重参数 if t%5==0: plt.cla()#清除当前axes对象 plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy()) plt.text(0.5, 0, 'Loss=%.4f'%loss.item(), fontdict={ 'size':20, 'color':'red'}) plt.pause(0.1)#暂停plt.ioff()#关闭交互模式plt.show()
转载地址:http://twksi.baihongyu.com/