基于MapReduce实现决策树算法

本文实例为大家分享了MapReduce实现决策树算法的具体代码,供大家参考,具体内容如下

首先,基于C45决策树算法实现对应的Mapper算子,相关的代码如下:

public class MapClass extends MapReduceBase implements Mapper {
 
  private final static IntWritable one = new IntWritable(1);
  private Text attValue = new Text();
  private int i;
  private String token;
  public static int no_Attr;
  public Split split = null;
  
  public int size_split_1 = 0;
  
  public void configure(JobConf conf){
   try {
  split = (Split) ObjectSerializable.unSerialize(conf.get("currentsplit"));
 } catch (ClassNotFoundException e) {
  // TODO Auto-generated catch block
  e.printStackTrace();
 } catch (IOException e) {
  // TODO Auto-generated catch block
  e.printStackTrace();
 }
   size_split_1 = Integer.parseInt(conf.get("current_index"));
  }
  
  public void map(LongWritable key, Text value, OutputCollector output, Reporter reporter)
      throws IOException {
    String line = value.toString(); // changing input instance value to
                    // string
    StringTokenizer itr = new StringTokenizer(line);
    int index = 0;
    String attr_value = null;
    no_Attr = itr.countTokens() - 1;
    String attr[] = new String[no_Attr];
    boolean match = true;
    for (i = 0; i < no_Attr; i++) {
      attr[i] = itr.nextToken(); // Finding the values of different
                    // attributes
    }
 
    String classLabel = itr.nextToken();
    int size_split = split.attr_index.size();
    Counter counter = reporter.getCounter("reporter-"+Main.current_index, size_split+" "+size_split_1);
    counter.increment(1l);
    for (int count = 0; count < size_split; count++) {
      index = (Integer) split.attr_index.get(count);
      attr_value = (String) split.attr_value.get(count);
      if (!attr[index].equals(attr_value)) {
        match = false;
        break;
      }
    }
 
    if (match) {
      for (int l = 0; l < no_Attr; l++) {
        if (!split.attr_index.contains(l)) {
         //表示出某个属性在某个类标签上出现了一次
          token = l + " " + attr[l] + " " + classLabel;
          attValue.set(token);
          output.collect(attValue, one);
        }
        else{
         
        }
      }
      if (size_split == no_Attr) {
        token = no_Attr + " " + "null" + " " + classLabel;
        attValue.set(token);
        output.collect(attValue, one);
      }
    }
  }
 
}

然后,基于C45决策树算法实现对应的Reducer算子,相关的代码如下:

public class Reduce extends MapReduceBase implements Reducer {
 
  static int cnt = 0;
  ArrayList ar = new ArrayList();
  String data = null;
  private static int currentIndex;
 
  public void configure(JobConf conf) {
    currentIndex = Integer.valueOf(conf.get("currentIndex"));
  }
 
  public void reduce(Text key, Iterator values, OutputCollector output,
      Reporter reporter) throws IOException {
    int sum = 0;
    //sum表示按照某个属性进行划分的子数据集上的某个类出现的个数
    while (values.hasNext()) {
      sum += values.next().get();
    }
    //最后将这个属性上的取值写入output中;
    output.collect(key, new IntWritable(sum));
 
    String data = key + " " + sum;
    ar.add(data);
    //将最终结果写入到文件中;
    writeToFile(ar);
    ar.add("\n");
  }
 
  public static void writeToFile(ArrayList text) {
    try {
      cnt++;
      Path input = new Path("C45/intermediate" + currentIndex + ".txt");
      Configuration conf = new Configuration();
      FileSystem fs = FileSystem.get(conf);
      BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(fs.create(input, true)));
 
      for (String str : text) {
        bw.write(str);
      }
      bw.newLine();
      bw.close();
    } catch (Exception e) {
      System.out.println("File is not creating in reduce");
    }
  }
}

最后,编写Main函数,启动MapReduce作业,需要启动多趟,代码如下:

package com.hackecho.hadoop;
 
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;
 
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configured;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.FileOutputFormat;
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.log4j.PropertyConfigurator;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.Node;
import org.dmg.pmml.PMML;
import org.dmg.pmml.TreeModel;
 
//在这里MapReduce的作用就是根据各个属性的特征来划分子数据集
public class Main extends Configured implements Tool {
 
 //当前分裂
  public static Split currentsplit = new Split();
  //已经分裂完成的集合
  public static List splitted = new ArrayList();
  //current_index 表示目前进行分裂的位置
  public static int current_index = 0;
  
  public static ArrayList ar = new ArrayList();
  
  public static List leafSplits = new ArrayList();
  
  public static final String PROJECT_HOME = System.getProperty("user.dir");
 
