Skip to content

Commit

Permalink
Refactor the MLXLM model
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinPicard committed Feb 24, 2025
1 parent f1ab28b commit 3b2a724
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 223 deletions.
273 changes: 98 additions & 175 deletions outlines/models/mlxlm.py
Original file line number Diff line number Diff line change
@@ -1,202 +1,125 @@
import dataclasses
from typing import TYPE_CHECKING, Generator, Iterator, List, Optional, Tuple, Union
from functools import singledispatchmethod
from typing import TYPE_CHECKING

from .transformers import TransformerTokenizer
from outlines.models.base import Model, ModelTypeAdapter
from outlines.models.transformers import TransformerTokenizer

if TYPE_CHECKING:
import mlx.core as mx
import mlx.nn as nn
from transformers import PreTrainedTokenizer

from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.processors import OutlinesLogitsProcessor


__all__ = ["MLXLM", "from_mlxlm"]


class MLXLM:
"""
Represents an `mlx_lm` model
"""
class MLXLMTypeAdapter(ModelTypeAdapter):
"""Type adapter for `mlx_lm` models."""

@singledispatchmethod
def format_input(self, model_input):
"""Generate the prompt argument to pass to the model.
Argument
--------
model_input
The input passed by the user.
"""
raise NotImplementedError(
f"The input type {input} is not available. "
"The `mlx_lm` library does not support batch inference."
)

@format_input.register(str)
def format_str_input(self, model_input: str):
return model_input

def format_output_type(self, output_type):
"""Generate the logits processor argument to pass to the model.
Argument
--------
output_type
The logits processor provided.
"""
if not output_type:
return None
return [output_type]


class MLXLM(Model):
"""Represents an `mlx_lm` model."""

