ComputeTopicWordDist

标签:#test# 时间:2018/08/16 16:59:38 作者:小木

import com.google.common.collect.Lists;
import org.apache.commons.io.FileUtils;

import java.io.File;
import java.io.IOException;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.List;

/********
 * 计算群内相似度
 * **********/
public class ClusterSimilarity2 {

  private static double alpha = 0.01;
  private static int topicNumber = 15;     // 主题数量

  private static String docClusterFile = "d:/document_cluster_assignment";
  private static String docWordTopicFile = "d:/doc_word_topic_assignment";

  private static int docNumber;     // 文档总数
  private static int maxClusterNumber = 0;    // 类别总数
  private static int[] docCluster;      // 每个文档所属的类
  private static List<List<Integer>> clusterDocs;       // 每个类下包含的文档
  private static int[][] wordCountByDocAndTopic;  // 每个文档下每个主题下单词的数量
  private static int[] totalWordNumberByDoc;      // 每个文档下包含的单词的数量
  private static double[][] topicDistByDoc;       // 每个文档下每个主题的概率


  public static void main(String[] args) throws IOException {

    if (topicNumber == 0) {
      System.err.println("主题数量没有正确设置!");
      System.exit(1);
    }

    readClusterIndex();
    readWordCountByDocAndTopic();
    computeTopicDistByDoc();

    System.out.println("topic number:" + topicNumber);
    System.out.println("cluster number:" + maxClusterNumber);
    System.out.println("doc number:" + docNumber);

    double[] result = new double[maxClusterNumber];

    // 开始计算
    for (int c = 0; c < maxClusterNumber; c++) {

      int docSizeOfCluster = clusterDocs.get(c).size();

      if (docSizeOfCluster == 1) {
        result[c] = 1;
      } else {
        // 计算类别c下文档数的组合
        double term1 = (double) 1 / combination(docSizeOfCluster);

        double term2 = 0;
        for (int i = 0; i < docSizeOfCluster - 1; i++) {

          int doc1 = clusterDocs.get(c).get(i);
          double[] theta1 = topicDistByDoc[doc1];

          for (int j = i + 1; j < docSizeOfCluster; j++) {

            int doc2 = clusterDocs.get(c).get(j);
            double[] theta2 = topicDistByDoc[doc2];
            term2 += dot(theta1, theta2) / (Math.pow(mod(theta1), 0.5) * Math.pow(mod(theta2), 0.5));

          }

        }

        result[c] = term1 * term2;

        if (result[c] < 0) {
          System.err.println("very strange!" + result[c] + "\t" + term1 + "\t" + docSizeOfCluster);
        }

      }

    }

    System.out.println(Arrays.toString(result));


  }


  // 读取每个文档所属的类别
  private static void readClusterIndex() throws IOException {

    List<String> docClusterList = FileUtils.readLines(new File(docClusterFile), "utf-8");
    docNumber = docClusterList.size();
    docCluster = new int[docNumber];
    totalWordNumberByDoc = new int[docNumber];

    int docID = 0;

    // 计算每个文档所属的类,同时找出最大的类索引
    for (String line : docClusterList) {

      int clusterID = Integer.valueOf(line);
      docCluster[docID] = clusterID;
      if (clusterID > maxClusterNumber) {
        maxClusterNumber = clusterID;
      }

      docID++;
    }

    maxClusterNumber++;   // 类索引+1即为类别数量
    clusterDocs = Lists.newArrayList();
    // 计算每个类下文档的集合
    for (int c = 0; c < maxClusterNumber; c++) {
      clusterDocs.add(Lists.newArrayList());
    }
    docID = 0;
    for (String line : docClusterList) {

      int clusterID = Integer.valueOf(line);
      clusterDocs.get(clusterID).add(docID);

      docID++;
    }

  }

  // 读取每个文档下每个主题单词的数量
  private static void readWordCountByDocAndTopic() throws IOException {

    List<String> docWordTopicAssignList = FileUtils.readLines(new File(docWordTopicFile), "utf-8");
    wordCountByDocAndTopic = new int[docNumber][topicNumber];

    int docID = 0;


    for (String line : docWordTopicAssignList) {

      for (String topic : line.split(" ")) {
        wordCountByDocAndTopic[docID][Integer.valueOf(topic)]++;
        totalWordNumberByDoc[docID]++;
      }

      docID++;
    }

  }

  // 计算每个文档下的主题分布
  private static void computeTopicDistByDoc() {

    topicDistByDoc = new double[docNumber][topicNumber];
    for (int docID = 0; docID < docNumber; docID++) {

      for (int k = 0; k < topicNumber; k++) {

        topicDistByDoc[docID][k] = (double) (wordCountByDocAndTopic[docID][k] + alpha) / (totalWordNumberByDoc[docID] + topicNumber * alpha);

      }

    }

  }

  // 计算一个向量的模
  private static double mod(double[] a) {
    double res = 0;
    for (double v : a) {

      res += v * v;

    }

    return Math.pow(res, 0.5);
  }

  /***********
   * 计算两个向量的内积
   * ****************/
  private static double dot(double[] a, double[] b) {
    double res = 0;
    for (int i = 0; i < a.length; i++) {
      res += a[i] * b[i];
    }
    return res;
  }

  public static long combination(int n) {
    return n * (n - 1) / 2;
  }
}
欢迎大家关注DataLearner官方微信,接受最新的AI技术推送
相关博客