Árbol de decisión en R - Árbol de clasificación & Código en R con ejemplo

Tabla de contenido:

Anonim

¿Qué son los árboles de decisión?

Los árboles de decisión son un algoritmo versátil de aprendizaje automático que puede realizar tareas de clasificación y regresión. Son algoritmos muy poderosos, capaces de ajustar conjuntos de datos complejos. Además, los árboles de decisión son componentes fundamentales de los bosques aleatorios, que se encuentran entre los algoritmos de aprendizaje automático más potentes disponibles en la actualidad.

Entrenamiento y visualización de árboles de decisión

Para construir su primer árbol de decisión en el ejemplo de R, procederemos de la siguiente manera en este tutorial de árbol de decisión:

  • Paso 1: importar los datos
  • Paso 2: limpia el conjunto de datos
  • Paso 3: crear un tren / conjunto de prueba
  • Paso 4: construye el modelo
  • Paso 5: haz una predicción
  • Paso 6: medir el rendimiento
  • Paso 7: ajuste los hiperparámetros

Paso 1) Importa los datos

Si tienes curiosidad sobre el destino del titanic, puedes ver este video en Youtube. El propósito de este conjunto de datos es predecir qué personas tienen más probabilidades de sobrevivir después de la colisión con el iceberg. El conjunto de datos contiene 13 variables y 1309 observaciones. El conjunto de datos está ordenado por la variable X.

set.seed(678)path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv'titanic <-read.csv(path)head(titanic)

Producción:

## X pclass survived name sex## 1 1 1 1 Allen, Miss. Elisabeth Walton female## 2 2 1 1 Allison, Master. Hudson Trevor male## 3 3 1 0 Allison, Miss. Helen Loraine female## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female## 6 6 1 1 Anderson, Mr. Harry male## age sibsp parch ticket fare cabin embarked## 1 29.0000 0 0 24160 211.3375 B5 S## 2 0.9167 1 2 113781 151.5500 C22 C26 S## 3 2.0000 1 2 113781 151.5500 C22 C26 S## 4 30.0000 1 2 113781 151.5500 C22 C26 S## 5 25.0000 1 2 113781 151.5500 C22 C26 S## 6 48.0000 0 0 19952 26.5500 E12 S## home.dest## 1 St Louis, MO## 2 Montreal, PQ / Chesterville, ON## 3 Montreal, PQ / Chesterville, ON## 4 Montreal, PQ / Chesterville, ON## 5 Montreal, PQ / Chesterville, ON## 6 New York, NY
tail(titanic)

Producción:

## X pclass survived name sex age sibsp## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0## parch ticket fare cabin embarked home.dest## 1304 0 2627 14.4583 C## 1305 0 2665 14.4542 C## 1306 0 2665 14.4542 C## 1307 0 2656 7.2250 C## 1308 0 2670 7.2250 C## 1309 0 315082 7.8750 S

Desde la salida de cabeza y cola, puede notar que los datos no se mezclan. ¡Este es un gran problema! Cuando divida sus datos entre un conjunto de trenes y un conjunto de prueba, seleccionará solo al pasajero de la clase 1 y 2 (ningún pasajero de la clase 3 se encuentra en el 80 por ciento superior de las observaciones), lo que significa que el algoritmo nunca verá el características del pasajero de la clase 3. Este error dará lugar a una mala predicción.

Para solucionar este problema, puede utilizar la función sample ().

shuffle_index <- sample(1:nrow(titanic))head(shuffle_index)

Árbol de decisión Código R Explicación

  • sample (1: nrow (titanic)): genera una lista aleatoria de índices de 1 a 1309 (es decir, el número máximo de filas).

Producción:

## [1] 288 874 1078 633 887 992 

Utilizará este índice para mezclar el conjunto de datos titánico.

titanic <- titanic[shuffle_index, ]head(titanic)

Producción:

## X pclass survived## 288 288 1 0## 874 874 3 0## 1078 1078 3 1## 633 633 3 0## 887 887 3 1## 992 992 3 1## name sex age## 288 Sutton, Mr. Frederick male 61## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42## 1078 O'Driscoll, Miss. Bridget female NA## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39## 887 Jermyn, Miss. Annie female NA## 992 Mamee, Mr. Hanna male NA## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ## 874 0 0 348121 7.6500 F G63 S## 1078 0 0 14311 7.7500 Q## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN## 887 0 0 14313 7.7500 Q## 992 0 0 2677 7.2292 C

