defmultivariate_normal_fit(X):N=X.shape[0]# Number of samplesD=X.shape[1]# Dimension of samplebeta=1m=np.zeros(D)W_inv=np.linalg.inv(np.diag(np.ones(D)))nu=Dbeta_hat=N+betam_hat=(X.sum(axis=0)+beta*m)/beta_hatX_sum=np.zeros([D,D])foriinrange(N):X_sum+=np.dot(X[i].reshape(-1,1),X[i].reshape(1,-1))W_hat_inv=X_sum+beta*np.dot(m.reshape(-1,1),m.reshape(1,-1)) \
-beta_hat*np.dot(m_hat.reshape(-1,1),m_hat.reshape(1,-1))+W_invnu_hat=N+nureturnm_hat,beta_hat,nu_hat,W_hat_inv
多次元版のスチューデントのt分布
学習後の確率分布を確認するため、多次元版のスチューデントのt分布をクラスとして実装する。
確率密度関数 (Probability Density Function, PDF) を求めるため、pdfメソッドを用意した。pdfメソッドに配列を引数として与えると、その配列に対応する確率を返す。
importmathimportnumpyasnpimportscipy.statsimportmatplotlib.pyplotaspltclassmultivariate_student_t():def__init__(self,mu,lam,nu):# mu: D size array, lam: DxD matrix, nu: scalarself.D=mu.shape[0]self.mu=muself.lam=lamself.nu=nudefpdf(self,x):temp1=np.exp(math.lgamma((self.nu+self.D)/2)-math.lgamma(self.nu/2))temp2=np.sqrt(np.linalg.det(self.lam))/(np.pi*self.nu)**(self.D/2)ifx.shape[0]==1:temp3=1+np.dot(np.dot((x-self.mu).T,self.lam),x-self.mu)/self.nuelse:temp3=[]forainx:temp3+=[1+np.dot(np.dot((a-self.mu).T,self.lam),a-self.mu)/self.nu]temp4=-(self.nu+self.D)/2returntemp1*temp2*(np.array(temp3)**temp4)defmultivariate_normal_fit(X):N=X.shape[0]# Number of samplesD=X.shape[1]# Dimension of samplebeta=1m=np.zeros(D)W_inv=np.linalg.inv(np.diag(np.ones(D)))nu=Dbeta_hat=N+betam_hat=(X.sum(axis=0)+beta*m)/beta_hatX_sum=np.zeros([D,D])foriinrange(N):X_sum+=np.dot(X[i].reshape(-1,1),X[i].reshape(1,-1))W_hat_inv=X_sum+beta*np.dot(m.reshape(-1,1),m.reshape(1,-1)) \
-beta_hat*np.dot(m_hat.reshape(-1,1),m_hat.reshape(1,-1))+W_invnu_hat=N+nureturnm_hat,beta_hat,nu_hat,W_hat_invif__name__=="__main__":np.random.seed(0)mean=np.array([0,1])cov=np.array([[2,1],[1,2]])Ns=100# Number of samplesX=np.random.multivariate_normal(mean,cov,Ns)# Sample datafig,ax=plt.subplots(figsize=(8,4))ax.scatter(X[:,0],X[:,1])ax.axis('square')ax.set_xlim(-5,5)ax.set_ylim(-5,5)ax.grid()ax.set_xlabel("x1")ax.set_ylabel("x2")fig.tight_layout()plt.show()m_hat,beta_hat,nu_hat,W_hat_inv=multivariate_normal_fit(X)D=m_hat.shape[0]mu_hat=m_hatlam_hat=(1-D+nu_hat)*beta_hat*np.linalg.inv(W_hat_inv)/(1+beta_hat)nu_hat=1-D+nu_hatmt=multivariate_student_t(mu_hat,lam_hat,nu_hat)X1,X2=np.meshgrid(np.arange(-5,5,0.1),np.arange(-5,5,0.1))Y=np.vstack([X1.ravel(),X2.ravel()]).Tmn_pdf=scipy.stats.multivariate_normal.pdf(Y,mean=mean,cov=cov)mn_pdf=mn_pdf.reshape(X1.shape[0],-1)mt_pdf=mt.pdf(Y)mt_pdf=mt_pdf.reshape(X1.shape[0],-1)fig,ax=plt.subplots(ncols=2,figsize=(10,4))ax0=ax[0].pcolor(X1,X2,mn_pdf,cmap="Blues",vmin=0,vmax=0.1)ax1=ax[1].pcolor(X1,X2,mt_pdf,cmap="Blues",vmin=0,vmax=0.1)foriinrange(2):ax[i].axis('equal')ax[i].grid()ax[i].set_xlabel("x1")ax[i].set_ylabel("x2")ax[0].set_title("Original PDF")ax[1].set_title("Inferred PDF")plt.colorbar(ax=ax[0],mappable=ax0)plt.colorbar(ax=ax[1],mappable=ax1)fig.tight_layout()plt.show()