A nice progess bar for trainning code with tqdm

In this blogpost we show how to add a simple progress bar in your training code.

Pytorch comes with a great ecosystem of tools and utilities you can use to wrap your training scripts. Various libraries like tnt suggest different types of wrappers that let you write only the “business” part of your training routine without worrying for more low-level staff.

I usually prefer having the flexibility of managing every aspect my code and a double loop, one for every epoch and one for every batch, is usually the way to go.

try:
    for epoch in range(numOfEpochs):
       for i_batch, sample in enumerate(self.params.dataloader):
           #trainning code

 

This kind of coding might make logging the training process more difficult, especially through the terminal. If you use just a print statement for every iteration with info related to the loss, average precision and so on, you might end up with a 5mb text file.

A preferable way is a simple progress bar for every epoch. Usually, training utilities like tnt have this functionality already in place but here’s how you can quickly manually added to your code using tqdm.

 

Just install it

pip install tqdm

and add make the following modification to your code,

from tqdm import trange
try:
    for epoch in range(numOfEpochs):
        with trange(len(self.params.dataloader)) as t:
           for i_batch, sample in enumerate(self.params.dataloader):
            #trainning code
            .
            .
            .


            t.set_postfix(loss=loss.data.tolist())
            t.update()

All and all you just added a wrapper around your batch training loop, for which you specify the size of the dataloader

with trange(len(self.params.dataloader)) as t:

Then every time you call the t.update() the bar will progress one step.

Moreover, by adding the set_postfix() you can present the current loss or every other metric of interest at the right side of the progress bar as seen below.

The input for the set_postfix() is a python dictionary so if you want to update multiple metrics you can do the following

metrics={'loss':loss,'precision':precision}

t.set_postfix(metrics)

 

 

Leave a Reply

Your email address will not be published. Required fields are marked *