def __init__(
self,
model: "nn.Module",
tokenizer: "PreTrainedTokenizer",
):
self.model = model
self.mlx_tokenizer = tokenizer # returns mlx tensors, used for encode()
self.tokenizer = TransformerTokenizer(
tokenizer._tokenizer
) # _tokenizer is HF Tokenizer
"""Create a MLXLM model instance.
def generate(
self,
prompts: Union[str, List[str]],
generation_parameters: "GenerationParameters",
logits_processor,
sampling_parameters: "SamplingParameters",
) -> str:
streamer = self.stream(
prompts, generation_parameters, logits_processor, sampling_parameters
)
return "".join(list(streamer))
Arguments
---------
model
An instance of an mlx-lm model.
tokenizer
An instance of an mlx-lm tokenizer or of a compatible
transformers tokenizer.
def stream(
self,
prompts: Union[str, List[str]],
generation_parameters: "GenerationParameters",
logits_processor,
sampling_parameters: "SamplingParameters",
) -> Iterator[str]:
"""Generate text using `mlx_lm`.
Parameters
----------
prompts
A prompt or list of prompts.
generation_parameters
An instance of `GenerationParameters` that contains the prompt,
the maximum number of tokens, stop sequences and seed. All the
arguments to `SequenceGeneratorAdapter`'s `__cal__` method.
logits_processor
The logits processor to use when generating text.
sampling_parameters
An instance of `SamplingParameters`, a dataclass that contains
the name of the sampler to use and related parameters as available
in Outlines.
Returns
-------
The generated text.
"""
import mlx.core as mx
self.model = model
# self.mlx_tokenizer is used by the mlx-lm in its generate function
self.mlx_tokenizer = tokenizer
# self.tokenizer is used by the logits processor
self.tokenizer = TransformerTokenizer(tokenizer._tokenizer)
self.type_adapter = MLXLMTypeAdapter()

def generate(self, model_input, output_type, **kwargs):
"""Generate text using `mlx-lm`.
Arguments
---------
model_input
The prompt passed by the user.
output_type
The logits processor provided.
kwargs
The inference kwargs that can be passed to the `mlx-lm` library.
max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters)
sampler, num_samples, top_p, top_k, temperature = dataclasses.astuple(
sampling_parameters
"""
from mlx_lm import generate

return generate(
self.model,
self.mlx_tokenizer,
self.type_adapter.format_input(model_input),
logits_processors=self.type_adapter.format_output_type(output_type),
**kwargs,
)
if max_tokens is None:
max_tokens = int(1e9)

if not isinstance(prompts, str):
raise NotImplementedError(
"The `mlx-lm` library does not support batch inference."
)
if sampler == "beam_search":
raise NotImplementedError(
"The `mlx-lm` library does not support Beam Search."
)
if num_samples != 1:
raise NotImplementedError(
"The `mlx-lm` library does not allow to take several samples."
)
if top_k is not None:
raise NotImplementedError("The `mlx-lm` library does not support top_k.")
if seed is not None:
raise NotImplementedError("The `mlx-lm` library does not support seed.")
if stop_at is not None:
raise NotImplementedError("The `mlx-lm` library does not support stop_at.")

generate_kwargs = {
"temp": temperature,
"top_p": top_p,
"sampler": sampler,
"logits_processor": logits_processor,
}

# Adapted from
# https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L267
prompt_tokens = mx.array(self.mlx_tokenizer.encode(prompts))

detokenizer = self.mlx_tokenizer.detokenizer
detokenizer.reset()

for (token, prob), n in zip(
self.generate_step(prompt_tokens, **generate_kwargs),
range(max_tokens),
):
if token == self.tokenizer.eos_token_id:
break
detokenizer.add_token(token)
yield detokenizer.last_segment

detokenizer.finalize()
yield detokenizer.last_segment
def stream(self, model_input, output_type, **kwargs):
"""Stream text using `mlx-lm`.
Arguments
---------
model_input
The prompt passed by the user.
output_type
The logits processor provided.
kwargs
The inference kwargs that can be passed to the `mlx-lm` library.
def generate_step(
self,
prompt: "mx.array",
temp: Optional[float],
top_p: Optional[float],
sampler: str,
logits_processor: "OutlinesLogitsProcessor",
) -> Generator[Tuple[int, float], None, None]:
"""
Adapted from
https://github.com/ml-explore/mlx-examples/blob/4872727/llms/mlx_lm/utils.py#L129
A generator producing token ids based on the given prompt from the model.
Parameters
----------
prompt
The input prompt.
temp
The temperature for sampling, if 0 the argmax is used.
top_p
Nulceus sampling, higher means model considers more less likely words.
sampler
The sampler string defined by SequenceGeneratorAdapter
logits_processor
Augment logits before sampling.
"""
import mlx.core as mx
import mlx_lm

temperature: float = temp or 1.0

def sample(logits: "mx.array") -> Tuple["mx.array", float]:
softmax_logits = mx.softmax(logits)

if temperature == 0.0 or sampler == "greedy":
token = mx.argmax(logits, axis=-1)
elif sampler == "multinomial":
if top_p is not None and top_p > 0 and top_p < 1.0:
token = mlx_lm.sample_utils.top_p_sampling(
logits, top_p, temperature
)
else:
token = mx.random.categorical(logits * (1 / temperature))
else:
raise ValueError(f"Invalid mlx-lm sampler: `{sampler}`")

prob = softmax_logits[0, token]
return token, prob

cache = mlx_lm.models.cache.make_prompt_cache(self.model)

# kv cache contains processed input IDs, we pass the unprocessed inputs and cache to model()
unprocessed_input_ids = prompt
generated_ids: List[int] = []

while True:
logits = self.model(unprocessed_input_ids[None], cache=cache)
logits = logits[:, -1, :]

if logits_processor is not None:
# convert to logits_processor 1d expectation, apply, then convert back
logits_1d = logits.reshape(-1)
logits_1d = logits_processor(generated_ids, logits_1d)
logits = logits_1d.reshape(1, -1)

new_token_single, prob = sample(logits)
new_token = new_token_single.item()
yield new_token, prob

generated_ids.append(new_token)
unprocessed_input_ids = new_token_single
from mlx_lm import stream_generate

for token in stream_generate(
self.model,
self.mlx_tokenizer,
self.type_adapter.format_input(model_input),
logits_processors=self.type_adapter.format_output_type(output_type),
**kwargs,
):
yield token


def from_mlxlm(model: "nn.Module", tokenizer: "PreTrainedTokenizer") -> MLXLM:
Expand Down
9 changes: 9 additions & 0 deletions outlines/processors/base_logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ def __call__(
torch_logits = self._to_torch(logits)
input_ids = self._to_torch(input_ids)

# if input_ids is 1D and logits is 2D with a single sequence,
# reshape input_ids to 2D (needed for mlx-lm)
if (
len(input_ids.shape) == 1
and len(torch_logits.shape) == 2
and torch_logits.shape[0] == 1
):
input_ids = input_ids.unsqueeze(0)

assert torch_logits.shape[:-1] == input_ids.shape[:-1]

# Guarantee passed as 2D Tensors, then covert back to original (1D or 2D) shape
Expand Down
Loading

0 comments on commit 3b2a724

Please sign in to comment.