Exporting models¶
After training an agent, you may want to deploy/use it in an other language or framework, like PyTorch or tensorflowjs. Stable Baselines does not include tools to export models to other frameworks, but this document aims to cover parts that are required for exporting along with more detailed stories from users of Stable Baselines.
Background¶
In Stable Baselines, the controller is stored inside policies which convert
observations into actions. Each learning algorithm (e.g. DQN, A2C, SAC) contains
one or more policies, some of which are only used for training. An easy way to find
the policy is to check the code for the predict
function of the agent:
This function should only call one policy with simple arguments.
Policies hold the necessary Tensorflow placeholders and tensors to do the inference (i.e. predict actions), so it is enough to export these policies to do inference in an another framework.
Note
Learning algorithms also may contain other Tensorflow placeholders, that are used for training only and are not required for inference.
Warning
When using CNN policies, the observation is normalized internally (dividing by 255 to have values in [0, 1])
Export to PyTorch¶
A known working solution is to use get_parameters
function to obtain model parameters, construct the network manually in PyTorch and assign parameters correctly.
Warning
PyTorch and Tensorflow have internal differences with e.g. 2D convolutions (see discussion linked below).
See discussion #372 for details.
Export to tensorflowjs / tfjs¶
Can be done via Tensorflow’s simple_save function and tensorflowjs_converter.
See discussion #474 for details.
Manual export¶
You can also manually export required parameters (weights) and construct the network in your desired framework, as done with the PyTorch example above.
You can access parameters of the model via agents’
get_parameters
function. If you use default policies, you can find the architecture of the networks in
source for policies. Otherwise, for DQN/SAC/DDPG or TD3 you need to check the policies.py file located
in their respective folders.