Skip to content

Commit

Permalink
refactor: move punctuation hash to default constant
Browse files Browse the repository at this point in the history
  • Loading branch information
roedoejet committed Feb 10, 2025
1 parent c4418bc commit edb0ef1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 23 deletions.
28 changes: 19 additions & 9 deletions everyvoice/text/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,24 @@
from everyvoice.config.text_config import TextConfig

N_PHONOLOGICAL_FEATURES = 43
DEFAULT_PUNCTUATION_HASH = {
"exclamations": "<EXCL>",
"ellipses": "<EPS>",
"question_symbols": "<QINT>",
"quotemarks": "<QUOTE>",
"periods": "<PERIOD>",
"commas": "<COMMA>",
"colons": "<COLON>",
"semi_colons": "<SEMICOL>",
"hyphens": "<HYPHEN>",
"parentheses": "<PAREN>",
}


class PhonologicalFeatureCalculator:
def __init__(self, text_config: TextConfig, punctuation_hash: dict):
def __init__(
self, text_config: TextConfig, punctuation_hash: dict = DEFAULT_PUNCTUATION_HASH
):
self.config = text_config
self.punctuation_hash = punctuation_hash
self.feature_table = FeatureTable()
Expand Down Expand Up @@ -38,8 +52,7 @@ def get_punctuation_features(self, tokens: list[str]) -> npt.NDArray[np.float32]
Returns:
npt.NDArray[np.float32]: a seven-dimensional one-hot encoding of punctuation, white space and silence
>>> punc_hash = punc_hash = {"exclamations": "<EXCL>", "ellipses": "<EPS>", "question_symbols": "<QINT>", "quotemarks": "<QUOTE>", "periods": "<PERIOD>", "commas": "<COMMA>", "colons": "<COLON>", "semi_colons": "<SEMICOL>", "hyphens": "<HYPHEN>", "parentheses": "<PAREN>"}
>>> pf = PhonologicalFeatureCalculator(TextConfig(), punc_hash)
>>> pf = PhonologicalFeatureCalculator(TextConfig())
>>> pf.get_punctuation_features(['h', 'ʌ', 'l', 'o', 'ʊ', '<EXCL>'])
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
Expand Down Expand Up @@ -88,8 +101,7 @@ def get_stress_features(self, tokens: list[str]) -> npt.NDArray[np.float32]:
Returns:
npt.NDArray[np.float32]: a two-dimensional one-hot encoding of primary and secondary stress
>>> punc_hash = punc_hash = {"exclamations": "<EXCL>", "question_symbols": "<QINT>", "quotemarks": "<QUOTE>", "periods": "<PERIOD>", "commas": "<COMMA>", "colons": "<COLON>", "semi_colons": "<SEMICOL>", "hyphens": "<HYPHEN>", "parentheses": "<PAREN>"}
>>> pf = PhonologicalFeatureCalculator(TextConfig(), punc_hash)
>>> pf = PhonologicalFeatureCalculator(TextConfig())
>>> pf.get_stress_features(['ˈ', 'ˌ' ])
array([[1., 0.],
[0., 1.]], dtype=float32)
Expand All @@ -114,8 +126,7 @@ def get_special_token_features(self, tokens: list[str]) -> npt.NDArray[np.float3
Returns:
npt.NDArray[np.float32]: a five-dimensional one-hot encoding of special tokens
>>> punc_hash = punc_hash = {"exclamations": "<EXCL>", "question_symbols": "<QINT>", "quotemarks": "<QUOTE>", "periods": "<PERIOD>", "commas": "<COMMA>", "colons": "<COLON>", "semi_colons": "<SEMICOL>", "hyphens": "<HYPHEN>", "parentheses": "<PAREN>"}
>>> pf = PhonologicalFeatureCalculator(TextConfig(), punc_hash)
>>> pf = PhonologicalFeatureCalculator(TextConfig())
>>> pf.get_special_token_features(['\x80', '[MASK]', '[CLS]', '[SEP]', '[UNK]' ])
array([[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
Expand Down Expand Up @@ -148,8 +159,7 @@ def token_to_segmental_features(self, token: str) -> npt.NDArray[np.float32]:
Returns:
npt.NDArray[np.float32]: a list of place and manner of articulation feature values
>>> punc_hash = {"exclamations": "<EXCL>", "question_symbols": "<QINT>", "quotemarks": "<QUOTE>", "periods": "<PERIOD>", "commas": "<COMMA>", "colons": "<COLON>", "semi_colons": "<SEMICOL>", "hyphens": "<HYPHEN>", "parentheses": "<PAREN>"}
>>> pf = PhonologicalFeatureCalculator(TextConfig(), punc_hash)
>>> pf = PhonologicalFeatureCalculator(TextConfig())
>>> pf.token_to_segmental_features('\x80') # pad symbol is all zeros
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0.], dtype=float32)
Expand Down
20 changes: 6 additions & 14 deletions everyvoice/text/text_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

from everyvoice.config.text_config import TextConfig
from everyvoice.exceptions import OutOfVocabularySymbolError
from everyvoice.text.features import PhonologicalFeatureCalculator
from everyvoice.text.features import (
DEFAULT_PUNCTUATION_HASH,
PhonologicalFeatureCalculator,
)
from everyvoice.text.phonemizer import AVAILABLE_G2P_ENGINES, get_g2p_engine
from everyvoice.text.utils import (
apply_cleaners_helper,
Expand Down Expand Up @@ -77,7 +80,7 @@ class TextProcessor:
'h/e/l/l/o/!'
"""

def __init__(self, config: TextConfig):
def __init__(self, config: TextConfig, punctuation_hash=DEFAULT_PUNCTUATION_HASH):
self.config = config
self.phonological_feature_calculator: Optional[
PhonologicalFeatureCalculator
Expand All @@ -86,18 +89,7 @@ def __init__(self, config: TextConfig):

# Add punctuation
# Add an internal hash to convert from the type of Punctuation to the internal representation
self.punctuation_internal_hash = {
"exclamations": "<EXCL>",
"question_symbols": "<QINT>",
"quotemarks": "<QUOTE>",
"colons": "<COLON>",
"semi_colons": "<SEMICOL>",
"hyphens": "<HYPHEN>",
"commas": "<COMMA>",
"periods": "<PERIOD>",
"ellipses": "<EPS>",
"parentheses": "<PAREN>",
}
self.punctuation_internal_hash = punctuation_hash
# Create a hash table from punctuation to the internal ID
self.punctuation_to_internal_id = {
v: self.punctuation_internal_hash[punctuation_type]
Expand Down

0 comments on commit edb0ef1

Please sign in to comment.