Top > Python

Python:Pandas+Scikit-learnによるデータ分析・可視化 w/ Jupyter



データ分析でよく使われている”Scikit-learn”(機械学習ライブラリ)と、その結果をmatplotlib+seaborn+highchartsで可視化する例です。
スマホから見ると、Outputエリアが見にくくなっているようです。ご容赦ください。

In [1]:
%matplotlib inline

# 必要なライブラリはpip install よろしくです。
import pandas as pd
import numpy as np
from sklearn import linear_model
from sklearn import cluster
from sklearn import tree
from sklearn.metrics import accuracy_score,classification_report

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
import pydotplus
from sklearn.externals.six import StringIO
from IPython.display import Image
from pandas_highcharts.display import display_charts

data = "http://aima.cs.berkeley.edu/data/iris.csv"
df = pd.read_csv(data, index_col=None, header=None)

names = ['Sepal.Length', 'Sepal.Width', 'Petal.Length', 'Petal.Width', 'Species']
df.columns = names
df.head()
Out[1]:
Sepal.Length Sepal.Width Petal.Length Petal.Width Species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa
3 4.6 3.1 1.5 0.2 setosa
4 5.0 3.6 1.4 0.2 setosa

以下分析・検定例

分析内容自体は適当ですのであしからず。主にscikit-learnを使います。

1.相関分析

In [2]:
# 単相関
corr = df['Sepal.Length'].corr(df['Petal.Length'])
print(corr)

# 以下散布図作成
plt.scatter(df['Sepal.Length'], df['Petal.Length'], color="orange")
plt.xlabel('Sepal.Length')
plt.ylabel('Petal.Length')
plt.show()
0.871754157305
In [3]:
# 総相関
mask = np.zeros_like(df.corr(), dtype=np.bool)
mask[np.triu_indices_from(mask)] = True
plt.subplots(figsize=(7, 6))
sns.set_context("talk")
ax = sns.heatmap(df.corr(), vmin=-1, vmax=1, mask=mask,
                 square=True, annot=True, linewidths=.5, fmt='g')
plt.show()

2.単回帰分析

個人的には係数類がちゃんと出力されるOrangeライブラリが好きだったりします

In [4]:
X, Y = df[['Petal.Length']], df[['Petal.Width']]
LinerRegr = linear_model.LinearRegression()
LinerRegr.fit(X, Y)
print("- R^2: %s" %LinerRegr.score(X, Y))
print("- 回帰式: y = %s*x %s" %(round(float(LinerRegr.coef_), 5), round(float(LinerRegr.intercept_), 5)))
px = np.arange(X.min(),X.max(), 0.1)[:,np.newaxis]
py = LinerRegr.predict(px)
plt.subplots(figsize=(7, 6))
plt.plot(px,py,color="blue", linewidth=3)
plt.scatter(X, Y, color="black")
plt.xlabel(X.columns[0]);plt.ylabel(Y.columns[0])
plt.show()
- R^2: 0.926901227922
- 回帰式: y = 0.41642*x -0.36651

3.クラスター分析

In [5]:
x, y = df[['Sepal.Length', 'Sepal.Width']], df['Species']

km = cluster.KMeans(n_clusters=3).fit(x)
c = km.predict(x)
plt.subplots(figsize=(7, 6))
plt.scatter(x.iloc[:,0], x.iloc[:,1], c=c, s=30, linewidths=0, cmap=plt.cm.jet)
plt.xlabel(x.iloc[:,0].name);plt.ylabel(x.iloc[:,1].name)
plt.show()

4.ロジスティック回帰分析による分類

In [6]:
LogRegr = linear_model.LogisticRegression(C=1e5)
LogRegr.fit(x, y)
print(" - accuracy")
print(classification_report(y, LogRegr.predict(x)))
 - accuracy
             precision    recall  f1-score   support

     setosa       1.00      1.00      1.00        50
 versicolor       0.72      0.68      0.70        50
  virginica       0.70      0.74      0.72        50

avg / total       0.81      0.81      0.81       150

5.決定木分析

In [7]:
dt = tree.DecisionTreeClassifier(criterion="entropy", max_depth=2, random_state=0).fit(x, y)
print("- accuracy")
print(classification_report(y, dt.predict(x)))
- accuracy
             precision    recall  f1-score   support

     setosa       0.94      0.98      0.96        50
 versicolor       0.83      0.20      0.32        50
  virginica       0.55      0.94      0.69        50

avg / total       0.77      0.71      0.66       150

In [8]:
dot_data = StringIO()
tree.export_graphviz(dt, out_file=dot_data, feature_names=x.columns, filled=True)
graph1 = pydotplus.graph_from_dot_data(dot_data.getvalue())
Image(graph1.create_png())
Out[8]:

以下、可視化例

1.ヴァイオリン

In [9]:
plt.subplots(figsize=(7, 4))
sns.violinplot(df, palette="Set2")
plt.show()

2.ヒストグラム

In [10]:
df['Petal.Length'].hist(normed=False, alpha=.3, color='b', label='Petal')
df['Sepal.Length'].hist(normed=False, alpha=.3, color='r', label='Sepal')
plt.xlabel('count');plt.ylabel('length');plt.legend();plt.show()

3.Highcharts連携

Highchartsを知らない方はコチラ→http://www.highcharts.com/
PandasとHighchartsの連携→https://github.com/gtnx/pandas-highcharts

In [11]:
display_charts(df, kind="line", title="line-chart")
In [12]:
df_pivot = df.pivot_table(['Sepal.Length', 'Petal.Length'],
                index=['Species'],
                columns=None, aggfunc='mean', margins=True)

display_charts(df_pivot, kind="bar", title="bar-chart")


Ads by Google