from__future__importannotationsimporttimefromIPythonimportdisplayfrommatplotlibimportpyplotaspltimportnumpyasnpdefuse_svg_display():"""Use the svg format to display a plot in Jupyter. Defined in :numref:`sec_calculus`"""display.set_matplotlib_formats('svg')defset_figsize(figsize=(3.5,2.5)):"""Set the figure size for matplotlib. Defined in :numref:`sec_calculus`"""use_svg_display()plt.rcParams['figure.figsize']=figsizedefset_axes(axes,xlabel,ylabel,xlim,ylim,xscale,yscale,legend):"""Set the axes for matplotlib. Defined in :numref:`sec_calculus`"""axes.set_xlabel(xlabel)axes.set_ylabel(ylabel)axes.set_xscale(xscale)axes.set_yscale(yscale)axes.set_xlim(xlim)axes.set_ylim(ylim)iflegend:axes.legend(legend)axes.grid()defplot(X,Y=None,xlabel=None,ylabel=None,legend=None,xlim=None,ylim=None,xscale='linear',yscale='linear',fmts=('-','m--','g-.','r:'),figsize=(3.5,2.5),axes=None):"""Plot data points. Defined in :numref:`sec_calculus`"""iflegendisNone:legend=[]set_figsize(figsize)axes=axesifaxeselseplt.gca()# Return True if `X` (tensor or list) has 1 axisdefhas_one_axis(X):return(hasattr(X,"ndim")andX.ndim==1orisinstance(X,list)andnothasattr(X[0],"__len__"))ifhas_one_axis(X):X=[X]ifYisNone:X,Y=[[]]*len(X),Xelifhas_one_axis(Y):Y=[Y]iflen(X)!=len(Y):X=X*len(Y)axes.cla()forx,y,fmtinzip(X,Y,fmts):iflen(x):axes.plot(x,y,fmt)else:axes.plot(y,fmt)set_axes(axes,xlabel,ylabel,xlim,ylim,xscale,yscale,legend)
[文档]classTimer:"""Record multiple running times."""def__init__(self):"""Defined in :numref:`subsec_linear_model`"""self.times=[]self.start()defstart(self):"""Start the timer."""self.tik=time.time()defstop(self):"""Stop the timer and record the time in a list."""self.times.append(time.time()-self.tik)returnself.times[-1]defavg(self):"""Return the average time."""returnsum(self.times)/len(self.times)defsum(self):"""Return the sum of time."""returnsum(self.times)defcumsum(self):"""Return the accumulated time."""returnnp.array(self.times).cumsum().tolist()
[文档]classAccumulator:"""For accumulating sums over `n` variables."""def__init__(self,n):"""Defined in :numref:`sec_softmax_scratch`"""self.data=[0.0]*ndefadd(self,*args):self.data=[a+float(b)fora,binzip(self.data,args)]defreset(self):self.data=[0.0]*len(self.data)def__getitem__(self,idx):returnself.data[idx]
[文档]classAnimator:"""For plotting data in animation."""def__init__(self,xlabel=None,ylabel=None,legend=None,xlim=None,ylim=None,xscale='linear',yscale='linear',fmts=('-','m--','g-.','r:'),nrows=1,ncols=1,figsize=(3.5,2.5)):"""Defined in :numref:`sec_softmax_scratch`"""# Incrementally plot multiple linesiflegendisNone:legend=[]use_svg_display()self.fig,self.axes=plt.subplots(nrows,ncols,figsize=figsize)ifnrows*ncols==1:self.axes=[self.axes,]# Use a lambda function to capture argumentsself.config_axes=lambda:set_axes(self.axes[0],xlabel,ylabel,xlim,ylim,xscale,yscale,legend)self.X,self.Y,self.fmts=None,None,fmtsdefadd(self,x,y):# Add multiple data points into the figureifnothasattr(y,"__len__"):y=[y]n=len(y)ifnothasattr(x,"__len__"):x=[x]*nifnotself.X:self.X=[[]for_inrange(n)]ifnotself.Y:self.Y=[[]for_inrange(n)]fori,(a,b)inenumerate(zip(x,y)):ifaisnotNoneandbisnotNone:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()forx,y,fmtinzip(self.X,self.Y,self.fmts):self.axes[0].plot(x,y,fmt)self.config_axes()display.display(self.fig)display.clear_output(wait=True)