Hi, I am trying to implement total variation function for tensor or in more accurate, multichannel images. I found that for above Total Variation (in picture), there is source code like this:

def compute_total_variation_loss(img, weight): tv_h = ((img[:,:,1:,:] - img[:,:,:-1,:]).pow(2)).sum() tv_w = ((img[:,:,:,1:] - img[:,:,:,:-1]).pow(2)).sum() return weight * (tv_h + tv_w)

Since, I am very beginner in python I didn’t understood how the indices are referred to i and j in image. I also want to add total variation for c (besides i and j) but I don’t know which index refers to c.

Or to be more concise, how to write following equation in python: enter image description here

## Answer

This function assumes batched images. So `img`

is a 4 dimensional tensor of dimensions `(B, C, H, W)`

(`B`

is the number of images in the batch, `C`

the number of color channels, `H`

the height and `W`

the width).

So, `img[0, 1, 2, 3]`

is the pixel `(2, 3)`

of the second color (green in RGB) in the first image.

In Python (and Numpy and PyTorch), a **slice** of elements can be selected with the notation `i:j`

, meaning that the elements `i, i + 1, i + 2, ..., j - 1`

are selected. In your example, `:`

means *all elements*, `1:`

means *all elements but the first* and `:-1`

means *all elements but the last* (negative indices retrieves the elements backward). Please refer to tutorials on “slicing in NumPy”.

So `img[:,:,1:,:] - img[:,:,:-1,:]`

is equivalent to the (batch of) images minus themselves shifted by one pixel vertically, or, in your notation `X(i + 1, j, k) - X(i, j, k)`

. Then the tensor is squared (`.pow(2)`

) and summed (`.sum()`

). Note that the sum is also over the batch in this case, so you receive the total variation of the batch, not of each images.