Pseudo-document-based Topic Model(基于伪文档的主题模型)的理解以及源码解读

2,733 阅读

原文地址:https://blog.csdn.net/qy20115549/article/details/79877825

本文作者:合肥工业大学 管理学院 钱洋 email:1563178220@qq.com 内容可能有不到之处,欢迎交流。

未经本人允许禁止转载。 #论文来源

Zuo Y, Wu J, Zhang H, et al. Topic modeling of short texts: A pseudo-document view[C]//Proceedings of the 22nd ACM SIGKDD international conference on knowledge discovery and data mining. ACM, 2016: 2105-2114. 来自于16年,计算机顶会KDD的文章。作者是北航的学者。 #论文简介 主题模型的底层原理是基于共现,但是对于短文本来说,这种共现是很稀疏的,这将导致模型学习的效果不好。当然,有很多种方法来处理短文本主题学习。作者这篇文章提供了一种伪文档策略。 下面我们来看看模型的概率图:

这里写图片描述
这里写图片描述
(a)图是基本的PTM,(b)图引入了稀疏性先验,即Spike and Slab prior该先验在很多主题模型都使用过,具体可以看我之前的一些博客分享。这里使用的目的是实现伪文档主题分布的稀疏性。 模型的生成过程如下:
这里写图片描述
这里写图片描述
引入稀疏性,只是改了右半边的生成方式,如下图所示:

#模型推理 首先,抽取文档所属的伪文档,如下图所示,该公式是跟对包含稀疏性的SPTM,如果是PTM则简单的改动一下就行。

这里写图片描述
这里写图片描述
再抽取文档单词所属的主题,如下图所示:
这里写图片描述
这里写图片描述
接着,抽取伪文档是否包含某主题,即伪文档主题选择器。该公式依据的是Wang等人的抽样方式,该文章是非参模型,且提供了详细的推导过程,大家可以学习。 C. Wang and D. M. Blei. Decoupling sparsity and smoothness in the discrete hierarchical dirichlet process. In Advances in neural information processing systems, pages 1982{1989. 2009.
这里写图片描述
这里写图片描述

#源码解读 这里解读的源码是PTM模型,根据公式理解还是很简单的。

package main;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

public class PseudoDocTM implements Runnable {

	public int K1 = 1000;  //设置伪文档数量
	public int K2 = 100; //
	
	public int M;
	public int V;

	public double alpha1 = 0.1;
	public double alpha2 = 0.1;

	public double beta = 0.01;

	public int mp[]; //分配到每个伪文档文档的数量

	public int npk[][];  //伪文档l由主题k生成的单词数量
	public int npkSum[];  //伪文档对应的单词总数

	public int nkw[][]; //主题k对应的单词w的数量
	public int nkwSum[]; //主题k对应的单词总数

	public int zAssigns_1[];  //文档分配伪文档
	public int zAssigns_2[][]; //文档单词分配主题
 
	public int niters = 200; 
	public int saveStep = 1000; 
	public String inputPath="";
	public String outputPath="";

	public int innerSteps = 10;

	public List<List<Integer>> docs = new ArrayList<List<Integer>>(); //文档表示
	public HashMap<String, Integer> w2i = new HashMap<String, Integer>(); //词的编号
	public HashMap<Integer, String> i2w = new HashMap<Integer, String>(); //编号转化为词


	public PseudoDocTM(int P,int K,int iter,int innerStep,int saveStep,double alpha1,double alpha2,double beta,String inputPath,String outputPath){
		this.K1=P;
		this.K2=K;
		this.niters=iter;
		this.innerSteps= innerStep;
		this.saveStep =saveStep;
		this.alpha1=alpha1;
		this.alpha2= alpha2;
		this.beta = beta;
		this.inputPath=inputPath;
		this.outputPath=outputPath;
	}
	//加载语料
	public void loadTxts(String txtPath) {
		BufferedReader reader = IOUtil.getReader(txtPath, "UTF-8");

		String line;
		try {
			line = reader.readLine();
			while (line != null) {
				List<Integer> doc = new ArrayList<Integer>();

				String[] tokens = line.trim().split("\\s+");
				for (String token : tokens) {
					if (!w2i.containsKey(token)) {
						w2i.put(token, w2i.size());
						i2w.put(w2i.get(token), token);
					}
					doc.add(w2i.get(token));
				}
				docs.add(doc);
				line = reader.readLine();
			}
			reader.close();
		} catch (IOException e) {
			e.printStackTrace();
		}

		//文档数量
		M = docs.size();
		//语料词的数量
		V = w2i.size();

		return;
	}
	//初始化模型
	public void initModel() {
		
		mp = new int[K1];

		npk = new int[K1][K2];
		npkSum = new int[K1];

		nkw = new int[K2][V];
		nkwSum = new int[K2];
		
		zAssigns_1 = new int[M]; //文档所属的伪文档
		zAssigns_2 = new int[M][]; //文档每个单词所属的主题

		for (int m = 0; m != M; m++) {
			//文档单词的数量
			int N = docs.get(m).size();
			//初始化
			zAssigns_2[m] = new int[N];
			//随机分配文档所属的伪文档
			int z1 = (int) Math.floor(Math.random()*K1);
			zAssigns_1[m] = z1;

			mp[z1] ++; //伪文档对应的文本数量增加
			//对每个单词随机分配主题
			for (int n = 0; n != N; n++) {
				int w = docs.get(m).get(n);
				int z2 = (int) Math.floor(Math.random()*K2);

				npk[z1][z2] ++;
				npkSum[z1] ++;

				nkw[z2][w] ++;
				nkwSum[z2] ++;

				zAssigns_2[m][n] = z2;
			}
		}
	}
	//抽取文档所属的伪文档
	public void sampleZ1(int m) {
		int z1 = zAssigns_1[m];  //获取文档所属的伪文档
		int N = docs.get(m).size(); //获取文档单词的数量

		mp[z1] --; //移除该文档,伪文档z1对应的单词数量减少
		
		Map<Integer, Integer> k2Count = new HashMap<Integer, Integer>();
		for (int n = 0; n != N; n++){ //循环文档的每个单词
			int z2 = zAssigns_2[m][n]; //获取单词的主题分配
			if (k2Count.containsKey(z2)) { //计算每个主题包含该文档单词的总数量
				k2Count.put(z2, k2Count.get(z2)+1);
			} else {
				k2Count.put(z2, 1);
			}

			npk[z1][z2] --;
			npkSum[z1] --;
		}
		
		double k2Alpha2 = K2 * alpha2;   //分母的K*alpha

		double[] pTable = new double[K1];
		//循环每个伪文档
		for (int k = 0; k != K1; k++) {
			double expectTM = 1.0;
			int index = 0;
			//这里要计算单词的频次,进行连乘
			for (int z2 : k2Count.keySet()) {
				int c = k2Count.get(z2);
				for (int i = 0; i != c; i++) {
					expectTM *= (npk[k][z2] + alpha2 + i) / (k2Alpha2 + npkSum[k] + index);
					index ++;
				}
			}
			//基于公式计算概率
			pTable[k] = (mp[k] + alpha1) / (M + K1 * alpha1) * expectTM;
		}
		//轮盘赌选择
		for (int k = 1; k != K1; k++) { //这里注意k=1开始,不能k=0
			pTable[k] += pTable[k-1];
		}

		double r = Math.random() * pTable[K1-1];

		for (int k = 0; k != K1; k++) {
			if (pTable[k] > r) {
				z1 = k;
				break;
			}
		}
		//基于轮盘赌选择的伪文档,重新统计
		mp[z1] ++;
		for (int n =0; n != N; n++) {
			int z2 = zAssigns_2[m][n];
			npk[z1][z2] ++;
			npkSum[z1] ++;
		}

		zAssigns_1[m] = z1;
	}
	//抽取文档m第n个单词的主题
	public void sampleZ2(int m, int n) {
		
		int z1 = zAssigns_1[m]; //获取文档所属的伪文档
		int z2 = zAssigns_2[m][n]; //获取文档m第n个所属的主题
		int w = docs.get(m).get(n); //获取单词编号

		npk[z1][z2] --;  //统计伪文档z1、主题z2生成的单词数量
		npkSum[z1] --; //伪文档z1对应的总单词数量
		nkw[z2][w] --; //主题z2对应的单词w的数量
		nkwSum[z2] --; //主题z2中所有单词的数量

		double VBeta = V * beta; //分母中的V*beta
		double k2Alpha2 = K2 * alpha2; //分母中的 K*alpha

		double[] pTable = new double[K2];
		//基于公式计算-----这里和公式有差异,公式应该按照这里写,及主题词分母应该按照前面的表达
		for (int k = 0; k != K2; k++) {
			pTable[k] = (npk[z1][k] + alpha2) / (npkSum[z1] + k2Alpha2) *
					(nkw[k][w] + beta) / (nkwSum[k] + VBeta);
		}
		//轮盘赌选择
		for (int k = 1; k != K2; k++) {
			pTable[k] += pTable[k-1];
		}

		double r = Math.random() * pTable[K2-1];

		for (int k = 0; k != K2; k++) {
			if (pTable[k] > r) {
				z2 = k;
				break;
			}
		}
		//重新统计相关词频
		npk[z1][z2] ++;
		npkSum[z1] ++;
		nkw[z2][w] ++;
		nkwSum[z2] ++;

		zAssigns_2[m][n] = z2;
		return;
	}

	public void estimate() {
		long start = 0;
		for (int iter = 0; iter != niters; iter++) {
			start = System.currentTimeMillis();
			System.out.println("PAM4ST Iteration: " + iter + " ...");
			if(iter%this.saveStep==0&&iter!=0&&iter!=this.niters-1){
				this.storeResult(iter);
			}
			//对每篇文档循环,将文档分配到伪文档
			for (int i = 0; i != innerSteps; i++) {
				for (int m = 0; m != M; m++) {
					this.sampleZ1(m);
				}
			}
			//对每篇文档进行循环,抽取每个单词所属的主题
			for (int i = 0; i != innerSteps; i++) {
				for (int m = 0; m != M; m++) {
					int N = docs.get(m).size();
					for (int n = 0; n != N; n++) {
						sampleZ2(m, n);
					}
				}
			}
			System.out.println("cost time:"+(System.currentTimeMillis()-start));
		}
		return;
	}
	//计算伪文档的主题分布---相当于LDA的文档主题分布
	public double[][] computeThetaP() {
		double[][] theta = new double[K1][K2];
		for (int k1 = 0; k1 != K1; k1++) {
			for (int k2 = 0; k2 != K2; k2++) {
				theta[k1][k2] = (npk[k1][k2] + alpha2) / (npkSum[k1] + K2*alpha2);
			}
		}
		return theta;
	}
	
	public void saveThetaP(String path) throws IOException {
		BufferedWriter writer = IOUtil.getWriter(path);
		double[][] theta = this.computeThetaP();
		for (int k1 = 0; k1 != K1; k1++) {
			for (int k2 = 0; k2 != K2; k2++) {
				writer.append(theta[k1][k2]+" ");
			}
			writer.newLine();
		}
		writer.flush();
		writer.close();
	}
	
	public void saveZAssigns1(String path) throws IOException {
		BufferedWriter writer = IOUtil.getWriter(path);
		
		for (int m = 0; m != M; m++) {
			writer.append(zAssigns_1[m]+"\n");
		}
		
		writer.flush();
		writer.close();
	}
	//计算主题词分布
	public double[][] computePhi() {
		double[][] phi = new double[K2][V];
		for (int k = 0; k != K2; k++) {
			for (int v = 0; v != V; v++) {
				phi[k][v] = (nkw[k][v] + beta) / (nkwSum[k] + V*beta);
			}
		}
		return phi;
	}
	//排序算法
	public ArrayList<List<Entry<String, Double>>> sortedTopicWords(
			double[][] phi, int T) {
		ArrayList<List<Entry<String, Double>>> res = new ArrayList<List<Entry<String, Double>>>();
		for (int k = 0; k != T; k++) {
			HashMap<String, Double> term2weight = new HashMap<String, Double>();
			for (String term : w2i.keySet())
				term2weight.put(term, phi[k][w2i.get(term)]);

			List<Entry<String, Double>> pairs = new ArrayList<Entry<String, Double>>(
					term2weight.entrySet());
			Collections.sort(pairs, new Comparator<Entry<String, Double>>() {
				public int compare(Entry<String, Double> o1,
						Entry<String, Double> o2) {
					return (o2.getValue().compareTo(o1.getValue()));
				}
			});
			res.add(pairs);
		}
		return res;
	}


	public void printTopics(String path,int top_n) throws IOException {
		BufferedWriter writer = IOUtil.getWriter(path);
		double[][] phi = computePhi();
		ArrayList<List<Entry<String, Double>>> pairsList = this
				.sortedTopicWords(phi, K2);
		for (int k = 0; k != K2; k++) {
			writer.write("Topic " + k + ":\n");
			for (int i = 0; i != top_n; i++) {
				writer.write(pairsList.get(k).get(i).getKey() + " "
						+ pairsList.get(k).get(i).getValue()+"\n");
			}
		}
		writer.close();
	}

	public void savePhi(String path) {
		BufferedWriter writer = IOUtil.getWriter(path, "utf-8");

		double[][] phi = computePhi();
		int K = phi.length;
		assert K > 0;
		int V = phi[0].length;

		try {
			for (int k = 0; k != K; k++) {
				for (int v = 0; v != V; v++) {
					writer.append(phi[k][v]+" ");
				}
				writer.append("\n");
			}
			writer.flush();
			writer.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
		return;
	}

	public void saveWordmap(String path) {
		BufferedWriter writer = IOUtil.getWriter(path, "utf-8");

		try {
			for (String word : w2i.keySet())
				writer.append(word + "\t" + w2i.get(word) + "\n");

			writer.flush();
			writer.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
		return;
	}

	public void saveAssign(String path){
		BufferedWriter writer = IOUtil.getWriter(path, "utf-8");
		try {
			for(int i=0;i<zAssigns_2.length;i++){
				for(int j=0;j<zAssigns_2[i].length;j++){
					writer.write(docs.get(i).get(j)+":"+zAssigns_2[i][j]+" ");
				}
				writer.write("\n");
			}
			writer.flush();
			writer.close();
		} catch (IOException e) {
			e.printStackTrace();
		}

		return;
	}
	public void printModel(){
		System.out.println("\tK1 :"+this.K1+
				"\tK2 :"+this.K2+
				"\tniters :"+this.niters+
				"\tinnerSteps :"+this.innerSteps+
				"\tsaveStep :"+this.saveStep +
				"\talpha1 :"+this.alpha1+
				"\talpha2 :"+this.alpha2+
				"\tbeta :"+this.beta +
				"\tinputPath :"+this.inputPath+
				"\toutputPath :"+this.outputPath);
	}
	
	int[][] ndk;
	int[] ndkSum;
	
	public void convert_zassigns_to_arrays_theta(){
		ndk = new int[M][K2];
		ndkSum = new int[M];
		
		for (int m = 0; m != M; m++) {
			for (int n = 0; n != docs.get(m).size(); n++) {
				ndk[m][zAssigns_2[m][n]] ++;
				ndkSum[m] ++;
			}
		}
	}
	//计算文档主题分布
	public double[][] computeTheta() {
		convert_zassigns_to_arrays_theta();
		double[][] theta = new double[M][K2];
		for (int m = 0; m != M; m++) {
			for (int k = 0; k != K2; k++) {
				theta[m][k] = (ndk[m][k] + alpha2) / (ndkSum[m] + K2 * alpha2);
			}
		}
		return theta;
	}
	
	public void saveTheta(String path) {
		BufferedWriter writer = IOUtil.getWriter(path, "utf-8");
		
		double[][] theta = computeTheta();
		try {
			for (int m = 0; m != M; m++) {
				for (int k = 0; k != K2; k++) {
					writer.append(theta[m][k]+" ");
				}
				writer.append("\n");
			}
			writer.flush();
			writer.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
		return;
	}
	
	public void storeResult(int times){
		String appendString="final";
		if(times!=0){
			appendString =times+"";
		}
		try {
			this.printTopics(outputPath+"/model-"+appendString+".twords",20);
			this.saveWordmap(outputPath+"/wordmap.txt");
			this.savePhi(outputPath+"/model-"+appendString+".phi");
			this.saveAssign(outputPath+"/model-"+appendString+".tassign");
			this.saveTheta(outputPath+"/model-"+appendString+".theta");
			this.saveThetaP(outputPath+"/model-"+appendString+".thetap");
			this.saveZAssigns1(outputPath+"/model-"+appendString+".assign1");
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}
	public void run() {
		printModel();
		this.loadTxts(inputPath);//加载语料
		this.initModel(); //初始化模型
		this.estimate(); //估计
		this.storeResult(0); //保存结果

	}
	
	
	public static void PseudoDocTM(int P,int K,int iter,int innerStep,int saveStep,double alpha1,double alpha2,double beta,int threadNum,String path){
		File trainFile = new File(path);
		String parent_path = trainFile.getParentFile().getAbsolutePath();
		(new File(parent_path+"/PTM_with_case_"+P+"_"+K+"_"+iter+"_"+alpha1+"_"+alpha2+"_"+beta+"/")).mkdirs();
		try {
			Thread.sleep(1000);
		} catch (InterruptedException e) {
			e.printStackTrace();
		}
		(new PseudoDocTM(P,K,iter,innerStep,saveStep,alpha1,alpha2,beta,path,parent_path+"/PTM_with_case_"+P+"_"+K+"_"+iter+"_"+alpha1+"_"+alpha2+"_"+beta)).run();

	}
}

DataLearner 官方微信

欢迎关注 DataLearner 官方微信,获得最新 AI 技术推送

DataLearner 官方微信二维码