Blog

Practice makes perfect.

Linear Discriminant Analysis in Python

Zhijun / 2022-11-07


Load library

from turtle import shape
import pandas as pd
import numpy as np
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.model_selection import cross_val_score
from sklearn.datasets import make_classification
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis 
import matplotlib.pyplot as plt
from numpy import mean, unicode_
from numpy import std

Load the data

oil_data.head()
##    Number  Group Countries  808.5395  ...  1881.3715  1883.301  1885.2305   1887.16
## 0       1      1    Greece  0.103753  ...   0.003744  0.003461   0.002774  0.002253
## 1       1      1    Greece  0.100083  ...   0.004165  0.002345   0.000598 -0.000700
## 2       2      1    Greece  0.098488  ...   0.003454  0.003874   0.003860  0.002053
## 3       2      1    Greece  0.097094  ...   0.004128  0.002877   0.001644  0.000653
## 4       3      1    Greece  0.098733  ...   0.002996  0.001952   0.001568  0.001747
## 
## [5 rows x 563 columns]

define dataset

X = oil_data.values[:,3:]
y = oil_data.values[:,2]

Summarize the dataset

print(X.shape, y.shape)
## (120, 560) (120,)

LDA model

model = LinearDiscriminantAnalysis()
# define model evaluation method
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1)
# evaluate model
scores = cross_val_score(model, X, y, scoring='accuracy', cv=cv)
# summarize result
print('Mean Accuracy: %.3f (%.3f)' % (mean(scores), std(scores)))
## Mean Accuracy: 0.975 (0.049)

We can see that the model performed a mean accuracy of 97.50%.

Define data to plot

lda = LinearDiscriminantAnalysis()
data_plot = lda.fit(X, y).transform(X)
target_names = oil_data['Countries'].astype(str).unique()

# Plot
plt.figure()
colors = ['red', 'green', 'blue', 'orange']
lw = 3
for color, i, target_name in zip(colors, ['Greece', 'Italy', 'Portugal', 'Spain'], target_names):
    plt.scatter(
        data_plot[y == i, 0], data_plot[y == i, 1], alpha=0.8, color=color, label=target_name
    )
plt.legend(loc="best", shadow=False, scatterpoints=1)
plt.title("LDA of oil dataset")

plt.show()