Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch Pool1d + Squeeze bug, with cpp source code insights #1054

Closed
4 tasks done
sei-jgwohlbier opened this issue Aug 21, 2024 · 6 comments
Closed
4 tasks done

PyTorch Pool1d + Squeeze bug, with cpp source code insights #1054

sei-jgwohlbier opened this issue Aug 21, 2024 · 6 comments
Labels

Comments

@sei-jgwohlbier
Copy link

sei-jgwohlbier commented Aug 21, 2024

Prerequisites

Please make sure to check off these prerequisites before submitting a bug report.

  • Test that the bug appears on the current version of the master branch. Make sure to include the commit hash of the commit you checked out.
  • Check that the issue hasn't already been reported, by checking the currently open issues.
  • If there are steps to reproduce the problem, make sure to write them down below.
  • If relevant, please include the hls4ml project files, which were created directly before and/or after the bug.

Quick summary

For a simple PyTorch model with a pool1d followed by a squeeze the generated code is incorrect. The correct values are present in the execution, but they are not stored correctly.

Details

Steps to Reproduce

  1. Clone the hls4ml repository
  2. Checkout the master branch, with commit hash: [2898ab2]
  3. Run test_pool_squeeze.py provided below.

Expected behavior

Expected passing test.

Actual behavior

The test fails. Details of generated source code behavior is provided below.

Optional

This is the test file test_pool_squeeze.py

from pathlib import Path

import numpy as np
import os
import torch
import torch.nn as nn

from hls4ml.converters import convert_from_pytorch_model
from hls4ml.utils.config import config_from_pytorch_model

test_root_path = Path(__file__).parent

if __name__ == "__main__":

    class test(nn.Module):
        def __init__(self, size_in, momentum=0.2):
            super().__init__()
            self.avgpool = nn.AvgPool1d(size_in)
            self.relu = nn.ReLU()

        def forward(self, x):
            z = self.avgpool(x)
            z = torch.squeeze(z)
            z = self.relu(z)
            return z

    n_in = 2
    n_out = 1
    size_in = 4
    n_batch = 3

    model = test(size_in)
    model.eval()
    print(model)

    X_input = np.random.rand(n_batch, n_in, size_in)
    with torch.no_grad():
        pytorch_prediction = model(torch.Tensor(X_input))

    config = config_from_pytorch_model(model,
                                       channels_last_conversion='internal',
                                       transpose_outputs=False)
    config['Model']['Strategy'] = 'Resource'
    config['Model']['Precision'] = 'ap_fixed<64,24>'
    print(config)

    backend='Vivado'
    output_dir = str(test_root_path / f'hls4mlprj_pool_squeeze_{backend}_io_stream')
    hls_model = convert_from_pytorch_model(
        model,
        (None, n_in, size_in),
        hls_config=config,
        output_dir=output_dir,
        backend=backend,
        io_type='io_stream',
    )
    print(list(hls_model.get_layers()))
    hls_model.compile()

    # X_input_hls is channels last
    X_input_hls = np.ascontiguousarray(X_input.transpose(0, 2, 1))
    # write tb data
    ipf = output_dir + "/tb_data/tb_input_features.dat"
    if os.path.isfile(ipf):
        os.remove(ipf)
    np.savetxt(ipf, X_input_hls.flatten(), newline=" ")
    hls_prediction = hls_model.predict(X_input_hls)

    print("pytorch_prediction")
    print(pytorch_prediction)
    # write tb data
    opf = output_dir + "/tb_data/tb_output_predictions.dat"
    if os.path.isfile(opf):
        os.remove(opf)
    with open(opf, "ab") as f:
        for p in pytorch_prediction:
            np.savetxt(f, p.flatten(), newline=" ")
    print("hls_prediction")
    print(hls_prediction)

    rtol = 1.0e-5
    atol = 5.0e-2
    for p, h in zip(pytorch_prediction, hls_prediction):
        np.testing.assert_allclose(p,
                                   h,
                                   rtol=rtol, atol=atol)

There is a problem with the generated myproject.cpp file. I have added some print statements to try to assess what is going on. The print statements are included below.

#include <iostream>

#include "myproject.h"
#include "parameters.h"


