How to Save a Trained Model in Pytorch

How do I save a trained model in PyTorch?

Found this page on their github repo:

Recommended approach for saving a model

There are two main approaches for serializing and restoring a model.

The first (recommended) saves and loads only the model parameters:

torch.save(the_model.state_dict(), PATH)

Then later:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

The second saves and loads the entire model:

torch.save(the_model, PATH)

Then later:

the_model = torch.load(PATH)

However in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.


See also: Save and Load the Model section from the official PyTorch tutorials.

How can i save model while training in torch

You can use the Trainer from transformers to train the model. This trainer will also need you to specify the TrainingArguments, which will allow you to save checkpoints of the model while training.

Some of the parameters you set when creating TrainingArguments are:

  • save_strategy: The checkpoint save strategy to adopt during training. Possible values are:
    • "no": No save is done during training.
    • "epoch": Save is done at the end of each epoch.
    • "steps": Save is done every save_steps.
  • save_steps: Number of updates steps before two checkpoint saves if save_strategy="steps".
  • save_total_limit: If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in output_dir.
  • load_best_model_at_end: Whether or not to load the best model found during training at the end of training.

One important thing about load_best_model_at_end is that when set to True, the parameter save_strategy needs to be the same as eval_strategy, and in the case it is “steps”, save_steps must be a round multiple of eval_steps.

Saving PyTorch model with no access to model class code

If you plan to do inference with the Pytorch library available (i.e. Pytorch in Python, C++, or other platforms it supports) then the best way to do this is via TorchScript.

I think the simplest thing is to use trace = torch.jit.trace(model, typical_input) and then torch.jit.save(trace, path). You can then load the traced model with torch.jit.load(path).

Here's a really simple example. We make two files:

train.py :

import torch

class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = torch.nn.Linear(4, 4)

def forward(self, x):
x = torch.relu(self.linear(x))
return x

model = Model()
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
with torch.no_grad():
print(model(x))
traced_cell = torch.jit.trace(model, (x))
torch.jit.save(traced_cell, "model.pth")

infer.py :

import torch
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
loaded_trace = torch.jit.load("model.pth")
with torch.no_grad():
print(loaded_trace(x))

Running these sequentially gives results:

python train.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
[0.0000, 0.5272, 0.3481, 0.1743]])

python infer.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
[0.0000, 0.5272, 0.3481, 0.1743]])

The results are the same, so we are good. (Note that the result will be different each time here due to randomness of the initialisation of the nn.Linear layer).

TorchScript provides for much more complex architectures and graph definitions (including if statements, while loops, and more) to be saved in a single file, without needing to redefine the graph at inference time. See the docs (linked above) for more advanced possibilities.

PyTorch: saving both weights and model definition

As Pytorch provides a huge amount of flexibility in the model, it will be challenging to save the architecture along with the weights in a single file. Keras models are usually built solely by stacking keras components, but pytorch models are orchestrated by the library consumer in their own way and therefore can contain any sort of logic.

I think you have three choices:

  1. Come up with a organised schema for your experiments so that losing the model definition is less likely. You could go for something as simple as a file named through a schema that only defines each model. I would recommend this approach as this level of organisation would likely benefit your prototyping in other ways and the overhead is minimal.

  2. Try and save the code along with the pickle file. Although potentially possible, I think this would lead you down a rabbit-hole with a lot of potential problems.

  3. Use a different standardised way of saving the model, such as onnx. I would recommend this route if you do not want to go with option 1. Onnx does allow you to save a pytorch model's architecture along with its weights but comes with a few drawbacks. For example, it only supports some operations so completely custom forward methods or use of non-matrix operations may not work.

Saving the weights of a Pytorch .pth model into a .txt or .json

When you are doing str(model.state_dict()), it recursively uses str method of elements it contains. So the problem is how individual element string representations are build. You should increase the limit of lines printed in individual string representation:

torch.set_printoptions(profile="full")

See the difference with this:

import torch
import torchvision.models as models
mobilenet_v2 = models.mobilenet_v2()
torch.set_printoptions(profile="default")
print(mobilenet_v2.state_dict()['features.2.conv.0.0.weight'])
torch.set_printoptions(profile="full")
print(mobilenet_v2.state_dict()['features.2.conv.0.0.weight'])

Tensors are currently not JSON serializable.



Related Topics



Leave a reply



Submit