Last active
December 8, 2024 16:10
-
-
Save vvolhejn/e265665c65d3df37e381316bf57b8421 to your computer and use it in GitHub Desktop.
Convert PyTorch convolutional layer to fully connected layer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| The function `torch_conv_layer_to_affine` takes a `torch.nn.Conv2d` layer `conv` | |
| and produces an equivalent `torch.nn.Linear` layer `fc`. | |
| Specifically, this means that the following holds for `x` of a valid shape: | |
| torch.flatten(conv(x)) == fc(torch.flatten(x)) | |
| Or equivalently: | |
| conv(x) == fc(torch.flatten(x)).reshape(conv(x).shape) | |
| allowing of course for some floating-point error. | |
| """ | |
| from typing import Tuple | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| def torch_conv_layer_to_affine( | |
| conv: torch.nn.Conv2d, input_size: Tuple[int, int] | |
| ) -> torch.nn.Linear: | |
| w, h = input_size | |
| # Formula from the Torch docs: | |
| # https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html | |
| output_size = [ | |
| (input_size[i] + 2 * conv.padding[i] - conv.kernel_size[i]) // conv.stride[i] | |
| + 1 | |
| for i in [0, 1] | |
| ] | |
| in_shape = (conv.in_channels, w, h) | |
| out_shape = (conv.out_channels, output_size[0], output_size[1]) | |
| fc = nn.Linear(in_features=np.product(in_shape), out_features=np.product(out_shape)) | |
| fc.weight.data.fill_(0.0) | |
| # Output coordinates | |
| for xo, yo in range2d(output_size[0], output_size[1]): | |
| # The upper-left corner of the filter in the input tensor | |
| xi0 = -conv.padding[0] + conv.stride[0] * xo | |
| yi0 = -conv.padding[1] + conv.stride[1] * yo | |
| # Position within the filter | |
| for xd, yd in range2d(conv.kernel_size[0], conv.kernel_size[1]): | |
| # Output channel | |
| for co in range(conv.out_channels): | |
| fc.bias[enc_tuple((co, xo, yo), out_shape)] = conv.bias[co] | |
| for ci in range(conv.in_channels): | |
| # Make sure we are within the input image (and not in the padding) | |
| if 0 <= xi0 + xd < w and 0 <= yi0 + yd < h: | |
| cw = conv.weight[co, ci, xd, yd] | |
| # Flatten the weight position to 1d in "canonical ordering", | |
| # i.e. guaranteeing that: | |
| # FC(img.reshape(-1)) == Conv(img).reshape(-1) | |
| fc.weight[ | |
| enc_tuple((co, xo, yo), out_shape), | |
| enc_tuple((ci, xi0 + xd, yi0 + yd), in_shape), | |
| ] = cw | |
| return fc | |
| def range2d(to_a, to_b): | |
| for a in range(to_a): | |
| for b in range(to_b): | |
| yield a, b | |
| def enc_tuple(tup: Tuple, shape: Tuple) -> int: | |
| res = 0 | |
| coef = 1 | |
| for i in reversed(range(len(shape))): | |
| assert tup[i] < shape[i] | |
| res += coef * tup[i] | |
| coef *= shape[i] | |
| return res | |
| def dec_tuple(x: int, shape: Tuple) -> Tuple: | |
| res = [] | |
| for i in reversed(range(len(shape))): | |
| res.append(x % shape[i]) | |
| x //= shape[i] | |
| return tuple(reversed(res)) | |
| def test_tuple_encoding(): | |
| x = enc_tuple((3, 2, 1), (5, 6, 7)) | |
| assert dec_tuple(x, (5, 6, 7)) == (3, 2, 1) | |
| print("Tuple encoding ok") | |
| def test_layer_conversion(): | |
| for stride in [1, 2]: | |
| for padding in [0, 1, 2]: | |
| for filter_size in [3, 4]: | |
| img = torch.rand((1, 2, 6, 7)) | |
| conv = nn.Conv2d(2, 5, filter_size, stride=stride, padding=padding) | |
| fc = torch_conv_layer_to_affine(conv, img.shape[2:]) | |
| # Also checks that our encoding flattens the inputs/outputs such that | |
| # FC(flatten(img)) == flatten(Conv(img)) | |
| res1 = fc(img.reshape((-1))).reshape(conv(img).shape) | |
| res2 = conv(img) | |
| worst_error = (res1 - res2).max() | |
| print("Output shape", res2.shape, "Worst error: ", float(worst_error)) | |
| assert worst_error <= 1.0e-6 | |
| print("Layer conversion ok") | |
| if __name__ == "__main__": | |
| test_tuple_encoding() | |
| test_layer_conversion() |
Author
@PadLex Glad it helped someone four years later (and someone who's also from ETH lol), thanks for the fix!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for script! :)
When I run it with a modern version of pytorch I get a numpy deprecation warning and one of those gradient update errors. Both can be fixed pretty easily: https://gist.github.com/PadLex/236d8178db45d950c5d4e93899fa608a