void myproject(
    hls::stream<input_t> &x,
    hls::stream<result_t> &layer4_out
) {

    // hls-fpga-machine-learning insert IO
    #pragma HLS INTERFACE axis port=x,layer4_out
    #pragma HLS DATAFLOW

    // hls-fpga-machine-learning insert load weights
#ifndef __SYNTHESIS__
    static bool loaded_weights = false;
    if (!loaded_weights) {
        loaded_weights = true;    }
#endif
    // ****************************************
    // NETWORK INSTANTIATION
    // ****************************************

    // hls-fpga-machine-learning insert layers

    // inputs. channels last so add every other value to check results.
    std::cerr << "input\n";
    nnet::print_result<layer2_t,8>
        (x, std::cerr, /*keep=*/true);

    hls::stream<layer2_t> layer2_out("layer2_out");
    #pragma HLS STREAM variable=layer2_out depth=1
    nnet::pooling1d_cl<input_t, layer2_t, config2>(x, layer2_out); // avgpool

    // same data size with averages per channel repeating.
    std::cerr << "layer2\n";
    nnet::print_result<layer2_t,2>
        (layer2_out, std::cerr, /*keep=*/true);

    std::cerr << "CONFIG_T::height " << config5::height << std::endl;
    std::cerr << "CONFIG_T::width " << config5::width << std::endl;

    hls::stream<layer5_t> layer5_out("layer5_out");
    #pragma HLS STREAM variable=layer5_out depth=2
    nnet::transpose_2d<layer2_t, layer5_t, config5>(layer2_out, layer5_out); // transpose_input_for_squeeze

    // this looks right only when the type is 1*1 defined
    std::cerr << "layer5\n";
    nnet::print_result<layer5_t,8>
        (layer5_out, std::cerr, /*keep=*/true);

    auto& layer3_out = layer5_out;

    nnet::relu<layer5_t, result_t, ReLU_config4>(layer3_out, layer4_out); // relu

}

Below is the output of running the compiled myproject_test.cpp.

Processing input 0
input
0.723548 0.842939 0.875767 0.878541 0.299715 0.416514 0.539979 0.0951735 
layer2
0.609752 0.558292 
CONFIG_T::height 1
CONFIG_T::width 2
layer5
0.609752 0.558292 0.609752 0.558292 0.609752 0.558292 0.609752 0.558292 
WARNING: Hls::stream 'layer5_out' contains leftover data, which may result in RTL simulation hanging.
Predictions
0.609752 0.558292 
Quantized predictions
0.609752 0.609752 
INFO: Saved inference results to file: tb_data/csim_results.log

Note the WARNING about layer5_out containing leftover data. More below.

The first line is the input, two channels of length 4 in channels last format.
layer2 is the output of the nnet::pooling1d_cl. I have confirmed the two values shown in layer2 are the correct averages accounting for channels last. I.e.,

0.609752 = avg(0.723548 0.875767 0.299715 0.539979)
0.558292 = avg(0.842939 0.878541 0.416514 0.0951735)

Indeed, these are the two values that should appear as "predictions," but the data is not being correctly extracted.

Other notes:

  • It seems like the definition of layer5_t in defines.h is not correct. Based on the expected output I think it should be typedef nnet::array<ap_fixed<64,24>, 2*1> layer5_t;, (2*1 not 1*1), but when I make that change the results are not correct. The layer5_out warning does go away in this case.
  • I have been trying to correct the source, but have yet been unsuccessful. If you have suggestions I would be happy to try them and work up a PR if I can track it back to hls4ml source.
@JanFSchulte
Copy link
Contributor

Hi @sei-jgwohlbier,

thanks for reporting this and the detailed studies you did. I was able to pinpoint the issue to a problem with the conversion to channels_last for the case of streamed inputs. Basically, before reshape layers, we insert a transpose to undo the conversion of the input data to channels_last so that the reshape operation can be done correctly. In this case, this messes up the dimensions of the output of the squeeze layer because there actually wasn't any transpose to do undo.

From what I can tell, this can be fixed by changing this line https://github.com/fastmachinelearning/hls4ml/blob/main/hls4ml/model/optimizer/passes/convert_to_channels_last.py#L97

from if isinstance(node, Reshape) and len(node.attributes['target_shape']) == 1: to if isinstance(node, Reshape) and len(node.attributes['target_shape']) == 1 and not model.config.config['HLSConfig']['Model']['ChannelsLastConversion'] == "internal":.

The one thing I can't quite understand is why there aren't any problems with the io_parallel case. There the result is identical with and without this additional transpose. So before making a PR with a fix, I would like to hear what @vloncar thinks about this.

@sei-jgwohlbier
Copy link
Author

Awesome, thanks very much. Indeed my test passes with your fix. Let me know if you need me to test anything.

@sei-jgwohlbier
Copy link
Author

@JanFSchulte do you know if anyone else has looked at this? I continue to use your patch in my version.
Thanks.
jgw

@JanFSchulte
Copy link
Contributor

Hi! Ah, no, this hasn't been fixed in the main branch yet. I'll get around to making a PR in the next few days. Thanks for reminding me!

@JanFSchulte
Copy link
Contributor

I ended up including the fix in this PR that also deals with fixes to the pytorch parser: #1086

@sei-jgwohlbier
Copy link
Author

Sorry, forgot to close. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants