diff --git a/torchsig/__init__.py b/torchsig/__init__.py index 3dc1f76..493f741 100644 --- a/torchsig/__init__.py +++ b/torchsig/__init__.py @@ -1 +1 @@ -__version__ = "0.1.0" +__version__ = "0.3.0" diff --git a/torchsig/datasets/sig53.py b/torchsig/datasets/sig53.py index cb0f27a..d3dd6db 100644 --- a/torchsig/datasets/sig53.py +++ b/torchsig/datasets/sig53.py @@ -110,7 +110,7 @@ def __getitem__(self, idx: int) -> tuple: class_index=mod, snr=snr, ) - data = SignalData( + data: SignalData = SignalData( data=deepcopy(iq_data.tobytes()), item_type=np.dtype(np.float64), data_type=np.dtype(np.complex128), diff --git a/torchsig/datasets/synthetic.py b/torchsig/datasets/synthetic.py index 0717ff1..330c631 100644 --- a/torchsig/datasets/synthetic.py +++ b/torchsig/datasets/synthetic.py @@ -589,6 +589,8 @@ def _generate_samples(self, item: Tuple) -> np.ndarray: # 2x for symbol length # 2x for number of symbols for at least 1 transition # 4x for largest burst duration option + sym_mult = 1 + if self.num_iq_samples <= 4 * 2 * 2 * num_subcarriers: sym_mult = self.num_iq_samples / (2 * 2 * num_subcarriers) + 1e-6 sym_mult = ( @@ -596,14 +598,11 @@ def _generate_samples(self, item: Tuple) -> np.ndarray: if sym_mult < 1.0 else int(np.ceil(sym_mult)) ) - else: - sym_mult = 1 + if self.num_iq_samples > 32768: # assume wideband task and reduce data for speed sym_mult = 0.3 - wideband = True - else: - wideband = False + if mod_type == "random": # Randomized subcarrier modulations @@ -740,9 +739,7 @@ def _generate_samples(self, item: Tuple) -> np.ndarray: flattened = cyclic_prefixed.T.flatten() # Generate randomized LPF cutoff = np.random.uniform(0.95, 0.95) - num_taps = int( - np.ceil(50 * 2 * np.pi / cutoff / 0.125 / 22) - ) # fred harris rule of thumb + num_taps = estimate_filter_length(cutoff) taps = sp.firwin( num_taps, cutoff, diff --git a/torchsig/utils/types.py b/torchsig/utils/types.py index 5b1b822..777e0ca 100644 --- a/torchsig/utils/types.py +++ b/torchsig/utils/types.py @@ -36,8 +36,9 @@ class SignalDescription: Name of the signal's class class_index (:obj:`Optional[int]`): Index of the signal's class - + """ + def __init__( self, sample_rate: Optional[int] = 1, @@ -58,10 +59,18 @@ def __init__( ): self.sample_rate = sample_rate self.num_iq_samples = num_iq_samples - self.lower_frequency = lower_frequency if lower_frequency else center_frequency - bandwidth / 2 - self.upper_frequency = upper_frequency if upper_frequency else center_frequency + bandwidth / 2 + self.lower_frequency = ( + lower_frequency if lower_frequency else center_frequency - bandwidth / 2 + ) + self.upper_frequency = ( + upper_frequency if upper_frequency else center_frequency + bandwidth / 2 + ) self.bandwidth = bandwidth if bandwidth else upper_frequency - lower_frequency - self.center_frequency = center_frequency if center_frequency else lower_frequency + self.bandwidth / 2 + self.center_frequency = ( + center_frequency + if center_frequency + else lower_frequency + self.bandwidth / 2 + ) self.start = start self.stop = stop self.duration = duration if duration else stop - start @@ -87,36 +96,39 @@ class SignalData: signal_description: Optional[Union[List[SignalDescription], SignalDescription]] Either a SignalDescription of signal data or a list of multiple SignalDescription objects describing multiple signals - + """ + def __init__( self, data: Optional[bytes], item_type: np.dtype, data_type: np.dtype, - signal_description: Optional[Union[List[SignalDescription], SignalDescription]] = None + signal_description: Optional[ + Union[List[SignalDescription], SignalDescription] + ] = None, ): + self.iq_data = None + self.signal_description = signal_description if data is not None: # No matter the underlying item type, we convert to double-precision - self.iq_data = np.frombuffer(data, dtype=item_type).astype(np.float64).view(data_type) - else: - # Allow for empty rf data object - self.iq_data = None - if isinstance(signal_description, list): - self.signal_description = signal_description - else: + self.iq_data = ( + np.frombuffer(data, dtype=item_type).astype(np.float64).view(data_type) + ) + + if not isinstance(signal_description, list): self.signal_description = [signal_description] class SignalCapture: def __init__( - self, - absolute_path: str, - num_bytes: int, - item_type: np.dtype, - is_complex: bool, - byte_offset: int = 0, - signal_description: Optional[SignalDescription] = None + self, + absolute_path: str, + num_bytes: int, + item_type: np.dtype, + is_complex: bool, + byte_offset: int = 0, + signal_description: Optional[SignalDescription] = None, ): self.absolute_path = absolute_path self.num_bytes = num_bytes