꺼내먹는지식 준

sklearn StratifiedKFold 용례 본문

AI/PyTorch

sklearn StratifiedKFold 용례

알 수 없는 사용자 2022. 3. 5. 17:12

sklearn StratifieldKFold 의 용례에 대해 정리해본다. 

당연히 official doc 을 보면 되지만, 생각보다 직관적이지 않다. 

 

StratifieldKFold는 기존 Fold와 뭐가 다를까? 

간단하게 설명하면 Fold는 무작정으로 데이터를 섞는다면, 

StratifieldKFold는 데이터 분포에서 적정 비율에 맞게 데이터를 나눠준다. 즉 한쪽으로 치우칠 일을 최대한으로 방지한다. 

 

즉, 분포로 나눠줄 수 있도록 y 값(데이터 분포를 볼)을 제공해주어야 한다. 

 

skf = StratifiedKFold(
    n_splits=5, shuffle=True, random_state=42
)

통상 5개로 나눈다. random_state 즉 random seed는 주로 42로 준다. 

random seed의 경우, 추후 실험을 해도 reproduce 할 수 있도록 고정이 필요하다. 

본 글이 이해가 가지 않는다면 Fold, Random Seed 에 대해 검색해서 읽고 난 후 다시 오자. 

 

for fold, (train_idx, val_idx) in enumerate(skf.split(df["id"], df["age"])):
        print("train_idx", train_idx)
        print("val_idx", val_idx)
        #train_idx [0 1 2 ... 2697 2698 2699]

        print("df.loc", df.loc[train_idx])
        train_df = df.loc[train_idx].reset_index(drop=True)
        val_df = df.loc[val_idx].reset_index(drop=True)

해당 부분 이해가 쉽지 않다. 

official doc을 살펴보면, 

split에 인자로 X, y 

그리고 X 는 (표본의 개수, feature의 개수)를 넣어줘야 한다고 되어 있다. 

하지만 다시 글을 잘 살펴보면, y값만 제공해주어도 되고, x에는 개수만 맞춰주면 된다고 되어있다. 

 

y값은 말그대로 분포로서 fold를 형성할 때 참고가 된다. 

 

이로 인해 유추해보면 X에 들어갈 수 있는 feature의 개수를 명시적으로 제공해주면 이에 맞춰 y값에서 데이터를 뽑을 때 참고하는 것 아닌가 싶다. 

 

또한 생기는 의문점은 X 에 데이터 개수만 맞춰주면 되고, y에는 분포 참고 값을 넣어준다면 도대체 fold에 정말로 들어가야 하는 data 값은 어떻게 들어가나이다. 

 

알고보니

구석에 다음과 같은 설명이 적혀있다. 

분명 return값이 뭔지 따로 작성은 안해놓고 다음의 글만 띡 적어놨다. 

 

아무튼 이를 통해 아, fold. 로 나누고 이에 맞춰서 indicies 를 return 해주는구나를 알 수 있다. 

 

다음의 결과 값을 참고해보자. 

df['id'] 는 '000001' ... '006569' 이다. len(df['id']) = 2699이다. 즉 id 중간에는 생략값들이 있어 개수대비 숫자가 크다.

df['age']는 19 ... 60 이다. 

 

fold 중 하나만 출력을 해보면 

train_idx는 

val_idx는 

다음과 같이 구성된 걸 볼 수 있다. 

 

즉 train_idx 와 val_idx가 잘 나눠졌다. 

 

그리고, 이 idx 를 기반으로 

train_df 와 val_idx를 구성하는 것이다. 

 

reset_index(drop = True)

는 index 에 작성된 값들 날려버리고 0,1,2,3,4, ... 숫자로 대채하는 것이다. 

 

drop = Fasle aus index의 값들이 column의 feature 중 하나로 대체되고, index 에는 숫자로 대체한다. 

 

 

이글과의 별개 Tip 

현재 해당 train data는 폴더명으로 구성되어 있다. 

무슨 말인고 하면, 실제로 폴더안에는 더 많은 데이터가 있다. 

하지만 폴더 내부의 데이터들끼리는 정보가 곂친다. 

이로 인해 학습시 폴더 내부의 데이터를 학습후, validation에 사용하면 당연히 overfitting 결과가 일어난다. 

이에 따라, 폴더기반으로 fold 를 설정해준 후, 추후 학습시 폴더 내부의 데이터를 불러와서 곂치는 현상이 일어나지 않도록 구현되어있다. 

Comments