Visualize your Pytorch Model. Don’t waist GPU time.

It feels very common to spend a significant amount of time training something only to find out afterward that there is a bug in the model architecture. I dont know if you can relate to that experience. The following small script can help to quickly visualize every model that I include in my training scripts.

 

A very useful little utility called pytorchviz can come to your aid to overcome this situation and extend your debugging routines. As the github page says, A small package to create visualizations of PyTorch execution graphs and traces. It can provide with a beautiful visualization of your network, with the addition of information related to the name, parameters, and inputs of each module

As you can see the first part of this network takes as input a 4 channel (10,3,16,16) tensor, that passes through a convolutional block, following a batch norm and so on.

To be able to quickly call this utility to your aid you can just add a small function on a Helper module that you have available on your code base.

    def SaveVariableGraph(inputTensor,network):
        from pytorchviz.dot import make_dot
        graph = make_dot(inputTensor,params=dict(network.named_parameters()))
        graph.save(filename='/home/psxdm5/DebugImages/graph.gv')

It is important that for the pytorchviz library to work you should pass an input tensor along with the model. This input tensor is passed through the network, and the graph is created through the forward pass.

As I usually work remotely, I prefer to save all my visualizations to file and then open them in my local machine through bash scripts. To open a window visualizing this graph

from graphviz import Source
path = '/Users/malldimi1/Desktop/DebugImages/graph.gv'
s = Source.from_file(path)
s.view()

 

Leave a Reply

Your email address will not be published. Required fields are marked *