Paso 2) Limpiar el conjunto de datos

La estructura de los datos muestra que algunas variables tienen NA. La limpieza de datos se realizará de la siguiente manera

  • Suelta las variables home.dest, cabin, name, X y ticket
  • Cree variables de factor para pclass y sobrevivió
  • Suelta la NA
library(dplyr)# Drop variablesclean_titanic <- titanic % > %select(-c(home.dest, cabin, name, X, ticket)) % > %#Convert to factor levelmutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')),survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > %na.omit()glimpse(clean_titanic)

Explicación del código

  • select (-c (home.dest, cabin, name, X, ticket)): Elimina las variables innecesarias
  • pclass = factor (pclass, niveles = c (1,2,3), labels = c ('Superior', 'Medio', 'Inferior')): Agrega una etiqueta a la variable pclass. 1 se convierte en Superior, 2 en Medio y 3 en Inferior
  • factor (sobrevivió, niveles = c (0,1), etiquetas = c ('No', 'Sí')): agregue una etiqueta a la variable sobrevivió. 1 se convierte en No y 2 se convierte en Sí
  • na.omit (): Elimina las observaciones de NA

Producción:

## Observations: 1,045## Variables: 8## $ pclass  Upper, Lower, Lower, Upper, Middle, Upper, Middle, U… ## $ survived  No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y… ## $ sex  male, male, female, female, male, male, female, male… ## $ age  61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0,… ## $ sibsp  0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,… ## $ parch  0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,… ## $ fare  32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542,… ## $ embarked  S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C… 

Paso 3) Crear tren / conjunto de prueba

Antes de entrenar su modelo, debe realizar dos pasos:

  • Cree un tren y un conjunto de prueba: entrena el modelo en el conjunto de trenes y prueba la predicción en el conjunto de prueba (es decir, datos no vistos)
  • Instale rpart.plot desde la consola

La práctica común es dividir los datos en 80/20, el 80 por ciento de los datos sirve para entrenar el modelo y el 20 por ciento para hacer predicciones. Necesita crear dos marcos de datos separados. No desea tocar el conjunto de prueba hasta que termine de construir su modelo. Puede crear un nombre de función create_train_test () que tome tres argumentos.

create_train_test(df, size = 0.8, train = TRUE)arguments:-df: Dataset used to train the model.-size: Size of the split. By default, 0.8. Numerical value-train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample < - 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}

Explicación del código

  • función (datos, tamaño = 0.8, tren = VERDADERO): agregue los argumentos en la función
  • n_row = nrow (datos): cuenta el número de filas en el conjunto de datos
  • total_row = size * n_row: Devuelve la enésima fila para construir el conjunto de trenes
  • train_sample <- 1: total_row: Seleccione la primera fila hasta la n-ésima fila
  • if (train == TRUE) {} else {}: Si la condición se establece en verdadera, devuelve el conjunto de trenes, de lo contrario, el conjunto de pruebas.

Puede probar su función y verificar la dimensión.

data_train <- create_train_test(clean_titanic, 0.8, train = TRUE)data_test <- create_train_test(clean_titanic, 0.8, train = FALSE)dim(data_train)

Producción:

## [1] 836 8
dim(data_test)

Producción:

## [1] 209 8 

El conjunto de datos del tren tiene 1046 filas, mientras que el conjunto de datos de prueba tiene 262 filas.

Utiliza la función prop.table () combinada con table () para verificar si el proceso de aleatorización es correcto.

prop.table(table(data_train$survived))

Producción:

#### No Yes## 0.5944976 0.4055024
prop.table(table(data_test$survived))

Producción:

#### No Yes## 0.5789474 0.4210526

En ambos conjuntos de datos, la cantidad de supervivientes es la misma, alrededor del 40 por ciento.

Instalar rpart.plot

rpart.plot no está disponible en las bibliotecas de conda. Puedes instalarlo desde la consola:

install.packages("rpart.plot") 

Paso 4) Construye el modelo

Estás listo para construir el modelo. La sintaxis de la función de árbol de decisión de Rpart es:

rpart(formula, data=, method='')arguments:- formula: The function to predict- data: Specifies the data frame- method:- "class" for a classification tree- "anova" for a regression tree

Utiliza el método de clase porque predice una clase.

library(rpart)library(rpart.plot)fit <- rpart(survived~., data = data_train, method = 'class')rpart.plot(fit, extra = 106

Explicación del código

  • rpart (): Función para ajustar el modelo. Los argumentos son:
    • sobrevivido ~ .: Fórmula de los árboles de decisión
    • data = data_train: conjunto de datos
    • método = 'clase': ajusta un modelo binario
  • rpart.plot (fit, extra = 106): traza el árbol. Las funciones adicionales se establecen en 101 para mostrar la probabilidad de la segunda clase (útil para respuestas binarias). Puede consultar la viñeta para obtener más información sobre las otras opciones.

Producción:

Comienza en el nodo raíz (profundidad 0 sobre 3, la parte superior del gráfico):

  1. En la parte superior, es la probabilidad general de supervivencia. Muestra la proporción de pasajeros que sobrevivieron al accidente. El 41 por ciento de los pasajeros sobrevivió.
  2. Este nodo pregunta si el género del pasajero es masculino. Si es así, entonces baja al nodo secundario izquierdo de la raíz (profundidad 2). El 63 por ciento son hombres con una probabilidad de supervivencia del 21 por ciento.
  3. En el segundo nodo, se pregunta si el pasajero masculino tiene más de 3,5 años. En caso afirmativo, la probabilidad de supervivencia es del 19 por ciento.
  4. Continúa así para comprender qué características afectan la probabilidad de supervivencia.

Tenga en cuenta que una de las muchas cualidades de los árboles de decisión es que requieren muy poca preparación de datos. En particular, no requieren escalado o centrado de características.

De forma predeterminada, la función rpart () utiliza la medida de impureza de Gini para dividir la nota. Cuanto mayor sea el coeficiente de Gini, más instancias diferentes dentro del nodo.

Paso 5) Haz una predicción

Puede predecir su conjunto de datos de prueba. Para hacer una predicción, puede utilizar la función predecir (). La sintaxis básica de predecir para R árbol de decisión es:

predict(fitted_model, df, type = 'class')arguments:- fitted_model: This is the object stored after model estimation.- df: Data frame used to make the prediction- type: Type of prediction- 'class': for classification- 'prob': to compute the probability of each class- 'vector': Predict the mean response at the node level

Desea predecir qué pasajeros tienen más probabilidades de sobrevivir después de la colisión desde el equipo de prueba. Es decir, sabrá entre esos 209 pasajeros, cuál sobrevivirá o no.

predict_unseen <-predict(fit, data_test, type = 'class')

Explicación del código

  • predecir (ajuste, prueba_datos, tipo = 'clase'): predice la clase (0/1) del conjunto de prueba

Probando al pasajero que no lo logró y a los que sí.

table_mat <- table(data_test$survived, predict_unseen)table_mat

Explicación del código

  • table (data_test $ survived, predict_unseen): cree una tabla para contar cuántos pasajeros se clasifican como sobrevivientes y fallecieron en comparación con la clasificación correcta del árbol de decisiones en R

Producción:

## predict_unseen## No Yes## No 106 15## Yes 30 58

El modelo predijo correctamente 106 pasajeros muertos, pero clasificó a 15 supervivientes como muertos. Por analogía, el modelo clasificó erróneamente a 30 pasajeros como supervivientes cuando resultaron estar muertos.

Paso 6) Medir el rendimiento

Puede calcular una medida de precisión para la tarea de clasificación con la matriz de confusión :

La matriz de confusión es una mejor opción para evaluar el desempeño de la clasificación. La idea general es contar el número de veces que las instancias verdaderas se clasifican como falsas.

