GLM i R: Generalisert lineær modell med eksempel

Innholdsfortegnelse:

Anonim

Hva er logistisk regresjon?

Logistisk regresjon brukes til å forutsi en klasse, dvs. en sannsynlighet. Logistisk regresjon kan forutsi et binært utfall nøyaktig.

Tenk deg at du vil forutsi om et lån blir nektet / akseptert basert på mange attributter. Den logistiske regresjonen er av formen 0/1. y = 0 hvis et lån avvises, y = 1 hvis det godtas.

En logistisk regresjonsmodell skiller seg fra lineær regresjonsmodell på to måter.

  • Først og fremst aksepterer den logistiske regresjonen bare dikotom (binær) inngang som en avhengig variabel (dvs. en vektor på 0 og 1).
  • For det andre måles utfallet av følgende sannsynlighetsforbindelsesfunksjon kalt sigmoid på grunn av sin S-formede .:

Funksjonens utgang er alltid mellom 0 og 1. Sjekk bildet nedenfor

Sigmoid-funksjonen returnerer verdier fra 0 til 1. For klassifiseringsoppgaven trenger vi en diskret utgang på 0 eller 1.

For å konvertere en kontinuerlig strøm til diskret verdi, kan vi sette en beslutning bundet til 0,5. Alle verdier over denne terskelen er klassifisert som 1

I denne veiledningen vil du lære

  • Hva er logistisk regresjon?
  • Hvordan lage Generalized Liner Model (GLM)
  • Trinn 1) Kontroller kontinuerlige variabler
  • Trinn 2) Sjekk faktorvariabler
  • Trinn 3) Funksjonsteknikk
  • Trinn 4) Sammendragstatistikk
  • Trinn 5) Tren / testsett
  • Trinn 6) Bygg modellen
  • Trinn 7) Vurder ytelsen til modellen

Hvordan lage Generalized Liner Model (GLM)

La oss bruke datasettet for voksne for å illustrere logistisk regresjon. "Voksen" er et flott datasett for klassifiseringsoppgaven. Målet er å forutsi om den enkeltes årlige inntekt i dollar vil overstige 50.000. Datasettet inneholder 46.033 observasjoner og ti funksjoner:

  • alder: individets alder. Numerisk
  • utdanning: Utdanningsnivå for den enkelte. Faktor.
  • sivilstatus: Sivilstatus for individet. Faktor dvs. aldri gift, gift-civ-ektefelle, ...
  • kjønn: Kjønn til individet. Faktor, dvs. mann eller kvinne
  • inntekt: Målvariabel. Inntekt over eller under 50K. Faktor dvs.> 50K, <= 50K

blant andre

library(dplyr)data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv")glimpse(data_adult)

Produksjon:

Observations: 48,842Variables: 10$ x  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,… $ age  25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26… $ workclass  Private, Private, Local-gov, Private, ?, Private,… $ education  11th, HS-grad, Assoc-acdm, Some-college, Some-col… $ educational.num  7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,… $ marital.status  Never-married, Married-civ-spouse, Married-civ-sp… $ race  Black, White, White, Black, White, White, Black,… $ gender  Male, Male, Male, Male, Female, Male, Male, Male,… $ hours.per.week  40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39… $ income  <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5… 

Vi fortsetter som følger:

  • Trinn 1: Kontroller kontinuerlige variabler
  • Trinn 2: Kontroller faktorvariabler
  • Trinn 3: Funksjonsteknikk
  • Trinn 4: Sammendragstatistikk
  • Trinn 5: Tren / testsett
  • Trinn 6: Bygg modellen
  • Trinn 7: Vurder ytelsen til modellen
  • trinn 8: Forbedre modellen

Din oppgave er å forutsi hvilken person som vil ha en inntekt som er høyere enn 50K.

I denne opplæringen vil hvert trinn bli detaljert for å utføre en analyse på et ekte datasett.

Trinn 1) Kontroller kontinuerlige variabler

I det første trinnet kan du se fordelingen av de kontinuerlige variablene.

