#KI-Werkbank¶
Trennschärfe bei der Klassifizierung¶
Worum es hier geht¶
Ich habe in meinen letzten Blogs Einstieg in neuronale Netze mit Keras zwei verschiedene Architekturen neuronaler Netzte auf die Sammlung handschriftlicher Ziffern von MNIST losgelassen. Ich möchte hier in diesem Mini Blog einen interessanten Unterschied im Ergebnis der beiden Modelle aufzeigen, nämlich die Trennschärfe
Was ist die Trennschärfe einer Klassifikation?¶
Mit der Trennschärfe möchte ich ein Maß dafür dafür definieren, wie sicher sich ein neuronales Netz bei seinen Entscheidungen ist. Ich wir brauchen so ein Qualitätskriterium, um Netze miteinander vergleichen zu können.
Ich lade dazu die beiden Modelle aus den letzten Versuchen ein und vergleiche sie miteinander um ein griffiges Maß für die Trennschärfe abzuleiten.
Neuronalen Netze sagen ihrer Antwort nicht so: „Du hast mir eine Zwei gezeigt“ sondern sie sagen: „Mit einer Wahrscheinlichkeit von 1% ist es eine Null, mit einer Wahrscheinlichkeit von 4% ist es eine Eins, mit einer Wahrscheinlichkeit von 91% ist es eine Zwei, mit …“. Oft ist sich das Modell nicht ganz sicher, um was es sich bei der Zahl handelt. Dann kann der Maximalwert vielleicht auch nur 21% sein.
Boilerplate Code¶
#tensorflow
import tensorflow as tf
from tensorflow.python.client import device_lib
#keras
import keras
from keras.backend.tensorflow_backend import set_session
from keras.datasets import mnist
from keras import models
from keras.models import model_from_json
#numpy
import numpy as np
#matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
#os
import os
config = tf.ConfigProto()
config.gpu_options.allow_growth = True # dynamically grow the memory used on the GPU
config.log_device_placement = True # to log device placement (on which device the operation ran)
# (nothing gets printed in Jupyter, only if you run it standalone)
sess = tf.Session(config=config)
set_session(sess)
Die trainierten Modelle laden¶
main_dir = os.path.normpath(os.getcwd() + os.sep + os.pardir)
mod_dir = os.path.join(main_dir,'train_model', 'persist_model')
def get_model(model_file, model_weights):
# load json and create model
with open(model_file, 'r', encoding='utf-8') as json_file:
loaded_model_json = json_file.read()
model = model_from_json(loaded_model_json)
# load weights into new model
model.load_weights(model_weights)
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
return model
simple_model = get_model(os.path.join(mod_dir, 'simple_MNIST_model.json'), os.path.join(mod_dir, 'simple_MNIST_model.h5'))
cnn_model = get_model(os.path.join(mod_dir, 'convolutional_MNIST_model.json'), os.path.join(mod_dir, 'convolutional_MNIST_model.h5'))
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= reshape_1 (Reshape) (None, 784) 0 _________________________________________________________________ dense_1 (Dense) (None, 324) 254340 _________________________________________________________________ dense_2 (Dense) (None, 36) 11700 _________________________________________________________________ dense_3 (Dense) (None, 10) 370 ================================================================= Total params: 266,410 Trainable params: 266,410 Non-trainable params: 0 _________________________________________________________________ _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d_17 (Conv2D) (None, 28, 28, 32) 320 _________________________________________________________________ max_pooling2d_7 (MaxPooling2 (None, 14, 14, 32) 0 _________________________________________________________________ conv2d_18 (Conv2D) (None, 14, 14, 32) 9248 _________________________________________________________________ max_pooling2d_8 (MaxPooling2 (None, 7, 7, 32) 0 _________________________________________________________________ conv2d_19 (Conv2D) (None, 7, 7, 32) 9248 _________________________________________________________________ max_pooling2d_9 (MaxPooling2 (None, 3, 3, 32) 0 _________________________________________________________________ conv2d_20 (Conv2D) (None, 3, 3, 16) 4624 _________________________________________________________________ dropout_11 (Dropout) (None, 3, 3, 16) 0 _________________________________________________________________ flatten_5 (Flatten) (None, 144) 0 _________________________________________________________________ dropout_12 (Dropout) (None, 144) 0 _________________________________________________________________ dense_9 (Dense) (None, 128) 18560 _________________________________________________________________ dropout_13 (Dropout) (None, 128) 0 _________________________________________________________________ dense_10 (Dense) (None, 10) 1290 ================================================================= Total params: 43,290 Trainable params: 43,290 Non-trainable params: 0 _________________________________________________________________
Die Testdaten laden und die Modelle befragen¶
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
simple_results = simple_model.predict(test_images)
cnn_results = cnn_model.predict(test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255)
Wie eindeutig sind die Ergebnisse¶
Zuerst werde ich einen naiven Blick auf die Histogramme der beiden Modelle, getrennt nach richtigen und falschen Vorhersagen.
def show_hist(results, decs):
_x = []
_x_wrong = []
for index, result in enumerate(results):
predict = result.argmax()
expect = test_labels[index]
if predict == expect:
_x.append(np.amax(result))
else:
_x_wrong.append(np.amax(result))
x = np.array(_x)
x_wrong = np.array(_x_wrong)
fig, axes = plt.subplots(nrows = 1 , ncols = 3, figsize=(15, 3))
ax0, ax1, ax2 = axes.flatten()
ax0.hist(x * 100, facecolor='g')
ax0.set_title('Right Predictions')
ax0.set_xlabel('prediction rate in percent')
ax0.set_ylabel('amount of prediction')
ax0.axis([0, 100, 0, 10000])
ax1.hist(x * 100, facecolor='g')
ax1.set_title('Right Predictions - Zoom')
ax1.set_xlabel('prediction rate in percent')
ax1.set_ylabel('amount of prediction')
ax1.axis([0, 100, 0, 400])
ax2.hist(x_wrong * 100, facecolor='b')
ax2.set_title('Wrong Predictions')
ax2.set_xlabel('prediction rate in percent')
ax2.set_ylabel('amount of prediction')
ax2.axis([0, 100, 0, 80])
fig.tight_layout()
plt.show()
print(decs)
show_hist(simple_results, "Abb 1: Trennschärfe der Vorhersage bei einem Fully Connected Neuron Network")
show_hist(cnn_results, "Abb 2: Trennschärfe der Vorhersage bei einem Convolutional Neuron Network")
Abb 1: Trennschärfe der Vorhersage bei einem Fully Connected Neuron Network
Abb 2: Trennschärfe der Vorhersage bei einem Convolutional Neuron Network
Nach rechts ist die Genauigkeit der Vorhersage aufgetragen. Die 100 bedeutet also, dass sich das Modell mit einer Sicherheit von 100% auf eine bestimmte Ziffer festgelegt hat.
Nach oben habe ich die totale Anzahl der Vorhersagen aufgetragen. Für das CNN in Abb 2 sieht man zum Beispiel, dass 10.000 Vorhersagen getroffen wurden, bei dem sich das Modell zu 100% sicher war.
Beim jeweils zweiten Histogramm habe ich ein wenig in die untere rechte Ecke hineingezoomt, um hervorzuheben, wo die Unschärfen liegen. Man erkennt sehr deutlich, dass sich das CNN mit seinen Vorhersagen viel sicherer ist, als das einfache NN.
Beim jeweils dritten Histogramm habe ich das gleiche Spiel für die falsch vorhergesagten Ziffern aufgetragen. Dummerweise ist sich hier das CNN auch viel sicherer mit seiner Vorhersage.
Eine eindeutige Maßeinheit¶
Es stellt sich nun die Frage, wie man möglichst griffig die Trennschärfe eines Modells angibt. Ich denke, der einfache Mittelwert minus des kleinsten Mittelwertes auf [0,1] skaliert ist ein gutes Maß dafür und zwar aus folgenden einfachen Gründen:
- Die Summe aller Wahrscheinlichkeiten pro Vorhersage ist immer 100%
- Ist eine Vorhersage also 100%, sind alle anderen Wahrscheinlichkeiten 0%
- Die kleinste Vorhersage muss immer ein klein wenig größer als 100 % geteilt durch die Anzahl der Möglichkeiten sein (in unserem Fall 100% / 10 verschiedene Ziffern = 10% plus ein kleines Delta)
- Eine Trennschärfe von 100% heißt dann, dass sich das Modell absolut sicher ist
- Eine Trennschärfe von 0% heißt, dass es alles oder nichts sein kann.
Somit ergeben sich für die beiden Modelle folgende Trennschärfen:
def show_hist(results, desc):
_x = []
for index, result in enumerate(results):
predict = result.argmax()
expect = test_labels[index]
if predict == expect:
_x.append(np.amax(result))
x = np.mean(np.array(_x) * 100 - 100 / 10) * 100 / 90
print(desc.format(x))
show_hist(simple_results, "Die Trennschärfe des Fully Connected Neuron Networkfür richtige Vorhersagen ist: {:.2f}%.\n")
show_hist(cnn_results, "Die Trennschärfe des Convolutional Neuron Network für richtige Vorhersagen: {:.2f}%.\n")
Die Trennschärfe des Fully Connected Neuron Networkfür richtige Vorhersagen ist: 97.33%. Die Trennschärfe des Convolutional Neuron Network für richtige Vorhersagen: 99.69%.
Schluss¶
Die Convolutions im CNN haben für den Fully Connecte Teil die Freiheitsgrade stark eingeschränkt. Den Erfolg dieser Übung erkennt man nicht nur in den besseren Testergebnissen, sondern auch in der Trennschärfe. Damit ist er ein wichtiges Maß für die Richtigkeit der gewählten Netzarchitektur.