Cada fila en una matriz de confusión representa un objetivo real, mientras que cada columna representa un objetivo previsto. La primera fila de esta matriz considera pasajeros muertos (la clase Falso): 106 se clasificaron correctamente como muertos ( Verdadero negativo ), mientras que el restante se clasificó erróneamente como superviviente ( Falso positivo ). La segunda fila considera a los sobrevivientes, la clase positiva fue 58 ( Verdadero positivo ), mientras que la Verdadera negativa fue 30.

Puede calcular la prueba de precisión a partir de la matriz de confusión:

Es la proporción de verdadero positivo y verdadero negativo sobre la suma de la matriz. Con R, puede codificar de la siguiente manera:

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)

Explicación del código

  • sum (diag (table_mat)): Suma de la diagonal
  • sum (table_mat): Suma de la matriz.

Puede imprimir la precisión del conjunto de prueba:

print(paste('Accuracy for test', accuracy_Test))

Producción:

## [1] "Accuracy for test 0.784688995215311" 

Tiene una puntuación del 78 por ciento para el conjunto de pruebas. Puede replicar el mismo ejercicio con el conjunto de datos de entrenamiento.

Paso 7) Ajustar los hiperparámetros

El árbol de decisión en R tiene varios parámetros que controlan aspectos del ajuste. En la biblioteca del árbol de decisiones de rpart, puede controlar los parámetros utilizando la función rpart.control (). En el siguiente código, introduce los parámetros que ajustará. Puede consultar la viñeta para conocer otros parámetros.

rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30)Arguments:-minsplit: Set the minimum number of observations in the node before the algorithm perform a split-minbucket: Set the minimum number of observations in the final note i.e. the leaf-maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0

Procederemos de la siguiente manera:

  • Construya la función para devolver la precisión
  • Sintoniza la profundidad máxima
  • Ajuste la cantidad mínima de muestra que debe tener un nodo antes de que pueda dividirse
  • Ajustar la cantidad mínima de muestra que debe tener un nodo hoja

Puede escribir una función para mostrar la precisión. Simplemente envuelve el código que usaste antes:

  1. predecir: predecir_no visto <- predecir (ajuste, prueba_datos, tipo = 'clase')
  2. Producir tabla: table_mat <- table (data_test $ sobrevivido, predict_unseen)
  3. Calcular la precisión: precision_Test <- sum (diag (table_mat)) / sum (table_mat)
accuracy_tune <- function(fit) {predict_unseen <- predict(fit, data_test, type = 'class')table_mat <- table(data_test$survived, predict_unseen)accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_Test}

Puede intentar ajustar los parámetros y ver si puede mejorar el modelo sobre el valor predeterminado. Como recordatorio, debe obtener una precisión superior a 0,78

control <- rpart.control(minsplit = 4,minbucket = round(5 / 3),maxdepth = 3,cp = 0)tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control)accuracy_tune(tune_fit)

Producción:

## [1] 0.7990431 

Con el siguiente parámetro:

minsplit = 4minbucket= round(5/3)maxdepth = 3cp=0 

Obtienes un rendimiento superior al del modelo anterior. Enhorabuena

Resumen

Podemos resumir las funciones para entrenar un algoritmo de árbol de decisión en R

Biblioteca

Objetivo

función

clase

parámetros

detalles

rpart

Árbol de clasificación de trenes en R

rpart ()

clase

fórmula, df, método

rpart

Árbol de regresión de tren

rpart ()

anova

fórmula, df, método

rpart

Trazar los árboles

rpart.plot ()

modelo ajustado

base

predecir

predecir()

clase

modelo ajustado, tipo

base

predecir

predecir()

problema

modelo ajustado, tipo

base

predecir

predecir()

vector

modelo ajustado, tipo

rpart

Parámetros de control

rpart.control ()

minplit

Establecer el número mínimo de observaciones en el nodo antes de que el algoritmo realice una división

minbucket

Establezca el número mínimo de observaciones en la nota final, es decir, la hoja

máxima profundidad

Establezca la profundidad máxima de cualquier nodo del árbol final. El nodo raíz se trata a una profundidad 0

rpart

Modelo de tren con parámetro de control

rpart ()

fórmula, df, método, control

Nota: Entrene el modelo con datos de entrenamiento y pruebe el rendimiento en un conjunto de datos invisible, es decir, un conjunto de prueba.