Get order of class labels for Keras predict function
noobalert, to get the top 2 predictions, as you requested to the Matias Valdenegro 's question in the comments section, you can do the following code:
prediction1 = model.predict(your_data)
# sorting the predictions in descending order
sorting = (-prediction1).argsort()
# getting the top 2 predictions
sorted_ = sorting[0][:2]
for value in sorted_:
# you can get your classes from the encoder(your_classes = encoder.classes_)
# or from a dictionary that you created before.
# And then we access them with the predicted index.
predicted_label = your_classes[value]
# just some rounding steps
prob = (prediction1[0][value]) * 100
prob = "%.2f" % round(prob,2)
print("I have %s%% sure that it belongs to %s." % (prob, predicted_label)
Keras: how to get predicted labels for more than two classes
For a multi-class classification problem with k labels, you can retrieve the index of the predicted classes by using model.predict_classes()
. Toy example:
import keras
import numpy as np
# Simpel model, 3 output nodes
model = keras.Sequential()
model.add(keras.layers.Dense(3, input_shape=(10,), activation='softmax'))
# 10 random input data points
x = np.random.rand(10, 10)
model.predict_classes(x)
> array([1, 1, 2, 1, 2, 1, 2, 1, 1, 1])
If you have the labels in a list, you can use the predicted classes to get the predicted labels:
labels = ['label1', 'label2', 'label3']
[labels[i] for i in model.predict_classes(x)]
> ['label2', 'label2', 'label3', 'label2', 'label3', 'label2', 'label3', 'label2', 'label2', 'label2']
Under the hood, model.predict_classes
return the index of the maximum predicted class probability for each row in the predictions:
model.predict_classes(x)
> array([1, 1, 2, 1, 2, 1, 2, 1, 1, 1])
model.predict(x).argmax(axis=-1) # same thing
> array([1, 1, 2, 1, 2, 1, 2, 1, 1, 1])
Attaching class labels to a Keras model
So I tried my hand at a solution myself and this seems to work. I was hoping for something simpler though.
Opening the model file a second time is not really optimal I think. If anyone can do better, by all means, do.
import h5py
from keras.models import load_model
from keras.models import save_model
def load_model_ext(filepath, custom_objects=None):
model = load_model(filepath, custom_objects=None)
f = h5py.File(filepath, mode='r')
meta_data = None
if 'my_meta_data' in f.attrs:
meta_data = f.attrs.get('my_meta_data')
f.close()
return model, meta_data
def save_model_ext(model, filepath, overwrite=True, meta_data=None):
save_model(model, filepath, overwrite)
if meta_data is not None:
f = h5py.File(filepath, mode='a')
f.attrs['my_meta_data'] = meta_data
f.close()
Since h5 files do not accept python containers, you should consider converting the meta data into a string. Assuming that your meta data exists in the form of a dictionary or a list, you can use json to do the conversion. This would also allow you to store more complex data structures within your model.
Full usage example:
import json
import keras
# prepare model and label lookup
model = keras.Sequential();
model.add(keras.layers.Dense(10, input_dim=8, activation='relu'));
model.add(keras.layers.Dense(3, activation='softmax'))
model.compile()
filepath = r".\mymodel.h5"
labels = ["dog", "cat", "automobile"]
# save
labels_string = json.dumps(labels)
save_model_ext(model, filepath, meta_data=labels_string)
# load
loaded_model, loaded_labels_string = load_model_ext(filepath)
loaded_labels = json.loads(loaded_labels_string)
# label of class 0: "dog"
print(loaded_labels[0])
If you prefer to have a dictionary for your classes, be aware that json will convert numeric dictionary keys to strings, so you will have to convert them back to numbers after loading.
Related Topics
Django Model Field Default Based Off Another Field in Same Model
Keep Persistent Variables in Memory Between Runs of Python Script
Iso to Datetime Object: 'Z' Is a Bad Directive
How to Left Align a Fixed Width String
How to Get a List of Column Names in SQLite
Anyone Know of a Good Python Based Web Crawler That I Could Use
Matrix Multiplication in Pure Python
Error: Pg_Config Executable Not Found When Installing Psycopg2 on Alpine in Docker
When Should I Be Using Classes in Python
How to Read the Contents of an Url with Python
How to Replace (Or Strip) an Extension from a Filename in Python
Python How to Read N Number of Lines at a Time
How to Load/Edit/Run/Save Text Files (.Py) into an Ipython Notebook Cell
Typeerror: Cannot Create a Consistent Method Resolution Order (Mro)