Probability : Naive Bayes classifier ตัวอย่างที่ 1

 จากภาคทฤษฎี เราทราบขั้นตอนและวิธีคิดการทำ classification ด้วย maximum likelihood แล้ว [1]  เป้าหมายคือการประมาณค่าความน่าจะเป็นด้วยสมการ 


(1.0)P(yixi)=P(yi)P(xi1yi)P(xi2yi)P(xi3yi)P(xijyi)

เมื่อ xi=(xi1,xi2,xi3,...,xij)


มาตอนนี้จะว่ากันถึงตัวอย่างการใช้งานกัน โดยจะใช้ในการแบ่งกลุ่ม Iris dataset 

ในทางปฏิบัติข้อมูลที่ได้มาจากการวัดคุณลักษณะที่ปรากฎออกมาของสิ่งมีชีวิต ตามธรรมชาติมีลักษณะเป็นข้อมูลแบบต่อเนื่อง เป็นไปได้ว่าในข้อมูลธรรมชาติอาจไม่ปรากฏในตัวอย่างที่จัดเก็บมา เมื่อลองทำเอาข้อมูลมาทำ histogram ดูก็จะพบว่ามีลักษณะการแจกแจงที่พอจะอนุมานว่าเป็น normal distribution [2]



โดยค่าความน่าจะเป็นของแต่ละ features x_i จะหาได้จากสมการ  


(1.1)p(x)=e12xμσσ2π


เมื่อ μ ค่ากลาง (mean,mode, media) และ σ คือ standard deviation

 จัดเตรียมข้อมูล


import numpy as np
import pandas as pd
from scipy.stats import norm
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt

load dataset


#load data
iris_data = load_iris()

# preview
print(iris_data.data)              # Names of the columns
print(iris_data.feature_names)     # Target variable
print(iris_data.target)            # Target names
print(iris_data.target_names)      # name of target variable

create dataframe


# convert to pandas dataframe
iris_df = pd.DataFrame(iris_data.data, 
                       columns=iris_data.feature_names)

# there would be smarter way, but this is easier
cls = {'class':[]}
for i in range(150):
    if i < 50 :
        cls['class'].append('setosa')
    elif i < 100 :
        cls['class'].append('versicolor')
    else :
        cls['class'].append('verginica')
iris_df['class'] = cls['class']

print(iris_df) # prints labeled data

shuffle data และ split data 


# shuffle data
iris_df = iris_df.sample(frac=1).reset_index(drop=True)

t = int(0.8 * len(iris_df))
train_data = iris_df.loc[:t,:]
test_data = iris_df.loc[t:,:]

# preview
print(train_data)



describe data

# group by class
iris_desc = train_data.groupby("class").describe() 

#preview
print(iris_desc)


หาค่า P(y) ของแต่ละ class 

# get sample count
N = iris_desc['sepal length (cm)']['count'].sum()
seto_n = iris_desc['sepal length (cm)']['count']['setosa']
vers_n = iris_desc['sepal length (cm)']['count']['versicolor']
verg_n = iris_desc['sepal length (cm)']['count']['verginica']

#get probability

prob_seto = seto_n/N # P(y1)
prob_vers = vers_n/N # P(y2)
prob_verg = verg_n/N # P(y3)

#preview 
print(N,seto_n,vers_n,verg_n)
print(prob_seto,prob_verg,prob_vers)

มาถึงขั้นตอนนี้ เราก็มี parameters พื้นฐานครบในการที่จะเริ่มคำนวณตามสมการที่ (1.6) และ (1.7) แล้ว






ส่วนที่ 2 : ทดสอบด้วย test data  



ขั้นตอนการทำงานคือ 
1. ข้อมูลจะถูกดึงจาก dataset มาทีละรายการ แต่ละรายการประกอบค่าของ features ทั้ง 4 (x1,x2,x3,x4)
2. นำค่า x1,x2,x3,x4 ไปหาค่าความน่าจะเป็นด้วยสมการ (1.1) ด้วย parameter ที่ได้จากขั้นตอนแรก 
3. หาค่าความน่าจะเป็น P(y^=setosax),P(y^=versicolorx),P(y^=verginicax) ตามสมการ (1.0)
4. นำค่าที่ได้จากข้อ 3 เปรียบเทียบกัน y^ ใน class ใดที่ให้ค่าความน่าจะเป็นสูงสุดจะถือเอา class นั้นเป็นค่าพยากรณ์


# define probability density function as (1.1)
def norm_pdf(X,mu,sigma) : 
     nominator = np.exp( - (X - mu)**2 / (2 * sigma**2) )
     denominator = (sigma * np.sqrt(2 * np.pi)) 
     return nominator / denominator



feat_list = ['sepal length (cm)','sepal width (cm)','petal length (cm)','petal width (cm)']
correct = 0
history = []
N = len(test_data)
for i in range(N):
    # get target class    
    label = test_data.iloc[i]['class']
    

    # initial prob
    probs = {'setosa':prob_seto,'versicolor':prob_vers,'verginica':prob_verg}
    
    for c,py in probs.items():
        p = py
        for f in feat_list:
            x = test_data.iloc[i][f]
            mu = iris_desc[f]['mean'][c]
            sigma = iris_desc[f]['std'][c]
            p *= norm_pdf(x,mu,sigma)
        probs[c] = p
    y_hat = max(probs, key=probs.get)    
    if label == y_hat :
        correct += 1
    probs['label']=label
    history.append(probs)    
 
ค่าของ history 


