Skip to content

Commit

Permalink
Merge pull request #34 from eriknovak/feature/quant
Browse files Browse the repository at this point in the history
Add support for generator model quantization
  • Loading branch information
eriknovak authored Feb 10, 2025
2 parents 206c27d + efebda0 commit 54f61aa
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 46 deletions.
109 changes: 91 additions & 18 deletions anonipy/anonymize/generators/llm_label_generator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import re
import warnings
from typing import Tuple, List

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

from ...utils.package import is_installed_with
from .interface import GeneratorInterface
from ...definitions import Entity


# =====================================
# Main class
# =====================================
Expand Down Expand Up @@ -36,13 +35,15 @@ def __init__(
*args,
model_name: str = "HuggingFaceTB/SmolLM2-1.7B-Instruct",
use_gpu: bool = False,
use_quant: bool = False,
**kwargs,
):
"""Initializes the LLM label generator.
Args:
model_name: The name of the model to use.
use_gpu: Whether to use GPU or not.
use_quant: Whether to use quantization or not.
Examples:
>>> from anonipy.anonymize.generators import LLMLabelGenerator
Expand All @@ -59,8 +60,14 @@ def __init__(
)
use_gpu = False

if use_quant and not is_installed_with(["quant", "all"]):
warnings.warn(
"The use_quant=True flag requires the 'quant' extra dependencies, but they are not installed. Setting use_quant=False."
)
use_quant = False

self.model, self.tokenizer = self._prepare_model_and_tokenizer(
model_name, use_gpu
model_name, use_gpu, use_quant
)

def generate(
Expand Down Expand Up @@ -108,7 +115,7 @@ def generate(
# =================================

def _prepare_model_and_tokenizer(
self, model_name: str, use_gpu: bool
self, model_name: str, use_gpu: bool, use_quant: bool
) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
"""Prepares the model and tokenizer.
Expand All @@ -125,12 +132,66 @@ def _prepare_model_and_tokenizer(
device = torch.device(
"cuda" if use_gpu and torch.cuda.is_available() else "cpu"
)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
# prepare the tokenizer
tokenizer = AutoTokenizer.from_pretrained(
dtype = torch.float32 if device.type == "cpu" else torch.float16

model = self._load_model(model_name, device, dtype, use_quant, use_gpu)
tokenizer = self._load_tokenizer(model_name)

return model, tokenizer

def _load_model(
self,
model_name: str,
device: torch.device,
dtype: torch.dtype,
use_quant: bool,
use_gpu: bool,
) -> AutoModelForCausalLM:
"""Load the model with appropriate configuration.
Args:
model_name: The name of the model to use.
device: The device to use for the model.
dtype: The data type to use for the model.
use_quant: Whether to use quantization or not.
use_gpu: Whether to use GPU or not.
Returns:
The huggingface model.
"""
if use_quant and use_gpu:
quant_config = BitsAndBytesConfig(
load_in_8bit=True, bnb_4bit_compute_dtype=dtype
)
return AutoModelForCausalLM.from_pretrained(
model_name,
device_map=device,
torch_dtype=dtype,
quantization_config=quant_config,
)

if use_quant:
warnings.warn(
"Quantization is only supported on GPU, but use_gpu=False. Loading model without quantization."
)

return AutoModelForCausalLM.from_pretrained(
model_name, device_map=device, torch_dtype=dtype
)

def _load_tokenizer(self, model_name: str) -> AutoTokenizer:
"""Load the tokenizer with appropriate configuration.
Args:
model_name: The name of the model to use.
Returns:
The huggingface tokenizer.
"""
return AutoTokenizer.from_pretrained(
model_name, padding_side="right", use_fast=False
)
return model, tokenizer

def _generate_response(
self, message: List[dict], temperature: float, top_p: float
Expand All @@ -152,15 +213,27 @@ def _generate_response(
message, tokenize=True, return_tensors="pt", add_generation_prompt=True
).to(self.model.device)

# generate the response
with torch.no_grad():
output_ids = self.model.generate(
input_ids,
max_new_tokens=50,
temperature=temperature,
top_p=top_p,
do_sample=True,
)
# create attention mask (1 for all tokens)
attention_mask = torch.ones_like(input_ids)

# set pad token id if not set
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=".*the `logits` model output.*")

# generate the response
with torch.no_grad():
output_ids = self.model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=50,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)

# decode the response
response = self.tokenizer.decode(
Expand Down
34 changes: 34 additions & 0 deletions anonipy/utils/package.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from importlib.metadata import requires
from typing import Union


def is_installed_with(extra: Union[str, list[str]]) -> bool:
"""Check if anonipy was installed with specific optional dependencies.
Args:
extra: The optional dependency or list of dependencies to check.
Valid values are: 'dev', 'test', 'quant', 'all'
Returns:
True if package was installed with the specified optional dependencies,
False otherwise.
Example:
>>> from anonipy.utils.package import is_installed_with
>>> is_installed_with('dev') # check if dev dependencies are installed
>>> is_installed_with(['dev', 'test']) # check multiple dependency groups
"""
if isinstance(extra, str):
extra = [extra]

try:
package_requires = requires("anonipy") or []
installed_extras = set()

for req in package_requires:
if "extra == " in req:
installed_extras.add(req.split("extra == ")[1].strip("\"'"))

return any(e in installed_extras for e in extra)
except Exception:
return False
12 changes: 10 additions & 2 deletions docs/how-to-guides/posts/generators-overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,15 @@ The [LLMLabelGenerator][anonipy.anonymize.generators.LLMLabelGenerator] is a one
from anonipy.anonymize.generators import LLMLabelGenerator
```

The `LLMLabelGenerator` currently does not require any input parameters at initialization.
The `LLMLabelGenerator` requires the following input parameters at initialization:

::: anonipy.anonymize.generators.LLMLabelGenerator.__init__
options:
show_root_heading: False
show_docstring_description: False
show_docstring_examples: False
show_docstring_returns: False
show_source: False

Let us now initialize the LLM label generator.

Expand All @@ -92,7 +100,7 @@ llm_generator = LLMLabelGenerator()
```

!!! info "Initialization warnings"
The initialization of `LLMLabelGenerator` will throw some warnings. Ignore them. These are expected due to the use of package dependencies.
The initialization of `LLMLabelGenerator` will throw some warnings. Ignore them. These are expected due to the use of package dependencies.

To use the generator, we can call the `generate` method. The `generate` method receives the following parameters:

Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ test = [
"pytest",
"pytest-cov",
]
all = ["anonipy[dev,test]"]
quant = [
"bitsandbytes",
]
all = ["anonipy[dev,test,quant]"]

[tool.setuptools.packages.find]
where = ["."]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# NLP and LLMs
spacy==3.8.2
gliner==0.2.13
gliner==0.2.16
gliner-spacy==0.0.10
transformers==4.45.2
accelerate>=0.26.0
Expand Down
62 changes: 38 additions & 24 deletions test/test_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@
start_index=86,
end_index=96,
type="date",
regex=r"Date of Examination: (.*)"
regex=r"Date of Examination: (.*)",
),
# Repeated entity
Entity(
Expand All @@ -180,7 +180,7 @@
start_index=759,
end_index=769,
type="date",
regex=r"Date of Examination: (.*)"
regex=r"Date of Examination: (.*)",
),
]
TEST_MULTI_REPEATS = [
Expand Down Expand Up @@ -224,6 +224,7 @@
),
]


@pytest.fixture(autouse=True)
def suppress_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
Expand Down Expand Up @@ -270,7 +271,7 @@ def pattern_extractor():
{"SHAPE": "dddd"},
]
],
}
},
]
return PatternExtractor(labels=labels, lang=LANGUAGES.ENGLISH)

Expand Down Expand Up @@ -412,6 +413,7 @@ def test_pattern_extractor_extract_default(pattern_extractor):
assert p_entity.regex == t_entity.regex
assert p_entity.score == 1.0


def test_pattern_extractor_detect_repeats_false():
extractor = PatternExtractor(
labels=[
Expand All @@ -434,6 +436,7 @@ def test_pattern_extractor_detect_repeats_false():
assert excepted_entity.regex == entities[0].regex
assert excepted_entity.score >= 0.5


def test_pattern_extractor_detect_repeats_true():
extractor = PatternExtractor(
labels=[
Expand All @@ -455,6 +458,7 @@ def test_pattern_extractor_detect_repeats_true():
assert p_entity.regex == t_entity.regex
assert p_entity.score >= 0.5


def test_multi_extractor_init():
with pytest.raises(TypeError):
MultiExtractor()
Expand Down Expand Up @@ -568,16 +572,21 @@ def test_multi_extractor_extract_single_extractor_pattern(multi_extractor):

def test_multi_extractor_detect_repeats_false():
extractors = [
NERExtractor(labels=[
{"label": "name", "type": "string"},
]),
PatternExtractor(labels=[
{
"label": "date",
"type": "date",
"regex": r"Date of Examination: (.*)",
},
])]
NERExtractor(
labels=[
{"label": "name", "type": "string"},
]
),
PatternExtractor(
labels=[
{
"label": "date",
"type": "date",
"regex": r"Date of Examination: (.*)",
},
]
),
]
extractor = MultiExtractor(extractors)
_, joint_entities = extractor(TEST_ORIGINAL_TEXT, detect_repeats=False)
for p_entity, t_entity in zip(joint_entities, TEST_MULTI_REPEATS[:3]):
Expand All @@ -592,16 +601,21 @@ def test_multi_extractor_detect_repeats_false():

def test_multi_extractor_detect_repeats_true():
extractors = [
NERExtractor(labels=[
{"label": "name", "type": "string"},
]),
PatternExtractor(labels=[
{
"label": "date",
"type": "date",
"regex": r"Date of Examination: (.*)",
},
])]
NERExtractor(
labels=[
{"label": "name", "type": "string"},
]
),
PatternExtractor(
labels=[
{
"label": "date",
"type": "date",
"regex": r"Date of Examination: (.*)",
},
]
),
]
extractor = MultiExtractor(extractors)
_, joint_entities = extractor(TEST_ORIGINAL_TEXT, detect_repeats=True)
for p_entity, t_entity in zip(joint_entities, TEST_MULTI_REPEATS):
Expand All @@ -611,4 +625,4 @@ def test_multi_extractor_detect_repeats_true():
assert p_entity.end_index == t_entity.end_index
assert p_entity.type == t_entity.type
assert p_entity.regex == t_entity.regex
assert p_entity.score >= 0.5
assert p_entity.score >= 0.5

0 comments on commit 54f61aa

Please sign in to comment.