Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Debug in mesh Tensorflow #235

Open
patrickvonplaten opened this issue Nov 11, 2020 · 3 comments
Open

Debug in mesh Tensorflow #235

patrickvonplaten opened this issue Nov 11, 2020 · 3 comments

Comments

@patrickvonplaten
Copy link

patrickvonplaten commented Nov 11, 2020

Hey guys,

thanks so much for releasing all the t51.1 and mt5 weights!
I'm currently working on porting all these models to huggingface's transformers.
Is there anyway to run mesh tensorflow in eager mode by any chance?

E.g. if I run the following predict command:

import t5
from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary

t5_model = t5.models.MtfModel(
    model_dir="./checkpoint",
    batch_size=16,
    sequence_length={"inputs": 128, "targets": 32},
    learning_rate_schedule=0.003,
    save_checkpoints_steps=5000,
    keep_checkpoint_max=None,
    iterations_per_loop=100,
    tpu=None
)

vocab_model_path = '<path/to/spm_vocab>'
vocab = SentencePieceVocabulary(vocab_model_path, extra_ids=100)

t5_model.predict(
    input_file="input.txt",
    output_file="output.txt",
    vocabulary=vocab,
    temperature=0
)

is there any way that I can run the prediction in eager mode so that I can print out the actual values in of the tensors? E.g. the tensor values of the input to the cross attention layer:

I had a hard time finding tests in the repo that run a small transformer network.

I'd be super happy for some pointers :-)

Also pinging @craffel in case you have any good pointers for good debugging tools :-)

@craffel
Copy link
Contributor

craffel commented Nov 11, 2020

Hey Patrick, unfortunately I believe that because the mesh tf transformer uses tf.Estimator it is not eager-friendly. In the past when we've needed to do similar things, I'm sad to say that we just used Print ops.

FWIW we will soon be releasing a JAX implementation of T5(.1.1) which should make this kind of debugging and inspection a lot easier.

@patrickvonplaten
Copy link
Author

patrickvonplaten commented Nov 11, 2020

Thanks a lot for your answer! :-) Print ops it is then

@tungvx
Copy link

tungvx commented Mar 14, 2022

How do we use Print ops? The function says "WARNING:tensorflow:Warning - mtf.Print not implemented for this mesh type"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants