¿Qué es la regresión logística?
La regresión logística se utiliza para predecir una clase, es decir, una probabilidad. La regresión logística puede predecir un resultado binario con precisión.
Imagine que desea predecir si un préstamo se rechaza / acepta en función de muchos atributos. La regresión logística es de la forma 0/1. y = 0 si se rechaza un préstamo, y = 1 si se acepta.
Un modelo de regresión logística se diferencia del modelo de regresión lineal en dos formas.
- En primer lugar, la regresión logística acepta únicamente una entrada dicotómica (binaria) como variable dependiente (es decir, un vector de 0 y 1).
- En segundo lugar, el resultado se mide mediante la siguiente función de enlace probabilístico llamada sigmoidea debido a su forma de S:
La salida de la función siempre está entre 0 y 1. Verifique la imagen a continuación
La función sigmoidea devuelve valores de 0 a 1. Para la tarea de clasificación, necesitamos una salida discreta de 0 o 1.
Para convertir un flujo continuo en un valor discreto, podemos establecer un límite de decisión en 0.5. Todos los valores por encima de este umbral se clasifican como 1
En este tutorial, aprenderá
- ¿Qué es la regresión logística?
- Cómo crear un modelo de revestimiento generalizado (GLM)
- Paso 1) Verifique las variables continuas
- Paso 2) Verifique las variables de los factores
- Paso 3) Ingeniería de funciones
- Paso 4) Resumen estadístico
- Paso 5) Equipo de entrenamiento / prueba
- Paso 6) Construye el modelo
- Paso 7) Evaluar el desempeño del modelo
Cómo crear un modelo de revestimiento generalizado (GLM)
Usemos el conjunto de datos de adultos para ilustrar la regresión logística. El "adulto" es un gran conjunto de datos para la tarea de clasificación. El objetivo es predecir si el ingreso anual en dólares de un individuo superará los 50.000. El conjunto de datos contiene 46,033 observaciones y diez características:
- age: edad del individuo. Numérico
- educación: Nivel educativo del individuo. Factor.
- marital.status: Estado civil del individuo. Factor, es decir, nunca casado, cónyuge-civil-casado, ...
- género: género del individuo. Factor, es decir, masculino o femenino
- ingresos: Variable objetivo. Ingresos por encima o por debajo de 50K. Factor, es decir,> 50 K, <= 50 K
Entre otros
library(dplyr)data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv")glimpse(data_adult)
Producción:
Observations: 48,842Variables: 10$ x1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,… $ age 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26… $ workclass Private, Private, Local-gov, Private, ?, Private,… $ education 11th, HS-grad, Assoc-acdm, Some-college, Some-col… $ educational.num 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,… $ marital.status Never-married, Married-civ-spouse, Married-civ-sp… $ race Black, White, White, Black, White, White, Black,… $ gender Male, Male, Male, Male, Female, Male, Male, Male,… $ hours.per.week 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39… $ income <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5…
Procederemos de la siguiente manera:
- Paso 1: Verifique las variables continuas
- Paso 2: Verifique las variables de los factores
- Paso 3: ingeniería de funciones
- Paso 4: Estadística de resumen
- Paso 5: Entrenar / probar el equipo
- Paso 6: construye el modelo
- Paso 7: evaluar el desempeño del modelo
- paso 8: mejorar el modelo
Su tarea es predecir qué individuo tendrá un ingreso superior a 50K.
En este tutorial, se detallará cada paso para realizar un análisis en un conjunto de datos real.
Paso 1) Verifique las variables continuas
En el primer paso, puede ver la distribución de las variables continuas.
continuous <-select_if(data_adult, is.numeric)summary(continuous)
Explicación del código
- continuo <- select_if (data_adult, is.numeric): Use la función select_if () de la biblioteca dplyr para seleccionar solo las columnas numéricas
- resumen (continuo): imprime la estadística de resumen
Producción:
## X age educational.num hours.per.week## Min. : 1 Min. :17.00 Min. : 1.00 Min. : 1.00## 1st Qu.:11509 1st Qu.:28.00 1st Qu.: 9.00 1st Qu.:40.00## Median :23017 Median :37.00 Median :10.00 Median :40.00## Mean :23017 Mean :38.56 Mean :10.13 Mean :40.95## 3rd Qu.:34525 3rd Qu.:47.00 3rd Qu.:13.00 3rd Qu.:45.00## Max. :46033 Max. :90.00 Max. :16.00 Max. :99.00
En la tabla anterior, puede ver que los datos tienen escalas y horas totalmente diferentes. Por semana tiene valores atípicos grandes (por ejemplo, observe el último cuartil y el valor máximo).
Puede solucionarlo siguiendo dos pasos:
- 1: Trace la distribución de horas por semana
- 2: Estandarizar las variables continuas
- Trazar la distribución
Veamos más de cerca la distribución de horas por semana.
# Histogram with kernel density curvelibrary(ggplot2)ggplot(continuous, aes(x = hours.per.week)) +geom_density(alpha = .2, fill = "#FF6666")
Producción:
La variable tiene muchos valores atípicos y una distribución no bien definida. Puede abordar parcialmente este problema eliminando el 0.01 por ciento superior de las horas por semana.
Sintaxis básica del cuantil:
quantile(variable, percentile)arguments:-variable: Select the variable in the data frame to compute the percentile-percentile: Can be a single value between 0 and 1 or multiple value. If multiple, use this format: `c(A,B,C,… )- `A`,`B`,`C` and `… ` are all integer from 0 to 1.
Calculamos el percentil 2 superior
top_one_percent <- quantile(data_adult$hours.per.week, .99)top_one_percent
Explicación del código
- quantile (data_adult $ hours.per.week, .99): Calcule el valor del 99 por ciento del tiempo de trabajo
Producción:
## 99%## 80
El 98 por ciento de la población trabaja menos de 80 horas semanales.
Puede eliminar las observaciones por encima de este umbral. Utiliza el filtro de la biblioteca dplyr.
data_adult_drop <-data_adult %>%filter(hours.per.weekProducción:
## [1] 45537 10
- Estandarizar las variables continuas
Puede estandarizar cada columna para mejorar el rendimiento porque sus datos no tienen la misma escala. Puede utilizar la función mutate_if de la biblioteca dplyr. La sintaxis básica es:
mutate_if(df, condition, funs(function))arguments:-`df`: Data frame used to compute the function- `condition`: Statement used. Do not use parenthesis- funs(function): Return the function to apply. Do not use parenthesis for the functionPuede estandarizar las columnas numéricas de la siguiente manera:
data_adult_rescale <- data_adult_drop % > %mutate_if(is.numeric, funs(as.numeric(scale(.))))head(data_adult_rescale)Explicación del código
- mutate_if (is.numeric, funs (scale)): la condición es solo una columna numérica y la función es scale
Producción:
## X age workclass education educational.num## 1 -1.732680 -1.02325949 Private 11th -1.22106443## 2 -1.732605 -0.03969284 Private HS-grad -0.43998868## 3 -1.732530 -0.79628257 Local-gov Assoc-acdm 0.73162494## 4 -1.732455 0.41426100 Private Some-college -0.04945081## 5 -1.732379 -0.34232873 Private 10th -1.61160231## 6 -1.732304 1.85178149 Self-emp-not-inc Prof-school 1.90323857## marital.status race gender hours.per.week income## 1 Never-married Black Male -0.03995944 <=50K## 2 Married-civ-spouse White Male 0.86863037 <=50K## 3 Married-civ-spouse White Male -0.03995944 >50K## 4 Married-civ-spouse Black Male -0.03995944 >50K## 5 Never-married White Male -0.94854924 <=50K## 6 Married-civ-spouse White Male -0.76683128 >50KPaso 2) Verifique las variables de los factores
Este paso tiene dos objetivos:
- Verifique el nivel en cada columna categórica
- Definir nuevos niveles
Dividiremos este paso en tres partes:
- Seleccione las columnas categóricas
- Almacene el gráfico de barras de cada columna en una lista
- Imprime las gráficas
Podemos seleccionar las columnas de factores con el siguiente código:
# Select categorical columnfactor <- data.frame(select_if(data_adult_rescale, is.factor))ncol(factor)Explicación del código
- data.frame (select_if (data_adult, is.factor)): Almacenamos las columnas de factor en factor en un tipo de marco de datos. La biblioteca ggplot2 requiere un objeto de marco de datos.
Producción:
## [1] 6El conjunto de datos contiene 6 variables categóricas
El segundo paso es más hábil. Desea trazar un gráfico de barras para cada columna en el factor de marco de datos. Es más conveniente automatizar el proceso, especialmente en situaciones en las que hay muchas columnas.
library(ggplot2)# Create graph for each columngraph <- lapply(names(factor),function(x)ggplot(factor, aes(get(x))) +geom_bar() +theme(axis.text.x = element_text(angle = 90)))Explicación del código
- lapply (): use la función lapply () para pasar una función en todas las columnas del conjunto de datos. Almacena la salida en una lista
- function (x): La función se procesará para cada x. Aquí x son las columnas
- ggplot (factor, aes (get (x))) + geom_bar () + theme (axis.text.x = element_text (angle = 90)): crea un gráfico de caracteres de barras para cada elemento x. Tenga en cuenta que para devolver x como una columna, debe incluirlo dentro de get ()
El último paso es relativamente sencillo. Quieres imprimir los 6 gráficos.
# Print the graphgraphProducción:
## [[1]]## ## [[2]]## ## [[3]]## ## [[4]]## ## [[5]]## ## [[6]]Nota: Utilice el botón siguiente para navegar al siguiente gráfico.
Paso 3) Ingeniería de funciones
Educación refundida
En el gráfico anterior, puede ver que la variable educación tiene 16 niveles. Esto es sustancial y algunos niveles tienen un número relativamente bajo de observaciones. Si desea mejorar la cantidad de información que puede obtener de esta variable, puede modificarla a un nivel superior. Es decir, crea grupos más grandes con un nivel de educación similar. Por ejemplo, el bajo nivel de educación se convertirá en deserción. Los niveles superiores de educación se cambiarán a máster.
Aquí está el detalle:
Nivel antiguo
Nuevo nivel
Preescolar
abandonar
Décimo
Abandonar
11º
Abandonar
12º
Abandonar
1º-4º
Abandonar
5to-6to
Abandonar
7 ° al 8 °
Abandonar
Noveno
Abandonar
HS-Grad
HighGrad
Alguna educación superior
Comunidad
Assoc-acdm
Comunidad
Assoc-voc
Comunidad
Solteros
Solteros
Maestros
Maestros
Prof-escuela
Maestros
Doctorado
Doctor
recast_data <- data_adult_rescale % > %select(-X) % > %mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community",ifelse(education == "Bachelors", "Bachelors",ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))Explicación del código
- Usamos el verbo mutar de la biblioteca dplyr. Cambiamos los valores de la educación con la declaración ifelse
En la siguiente tabla, crea una estadística resumida para ver, en promedio, cuántos años de educación (valor z) se necesitan para obtener la licenciatura, la maestría o el doctorado.
recast_data % > %group_by(education) % > %summarize(average_educ_year = mean(educational.num),count = n()) % > %arrange(average_educ_year)Producción:
## # A tibble: 6 x 3## education average_educ_year count#### 1 dropout -1.76147258 5712## 2 HighGrad -0.43998868 14803## 3 Community 0.09561361 13407## 4 Bachelors 1.12216282 7720## 5 Master 1.60337381 3338## 6 PhD 2.29377644 557 Refundición del estado civil
También es posible crear niveles más bajos para el estado civil. En el siguiente código, cambia el nivel de la siguiente manera:
Nivel antiguo
Nuevo nivel
Nunca casado
No casado
Cónyuge-casado-ausente
No casado
Cónyuge-AF-casado
Casado
Cónyuge-civil-casado
Apartado
Apartado
Divorciado
Viudas
Viuda
# Change level marryrecast_data <- recast_data % > %mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))Puede verificar el número de personas dentro de cada grupo.table(recast_data$marital.status)Producción:
## ## Married Not_married Separated Widow## 21165 15359 7727 1286Paso 4) Resumen estadístico
Es hora de comprobar algunas estadísticas sobre nuestras variables objetivo. En el gráfico a continuación, cuenta el porcentaje de personas que ganan más de 50k según su género.
# Plot gender incomeggplot(recast_data, aes(x = gender, fill = income)) +geom_bar(position = "fill") +theme_classic()Producción:
A continuación, compruebe si el origen del individuo afecta sus ingresos.
# Plot origin incomeggplot(recast_data, aes(x = race, fill = income)) +geom_bar(position = "fill") +theme_classic() +theme(axis.text.x = element_text(angle = 90))Producción:
El número de horas trabajadas por género.
# box plot gender working timeggplot(recast_data, aes(x = gender, y = hours.per.week)) +geom_boxplot() +stat_summary(fun.y = mean,geom = "point",size = 3,color = "steelblue") +theme_classic()Producción:
El diagrama de caja confirma que la distribución del tiempo de trabajo se ajusta a diferentes grupos. En el diagrama de caja, ambos sexos no tienen observaciones homogéneas.
Puede consultar la densidad del tiempo de trabajo semanal por tipo de educación. Las distribuciones tienen muchas selecciones distintas. Probablemente se pueda explicar por el tipo de contrato en los EE. UU.
# Plot distribution working time by educationggplot(recast_data, aes(x = hours.per.week)) +geom_density(aes(color = education), alpha = 0.5) +theme_classic()Explicación del código
- ggplot (recast_data, aes (x = hours.per.week)): una gráfica de densidad solo requiere una variable
- geom_density (aes (color = education), alpha = 0.5): El objeto geométrico para controlar la densidad
Producción:
Para confirmar sus pensamientos, puede realizar una prueba ANOVA unidireccional:
anova <- aov(hours.per.week~education, recast_data)summary(anova)Producción:
## Df Sum Sq Mean Sq F value Pr(>F)## education 5 1552 310.31 321.2 <2e-16 ***## Residuals 45531 43984 0.97## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1La prueba ANOVA confirma la diferencia de promedio entre grupos.
No linealidad
Antes de ejecutar el modelo, puede ver si el número de horas trabajadas está relacionado con la edad.
library(ggplot2)ggplot(recast_data, aes(x = age, y = hours.per.week)) +geom_point(aes(color = income),size = 0.5) +stat_smooth(method = 'lm',formula = y~poly(x, 2),se = TRUE,aes(color = income)) +theme_classic()Explicación del código
- ggplot (recast_data, aes (x = age, y = hours.per.week)): establece la estética del gráfico
- geom_point (aes (color = ingresos), tamaño = 0.5): Construye el diagrama de puntos
- stat_smooth (): agregue la línea de tendencia con los siguientes argumentos:
- método = 'lm': grafica el valor ajustado si la regresión lineal
- formula = y ~ poly (x, 2): Ajustar una regresión polinomial
- se = TRUE: Agrega el error estándar
- aes (color = ingresos): divide el modelo por ingresos
Producción:
En pocas palabras, puede probar los términos de interacción en el modelo para detectar el efecto de no linealidad entre el tiempo de trabajo semanal y otras características. Es importante detectar en qué condiciones difiere el tiempo de trabajo.
Correlación
La siguiente comprobación es visualizar la correlación entre las variables. Convierte el tipo de nivel de factor en numérico para poder trazar un mapa de calor que contenga el coeficiente de correlación calculado con el método de Spearman.
library(GGally)# Convert data to numericcorr <- data.frame(lapply(recast_data, as.integer))# Plot the graphggcorr(corr,method = c("pairwise", "spearman"),nbreaks = 6,hjust = 0.8,label = TRUE,label_size = 3,color = "grey50")Explicación del código
- data.frame (lapply (recast_data, as.integer)): convierte datos en numéricos
- ggcorr () traza el mapa de calor con los siguientes argumentos:
- método: método para calcular la correlación
- nbreaks = 6: Número de pausas
- hjust = 0.8: Posición de control del nombre de la variable en el gráfico
- label = TRUE: agrega etiquetas en el centro de las ventanas
- label_size = 3: Etiquetas de tamaño
- color = "grey50"): Color de la etiqueta
Producción:
Paso 5) Equipo de entrenamiento / prueba
Cualquier tarea de aprendizaje automático supervisada requiere dividir los datos entre un conjunto de trenes y un conjunto de prueba. Puede utilizar la "función" que creó en los otros tutoriales de aprendizaje supervisado para crear un conjunto de entrenamiento / prueba.
set.seed(1234)create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample <- 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}data_train <- create_train_test(recast_data, 0.8, train = TRUE)data_test <- create_train_test(recast_data, 0.8, train = FALSE)dim(data_train)Producción:
## [1] 36429 9dim(data_test)Producción:
## [1] 9108 9Paso 6) Construye el modelo
Para ver cómo funciona el algoritmo, use el paquete glm (). El modelo lineal generalizado es una colección de modelos. La sintaxis básica es:
glm(formula, data=data, family=linkfunction()Argument:- formula: Equation used to fit the model- data: dataset used- Family: - binomial: (link = "logit")- gaussian: (link = "identity")- Gamma: (link = "inverse")- inverse.gaussian: (link = "1/mu^2")- poisson: (link = "log")- quasi: (link = "identity", variance = "constant")- quasibinomial: (link = "logit")- quasipoisson: (link = "log")Está listo para estimar el modelo logístico para dividir el nivel de ingresos entre un conjunto de características.
formula <- income~.logit <- glm(formula, data = data_train, family = 'binomial')summary(logit)Explicación del código
- fórmula <- ingresos ~.: Crea el modelo para ajustar
- logit <- glm (fórmula, data = data_train, family = 'binomial'): Ajuste un modelo logístico (family = 'binomial') con los datos de data_train.
- resumen (logit): Imprime el resumen del modelo
Producción:
#### Call:## glm(formula = formula, family = "binomial", data = data_train)## ## Deviance Residuals:## Min 1Q Median 3Q Max## -2.6456 -0.5858 -0.2609 -0.0651 3.1982#### Coefficients:## Estimate Std. Error z value Pr(>|z|)## (Intercept) 0.07882 0.21726 0.363 0.71675## age 0.41119 0.01857 22.146 < 2e-16 ***## workclassLocal-gov -0.64018 0.09396 -6.813 9.54e-12 ***## workclassPrivate -0.53542 0.07886 -6.789 1.13e-11 ***## workclassSelf-emp-inc -0.07733 0.10350 -0.747 0.45499## workclassSelf-emp-not-inc -1.09052 0.09140 -11.931 < 2e-16 ***## workclassState-gov -0.80562 0.10617 -7.588 3.25e-14 ***## workclassWithout-pay -1.09765 0.86787 -1.265 0.20596## educationCommunity -0.44436 0.08267 -5.375 7.66e-08 ***## educationHighGrad -0.67613 0.11827 -5.717 1.08e-08 ***## educationMaster 0.35651 0.06780 5.258 1.46e-07 ***## educationPhD 0.46995 0.15772 2.980 0.00289 **## educationdropout -1.04974 0.21280 -4.933 8.10e-07 ***## educational.num 0.56908 0.07063 8.057 7.84e-16 ***## marital.statusNot_married -2.50346 0.05113 -48.966 < 2e-16 ***## marital.statusSeparated -2.16177 0.05425 -39.846 < 2e-16 ***## marital.statusWidow -2.22707 0.12522 -17.785 < 2e-16 ***## raceAsian-Pac-Islander 0.08359 0.20344 0.411 0.68117## raceBlack 0.07188 0.19330 0.372 0.71001## raceOther 0.01370 0.27695 0.049 0.96054## raceWhite 0.34830 0.18441 1.889 0.05894 .## genderMale 0.08596 0.04289 2.004 0.04506 *## hours.per.week 0.41942 0.01748 23.998 < 2e-16 ***## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1## ## (Dispersion parameter for binomial family taken to be 1)## ## Null deviance: 40601 on 36428 degrees of freedom## Residual deviance: 27041 on 36406 degrees of freedom## AIC: 27087#### Number of Fisher Scoring iterations: 6El resumen de nuestro modelo revela información interesante. El rendimiento de una regresión logística se evalúa con métricas clave específicas.
- AIC (Akaike Information Criteria): este es el equivalente de R2 en regresión logística. Mide el ajuste cuando se aplica una penalización al número de parámetros. Los valores de AIC más pequeños indican que el modelo está más cerca de la verdad.
- Desviación nula: se ajusta al modelo solo con la intersección. El grado de libertad es n-1. Podemos interpretarlo como un valor de Chi-cuadrado (valor ajustado diferente de la prueba de hipótesis del valor real).
- Desviación residual: Modelo con todas las variables. También se interpreta como una prueba de hipótesis de Chi-cuadrado.
- Número de iteraciones de puntuación de Fisher: número de iteraciones antes de la convergencia.
La salida de la función glm () se almacena en una lista. El siguiente código muestra todos los elementos disponibles en la variable logit que construimos para evaluar la regresión logística.
# La lista es muy larga, imprima solo los primeros tres elementos
lapply(logit, class)[1:3]Producción:
## $coefficients## [1] "numeric"#### $residuals## [1] "numeric"#### $fitted.values## [1] "numeric"Cada valor se puede extraer con el signo $ seguido del nombre de las métricas. Por ejemplo, almacenó el modelo como logit. Para extraer los criterios AIC, utiliza:
logit$aicProducción:
## [1] 27086.65Paso 7) Evaluar el desempeño del modelo
Matriz de confusión
La matriz de confusión es una mejor opción para evaluar el rendimiento de la clasificación en comparación con las diferentes métricas que vio antes. La idea general es contar el número de veces que las instancias verdaderas se clasifican como falsas.
Para calcular la matriz de confusión, primero debe tener un conjunto de predicciones para poder compararlas con los objetivos reales.
predict <- predict(logit, data_test, type = 'response')# confusion matrixtable_mat <- table(data_test$income, predict > 0.5)table_matExplicación del código
- predict (logit, data_test, type = 'response'): calcula la predicción en el conjunto de prueba. Establezca type = 'response' para calcular la probabilidad de respuesta.
- table (data_test $ ingresos, predecir> 0.5): Calcule la matriz de confusión. predecir> 0,5 significa que devuelve 1 si las probabilidades predichas están por encima de 0,5, en caso contrario 0.
Producción:
#### FALSE TRUE## <=50K 6310 495## >50K 1074 1229Cada fila en una matriz de confusión representa un objetivo real, mientras que cada columna representa un objetivo previsto. La primera fila de esta matriz considera los ingresos inferiores a 50k (la clase Falso): 6241 se clasificaron correctamente como individuos con ingresos inferiores a 50k ( Verdadero negativo ), mientras que el restante se clasificó erróneamente como superior a 50k ( Falso positivo ). La segunda fila considera los ingresos por encima de 50k, la clase positiva fue 1229 ( verdadero positivo ), mientras que el verdadero negativo fue 1074.
Puede calcular la precisión del modelo sumando el verdadero positivo + el verdadero negativo sobre la observación total
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_TestExplicación del código
- sum (diag (table_mat)): Suma de la diagonal
- sum (table_mat): Suma de la matriz.
Producción:
## [1] 0.8277339El modelo parece tener un problema: sobrestima el número de falsos negativos. A esto se le llama la paradoja de la prueba de precisión . Dijimos que la precisión es la relación entre las predicciones correctas y el número total de casos. Podemos tener una precisión relativamente alta pero un modelo inútil. Ocurre cuando hay una clase dominante. Si mira hacia atrás en la matriz de confusión, puede ver que la mayoría de los casos están clasificados como verdaderos negativos. Imagínese ahora, el modelo clasificó todas las clases como negativas (es decir, inferiores a 50k). Tendría una precisión del 75 por ciento (6718/6718 + 2257). Su modelo funciona mejor, pero le cuesta distinguir el verdadero positivo del verdadero negativo.
En tal situación, es preferible tener una métrica más concisa. Podemos mirar:
- Precisión = TP / (TP + FP)
- Recuperar = TP / (TP + FN)
Precisión vs recuperación
Precision mira la exactitud de la predicción positiva. Recall es la proporción de instancias positivas que son detectadas correctamente por el clasificador;
Puede construir dos funciones para calcular estas dos métricas
- Construir precisión
precision <- function(matrix) {# True positivetp <- matrix[2, 2]# false positivefp <- matrix[1, 2]return (tp / (tp + fp))}Explicación del código
- mat [1,1]: Devuelve la primera celda de la primera columna del marco de datos, es decir, el verdadero positivo
- estera [1,2]; Devuelve la primera celda de la segunda columna del marco de datos, es decir, el falso positivo
recall <- function(matrix) {# true positivetp <- matrix[2, 2]# false positivefn <- matrix[2, 1]return (tp / (tp + fn))}Explicación del código
- mat [1,1]: Devuelve la primera celda de la primera columna del marco de datos, es decir, el verdadero positivo
- estera [2,1]; Devuelve la segunda celda de la primera columna del marco de datos, es decir, el falso negativo
Puedes probar tus funciones
prec <- precision(table_mat)precrec <- recall(table_mat)recProducción:
## [1] 0.712877## [2] 0.5336518Cuando el modelo dice que es un individuo por encima de 50k, es correcto en solo el 54 por ciento de los casos, y puede reclamar individuos por encima de 50k en el 72 por ciento de los casos.
Puede crear la es una media armónica de estas dos métricas, lo que significa que da más peso a los valores más bajos.
f1 <- 2 * ((prec * rec) / (prec + rec))f1Producción:
## [1] 0.6103799Compensación de precisión vs recuperación
Es imposible tener una alta precisión y una alta recuperación.
Si aumentamos la precisión, se predecirá mejor el individuo correcto, pero perderíamos muchos de ellos (menor recuerdo). En alguna situación, preferimos una mayor precisión que recordar. Existe una relación cóncava entre precisión y recuerdo.
- Imagínese, necesita predecir si un paciente tiene una enfermedad. Quieres ser lo más preciso posible.
- Si necesita detectar posibles personas fraudulentas en la calle a través del reconocimiento facial, sería mejor atrapar a muchas personas etiquetadas como fraudulentas aunque la precisión sea baja. La policía podrá liberar al individuo no fraudulento.
La curva ROC
La curva de características de funcionamiento del receptor es otra herramienta común que se utiliza con la clasificación binaria. Es muy similar a la curva de precisión / recuperación, pero en lugar de representar la precisión frente a la recuperación, la curva ROC muestra la tasa de verdaderos positivos (es decir, la recuperación) frente a la tasa de falsos positivos. La tasa de falsos positivos es la proporción de casos negativos que se clasifican incorrectamente como positivos. Es igual a uno menos la tasa negativa verdadera. La verdadera tasa negativa también se llama especificidad . Por lo tanto, la curva ROC traza la sensibilidad (recuperación) frente a la especificidad 1
Para trazar la curva ROC, necesitamos instalar una biblioteca llamada RORC. Lo podemos encontrar en la biblioteca de conda. Puede escribir el código:
conda install -cr r-rocr --sí
Podemos graficar la ROC con las funciones de predicción () y rendimiento ().
library(ROCR)ROCRpred <- prediction(predict, data_test$income)ROCRperf <- performance(ROCRpred, 'tpr', 'fpr')plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))Explicación del código
- predicción (predecir, prueba_datos $ ingresos): la biblioteca ROCR necesita crear un objeto de predicción para transformar los datos de entrada
- rendimiento (ROCRpred, 'tpr', 'fpr'): Devuelve las dos combinaciones para producir en el gráfico. Aquí, se construyen tpr y fpr. Tot trazar precisión y recordar juntos, use "prec", "rec".
Producción:
Paso 8) Mejora el modelo
Puede intentar agregar no linealidad al modelo con la interacción entre
- edad y horas por semana
- género y horas por semana.
Debe utilizar la prueba de puntuación para comparar ambos modelos
formula_2 <- income~age: hours.per.week + gender: hours.per.week + .logit_2 <- glm(formula_2, data = data_train, family = 'binomial')predict_2 <- predict(logit_2, data_test, type = 'response')table_mat_2 <- table(data_test$income, predict_2 > 0.5)precision_2 <- precision(table_mat_2)recall_2 <- recall(table_mat_2)f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2))f1_2Producción:
## [1] 0.6109181La puntuación es ligeramente superior a la anterior. Puede seguir trabajando en los datos para intentar superar la puntuación.
Resumen
Podemos resumir la función para entrenar una regresión logística en la siguiente tabla:
Paquete
Objetivo
función
argumento
-
Crear un conjunto de datos de entrenamiento / prueba
create_train_set ()
datos, tamaño, tren
glm
Entrenar un modelo lineal generalizado
glm ()
fórmula, datos, familia *
glm
Resume el modelo
resumen()
modelo ajustado
base
Hacer predicción
predecir()
modelo ajustado, conjunto de datos, tipo = 'respuesta'
base
Crea una matriz de confusión
mesa()
y, predecir ()
base
Crear puntuación de precisión
suma (diag (tabla ()) / suma (tabla ()
ROCR
Crear ROC: Paso 1 Crear predicción
predicción()
predecir (), y
ROCR
Crear ROC: Paso 2 Crear rendimiento
rendimiento()
predicción (), 'tpr', 'fpr'
ROCR
Crear ROC: Paso 3 Trazar un gráfico
trama()
rendimiento()
Los otros modelos de GLM son:
- binomio: (enlace = "logit")
- gaussiano: (enlace = "identidad")
- Gamma: (enlace = "inverso")
- inverso.gaussiano: (enlace = "1 / mu 2")
- poisson: (enlace = "registro")
- cuasi: (enlace = "identidad", varianza = "constante")
- cuasibinomio: (enlace = "logit")
- quasipoisson: (enlace = "registro")