-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathconfig.py
65 lines (54 loc) · 2.44 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from ...import_utils import py_txi_version
from ...system_utils import is_nvidia_system, is_rocm_system
from ...task_utils import TEXT_EMBEDDING_TASKS, TEXT_GENERATION_TASKS
from ..config import BackendConfig
@dataclass
class PyTXIConfig(BackendConfig):
name: str = "py-txi"
version: Optional[str] = py_txi_version()
_target_: str = "optimum_benchmark.backends.py_txi.backend.PyTXIBackend"
# optimum-benchmark specific
no_weights: bool = False
# Image to use for the container
image: Optional[str] = None
# Shared memory size for the container
shm_size: Optional[str] = None
# List of custom devices to forward to the container e.g. ["/dev/kfd", "/dev/dri"] for ROCm
devices: Optional[List[str]] = None
# NVIDIA-docker GPU device options e.g. "all" (all) or "0,1,2,3" (ids) or 4 (count)
gpus: Optional[Union[str, int]] = None
# Things to forward to the container
ports: Optional[Dict[str, Any]] = None
environment: Optional[List[str]] = None
volumes: Optional[Dict[str, Any]] = None
# First connection/request
connection_timeout: Optional[int] = None
first_request_timeout: Optional[int] = None
max_concurrent_requests: Optional[int] = None
# Common options
dtype: Optional[str] = None
# TEI specific
pooling: Optional[str] = None
# TGI specific
sharded: Optional[str] = None
quantize: Optional[str] = None
num_shard: Optional[int] = None
speculate: Optional[int] = None
cuda_graphs: Optional[int] = None
trust_remote_code: Optional[bool] = None
disable_custom_kernels: Optional[bool] = None
def __post_init__(self):
super().__post_init__()
if self.task not in TEXT_GENERATION_TASKS + TEXT_EMBEDDING_TASKS:
raise NotImplementedError(f"TXI does not support task {self.task}")
# Device options
if self.device_ids is not None and is_nvidia_system() and self.gpus is None:
self.gpus = self.device_ids
if self.device_ids is not None and is_rocm_system() and self.devices is None:
ids = list(map(int, self.device_ids.split(",")))
renderDs = [file for file in os.listdir("/dev/dri") if file.startswith("renderD")]
self.devices = ["/dev/kfd"] + [f"/dev/dri/{renderDs[i]}" for i in ids]
self.trust_remote_code = self.model_kwargs.get("trust_remote_code", None)