yedaoxiaodi 2020-05-20
概率定义:
概率定义为一件事情发生的可能性,例如,随机抛硬币,正面朝上的概率。
联合概率:
包含多个条件,且所有条件同时成立的概率,记作:??(??,??) 。
条件概率:
事件A在另外一个事件B已经发生条件下的发生概率,记作:??(??|??) 。P(A1,A2|B) = P(A1|B)P(A2|B),需要注意的是:此条件概率的成立,是由于A1,A2相互独立的结果。
公式:
其中c可以是不同类别。
公式分为三个部分:
??(??):每个文档类别的概率(某文档类别词数/总文档词数)
??(??│??):给定类别下特征(被预测文档中出现的词)的概率
计算方法:??(??1│??)=????/?? (训练文档中去计算)
????为该??1词在C类别所有文档中出现的次数
N为所属类别C下的文档所有词出现的次数和
??(??1,??2,…): 预测文档中每个词的概率
举个栗子:
现有一篇被预测文档:出现了都江宴,武汉,武松,计算属于历史,地理的类别概率?
历史:??(都江宴,武汉,武松│历史)∗P(历史)=(10/108)∗(22/108)∗(65/108)∗(108/235) =0.00563435
地理:??(都江宴,武汉,武松│地理)∗P(地理)=(58/127)∗(17/127)∗(0/127)∗(127/235)=0
拉普拉斯平滑:
思考:属于某个类别为0,合适吗?
从上面的例子我们得到地理概率为0,这是不合理的,如果词频列表里面有很多出现次数都为0,很可能计算结果都为零。
解决方法:拉普拉斯平滑系数。
sklearn朴素贝叶斯实现API:
sklearn.naive_bayes.MultinomialNB(alpha = 1.0) alpha:拉普拉斯平滑系数
案例:新闻分类
from sklearn.datasets import fetch_20newsgroups from sklearn.model_selection import train_test_split from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.naive_bayes import MultinomialNB from sklearn.metrics import classification_report news = fetch_20newsgroups(subset=‘all‘) # 进行数据分割 x_train, x_test, y_train, y_test = train_test_split(news.data, news.target, test_size=0.25) # 对数据集进行特征抽取 tf = TfidfVectorizer() # 以训练集当中的词的列表进行每篇文章重要性统计[‘a‘,‘b‘,‘c‘,‘d‘] x_train = tf.fit_transform(x_train) x_test = tf.transform(x_test) # 进行朴素贝叶斯算法的预测 mlt = MultinomialNB(alpha=1.0) print(x_train)
(0, 120993) 0.0838226531816039 (0, 36277) 0.028728297074726128 (0, 118261) 0.051733692584494416 (0, 118605) 0.08660213360333731 (0, 78914) 0.10725171098177662 (0, 120174) 0.07226288195761017 (0, 146730) 0.03649798864200877 (0, 49960) 0.09535813190987927 (0, 108029) 0.10406938034117505 (0, 151947) 0.1081016719923428 (0, 120110) 0.13513684031456163 (0, 34588) 0.06453595223748614 (0, 133893) 0.04993313285348771 (0, 31218) 0.07845873103784344 (0, 108032) 0.08430822316250115 (0, 30921) 0.11806736198114927 (0, 33267) 0.030864914635712264 (0, 36137) 0.0714722249527062 (0, 57776) 0.07110907374703304 (0, 77937) 0.026514922107534245 (0, 90944) 0.09746338158610199 (0, 135824) 0.09394365947415394 (0, 49956) 0.09183375914922258 (0, 151957) 0.07203295034824395 (0, 33356) 0.07203295034824395 : : (14133, 45099) 0.030803124311834594 (14133, 135309) 0.02305588722190138 (14133, 135472) 0.06570104508511963 (14133, 52014) 0.05222321951090842 (14133, 108029) 0.05584161408783517 (14133, 36137) 0.07670122356304401 (14133, 34063) 0.12187079805145053 (14133, 106978) 0.0851182715752145 (14133, 106534) 0.03378056586331488 (14133, 105921) 0.09707364301640503 (14133, 103839) 0.07144955527096918 (14133, 136535) 0.03801377630817533 (14133, 42966) 0.028558472354146207 (14133, 81075) 0.02180715538325887 (14133, 135641) 0.025875408277197205 (14133, 148185) 0.028450089379106706 (14133, 78894) 0.020030955308174968 (14133, 147914) 0.047259202253661425 (14133, 90152) 0.017166154294786778 (14133, 45598) 0.05645818387150284 (14133, 135325) 0.03667700550640032 (14133, 118218) 0.02343357701502816 (14133, 131632) 0.01710795977554328 (14133, 59957) 0.0485327006460036 (14133, 67480) 0.01710795977554328
mlt.fit(x_train, y_train) #MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True) y_predict = mlt.predict(x_test) print("预测的文章类别为:", y_predict) #预测的文章类别为: [ 3 16 5 ... 0 5 8] # 得出准确率 print("准确率为:", mlt.score(x_test, y_test)) #准确率为: 0.8414685908319185 print("每个类别的精确率和召回率:", classification_report(y_test, y_predict, target_names=news.target_names))
每个类别的精确率和召回率: precision recall f1-score support alt.atheism 0.89 0.75 0.81 210 comp.graphics 0.87 0.81 0.84 225 comp.os.ms-windows.misc 0.77 0.90 0.83 209 comp.sys.ibm.pc.hardware 0.77 0.78 0.78 258 comp.sys.mac.hardware 0.86 0.88 0.87 223 comp.windows.x 0.97 0.76 0.85 260 misc.forsale 0.92 0.68 0.78 233 rec.autos 0.91 0.89 0.90 263 rec.motorcycles 0.94 0.96 0.95 260 rec.sport.baseball 0.93 0.92 0.92 230 rec.sport.hockey 0.89 0.97 0.93 234 sci.crypt 0.64 0.99 0.78 235 sci.electronics 0.94 0.68 0.79 275 sci.med 0.96 0.89 0.93 241 sci.space 0.89 0.97 0.93 246 soc.religion.christian 0.56 0.99 0.72 257 talk.politics.guns 0.84 0.94 0.89 256 talk.politics.mideast 0.92 0.98 0.94 245 talk.politics.misc 0.98 0.67 0.80 182 talk.religion.misc 1.00 0.17 0.29 170 accuracy 0.84 4712 macro avg 0.87 0.83 0.83 4712 weighted avg 0.87 0.84 0.84 4712
优点:
缺点: