From 536cc7df5a8f0d2b4f3f324cb269e4c0c4a9baef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 23 Jan 2024 17:29:53 +0100 Subject: [PATCH] Add samplers documentation --- docs/reference/samplers.md | 73 ++++++++++++++++++++++++++++++++++++++ mkdocs.yml | 1 + 2 files changed, 74 insertions(+) create mode 100644 docs/reference/samplers.md diff --git a/docs/reference/samplers.md b/docs/reference/samplers.md new file mode 100644 index 000000000..3c4a4d27f --- /dev/null +++ b/docs/reference/samplers.md @@ -0,0 +1,73 @@ +# Samplers + +## Multinomial sampling + +Outlines defaults to the multinomial sampler without top-p or top-k sampling, and temperature equal to 1. Not specifying a sampler is equivalent to: + +```python +from outlines import models, generate, samplers + + +model = models.transformers("mistralai/Mistral-7B-0.1") +sampler = samplers.multinomial() + +generator = generate.text(model, sampler) +answer = generator("What is 2+2?") + +print(answer) +# 4 +``` + +You can ask the generator to take multiple samples by passing the number of samples when initializing the sampler: + +```python +from outlines import models, generate, samplers + + +model = models.transformers("mistralai/Mistral-7B-0.1") +sampler = samplers.multinomial(3) + +generator = generate.text(model, sampler) +answer = generator("What is 2+2?") + +print(answer) +# [4, 4, 4] +``` + +If you ask multiple samples for a batch of prompt the returned array will be of shape `(num_samples, num_batches)`: + +```python +from outlines import models, generate, samplers + + +model = models.transformers("mistralai/Mistral-7B-0.1") +sampler = samplers.multinomial(3) + +generator = generate.text(model, sampler) +answer = generator(["What is 2+2?", "What is 3+3?"]) + +print(answer) +# [[4, 4, 4], [6, 6, 6]] +``` + + +## Greedy sampler + +You can also use the greedy sampler. For this you need to initialize the generator with the sampler: + + +```python +from outlines import models, generate, samplers + + +model = models.transformers("mistralai/Mistral-7B-0.1") +sampler = samplers.greedy() + +generator = generate.text(model, sampler) +answer = generator("What is 2+2?") + +print(answer) +# 4 +``` + +You cannot ask for multiple samples with the greedy sampler since it does not clear what the result should be. diff --git a/mkdocs.yml b/mkdocs.yml index 23079148d..560e0ed93 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -122,6 +122,7 @@ nav: - Grammar: reference/cfg.md - Regex: reference/regex.md - Types: reference/types.md + - Samplers: reference/samplers.md - Utilities: - Serve with vLLM: reference/vllm.md - Prompt templating: reference/prompting.md