About Keras models
There are two main types of models available in Keras: the Sequential model, and the Model class used with the functional API.
These models have a number of methods and attributes in common:
model.layersis a flattened list of the layers comprising the model.model.inputsis the list of input tensors of the model.model.outputsis the list of output tensors of the model.model.summary()prints a summary representation of your model. For layers with multiple outputs,multipleis displayed instead of each individual output shape due to size limitations. Shortcut for utils.print_summarymodel.get_config()returns a dictionary containing the configuration of the model. The model can be reinstantiated from its config via:
config = model.get_config()
model = Model.from_config(config)
# or, for Sequential:
model = Sequential.from_config(config)
model.get_weights()returns a list of all weight tensors in the model, as Numpy arrays.model.set_weights(weights)sets the values of the weights of the model, from a list of Numpy arrays. The arrays in the list should have the same shape as those returned byget_weights().model.to_json()returns a representation of the model as a JSON string. Note that the representation does not include the weights, only the architecture. You can reinstantiate the same model (with reinitialized weights) from the JSON string via:
from keras.models import model_from_json
json_string = model.to_json()
model = model_from_json(json_string)
model.to_yaml()returns a representation of the model as a YAML string. Note that the representation does not include the weights, only the architecture. You can reinstantiate the same model (with reinitialized weights) from the YAML string via:
from keras.models import model_from_yaml
yaml_string = model.to_yaml()
model = model_from_yaml(yaml_string)
model.save_weights(filepath)saves the weights of the model as a HDF5 file.model.load_weights(filepath, by_name=False)loads the weights of the model from a HDF5 file (created bysave_weights). By default, the architecture is expected to be unchanged. To load weights into a different architecture (with some layers in common), useby_name=Trueto load only those layers with the same name.
Note: Please also see How can I install HDF5 or h5py to save my models in Keras? in the FAQ for instructions on how to install h5py.
Model subclassing
In addition to these two types of models, you may create your own fully-customizable models by subclassing the Model class
and implementing your own forward pass in the call method (the Model subclassing API was introduced in Keras 2.2.0).
Here's an example of a simple multi-layer perceptron model written as a Model subclass:
import keras
class SimpleMLP(keras.Model):
def __init__(self, use_bn=False, use_dp=False, num_classes=10):
super(SimpleMLP, self).__init__(name='mlp')
self.use_bn = use_bn
self.use_dp = use_dp
self.num_classes = num_classes
self.dense1 = keras.layers.Dense(32, activation='relu')
self.dense2 = keras.layers.Dense(num_classes, activation='softmax')
if self.use_dp:
self.dp = keras.layers.Dropout(0.5)
if self.use_bn:
self.bn = keras.layers.BatchNormalization(axis=-1)
def call(self, inputs):
x = self.dense1(inputs)
if self.use_dp:
x = self.dp(x)
if self.use_bn:
x = self.bn(x)
return self.dense2(x)
model = SimpleMLP()
model.compile(...)
model.fit(...)
Layers are defined in __init__(self, ...), and the forward pass is specified in call(self, inputs). In call, you may specify custom losses by calling self.add_loss(loss_tensor) (like you would in a custom layer).
In subclassed models, the model's topology is defined as Python code (rather than as a static graph of layers). That means the model's topology cannot be inspected or serialized. As a result, the following methods and attributes are not available for subclassed models:
model.inputsandmodel.outputs.model.to_yaml()andmodel.to_json()model.get_config()andmodel.save().
Key point: use the right API for the job. The Model subclassing API can provide you with greater flexbility for implementing complex models,
but it comes at a cost (in addition to these missing features):
it is more verbose, more complex, and has more opportunities for user errors. If possible, prefer using the functional API, which is more user-friendly.