continuous <-select_if(data_adult, is.numeric)summary(continuous)

Kode Forklaring

  • kontinuerlig <- select_if (data_adult, is.numeric): Bruk funksjonen select_if () fra dplyr-biblioteket for å velge bare de numeriske kolonnene
  • sammendrag (kontinuerlig): Skriv ut sammendragstatistikken

Produksjon:

## X age educational.num hours.per.week## Min. : 1 Min. :17.00 Min. : 1.00 Min. : 1.00## 1st Qu.:11509 1st Qu.:28.00 1st Qu.: 9.00 1st Qu.:40.00## Median :23017 Median :37.00 Median :10.00 Median :40.00## Mean :23017 Mean :38.56 Mean :10.13 Mean :40.95## 3rd Qu.:34525 3rd Qu.:47.00 3rd Qu.:13.00 3rd Qu.:45.00## Max. :46033 Max. :90.00 Max. :16.00 Max. :99.00

Fra tabellen ovenfor kan du se at dataene har helt forskjellige skalaer og timer. Per.weeks har store outliers (. Ser på den siste kvartilen og maksimumsverdien).

Du kan håndtere det ved å følge to trinn:

  • 1: Plott fordeling av timer. Per uke
  • 2: Standardiser de kontinuerlige variablene
  1. Plott fordelingen

La oss se nærmere på fordelingen av timer. Per uke

# Histogram with kernel density curvelibrary(ggplot2)ggplot(continuous, aes(x = hours.per.week)) +geom_density(alpha = .2, fill = "#FF6666")

Produksjon:

Variabelen har mange avvik og ikke veldefinert fordeling. Du kan takle dette problemet delvis ved å slette de øverste 0,01 prosent av timene per uke.

Grunnleggende syntaks for kvantil:

quantile(variable, percentile)arguments:-variable: Select the variable in the data frame to compute the percentile-percentile: Can be a single value between 0 and 1 or multiple value. If multiple, use this format: `c(A,B,C,… )- `A`,`B`,`C` and `… ` are all integer from 0 to 1.

Vi beregner de to øverste prosentilene

top_one_percent <- quantile(data_adult$hours.per.week, .99)top_one_percent

Kode Forklaring

  • kvantil (data_adult $ hours.per.week, .99): Beregn verdien av 99 prosent av arbeidstiden

Produksjon:

## 99%## 80 

98 prosent av befolkningen jobber under 80 timer per uke.

Du kan slippe observasjonene over denne terskelen. Du bruker filteret fra dplyr-biblioteket.

data_adult_drop <-data_adult %>%filter(hours.per.week

Produksjon:

## [1] 45537 10 
  1. Standardiser de kontinuerlige variablene

Du kan standardisere hver kolonne for å forbedre ytelsen fordi dataene dine ikke har samme skala. Du kan bruke funksjonen mutate_if fra dplyr-biblioteket. Den grunnleggende syntaksen er:

mutate_if(df, condition, funs(function))arguments:-`df`: Data frame used to compute the function- `condition`: Statement used. Do not use parenthesis- funs(function): Return the function to apply. Do not use parenthesis for the function

Du kan standardisere de numeriske kolonnene som følger:

data_adult_rescale <- data_adult_drop % > %mutate_if(is.numeric, funs(as.numeric(scale(.))))head(data_adult_rescale)

Kode Forklaring

  • mutate_if (is.numeric, funs (scale)): Tilstanden er bare numerisk kolonne og funksjonen er skala

Produksjon:

## X age workclass education educational.num## 1 -1.732680 -1.02325949 Private 11th -1.22106443## 2 -1.732605 -0.03969284 Private HS-grad -0.43998868## 3 -1.732530 -0.79628257 Local-gov Assoc-acdm 0.73162494## 4 -1.732455 0.41426100 Private Some-college -0.04945081## 5 -1.732379 -0.34232873 Private 10th -1.61160231## 6 -1.732304 1.85178149 Self-emp-not-inc Prof-school 1.90323857## marital.status race gender hours.per.week income## 1 Never-married Black Male -0.03995944 <=50K## 2 Married-civ-spouse White Male 0.86863037 <=50K## 3 Married-civ-spouse White Male -0.03995944 >50K## 4 Married-civ-spouse Black Male -0.03995944 >50K## 5 Never-married White Male -0.94854924 <=50K## 6 Married-civ-spouse White Male -0.76683128 >50K

Trinn 2) Sjekk faktorvariabler

Dette trinnet har to mål:

  • Sjekk nivået i hver kategoriske kolonne
  • Definer nye nivåer

Vi vil dele dette trinnet i tre deler:

  • Velg de kategoriske kolonnene
  • Lagre stolpediagrammet til hver kolonne i en liste
  • Skriv ut grafene

Vi kan velge faktorkolonnene med koden nedenfor:

# Select categorical columnfactor <- data.frame(select_if(data_adult_rescale, is.factor))ncol(factor)

Kode Forklaring

  • data.frame (select_if (data_adult, is.factor)): Vi lagrer faktorkolonnene i faktor i en datarammetype. Biblioteket ggplot2 krever et datarammeobjekt.

Produksjon:

## [1] 6 

Datasettet inneholder 6 kategoriske variabler

Det andre trinnet er mer dyktig. Du vil plotte et søylediagram for hver kolonne i datarammefaktoren. Det er mer praktisk å automatisere prosessen, spesielt i situasjoner er det mange kolonner.

library(ggplot2)# Create graph for each columngraph <- lapply(names(factor),function(x)ggplot(factor, aes(get(x))) +geom_bar() +theme(axis.text.x = element_text(angle = 90)))

Kode Forklaring

  • lapply (): Bruk funksjonen lapply () for å sende en funksjon i alle kolonnene i datasettet. Du lagrer utdataene i en liste
  • funksjon (x): Funksjonen blir behandlet for hver x. Her er x kolonnene
  • ggplot (faktor, aes (get (x))) + geom_bar () + tema (axis.text.x = element_text (vinkel = 90)): Lag et søylediagram for hvert x-element. Merk at for å returnere x som en kolonne, må du ta den med i get ()

Det siste trinnet er relativt enkelt. Du vil skrive ut de 6 grafene.

# Print the graphgraph

Produksjon:

## [[1]]

## ## [[2]]

## ## [[3]]

## ## [[4]]

## ## [[5]]

## ## [[6]]

Merk: Bruk neste knapp for å navigere til neste graf

Trinn 3) Funksjonsteknikk

Omarbeidet utdanning

Fra grafen over kan du se at variabel utdanning har 16 nivåer. Dette er betydelig, og noen nivåer har et relativt lavt antall observasjoner. Hvis du vil forbedre mengden informasjon du kan få fra denne variabelen, kan du omarbeide den til høyere nivå. Du lager nemlig større grupper med tilsvarende utdanningsnivå. For eksempel vil lavt utdanningsnivå konverteres til frafall. Høyere utdanningsnivå vil bli endret til å mestre.

Her er detaljene:

Gammelt nivå

Nytt nivå

Barnehage

frafall

10.

Frafall

11

Frafall

12. plass

Frafall

1.-4

Frafall

5.-6

Frafall

7.-8

Frafall

9. plass

Frafall

HS-Grad

HighGrad

Noen college

Samfunnet

Assoc-acdm

Samfunnet

Assoc-voc

Samfunnet

Bachelor

Bachelor

mestere

mestere

Prof-skole

mestere

Doktorgrad

PhD

recast_data <- data_adult_rescale % > %select(-X) % > %mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community",ifelse(education == "Bachelors", "Bachelors",ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))

Kode Forklaring

  • Vi bruker verbet mutere fra dplyr-biblioteket. Vi endrer utdanningsverdiene med uttalelsen ifelse

I tabellen nedenfor lager du en oppsummeringsstatistikk for å se i gjennomsnitt hvor mange års utdanning (z-verdi) det tar å nå Bachelor, Master eller PhD.

recast_data % > %group_by(education) % > %summarize(average_educ_year = mean(educational.num),count = n()) % > %arrange(average_educ_year)

Produksjon:

## # A tibble: 6 x 3## education average_educ_year count##   ## 1 dropout -1.76147258 5712## 2 HighGrad -0.43998868 14803## 3 Community 0.09561361 13407## 4 Bachelors 1.12216282 7720## 5 Master 1.60337381 3338## 6 PhD 2.29377644 557

Omarbeidet Sivilstatus

Det er også mulig å lage lavere nivåer for sivilstanden. I følgende kode endrer du nivået som følger:

Gammelt nivå

Nytt nivå

Aldri gift

Ikke gift

Gift-ektefelle-fraværende

Ikke gift

Gift-AF-ektefelle

Gift

Gift-civ-ektefelle

Separert

Separert

Skilt

Enker

Enke

# Change level marryrecast_data <- recast_data % > %mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))
Du kan sjekke antall personer i hver gruppe.
table(recast_data$marital.status)

Produksjon:

## ## Married Not_married Separated Widow## 21165 15359 7727 1286 

Trinn 4) Sammendragstatistikk

Det er på tide å sjekke litt statistikk om målvariablene våre. I grafen nedenfor teller du prosentandelen av individer som tjener mer enn 50 000 gitt kjønn.

# Plot gender incomeggplot(recast_data, aes(x = gender, fill = income)) +geom_bar(position = "fill") +theme_classic()

Produksjon:

Sjekk deretter om opprinnelsen til den enkelte påvirker deres inntjening.

# Plot origin incomeggplot(recast_data, aes(x = race, fill = income)) +geom_bar(position = "fill") +theme_classic() +theme(axis.text.x = element_text(angle = 90))

Produksjon:

Antall arbeidstimer fordelt på kjønn.

# box plot gender working timeggplot(recast_data, aes(x = gender, y = hours.per.week)) +geom_boxplot() +stat_summary(fun.y = mean,geom = "point",size = 3,color = "steelblue") +theme_classic()

Produksjon:

Rutetomten bekrefter at fordelingen av arbeidstid passer til forskjellige grupper. I boksplottet har ikke begge kjønn homogene observasjoner.

Du kan sjekke tettheten til den ukentlige arbeidstiden etter type utdanning. Distribusjonene har mange forskjellige valg. Det kan sannsynligvis forklares med typen kontrakt i USA.

# Plot distribution working time by educationggplot(recast_data, aes(x = hours.per.week)) +geom_density(aes(color = education), alpha = 0.5) +theme_classic()

Kode Forklaring

  • ggplot (recast_data, aes (x = hours.per.week)): En tetthetsplott krever bare en variabel
  • geom_density (aes (farge = utdannelse), alfa = 0,5): Det geometriske objektet for å kontrollere tettheten

Produksjon:

For å bekrefte tankene dine, kan du utføre en enveis ANOVA-test:

anova <- aov(hours.per.week~education, recast_data)summary(anova)

Produksjon:

## Df Sum Sq Mean Sq F value Pr(>F)## education 5 1552 310.31 321.2 <2e-16 ***## Residuals 45531 43984 0.97## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

ANOVA-testen bekrefter forskjellen i gjennomsnitt mellom grupper.

Ikke-linearitet

Før du kjører modellen kan du se om antall arbeidte timer er relatert til alder.

library(ggplot2)ggplot(recast_data, aes(x = age, y = hours.per.week)) +geom_point(aes(color = income),size = 0.5) +stat_smooth(method = 'lm',formula = y~poly(x, 2),se = TRUE,aes(color = income)) +theme_classic()

Kode Forklaring

  • ggplot (recast_data, aes (x = age, y = hours.per.week)): Angi estetikken til grafen
  • geom_point (aes (farge = inntekt), størrelse = 0,5): Konstruer prikkplottet
  • stat_smooth (): Legg til trendlinjen med følgende argumenter:
    • metode = 'lm': Plott den tilpassede verdien hvis lineær regresjon
    • formel = y ~ poly (x, 2): Pass på en polynomial regresjon
    • se = SANT: Legg til standardfeilen
    • aes (farge = inntekt): Bryt modellen etter inntekt

Produksjon:

I et nøtteskall kan du teste samhandlingsuttrykk i modellen for å plukke opp den ikke-lineære effekten mellom den ukentlige arbeidstiden og andre funksjoner. Det er viktig å oppdage under hvilke forhold arbeidstiden er forskjellig.

Sammenheng

Neste sjekk er å visualisere sammenhengen mellom variablene. Du konverterer faktornivåetypen til numerisk slik at du kan plotte et varmekart som inneholder korrelasjonskoeffisienten beregnet med Spearman-metoden.

library(GGally)# Convert data to numericcorr <- data.frame(lapply(recast_data, as.integer))# Plot the graphggcorr(corr,method = c("pairwise", "spearman"),nbreaks = 6,hjust = 0.8,label = TRUE,label_size = 3,color = "grey50")

Kode Forklaring

  • data.frame (lapply (recast_data, as.integer)): Konverter data til numerisk
  • ggcorr () plott varmekartet med følgende argumenter:
    • metode: Metode for å beregne korrelasjonen
    • nbrudd = 6: Antall brudd
    • hjust = 0,8: Kontrollposisjon for variabelnavnet i plottet
    • label = TRUE: Legg til etiketter i midten av vinduene
    • label_size = 3: Størrelsesetiketter
    • color = "grey50"): Farge på etiketten

Produksjon:

Trinn 5) Tren / testsett

Enhver overvåket maskinlæringsoppgave krever å dele dataene mellom et togsett og et testsett. Du kan bruke "funksjonen" du opprettet i de andre veiledningsopplæringsveiledningene for å lage et tog / testsett.

set.seed(1234)create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample <- 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}data_train <- create_train_test(recast_data, 0.8, train = TRUE)data_test <- create_train_test(recast_data, 0.8, train = FALSE)dim(data_train)

Produksjon:

## [1] 36429 9
dim(data_test)

Produksjon:

## [1] 9108 9 

Trinn 6) Bygg modellen

For å se hvordan algoritmen fungerer, bruker du pakken glm (). Den generaliserte lineære modellen er en samling modeller. Den grunnleggende syntaksen er:

glm(formula, data=data, family=linkfunction()Argument:- formula: Equation used to fit the model- data: dataset used- Family: - binomial: (link = "logit")- gaussian: (link = "identity")- Gamma: (link = "inverse")- inverse.gaussian: (link = "1/mu^2")- poisson: (link = "log")- quasi: (link = "identity", variance = "constant")- quasibinomial: (link = "logit")- quasipoisson: (link = "log")

Du er klar til å estimere den logistiske modellen for å dele inntektsnivået mellom et sett med funksjoner.

formula <- income~.logit <- glm(formula, data = data_train, family = 'binomial')summary(logit)

Kode Forklaring

  • formel <- inntekt ~.: Lag modellen slik at den passer
  • logit <- glm (formel, data = data_train, family = 'binomial'): Tilpass en logistisk modell (family = 'binomial') med data_train-dataene.
  • sammendrag (logit): Skriv ut sammendraget av modellen

Produksjon:

#### Call:## glm(formula = formula, family = "binomial", data = data_train)## ## Deviance Residuals:## Min 1Q Median 3Q Max## -2.6456 -0.5858 -0.2609 -0.0651 3.1982#### Coefficients:## Estimate Std. Error z value Pr(>|z|)## (Intercept) 0.07882 0.21726 0.363 0.71675## age 0.41119 0.01857 22.146 < 2e-16 ***## workclassLocal-gov -0.64018 0.09396 -6.813 9.54e-12 ***## workclassPrivate -0.53542 0.07886 -6.789 1.13e-11 ***## workclassSelf-emp-inc -0.07733 0.10350 -0.747 0.45499## workclassSelf-emp-not-inc -1.09052 0.09140 -11.931 < 2e-16 ***## workclassState-gov -0.80562 0.10617 -7.588 3.25e-14 ***## workclassWithout-pay -1.09765 0.86787 -1.265 0.20596## educationCommunity -0.44436 0.08267 -5.375 7.66e-08 ***## educationHighGrad -0.67613 0.11827 -5.717 1.08e-08 ***## educationMaster 0.35651 0.06780 5.258 1.46e-07 ***## educationPhD 0.46995 0.15772 2.980 0.00289 **## educationdropout -1.04974 0.21280 -4.933 8.10e-07 ***## educational.num 0.56908 0.07063 8.057 7.84e-16 ***## marital.statusNot_married -2.50346 0.05113 -48.966 < 2e-16 ***## marital.statusSeparated -2.16177 0.05425 -39.846 < 2e-16 ***## marital.statusWidow -2.22707 0.12522 -17.785 < 2e-16 ***## raceAsian-Pac-Islander 0.08359 0.20344 0.411 0.68117## raceBlack 0.07188 0.19330 0.372 0.71001## raceOther 0.01370 0.27695 0.049 0.96054## raceWhite 0.34830 0.18441 1.889 0.05894 .## genderMale 0.08596 0.04289 2.004 0.04506 *## hours.per.week 0.41942 0.01748 23.998 < 2e-16 ***## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1## ## (Dispersion parameter for binomial family taken to be 1)## ## Null deviance: 40601 on 36428 degrees of freedom## Residual deviance: 27041 on 36406 degrees of freedom## AIC: 27087#### Number of Fisher Scoring iterations: 6

Sammendraget av modellen vår avslører interessant informasjon. Utførelsen av en logistisk regresjon blir evaluert med spesifikke nøkkelverdier.

  • AIC (Akaike Information Criteria): Dette tilsvarer R2 i logistisk regresjon. Den måler passformen når en straff blir brukt på antall parametere. Mindre AIC- verdier indikerer at modellen er nærmere sannheten.
  • Null avvik: Passer bare til modellen med skjæringspunktet. Graden av frihet er n-1. Vi kan tolke det som en Chi-kvadratverdi (tilpasset verdi forskjellig fra den faktiske verdihypotesetesten).
  • Restavvik: Modell med alle variablene. Det tolkes også som en Chi-square-hypotesetesting.
  • Antall Fisher-poeng-iterasjoner: Antall iterasjoner før konvergerende.

Utgangen fra funksjonen glm () er lagret i en liste. Koden nedenfor viser alle elementene som er tilgjengelige i logit-variabelen vi konstruerte for å evaluere den logistiske regresjonen.

# Listen er veldig lang, skriv ut bare de tre første elementene

lapply(logit, class)[1:3]

Produksjon:

## $coefficients## [1] "numeric"#### $residuals## [1] "numeric"#### $fitted.values## [1] "numeric"

Hver verdi kan ekstraheres med $ -tegnet, og følg navnet på beregningene. For eksempel lagret du modellen som logit. For å trekke ut AIC-kriteriene bruker du:

logit$aic

Produksjon:

## [1] 27086.65

Trinn 7) Vurder ytelsen til modellen

Forvirringsmatrise

Den forvirring matrise er et bedre valg for å vurdere klassifiseringen ytelse sammenlignet med de forskjellige beregningene du har sett før. Den generelle ideen er å telle antall ganger Sanne forekomster blir klassifisert som falske.

For å beregne forvirringsmatrisen må du først ha et sett med spådommer slik at de kan sammenlignes med de faktiske målene.

predict <- predict(logit, data_test, type = 'response')# confusion matrixtable_mat <- table(data_test$income, predict > 0.5)table_mat

Kode Forklaring

  • forutsi (logit, data_test, type = 'respons'): Beregn spådommen på testsettet. Sett type = 'respons' for å beregne svarssannsynligheten.
  • tabell (data_test $ inntekt, forutsi> 0,5): Beregn forvirringsmatrisen. forutsi> 0,5 betyr at den returnerer 1 hvis de forutsagte sannsynlighetene er over 0,5, ellers 0.

Produksjon:

#### FALSE TRUE## <=50K 6310 495## >50K 1074 1229

Hver rad i en forvirringsmatrise representerer et faktisk mål, mens hver kolonne representerer et forutsagt mål. Den første raden i denne matrisen vurderer inntekten lavere enn 50k (den falske klassen): 6241 ble korrekt klassifisert som individer med inntekt lavere enn 50k ( sann negativ ), mens den gjenværende ble feil klassifisert som over 50k ( falsk positiv ). Den andre raden vurderer inntekten over 50 000, den positive klassen var 1229 ( Ekte positiv ), mens Den sanne negative var 1074.

Du kan beregne modellnøyaktigheten ved å summere den sanne positive + sanne negative over den totale observasjonen

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_Test

Kode Forklaring

  • sum (diag (table_mat)): Summen av diagonalen
  • sum (table_mat): Summen av matrisen.

Produksjon:

## [1] 0.8277339 

Modellen ser ut til å lide av ett problem, den overvurderer antall falske negativer. Dette kalles nøyaktighetstestparadokset . Vi uttalte at nøyaktigheten er forholdet mellom riktige spådommer og totalt antall saker. Vi kan ha relativt høy nøyaktighet, men en ubrukelig modell. Det skjer når det er en dominerende klasse. Hvis du ser tilbake på forvirringsmatrisen, kan du se at de fleste tilfellene er klassifisert som sanne negative. Tenk deg nå, modellen klassifiserte alle klassene som negative (dvs. lavere enn 50k). Du vil ha en nøyaktighet på 75 prosent (6718/6718 + 2257). Modellen din klarer seg bedre, men sliter med å skille det sanne positive med det sanne negative.

I en slik situasjon er det å foretrekke å ha en mer kortfattet beregning. Vi kan se på:

  • Presisjon = TP / (TP + FP)
  • Tilbakekall = TP / (TP + FN)

Presisjon vs tilbakekalling

Presisjon ser på nøyaktigheten av den positive spådommen. Tilbakekall er forholdet mellom positive forekomster som blir oppdaget riktig av klassifisereren;

Du kan konstruere to funksjoner for å beregne disse to beregningene

  1. Konstruer presisjon
precision <- function(matrix) {# True positivetp <- matrix[2, 2]# false positivefp <- matrix[1, 2]return (tp / (tp + fp))}

Kode Forklaring

  • matte [1,1]: Returner den første cellen i den første kolonnen i datarammen, dvs. den sanne positive
  • matte [1,2]; Returner den første cellen i den andre kolonnen i datarammen, dvs. den falske positive
recall <- function(matrix) {# true positivetp <- matrix[2, 2]# false positivefn <- matrix[2, 1]return (tp / (tp + fn))}

Kode Forklaring

  • matte [1,1]: Returner den første cellen i den første kolonnen i datarammen, dvs. den sanne positive
  • matte [2,1]; Returner den andre cellen i den første kolonnen i datarammen, dvs. den falske negative

Du kan teste funksjonene dine

prec <- precision(table_mat)precrec <- recall(table_mat)rec

Produksjon:

## [1] 0.712877## [2] 0.5336518

Når modellen sier at det er et individ over 50 000, er det riktig i bare 54 prosent av saken, og kan kreve individer over 50 000 i 72 prosent av saken.

Du kan lage er et harmonisk gjennomsnitt av disse to beregningene, noe som betyr at det gir mer vekt til de lavere verdiene.

f1 <- 2 * ((prec * rec) / (prec + rec))f1

Produksjon:

## [1] 0.6103799 

Presisjon mot tilbakekalling av kompromiss

Det er umulig å ha både høy presisjon og høy tilbakekalling.

Hvis vi øker presisjonen, blir det bedre forutsagt riktig individ, men vi vil savne mange av dem (lavere tilbakekalling). I noen situasjoner foretrekker vi høyere presisjon enn tilbakekalling. Det er et konkavt forhold mellom presisjon og tilbakekalling.

  • Tenk deg, du må forutsi om en pasient har en sykdom. Du vil være så presis som mulig.
  • Hvis du trenger å oppdage potensielle falske mennesker på gaten gjennom ansiktsgjenkjenning, ville det være bedre å fange mange mennesker som er merket som falske, selv om presisjonen er lav. Politiet vil være i stand til å løslate den ikke-falske personen.

ROC-kurven

Den Receiver Operating Karakteristisk kurve er en annen vanlig verktøy som benyttes sammen med binær klassifisering. Det er veldig lik presisjon / tilbakekallingskurven, men i stedet for å tegne presisjon versus tilbakekalling viser ROC-kurven den sanne positive hastigheten (dvs. tilbakekallingen) mot den falske positive hastigheten. Den falske positive frekvensen er forholdet mellom negative forekomster som er feil klassifisert som positive. Det er lik en minus den sanne negative hastigheten. Den sanne negative frekvensen kalles også spesifisitet . Derfor ROC kurven plotter følsomhet (tilbakekalling) versus 1-spesifisitet

For å plotte ROC-kurven, må vi installere et bibliotek som heter RORC. Vi finner i condabiblioteket. Du kan skrive inn koden:

conda install -cr r-rocr - ja

Vi kan plotte ROC med prediksjon () og ytelse () funksjoner.

library(ROCR)ROCRpred <- prediction(predict, data_test$income)ROCRperf <- performance(ROCRpred, 'tpr', 'fpr')plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))

Kode Forklaring

  • prediksjon (forutsi, data_test $ inntekt): ROCR-biblioteket må lage et prediksjonsobjekt for å transformere inngangsdataene
  • ytelse (ROCRpred, 'tpr', 'fpr'): Returner de to kombinasjonene som skal produseres i grafen. Her er tpr og fpr konstruert. Tot plott presisjon og husk sammen, bruk "prec", "rec".

Produksjon:

Trinn 8) Forbedre modellen

Du kan prøve å legge til ikke-linearitet i modellen med samspillet mellom

  • alder og timer. per uke
  • kjønn og timer. per uke.

Du må bruke poengtesten for å sammenligne begge modeller

formula_2 <- income~age: hours.per.week + gender: hours.per.week + .logit_2 <- glm(formula_2, data = data_train, family = 'binomial')predict_2 <- predict(logit_2, data_test, type = 'response')table_mat_2 <- table(data_test$income, predict_2 > 0.5)precision_2 <- precision(table_mat_2)recall_2 <- recall(table_mat_2)f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2))f1_2

Produksjon:

## [1] 0.6109181 

Poengsummen er litt høyere enn den forrige. Du kan fortsette å jobbe med dataene og prøve å slå poengsummen.

Sammendrag

Vi kan oppsummere funksjonen for å trene en logistisk regresjon i tabellen nedenfor:

Pakke

Objektiv

funksjon

argument

-

Lag tog / test datasett

create_train_set ()

data, størrelse, tog

glm

Tren en generalisert lineær modell

glm ()

formel, data, familie *

glm

Oppsummer modellen

sammendrag()

montert modell

utgangspunkt

Gjør spådommer

forutsi()

montert modell, datasett, type = 'respons'

utgangspunkt

Lag en forvirringsmatrise

bord()

y, forutsi ()

utgangspunkt

Lag nøyaktighetspoeng

sum (diag (tabell ()) / sum (tabell ()

ROCR

Lag ROC: Trinn 1 Lag prediksjon

prediksjon()

forutsi (), y

ROCR

Lag ROC: Trinn 2 Opprett ytelse

opptreden()

prediksjon (), 'tpr', 'fpr'

ROCR

Opprett ROC: Trinn 3 Plotgraf

plott()

opptreden()

De andre GLM- modellene er:

- binomial: (link = "logit")

- gaussisk: (link = "identitet")

- Gamma: (link = "invers")

- inverse.gaussian: (link = "1 / mu 2")

- poisson: (link = "log")

- kvasi: (link = "identitet", varians = "konstant")

- kvasibinomial: (link = "logit")

- quasipoisson: (link = "log")