Item : 1
P(y='setosa')=0.000000
P(y='versicolor')=0.101163
P(y='verginica')=0.000002
Predict : versicolor, Label : versicolor, Correct => True
-----
Item : 2
P(y='setosa')=0.000000
P(y='versicolor')=0.000095
P(y='verginica')=0.052913
Predict : verginica, Label : verginica, Correct => True
-----
Item : 3
P(y='setosa')=0.001799
P(y='versicolor')=0.000000
P(y='verginica')=0.000000
Predict : setosa, Label : setosa, Correct => True
-----
Item : 4
P(y='setosa')=1.369032
P(y='versicolor')=0.000000
P(y='verginica')=0.000000
Predict : setosa, Label : setosa, Correct => True
-----
Item : 5
P(y='setosa')=0.000000
P(y='versicolor')=0.006312
P(y='verginica')=0.062817
Predict : verginica, Label : verginica, Correct => True
-----
Item : 6
P(y='setosa')=0.000000
P(y='versicolor')=0.364024
P(y='verginica')=0.000578
Predict : versicolor, Label : versicolor, Correct => True
-----
Item : 7
P(y='setosa')=0.000000
P(y='versicolor')=0.000099
P(y='verginica')=0.044583
Predict : verginica, Label : verginica, Correct => True
-----
Item : 8
P(y='setosa')=0.079939
P(y='versicolor')=0.000000
P(y='verginica')=0.000000
Predict : setosa, Label : setosa, Correct => True
-----
Item : 9
P(y='setosa')=0.000000
P(y='versicolor')=0.002316
P(y='verginica')=0.047282
Predict : verginica, Label : verginica, Correct => True
-----
Item : 10
P(y='setosa')=0.068473
P(y='versicolor')=0.000000
P(y='verginica')=0.000000
Predict : setosa, Label : setosa, Correct => True
-----
Item : 11
P(y='setosa')=0.000000
P(y='versicolor')=0.370740
P(y='verginica')=0.000135
Predict : versicolor, Label : versicolor, Correct => True
-----
Item : 12
P(y='setosa')=0.000000
P(y='versicolor')=0.000388
P(y='verginica')=0.000000
Predict : versicolor, Label : versicolor, Correct => True
-----
Item : 13
P(y='setosa')=0.000000
P(y='versicolor')=0.024761
P(y='verginica')=0.012653
Predict : versicolor, Label : versicolor, Correct => True
-----
Item : 14
P(y='setosa')=0.000000
P(y='versicolor')=0.168323
P(y='verginica')=0.002193
Predict : versicolor, Label : versicolor, Correct => True
-----
Item : 15
P(y='setosa')=0.000000
P(y='versicolor')=0.000001
P(y='verginica')=0.150882
Predict : verginica, Label : verginica, Correct => True
-----
Item : 16
P(y='setosa')=0.000000
P(y='versicolor')=0.348247
P(y='verginica')=0.000188
Predict : versicolor, Label : versicolor, Correct => True
-----
Item : 17
P(y='setosa')=0.000000
P(y='versicolor')=0.054217
P(y='verginica')=0.000001
Predict : versicolor, Label : versicolor, Correct => True
-----
Item : 18
P(y='setosa')=0.000000
P(y='versicolor')=0.000001
P(y='verginica')=0.132567
Predict : verginica, Label : verginica, Correct => True
-----
Item : 19
P(y='setosa')=0.000000
P(y='versicolor')=0.000000
P(y='verginica')=0.050565
Predict : verginica, Label : verginica, Correct => True
-----
Item : 20
P(y='setosa')=0.000000
P(y='versicolor')=0.110623
P(y='verginica')=0.000024
Predict : versicolor, Label : versicolor, Correct => True
-----
Item : 21
P(y='setosa')=0.000000
P(y='versicolor')=0.105628
P(y='verginica')=0.000004
Predict : versicolor, Label : versicolor, Correct => True
-----
Item : 22
P(y='setosa')=0.000000
P(y='versicolor')=0.159142
P(y='verginica')=0.001874
Predict : versicolor, Label : versicolor, Correct => True
-----
Item : 23
P(y='setosa')=3.601586
P(y='versicolor')=0.000000
P(y='verginica')=0.000000
Predict : setosa, Label : setosa, Correct => True
-----
Item : 24
P(y='setosa')=0.000000
P(y='versicolor')=0.248433
P(y='verginica')=0.003330
Predict : versicolor, Label : versicolor, Correct => True
-----
Item : 25
P(y='setosa')=0.000000
P(y='versicolor')=0.017088
P(y='verginica')=0.046130
Predict : verginica, Label : verginica, Correct => True
-----
Item : 26
P(y='setosa')=0.000000
P(y='versicolor')=0.287213
P(y='verginica')=0.000667
Predict : versicolor, Label : versicolor, Correct => True
-----
Item : 27
P(y='setosa')=0.000000
P(y='versicolor')=0.120046
P(y='verginica')=0.001342
Predict : versicolor, Label : versicolor, Correct => True
-----
Item : 28
P(y='setosa')=0.510845
P(y='versicolor')=0.000000
P(y='verginica')=0.000000
Predict : setosa, Label : setosa, Correct => True
-----
Item : 29
P(y='setosa')=0.000000
P(y='versicolor')=0.000000
P(y='verginica')=0.112886
Predict : verginica, Label : verginica, Correct => True
-----
Item : 30
P(y='setosa')=0.000000
P(y='versicolor')=0.000000
P(y='verginica')=0.168677
Predict : verginica, Label : verginica, Correct => True
-----


ผลการทดสอบกับข้อมูลทดสอบจำนวน 30 รายการออกมาพยากรณ์ได้ถูกทุกรายการ น่าสนใจมากนะ

เอกสารอ้างอิง

[1] https://smarter-machine.blogspot.com/2020/11/probability-naive-bayes-classifier_20.html

[2] https://en.wikipedia.org/wiki/Normal_distribution

ความคิดเห็น