How do I save a trained model in PyTorch?
Found this page on their github repo:
Recommended approach for saving a model
There are two main approaches for serializing and restoring a model.
The first (recommended) saves and loads only the model parameters:
torch.save(the_model.state_dict(), PATH)
Then later:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))The second saves and loads the entire model:
torch.save(the_model, PATH)
Then later:
the_model = torch.load(PATH)
However in this case, the serialized data is bound to the specific classes and the exact directory structure used, so it can break in various ways when used in other projects, or after some serious refactors.
See also: Save and Load the Model section from the official PyTorch tutorials.
How can i save model while training in torch
You can use the Trainer
from transformers to train the model. This trainer will also need you to specify the TrainingArguments
, which will allow you to save checkpoints of the model while training.
Some of the parameters you set when creating TrainingArguments
are:
save_strategy
: The checkpoint save strategy to adopt during training. Possible values are:- "no": No save is done during training.
- "epoch": Save is done at the end of each epoch.
- "steps": Save is done every save_steps.
save_steps
: Number of updates steps before two checkpoint saves if save_strategy="steps".save_total_limit
: If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in output_dir.load_best_model_at_end
: Whether or not to load the best model found during training at the end of training.
One important thing about load_best_model_at_end
is that when set to True, the parameter save_strategy
needs to be the same as eval_strategy
, and in the case it is “steps”, save_steps
must be a round multiple of eval_steps.
Saving PyTorch model with no access to model class code
If you plan to do inference with the Pytorch library available (i.e. Pytorch in Python, C++, or other platforms it supports) then the best way to do this is via TorchScript.
I think the simplest thing is to use trace = torch.jit.trace(model, typical_input)
and then torch.jit.save(trace, path)
. You can then load the traced model with torch.jit.load(path)
.
Here's a really simple example. We make two files:
train.py
:
import torch
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x):
x = torch.relu(self.linear(x))
return x
model = Model()
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
with torch.no_grad():
print(model(x))
traced_cell = torch.jit.trace(model, (x))
torch.jit.save(traced_cell, "model.pth")
infer.py
:
import torch
x = torch.FloatTensor([[0.2, 0.3, 0.2, 0.7], [0.4, 0.2, 0.8, 0.9]])
loaded_trace = torch.jit.load("model.pth")
with torch.no_grad():
print(loaded_trace(x))
Running these sequentially gives results:
python train.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
[0.0000, 0.5272, 0.3481, 0.1743]])
python infer.py
tensor([[0.0000, 0.1845, 0.2910, 0.2497],
[0.0000, 0.5272, 0.3481, 0.1743]])
The results are the same, so we are good. (Note that the result will be different each time here due to randomness of the initialisation of the nn.Linear layer).
TorchScript provides for much more complex architectures and graph definitions (including if statements, while loops, and more) to be saved in a single file, without needing to redefine the graph at inference time. See the docs (linked above) for more advanced possibilities.
PyTorch: saving both weights and model definition
As Pytorch provides a huge amount of flexibility in the model, it will be challenging to save the architecture along with the weights in a single file. Keras models are usually built solely by stacking keras components, but pytorch models are orchestrated by the library consumer in their own way and therefore can contain any sort of logic.
I think you have three choices:
Come up with a organised schema for your experiments so that losing the model definition is less likely. You could go for something as simple as a file named through a schema that only defines each model. I would recommend this approach as this level of organisation would likely benefit your prototyping in other ways and the overhead is minimal.
Try and save the code along with the pickle file. Although potentially possible, I think this would lead you down a rabbit-hole with a lot of potential problems.
Use a different standardised way of saving the model, such as
onnx
. I would recommend this route if you do not want to go with option 1. Onnx does allow you to save a pytorch model's architecture along with its weights but comes with a few drawbacks. For example, it only supports some operations so completely customforward
methods or use of non-matrix operations may not work.
Saving the weights of a Pytorch .pth model into a .txt or .json
When you are doing str(model.state_dict())
, it recursively uses str
method of elements it contains. So the problem is how individual element string representations are build. You should increase the limit of lines printed in individual string representation:
torch.set_printoptions(profile="full")
See the difference with this:
import torch
import torchvision.models as models
mobilenet_v2 = models.mobilenet_v2()
torch.set_printoptions(profile="default")
print(mobilenet_v2.state_dict()['features.2.conv.0.0.weight'])
torch.set_printoptions(profile="full")
print(mobilenet_v2.state_dict()['features.2.conv.0.0.weight'])
Tensors are currently not JSON serializable.
Related Topics
Windows Cmd Encoding Change Causes Python Crash
How to Set Environment Variables in Pycharm
Maximum Value for Long Integer
Replacing Few Values in a Pandas Dataframe Column with Another Value
List to Dictionary Conversion with Multiple Values Per Key
How to Make a Tkinter Window Jump to the Front
Multiprocessing: Sharing a Large Read-Only Object Between Processes
Importing an Ipynb File from Another Ipynb File
Find and Replace String Values in List
How to Verify If One List Is a Subset of Another
Python String Prints as [U'String']
How to Sandbox Python in Pure Python
Find_Element_By_* Commands Are Deprecated in Selenium
Numpy Array Assignment with Copy