Công cụ nhận dạng chữ số Kaggle: Thử rừng ngẫu nhiên


Dương Ðắc Lực
6 năm trước
Hữu ích 3 Chia sẻ Viết bình luận 0
Đã xem 6776

Trước đây tôi đã viết về  phương pháp K- mean  mà  Jen  và tôi đã thực hiện khi cố gắng giải quyết  Công cụ nhận dạng số của Kaggle  và bị đình trệ với độ chính xác khoảng 80%, chúng tôi quyết định thử một trong các thuật toán được đề xuất trong phần  hướng dẫn  -  khu rừng ngẫu nhiên !

Ban đầu, chúng tôi đã sử dụng  thư viện rừng ngẫu nhiên clojure  nhưng đấu tranh để xây dựng rừng ngẫu nhiên từ dữ liệu tập huấn trong một khoảng thời gian hợp lý nên chúng tôi chuyển sang  phiên bản của Mahout  dựa trên  giấy rừng ngẫu nhiên Leo Breiman .

Có  một ví dụ thực sự tốt giải thích cách hoạt động của các nhóm trên blog Thực tế  mà chúng tôi thấy khá hữu ích trong việc giúp chúng tôi hiểu cách các khu rừng ngẫu nhiên được cho là hoạt động.

Một trong những kỹ thuật Machine Learning mạnh mẽ nhất mà chúng tôi hướng đến là tập hợp. Các phương pháp của Makeemble xây dựng các mô hình mạnh đáng ngạc nhiên từ một tập hợp các mô hình yếu được gọi là người học cơ sở và thường yêu cầu điều chỉnh ít hơn nhiều khi so sánh với các mô hình như Support Vector Machines.

Hầu hết các phương pháp tập hợp sử dụng cây quyết định làm người học cơ sở và nhiều kỹ thuật tạo thành, như Rừng ngẫu nhiên và Adaboost, là đặc trưng cho các nhóm cây.

Chúng tôi đã có thể điều chỉnh  BreimanExample  trong phần ví dụ của kho Mahout để làm những gì chúng tôi muốn.

Để bắt đầu, chúng tôi đã viết đoạn mã sau để xây dựng khu rừng ngẫu nhiên:

public class MahoutKaggleDigitRecognizer {
  public static void main(String[] args) throws Exception {
    String descriptor = "L N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N ";
    String[] trainDataValues = fileAsStringArray("data/train.csv");
 
    Data data = DataLoader.loadData(DataLoader.generateDataset(descriptor, false, trainDataValues), trainDataValues);
 
    int numberOfTrees = 100;
    DecisionForest forest = buildForest(numberOfTrees, data);
  }
 
  private static DecisionForest buildForest(int numberOfTrees, Data data) {
    int m = (int) Math.floor(Maths.log(2, data.getDataset().nbAttributes()) + 1);
 
    DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder();
    treeBuilder.setM(m);
 
    return new SequentialBuilder(RandomUtils.getRandom(), treeBuilder, data.clone()).build(numberOfTrees);
  }
 
  private static String[] fileAsStringArray(String file) throws Exception {
    ArrayList<String> list = new ArrayList<String>();
 
    DataInputStream in = new DataInputStream(new FileInputStream(file));
    BufferedReader br = new BufferedReader(new InputStreamReader(in));
 
    String strLine;
    br.readLine(); // discard top one (header)
    while ((strLine = br.readLine()) != null) {
      list.add(strLine);
    }
 
    in.close();
    return list.toArray(new String[list.size()]);
  }
}

Tệp dữ liệu đào tạo trông hơi giống như thế này:

label,pixel0,pixel1,pixel2,pixel3,pixel4,pixel5,pixel6,pixel7,pixel8...,pixel783
1,0,0,0,0,0,0,...,0
0,0,0,0,0,0,0,...,0

Vì vậy, trong trường hợp này, nhãn nằm trong cột đầu tiên được biểu thị là  L  trong bộ mô tả và 784 cột tiếp theo là giá trị số của các pixel trong ảnh (do đó 784  N 'trong bộ mô tả).

Chúng tôi đang bảo nó tạo một khu rừng ngẫu nhiên chứa 100 cây và vì chúng tôi có số lượng danh mục hữu hạn mà một mục nhập có thể được phân loại là chúng tôi chuyển  sai  thành đối số thứ hai (hồi quy) của DataLoader.generateDataSet .

Các  m  giá trị xác định có bao nhiêu thuộc tính (giá trị pixel trong trường hợp này) được sử dụng để xây dựng mỗi cây và được cho là  log 2 (number_of_attributes) + 1  là giá trị tối ưu cho điều đó!

Sau đó chúng tôi đã viết đoạn mã sau để dự đoán nhãn của tập dữ liệu thử nghiệm:

public class MahoutKaggleDigitRecognizer {
  public static void main(String[] args) throws Exception {
    ...
    String[] testDataValues = testFileAsStringArray("data/test.csv");
    Data test = DataLoader.loadData(data.getDataset(), testDataValues);
    Random rng = RandomUtils.getRandom();
 
    for (int i = 0; i < test.size(); i++) {
    Instance oneSample = test.get(i);
 
    double classify = forest.classify(test.getDataset(), rng, oneSample);
    int label = data.getDataset().valueOf(0, String.valueOf((int) classify));
 
    System.out.println("Label: " + label);
  }
 
  private static String[] testFileAsStringArray(String file) throws Exception {
    ArrayList<String> list = new ArrayList<String>();
 
    DataInputStream in = new DataInputStream(new FileInputStream(file));
    BufferedReader br = new BufferedReader(new InputStreamReader(in));
 
    String strLine;
    br.readLine(); // discard top one (header)
    while ((strLine = br.readLine()) != null) {
      list.add("-," + strLine);
    }
 
    in.close();
    return list.toArray(new String[list.size()]);
  }
}

Có một vài điều mà chúng tôi thấy khó hiểu khi tìm ra cách thực hiện điều này:

  1. Định dạng của dữ liệu kiểm tra cần phải giống hệt với dữ liệu huấn luyện bao gồm nhãn theo sau là 784 giá trị số. Rõ ràng với dữ liệu thử nghiệm, chúng tôi không có nhãn nên Mahout chấp nhận chúng tôi vượt qua '-' nơi nhãn sẽ đi nếu không nó sẽ đưa ra một ngoại lệ, điều này giải thích '-' trên   dòng list.add .
  2. Ban đầu, chúng tôi nghĩ rằng giá trị được trả về bởi  Forest. Classify  là dự đoán nhưng thực tế nó là một chỉ số mà sau đó chúng tôi cần tìm kiếm trên tập dữ liệu.

Khi chúng tôi chạy thuật toán này với bộ dữ liệu thử nghiệm với 10 cây, chúng tôi có độ chính xác là 83,8%, với 50 cây chúng tôi có 84,4%, với 100 cây chúng tôi có 96,28% và với 200 cây, chúng tôi có 96,33% hiện đang đạt đỉnh.

Lượng thời gian để xây dựng rừng khi chúng ta tăng số lượng cây cũng bắt đầu trở thành một vấn đề, vì vậy bước tiếp theo của chúng ta là xem xét cách song song với việc tạo rừng hoặc thực hiện một số cách  khai thác tính năng  để cố gắng và cải thiện độ chính xác.

Các  code đang trên github  nếu bạn quan tâm đến chơi với nó hoặc có bất cứ đề xuất về cách cải thiện nó.



Hữu ích 3 Chia sẻ Viết bình luận 0
Đã xem 6776