900字范文,内容丰富有趣,生活中的好帮手!
900字范文 > 机器学习实战-决策树 java版代码开发实现

机器学习实战-决策树 java版代码开发实现

时间:2023-01-02 18:51:18

相关推荐

机器学习实战-决策树 java版代码开发实现

话不多说,直接上代码,若有帮助,帮忙点赞哦

python版,或其他机器学习算法,可发邮箱:476562571@

主要实现功能:

特征 二值判别

递归遍历文件目录加载训练数据集

召回率计算

决策树构建

决策树存储(存储json文件)需要依赖 com.alibab fastjson-1.2.7.jar

决策树读取(读取json文件)需要依赖 com.alibab fastjson-1.2.7.jar

package com.code.ku.qa.metion.classifier;import com.alibaba.fastjson.JSONObject;import com.code.ku.qa.metion.Metion;import mons.io.FileUtils;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import java.io.File;import java.io.IOException;import java.io.Serializable;import java.util.*;/*** @Date: /11/15* @Time: 19:08* @User: Likf* @Description:*/public class DecisionTreeID3 {/** Logger */private static final Logger _LOG = LoggerFactory.getLogger(DecisionTreeID3.class);public static TreeNode tree = null;static{tree = loadTreeFromJsonFile(Metion.Config.getPath("classify\\id3\\tree.json"));}public DecisionTreeID3() {}public static String classify(List<String> labels,List<String> testData){return classify(tree,labels,testData);}/*** 计算香农熵* @param dataset*/public double calChannonEnt(List<List<String>> dataset){Map<String,Double> outLabels = new HashMap<>();for (List<String> fetures:dataset){String outLabel = fetures.get(fetures.size()-1);if(!outLabels.keySet().contains(outLabel)){outLabels.put(outLabel,0.0);}outLabels.put(outLabel,outLabels.get(outLabel)+1);}double channonEnt = 0.0;for(Map.Entry<String,Double> entry:outLabels.entrySet()){double pi = entry.getValue()/dataset.size();channonEnt -= pi*(Math.log(pi)/Math.log(2.0));}return channonEnt;}/*** 划分数据集* @param dataset* @param fetureIndex* @param value* @return*/private List<List<String>> splitDataSet(List<List<String>> dataset,int fetureIndex,String value){List<List<String>> subDataSet = new ArrayList<>();for(List<String> fetures:dataset){try {if(fetures.get(fetureIndex).equals(value)){List<String> reduceFetures = new LinkedList<>();reduceFetures.addAll(fetures.subList(0,fetureIndex));reduceFetures.addAll(fetures.subList(fetureIndex+1,fetures.size()));subDataSet.add(reduceFetures);}} catch (Exception e) {_LOG.trace("异常特征:"+fetures);}}return subDataSet;}/*** 选取信息增益最大的特征划分数据集* @param dataSet* @return*/private int chooseBestFetureToSplit(List<List<String>> dataSet){int numFetures = dataSet.get(0).size()-1;double baseEntropy = calChannonEnt(dataSet);double bestInfoGain = 0.0;int bestFeture = -1;for (int i = 0; i < numFetures; i++) {double infoGain = 0.0;double newEntropy = 0.0;Set<String> featureVals = getFetureVals(dataSet,i);for(String fetureVal:featureVals){List<List<String>> subDataSet = splitDataSet(dataSet,i,fetureVal);double prob = (double)subDataSet.size()/dataSet.size();newEntropy+=prob*calChannonEnt(subDataSet);}infoGain=baseEntropy-newEntropy;if(infoGain>bestInfoGain){bestInfoGain = infoGain;bestFeture = i;}}return bestFeture;}/*** 投票选取分类* @param classifyList*/public String majorityCnt(List<String> classifyList){Map<String,Double> classCount = new HashMap<>();classifyList.forEach(classify->{if(classCount.get(classify)==null){classCount.put(classify,0.0);}else{classCount.put(classify,classCount.get(classify)+1);}});double max = 0.0;String key = null;for(Map.Entry<String,Double> entry:classCount.entrySet()){if(entry.getValue()>max){max = entry.getValue();key = entry.getKey();}}return key;}/*** 创建决策树* @param dataset* @param labels*/public TreeNode createTree(List<List<String>> dataset,List<String> labels){List<String> classifyList = getFetureLists(dataset,dataset.get(0).size()-1);Set<String> classifySet = new HashSet<>(classifyList);if(classifySet.size() == 1){return new TreeNode(classifyList.get(0));}if(dataset.get(0).size() == 1){return new TreeNode(majorityCnt(classifyList));}int bestFeat = chooseBestFetureToSplit(dataset);String bestFeatLabel = null;if(bestFeat == -1){bestFeat =labels.size()-1;}bestFeatLabel = labels.get(bestFeat);TreeNode tree = new TreeNode(bestFeatLabel);//tree.addChild(bestFeatLabel);List<String> subLabels = new ArrayList<>();for(int i=0;i<labels.size();i++){if(i!=bestFeat){subLabels.add(labels.get(i));;}}//labels.remove(bestFeat);Set<String> fetureVals = getFetureVals(dataset,bestFeat);for(String fv:fetureVals){// tree.addLabel(fv);TreeNode node = tree.addChild(createTree(splitDataSet(dataset,bestFeat,fv),subLabels));node.setLabel(fv);node.setValue(fv);}return tree;}/*** 新的特征进行分类判别* @param tree* @param featLabels* @param testData* @return*/public static String classify(TreeNode tree,List<String> featLabels,List<String> testData){Map<String,Integer> featLabelMap = new HashMap<>();for (int i = 0; i <featLabels.size() ; i++) {featLabelMap.put(featLabels.get(i),i);}int featIndex = featLabelMap.get(tree.name);String classifyLabel = null;for(TreeNode node:tree.childs){if(node.value.equals(testData.get(featIndex))){if(node.childs.isEmpty()){classifyLabel = node.name;}else{classifyLabel = classify(node,featLabels,testData);}}}if(classifyLabel == null){//toDO 决策树中未发现的节点}return classifyLabel;}/*** 计算召回率* @param tree* @param labels* @param testDataSet* @return*/public double recallRate(TreeNode tree,List<String> labels,List<List<String>> testDataSet){int flagPos = testDataSet.get(0).size()-1;double count = 0.0;for(List<String> fetures:testDataSet){String realVal = fetures.get(flagPos);String preVal = classify(tree,labels,fetures.subList(0,flagPos));if(realVal.equals(preVal)){count++;}}return count/testDataSet.size();}/*** 获取所有去重后特征对应的值* @param dataSet* @param i* @return*/private Set<String> getFetureVals(List<List<String>> dataSet, int i) {Set<String> vals = new LinkedHashSet<>(getFetureLists(dataSet,i));return vals;}/*** 获取所有特征对应的值* @param dataSet* @param i* @return*/private List<String> getFetureLists(List<List<String>> dataSet, int i) {List<String> vals = new LinkedList<>();for(List<String> fetures:dataSet){try {vals.add(fetures.get(i));} catch (ArrayIndexOutOfBoundsException e) {//e.printStackTrace();_LOG.error("异常特征:"+fetures.toString());}}return vals;}/*** 将决策树存储在json,文件中,需要引入 ali fast-json.jar包* @param treePath* @return*/public void storeDTreeToJson(String treePath,List<String> labels){List<List<String>> trainDataSet = loadDataSet(treePath);DecisionTreeID3.TreeNode tree = createTree(trainDataSet,labels);File treeFile = new File(treePath);try {if(!treeFile.exists()){treeFile.createNewFile();}FileUtils.writeStringToFile(treeFile,JSONObject.toJSONString(tree),"utf-8");} catch (IOException e) {e.printStackTrace();}}/*** 从json文件中加载决策树,文件中,需要引入 ali fast-json.jar包* @param treePath* @return*/public static TreeNode loadTreeFromJsonFile(String treePath){try {File treeFile = new File(treePath);String jsonText = FileUtils.readFileToString(treeFile,"utf-8");return JSONObject.parseObject(jsonText,DecisionTreeID3.TreeNode.class);} catch (IOException e) {e.printStackTrace();_LOG.error("加载失败。。。");return null;}}/*** 测试数据,方便测试* @return*/public List<List<String>> createDataset(){List<List<String>> dataset = new LinkedList<>();dataset.add(Arrays.asList("1","1","yes"));dataset.add(Arrays.asList("1","1","yes"));dataset.add(Arrays.asList("1","0","no"));dataset.add(Arrays.asList("0","1","no"));dataset.add(Arrays.asList("0","1","no"));return dataset;}public List<String> createLabels(){List<String> labels = new ArrayList<>();labels.add("surfacing");labels.add("flippers");return labels;}/*** 打印树结构* @param tree* @param root*/public void print(TreeNode tree,String root){System.out.println(root+" "+tree.value+":"+tree.name);if(tree.childs==null || tree.childs.isEmpty()){return;}for(TreeNode node:tree.childs){print(node,root+root);}}/*** 树结构*/public static class TreeNode implements Serializable {private String name;private String label;/* private List<String> labels = new ArrayList<>();*/private List<TreeNode> childs = new ArrayList<>();private String value;public TreeNode(){}public TreeNode(String name) {this.name = name;}/* public void addLabel(String label){labels.add(label);}*/public TreeNode addChild(TreeNode node){childs.add(node);return node;}public TreeNode addChild(String name){return addChild(new TreeNode(name));}public String getValue() {return value;}public void setValue(String value) {this.value = value;}@Overridepublic String toString() {return name;}public String getName() {return name;}public void setName(String name) {this.name = name;}public List<TreeNode> getChilds() {return childs;}public void setChilds(List<TreeNode> childs) {this.childs = childs;}public String getLabel() {return label;}public void setLabel(String label) {this.label = label;}}/*** 加载数据* @param dirPath* @return*/public List<List<String>> loadDataSet(String dirPath){File dir = new File(dirPath);List<File> dsFiles = new ArrayList<>();travelDir(dir,dsFiles);List<List<String>> dataset = new ArrayList<>();for(File file:dsFiles){try {List<String> lines = FileUtils.readLines(file,"utf-8");for(String line:lines){String[] lineArr = line.split(",");dataset.add(Arrays.asList(lineArr));}} catch (IOException e) {e.printStackTrace();}}return dataset;}/*** 遍历文件夹* @param dir* @param dsFiles* @return*/public List<File> travelDir(File dir,List<File> dsFiles){File[] files = dir.listFiles();for(File file : files){if(file.isDirectory()){travelDir(file,dsFiles);}else{dsFiles.add(file);}}return dsFiles;}public static void main(String[] args) {DecisionTreeID3 id3 = new DecisionTreeID3();List<List<String>> dataset = id3.createDataset();_LOG.trace("香农熵:"+id3.calChannonEnt(dataset));_LOG.trace("切分:"+id3.splitDataSet(dataset,0,"1"));_LOG.trace("选择最好的特征分类:"+id3.chooseBestFetureToSplit(dataset));TreeNode tree = id3.createTree(dataset,id3.createLabels());id3.print(tree,"->");List<String> testData1 = Arrays.asList("1","0");List<String> testData2 = Arrays.asList("1","1");List<String> testData3 = Arrays.asList("0","1");System.out.println(testData1+":"+DecisionTreeID3.classify(tree,id3.createLabels(),testData1));System.out.println(testData2+":"+DecisionTreeID3.classify(tree,id3.createLabels(),testData2));System.out.println(testData3+":"+DecisionTreeID3.classify(tree,id3.createLabels(),testData3));String dir = "D:\\workspace\\zl\\pre\\mist-parent\\mist-kbqa\\datas\\metion\\train\\train-10";System.out.println(id3.loadDataSet(dir));}}

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。