Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"logits must be 2-dimensional" error on TF 1.9 #15

Open
alpsholic opened this issue Apr 17, 2020 · 3 comments
Open

"logits must be 2-dimensional" error on TF 1.9 #15

alpsholic opened this issue Apr 17, 2020 · 3 comments

Comments

@alpsholic
Copy link

I am getting the following exception on TF 1.9 when I load a saved Bert model on JAVA. I saved the model on Python using easy-bert and loaded in Java again with easy-bert. (https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1).

I am able to load fine but when I try to extract embeddings, it throws the following. I have to be on TF 1.9 only.

020-04-16 18:00:31.003625: I tensorflow/cc/saved_model/loader.cc:291] SavedModel load for tags { serve }; Status: success. Took 1197645 microseconds.
Exception in thread "main" java.lang.IllegalArgumentException: logits must be 2-dimensional
[[Node: module_apply_tokens/bert/encoder/layer_0/attention/self/Softmax = SoftmaxT=DT_FLOAT, _output_shapes=[[?,12,?,?]], _device="/job:localhost/replica:0/task:0/device:CPU:0"]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:298)
at org.tensorflow.Session$Runner.run(Session.java:248)
at com.robrua.nlp.bert.Bert.embedSequence(Bert.java:252)
at .bert.TestEasyBert.main(TestEasyBert.java:17)
@robrua
Copy link
Owner

robrua commented Apr 25, 2020

First thing is I want to figure out if this is a version compat issue or a general bug:

Is this a problem with TF 1.9 specifically? Has it worked on previous versions but now you need to run it on 1.9 for some other reason?

If it's only on 1.9, did you bump the TF version on both the Python & Java ends?

There's a decent chance this is just fragility in how the input/output nodes in the graph are getting passed between python & java.

@robrua
Copy link
Owner

robrua commented Apr 25, 2020

If you can leave some more specifics about how you're obtaining, saving, and loading the model that'd be helpful in figuring out what's wrong

@alpsholic
Copy link
Author

I didn't try for lower versions of TF. It works fine from TF 1.11 onwards but not on TF 1.9. I used your simple code to save model on Python side (I printed TF version). Loaded on Java side and tried to extract embeddings again using your sample code for Java (I printed TF version here as well).
Later I tried with models using your pom entries as well (i.e., without saving-in-python-then-loading-in-java).

I got this exception on both occassions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants