데이터마이닝 05강 | R로 의사결정나무 분석하는 법 (rpart, plot, prune, predict 함수 정리)
이번 시간에는 R을 활용해 의사결정나무를 분석하는 함수와 그 사용법을 정리해보겠습니다.
특히 rpart
패키지로 나무모형을 생성하고, 예측하고, 시각화하는 법까지 자세히 알아보겠습니다.
R의사결정나무 주요 함수 정리
rpart()
함수 — 나무모형 생성
1
|
rpart(formula, data, method, control, …)
|
cs |
-
formula
: Y ~ X1 + X2 형태 -
data
: 분석할 데이터 프레임 -
method
: 분석 종류 지정-
“class” : 분류나무
-
“anova” : 회귀나무
-
-
control
: 분할 규칙 설정 (아래 함수 참고)
rpart.control()
함수 — 분할 조건 설정
1
|
rpart.control(minsplit=20, minbucket=7, cp=0.01, xval=10, maxdepth=30)
|
cs |
-
minsplit
: 노드 분할 최소 관측치 수 -
minbucket
: 최종노드 최소 관측치 수 -
cp
: 비용복잡도 벌점계수 -
xval
: 교차타당성 fold 수 -
maxdepth
: 나무 깊이 제한
cp
: 비용 복잡도 벌점계수 (Complexity Parameter)
의미
트리모델(rpart 등)에서 복잡한 모델에 패널티(벌점)를 주는 계수입니다.
트리의 복잡도와 오차의 trade-off를 조절하는 역할을 합니다.
작동 방식
분할을 하면 모델의 오차는 줄어들지만 트리는 복잡해집니다.
이때 분할로 인한 오차 감소율이 cp
값보다 크지 않으면 그 분할은 수행하지 않는다는 원리입니다.
해석
-
cp
값이 작으면 : 세부적으로 많이 나뉨 (과적합 위험↑) -
cp
값이 크면 : 트리가 단순해짐 (과소적합 위험↑)
공식
트리 모델의 성능 측정값
: 트리의 오차
: 리프노드 수 (트리 복잡도)
복잡해질수록 패널티가 커지니, 적당한 선에서 트리를 가지치기(pruning)하게 됩니다.
xval
: 교차타당성 (Cross Validation) fold 수
의미
트리모형의 일반화 성능을 평가하기 위해 데이터셋을 몇 개로 나눠 교차검증할지 결정하는 값입니다.
작동 방식
-
전체 데이터를
xval
개로 나누고 -
그 중 하나를 검증용, 나머지를 학습용으로 트리모형을 만든 다음
-
이걸
xval
번 반복해서 평균 예측오차 계산
이 과정을 통해 모델의 과적합 여부와 가지치기 지점을 정하는데 활용됩니다.
해석
-
xval=10
→ 10-fold 교차검증 -
값이 너무 작으면 신뢰도↓
-
값이 너무 크면 계산량↑
보통 xval=10
정도가 많이 쓰입니다.
printcp()
함수 — cp값 확인
1
|
printcp(model)
|
cs |
rpart()로 만든 의사결정트리 모델의 가지치기(cp) 관련 정보를 요약해서 보여주는 함수입니다.
트리 가지치기를 할 때 어떤 cp 값에서 가지를 자를지 판단하는 근거로 사용됩니다.
가장 좋은 cp 찾기
가장 작은 xerror(교차검증오차) 값을 기준으로 결정
혹은 1-SE rule을 적용해서 가장 작은 xerror + 1 표준오차 이내의 가장 단순한 모델 선택 가능.
prune()
함수 — 가지치기
1
|
prune(model, cp=0.02)
|
cs |
의사결정트리(rpart로 만든 트리)를 가지치기(pruning)하는 함수입니다.
초기에 만든 트리는 복잡하고 과적합일 수 있어서,
교차검증 결과를 바탕으로 적절한 cp 값 기준으로 불필요한 가지를 잘라내 단순하고 일반화 성능 좋은 트리로 만드는 과정!
prune(model, cp=0.02)
의미
-
model
: 이미rpart()
로 학습한 트리 모델 -
cp=0.02
: 비용 복잡도 벌점계수(cp) 값이 0.02 이상인 분할만 유지하고 나머지는 가지치기
“분할로 인한 오차 감소율이 2% 이상일 때만 그 분할을 남기고,
그 이하로 줄어드는 분할은 과적합 가능성이 높으니 없애버리겠다” 는 뜻!
-
복잡한 트리를 단순하게 정리
-
과적합 방지
-
일반화 성능 향상
결국 적당한 크기의 트리로 최적화 하기 위함입니다.
plot()
+ text()
함수 — 기본 시각화
1
2
|
plot(model)
text(model, use.n=TRUE)
|
cs |
-
plot : 나무구조 그림
-
text : 분할 조건과 노드정보 표시
predict()
함수 — 예측값 계산
1
|
predict(model, newdata, type=”class”)
|
cs |
-
type
: “class”(범주형 예측), “prob”(확률), “vector”(회귀값)
prp()
함수 — 고급 시각화
1
|
prp(model, type=2, extra=1)
|
cs |
-
type=2 : 중간노드 아래에 분할조건
-
extra=1 : 관측치 수 등 출력
rpart()로 만든 의사결정트리(recursive partitioning model)를 시각화하는 함수
기본 plot()
+ text()
조합보다 더 직관적이고 예쁜 트리를 그릴 수 있다.
중요내용 정리
-
rpart()
: 나무모형 생성 -
rpart.control()
: 분할 규칙 설정 -
printcp()
: cp값 확인 -
prune()
: 가지치기 -
plot()
+text()
: 기본 시각화 -
prp()
: 고급 시각화 -
predict()
: 예측값 산출
객관식 문제 & 해설
Q1. rpart() 함수에서 분류나무를 생성하려면 method 옵션에 입력할 값은?
① “anova”
② “class”
③ “poisson”
④ “exp”
정답: ②
해설: “class”를 지정하면 분류형 의사결정나무가 생성됩니다.
Q2. rpart.control() 함수의 cp 옵션 설명으로 맞는 것은?
① 노드 최소 관측치 수
② 가지치기 벌점계수
③ 나무 최대 깊이
④ 교차타당성 fold 수
정답: ②
해설: cp값은 오분류율이 충분히 개선되지 않으면 분할을 중단하게 하는 벌점계수입니다.
Q3. rpart로 만든 나무모형을 시각화하는 함수로 올바른 것은?
① prune()
② predict()
③ plot()
④ printcp()
정답: ③
해설: plot() 함수로 기본적인 나무모형 그림을 그릴 수 있습니다.