  public static void main(String[] args) throws Exception {
   //在splitted中已经放入了一个currentsplit了,所以此时的splitted的size大小为1
   PropertyConfigurator.configure(PROJECT_HOME + "/conf/log/log4j.properties");
    splitted.add(currentsplit);
   
    Path c45 = new Path("C45");
    Configuration conf = new Configuration();
    FileSystem fs = FileSystem.get(conf);
    if (fs.exists(c45)) {
      fs.delete(c45, true);
    }
    fs.mkdirs(c45);
    int res = 0;
    int split_index = 0;
    //增益率
    double gainratio = 0;
    //最佳增益
    double best_gainratio = 0;
    //熵值
    double entropy = 0;
    //分类标签
    String classLabel = null;
    //属性个数
    int total_attributes = MapClass.no_Attr;
    total_attributes = 4;
    //分裂的个数
    int split_size = splitted.size();
    //增益率
    GainRatio gainObj;
    //产生分裂的新节点
    Split newnode;
 
    while (split_size > current_index) {
     currentsplit = splitted.get(current_index);
      gainObj = new GainRatio();
      res = ToolRunner.run(new Configuration(), new Main(), args);
      System.out.println("Current NODE INDEX . ::" + current_index);
      int j = 0;
      int temp_size;
      gainObj.getcount();
      //计算当前节点的信息熵
      entropy = gainObj.currNodeEntophy();
      //获取在当前节点的分类
      classLabel = gainObj.majorityLabel();
      currentsplit.classLabel = classLabel;
 
      if (entropy != 0.0 && currentsplit.attr_index.size() != total_attributes) {
        System.out.println("");
        System.out.println("Entropy NOTT zero  SPLIT INDEX::  " + entropy);
        best_gainratio = 0;
        //计算各个属性的信息增益值
        for (j = 0; j < total_attributes; j++) // Finding the gain of
                            // each attribute
        {
          if (!currentsplit.attr_index.contains(j)) {
           //按照每一个属性的序号,也就是索引j来计算各个属性的信息增益
            gainratio = gainObj.gainratio(j, entropy);
            //找出最佳的信息增益
            if (gainratio >= best_gainratio) {
              split_index = j;
              best_gainratio = gainratio;
            }
          }
        }
 
        //split_index表示在第几个属性上完成了分裂,也就是分裂的索引值;
        //attr_values_split表示分裂的属性所取的值的拼接成的字符串;
        String attr_values_split = gainObj.getvalues(split_index);
        StringTokenizer attrs = new StringTokenizer(attr_values_split);
        int number_splits = attrs.countTokens(); // number of splits
                             // possible with
                             // attribute selected
        String red = "";
        System.out.println(" INDEX :: " + split_index);
        System.out.println(" SPLITTING VALUES " + attr_values_split);
 
        //根据分裂形成的属性值的集合将在某个节点上按照属性值将数据集分成若干类
        for (int splitnumber = 1; splitnumber <= number_splits; splitnumber++) {
          temp_size = currentsplit.attr_index.size();
          newnode = new Split();
          for (int y = 0; y < temp_size; y++) {
            newnode.attr_index.add(currentsplit.attr_index.get(y));
            newnode.attr_value.add(currentsplit.attr_value.get(y));
          }
          red = attrs.nextToken();
 
          newnode.attr_index.add(split_index);
          newnode.attr_value.add(red);
          //按照当前的属性值将数据集将若干分类,同时将数据集按照这个属性划分位若干个新的分裂;
          splitted.add(newnode);
        }
      } else if(entropy==0.0 && currentsplit.attr_index.size()!=total_attributes){
       //每次计算到叶子节点的时候,就将其持久化到模型文件中
       /**
        String rule = "";
        temp_size = currentsplit.attr_index.size();
        for (int val = 0; val < temp_size; val++) {
          rule = rule + " " + currentsplit.attr_index.get(val) + " " + currentsplit.attr_value.get(val);
        }
        rule = rule + " " + currentsplit.classLabel;
        ar.add(rule);
        writeRuleToFile(ar);
        ar.add("\n");
        if (entropy != 0.0) {
          System.out.println("Enter rule in file:: " + rule);
        } else {
          System.out.println("Enter rule in file Entropy zero ::  " + rule);
        }
        System.out.println("persistence model@!!!!");
        */
       leafSplits.add(currentsplit);
      }
      else{
       TreeModel tree = PmmlDecisionTree.buildTreeModel(leafSplits);
       PMML pmml = new PMML();
       pmml.addModels(tree);
       PmmlModelFactory.pmmlPersistence("C45/DecisionTree.pmml", pmml);
      }
      split_size = splitted.size();
      System.out.println("TOTAL NODES::  " + split_size);
      current_index++;
    }
    System.out.println("Done!");
    System.exit(res);
  }
 
  public static void writeRuleToFile(ArrayList text) throws IOException {
   Path rule = new Path("C45/rule.txt");
    Configuration conf = new Configuration();
    FileSystem fs = FileSystem.get(conf);
    try {
      BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(fs.create(rule, true)));
      for (String str : text) {
        bw.write(str);
      }
      bw.newLine();
      bw.close();
    } catch (Exception e) {
      e.printStackTrace();
    }
  }
 
  public int run(String[] args) throws Exception {
    System.out.println("In main ---- run");
    JobConf conf = new JobConf(getConf(), Main.class);
    conf.setJobName("C45");
    conf.set("currentsplit",ObjectSerializable.serialize(currentsplit));
    conf.set("current_index",String.valueOf(currentsplit.attr_index.size()));
    conf.set("currentIndex", String.valueOf(current_index));
 
    // the keys are words (strings)
    conf.setOutputKeyClass(Text.class);
    // the values are counts (ints)
    conf.setOutputValueClass(IntWritable.class);
 
    conf.setMapperClass(MapClass.class);
    conf.setReducerClass(Reduce.class);
    System.out.println("back to run");
 
    FileSystem fs = FileSystem.get(conf);
 
    Path out = new Path(args[1] + current_index);
    if (fs.exists(out)) {
      fs.delete(out, true);
    }
    FileInputFormat.setInputPaths(conf, args[0]);
    FileOutputFormat.setOutputPath(conf, out);
 
    JobClient.runJob(conf);
    return 0;
  }
}

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持呐喊教程。

声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:notice#nhooo.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。