diff --git a/outlines/models/mlxlm.py b/outlines/models/mlxlm.py index 380fb0150..deeb42d59 100644 --- a/outlines/models/mlxlm.py +++ b/outlines/models/mlxlm.py @@ -1,202 +1,125 @@ -import dataclasses -from typing import TYPE_CHECKING, Generator, Iterator, List, Optional, Tuple, Union +from functools import singledispatchmethod +from typing import TYPE_CHECKING -from .transformers import TransformerTokenizer +from outlines.models.base import Model, ModelTypeAdapter +from outlines.models.transformers import TransformerTokenizer if TYPE_CHECKING: - import mlx.core as mx import mlx.nn as nn from transformers import PreTrainedTokenizer - from outlines.generate.api import GenerationParameters, SamplingParameters - from outlines.processors import OutlinesLogitsProcessor - __all__ = ["MLXLM", "from_mlxlm"] -class MLXLM: - """ - Represents an `mlx_lm` model - """ +class MLXLMTypeAdapter(ModelTypeAdapter): + """Type adapter for `mlx_lm` models.""" + + @singledispatchmethod + def format_input(self, model_input): + """Generate the prompt argument to pass to the model. + + Argument + -------- + model_input + The input passed by the user. + + """ + raise NotImplementedError( + f"The input type {input} is not available. " + "The `mlx_lm` library does not support batch inference." + ) + + @format_input.register(str) + def format_str_input(self, model_input: str): + return model_input + + def format_output_type(self, output_type): + """Generate the logits processor argument to pass to the model. + + Argument + -------- + output_type + The logits processor provided. + + """ + if not output_type: + return None + return [output_type] + + +class MLXLM(Model): + """Represents an `mlx_lm` model.""" def __init__( self, model: "nn.Module", tokenizer: "PreTrainedTokenizer", ): - self.model = model - self.mlx_tokenizer = tokenizer # returns mlx tensors, used for encode() - self.tokenizer = TransformerTokenizer( - tokenizer._tokenizer - ) # _tokenizer is HF Tokenizer + """Create a MLXLM model instance. - def generate( - self, - prompts: Union[str, List[str]], - generation_parameters: "GenerationParameters", - logits_processor, - sampling_parameters: "SamplingParameters", - ) -> str: - streamer = self.stream( - prompts, generation_parameters, logits_processor, sampling_parameters - ) - return "".join(list(streamer)) + Arguments + --------- + model + An instance of an mlx-lm model. + tokenizer + An instance of an mlx-lm tokenizer or of a compatible + transformers tokenizer. - def stream( - self, - prompts: Union[str, List[str]], - generation_parameters: "GenerationParameters", - logits_processor, - sampling_parameters: "SamplingParameters", - ) -> Iterator[str]: - """Generate text using `mlx_lm`. - - Parameters - ---------- - prompts - A prompt or list of prompts. - generation_parameters - An instance of `GenerationParameters` that contains the prompt, - the maximum number of tokens, stop sequences and seed. All the - arguments to `SequenceGeneratorAdapter`'s `__cal__` method. - logits_processor - The logits processor to use when generating text. - sampling_parameters - An instance of `SamplingParameters`, a dataclass that contains - the name of the sampler to use and related parameters as available - in Outlines. - - Returns - ------- - The generated text. """ - import mlx.core as mx + self.model = model + # self.mlx_tokenizer is used by the mlx-lm in its generate function + self.mlx_tokenizer = tokenizer + # self.tokenizer is used by the logits processor + self.tokenizer = TransformerTokenizer(tokenizer._tokenizer) + self.type_adapter = MLXLMTypeAdapter() + + def generate(self, model_input, output_type, **kwargs): + """Generate text using `mlx-lm`. + + Arguments + --------- + model_input + The prompt passed by the user. + output_type + The logits processor provided. + kwargs + The inference kwargs that can be passed to the `mlx-lm` library. - max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters) - sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple( - sampling_parameters + """ + from mlx_lm import generate + + return generate( + self.model, + self.mlx_tokenizer, + self.type_adapter.format_input(model_input), + logits_processors=self.type_adapter.format_output_type(output_type), + **kwargs, ) - if max_tokens is None: - max_tokens = int(1e9) - - if not isinstance(prompts, str): - raise NotImplementedError( - "The `mlx-lm` library does not support batch inference." - ) - if sampler == "beam_search": - raise NotImplementedError( - "The `mlx-lm` library does not support Beam Search." - ) - if num_samples != 1: - raise NotImplementedError( - "The `mlx-lm` library does not allow to take several samples." - ) - if top_k is not None: - raise NotImplementedError("The `mlx-lm` library does not support top_k.") - if seed is not None: - raise NotImplementedError("The `mlx-lm` library does not support seed.") - if stop_at is not None: - raise NotImplementedError("The `mlx-lm` library does not support stop_at.") - - generate_kwargs = { - "temp": temperature, - "top_p": top_p, - "sampler": sampler, - "logits_processor": logits_processor, - } - - # Adapted from - # https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L267 - prompt_tokens = mx.array(self.mlx_tokenizer.encode(prompts)) - - detokenizer = self.mlx_tokenizer.detokenizer - detokenizer.reset() - - for (token, prob), n in zip( - self.generate_step(prompt_tokens, **generate_kwargs), - range(max_tokens), - ): - if token == self.tokenizer.eos_token_id: - break - detokenizer.add_token(token) - yield detokenizer.last_segment - detokenizer.finalize() - yield detokenizer.last_segment + def stream(self, model_input, output_type, **kwargs): + """Stream text using `mlx-lm`. + + Arguments + --------- + model_input + The prompt passed by the user. + output_type + The logits processor provided. + kwargs + The inference kwargs that can be passed to the `mlx-lm` library. - def generate_step( - self, - prompt: "mx.array", - temp: Optional[float], - top_p: Optional[float], - sampler: str, - logits_processor: "OutlinesLogitsProcessor", - ) -> Generator[Tuple[int, float], None, None]: - """ - Adapted from - https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L129 - - A generator producing token ids based on the given prompt from the model. - - Parameters - ---------- - prompt - The input prompt. - temp - The temperature for sampling, if 0 the argmax is used. - top_p - Nulceus sampling, higher means model considers more less likely words. - sampler - The sampler string defined by SequenceGeneratorAdapter - logits_processor - Augment logits before sampling. """ - import mlx.core as mx - import mlx_lm - - temperature: float = temp or 1.0 - - def sample(logits: "mx.array") -> Tuple["mx.array", float]: - softmax_logits = mx.softmax(logits) - - if temperature == 0.0 or sampler == "greedy": - token = mx.argmax(logits, axis=-1) - elif sampler == "multinomial": - if top_p is not None and top_p > 0 and top_p < 1.0: - token = mlx_lm.sample_utils.top_p_sampling( - logits, top_p, temperature - ) - else: - token = mx.random.categorical(logits * (1 / temperature)) - else: - raise ValueError(f"Invalid mlx-lm sampler: `{sampler}`") - - prob = softmax_logits[0, token] - return token, prob - - cache = mlx_lm.models.cache.make_prompt_cache(self.model) - - # kv cache contains processed input IDs, we pass the unprocessed inputs and cache to model() - unprocessed_input_ids = prompt - generated_ids: List[int] = [] - - while True: - logits = self.model(unprocessed_input_ids[None], cache=cache) - logits = logits[:, -1, :] - - if logits_processor is not None: - # convert to logits_processor 1d expectation, apply, then convert back - logits_1d = logits.reshape(-1) - logits_1d = logits_processor(generated_ids, logits_1d) - logits = logits_1d.reshape(1, -1) - - new_token_single, prob = sample(logits) - new_token = new_token_single.item() - yield new_token, prob - - generated_ids.append(new_token) - unprocessed_input_ids = new_token_single + from mlx_lm import stream_generate + + for token in stream_generate( + self.model, + self.mlx_tokenizer, + self.type_adapter.format_input(model_input), + logits_processors=self.type_adapter.format_output_type(output_type), + **kwargs, + ): + yield token def from_mlxlm(model: "nn.Module", tokenizer: "PreTrainedTokenizer") -> MLXLM: diff --git a/outlines/processors/base_logits_processor.py b/outlines/processors/base_logits_processor.py index d6fe346e0..a74e7bf1f 100644 --- a/outlines/processors/base_logits_processor.py +++ b/outlines/processors/base_logits_processor.py @@ -81,6 +81,15 @@ def __call__( torch_logits = self._to_torch(logits) input_ids = self._to_torch(input_ids) + # if input_ids is 1D and logits is 2D with a single sequence, + # reshape input_ids to 2D (needed for mlx-lm) + if ( + len(input_ids.shape) == 1 + and len(torch_logits.shape) == 2 + and torch_logits.shape[0] == 1 + ): + input_ids = input_ids.unsqueeze(0) + assert torch_logits.shape[:-1] == input_ids.shape[:-1] # Guarantee passed as 2D Tensors, then covert back to original (1D or 2D) shape diff --git a/tests/models/test_mlxlm.py b/tests/models/test_mlxlm.py index 10280b756..3a4d64681 100644 --- a/tests/models/test_mlxlm.py +++ b/tests/models/test_mlxlm.py @@ -1,11 +1,18 @@ import pytest +import re +from enum import Enum import outlines +from outlines.types import Choice, JsonType, Regex from outlines.models.transformers import TransformerTokenizer +from pydantic import BaseModel +from transformers import PreTrainedTokenizer try: import mlx_lm import mlx.core as mx + from mlx_lm import load + from mlx_lm.tokenizer_utils import TokenizerWrapper HAS_MLX = mx.metal.is_available() except ImportError: @@ -22,8 +29,9 @@ def model(tmp_path_factory): @pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") -def test_mlxlm_model(model): +def test_mlxlm_model_initialization(model): assert hasattr(model, "model") + assert hasattr(model, "mlx_tokenizer") assert hasattr(model, "tokenizer") assert isinstance(model.tokenizer, TransformerTokenizer) @@ -32,69 +40,63 @@ def test_mlxlm_model(model): def test_mlxlm_tokenizer(model): # Test single string encoding/decoding test_text = "Hello, world!" - token_ids = mx.array(model.mlx_tokenizer.encode(test_text)) + token_ids, _ = model.tokenizer.encode(test_text) + token_ids = mx.array(token_ids) assert isinstance(token_ids, mx.array) @pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") -def test_mlxlm_generate(model): - from outlines.generate.api import GenerationParameters, SamplingParameters +def test_mlxlm_simple(model): + result = model.generate("Respond with one word. Not more.", None) + assert isinstance(result, str) - prompt = "Write a haiku about programming:" - # Test with basic generation parameters - gen_params = GenerationParameters(max_tokens=50, stop_at=None, seed=None) +@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") +def test_mlxlm_call(model): + result = model("Respond with one word. Not more.") + assert isinstance(result, str) + + +@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") +def test_mlxlm_invalid_input_type(model): + with pytest.raises(NotImplementedError, match="is not available"): + model(["Respond with one word. Not more."]) - # Test with different sampling parameters - sampling_params = SamplingParameters( - sampler="multinomial", num_samples=1, top_p=0.9, top_k=None, temperature=0.7 - ) - # Test generation - output = model.generate(prompt, gen_params, None, sampling_params) - assert isinstance(output, str) - assert len(output) > 0 +@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") +def test_mlxlm_invalid_inference_kwargs(model): + with pytest.raises(TypeError): + model("Respond with one word. Not more.", foo="bar") @pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") -def test_mlxlm_stream(model): - from outlines.generate.api import GenerationParameters, SamplingParameters +def test_mlxlm_inference_kwargs(model): + result = model("Write a short story about a cat.", max_tokens=2) + assert isinstance(result, str) + assert len(result) < 20 - prompt = "Count from 1 to 5:" - gen_params = GenerationParameters(max_tokens=20, stop_at=None, seed=None) +@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") +def test_mlxlm_regex(model): + result = model("Give a number between 0 and 9.", Regex(r"[0-9]")) + assert isinstance(result, str) + assert re.match(r"[0-9]", result) - sampling_params = SamplingParameters( - sampler="greedy", # Use greedy sampling for deterministic output - num_samples=1, - top_p=None, - top_k=None, - temperature=0.0, - ) - # Test streaming - stream = model.stream(prompt, gen_params, None, sampling_params) - tokens = list(stream) - assert len(tokens) > 0 - assert all(isinstance(token, str) for token in tokens) +@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") +def test_mlxlm_JsonType(model): + class Character(BaseModel): + name: str - # Test that concatenated streaming output matches generate output - streamed_text = "".join(tokens) - generated_text = model.generate(prompt, gen_params, None, sampling_params) - assert streamed_text == generated_text + result = model("Create a character with a name.", JsonType(Character)) + assert "name" in result @pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") -def test_mlxlm_errors(model): - # Test batch inference (should raise NotImplementedError) - with pytest.raises(NotImplementedError): - from outlines.generate.api import GenerationParameters, SamplingParameters - - gen_params = GenerationParameters(max_tokens=10, stop_at=None, seed=None) - sampling_params = SamplingParameters("multinomial", 1, None, None, 1.0) - model.generate(["prompt1", "prompt2"], gen_params, None, sampling_params) - - # Test beam search (should raise NotImplementedError) - with pytest.raises(NotImplementedError): - sampling_params = SamplingParameters("beam_search", 1, None, None, 1.0) - model.generate("test prompt", gen_params, None, sampling_params) +def test_mlxlm_choice(model): + class Foo(Enum): + cat = "cat" + dog = "dog" + + result = model("Cat or dog?", Choice(Foo)) + assert result in ["cat", "dog"]