diff --git a/rttDroneGCS/config.py b/rttDroneGCS/config.py new file mode 100644 index 0000000..5d96333 --- /dev/null +++ b/rttDroneGCS/config.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +import os +import sys +from configparser import ConfigParser +from pathlib import Path +from typing import Any, Dict, Tuple + + +class Configuration: + """Configuration file interface object""" + + def __init__(self, config_path: Path) -> None: + self.__config_path = config_path + + self.__map_extent_nw: Tuple[float, float] = (90.0, -180.0) + self.__map_extent_se: Tuple[float, float] = (-90.0, 180.0) + + self.__lora_port: str = self.__get_default_port() + self.__lora_baud: int = 115200 + self.__lora_frequency: int = 915000000 # Default to 915 MHz + + @staticmethod + def __get_default_port() -> str: + if sys.platform.startswith("win"): + return "COM1" + elif sys.platform.startswith("linux"): + return "/dev/ttyUSB0" + elif sys.platform.startswith("darwin"): + return "/dev/tty.usbserial-0001" + else: + return "" + + def __create_dict(self): + return { + "LastCoords": { + "lat1": self.__map_extent_nw[0], + "lat2": self.__map_extent_se[0], + "lon1": self.__map_extent_nw[1], + "lon2": self.__map_extent_se[1], + }, + "LoRa": { + "port": self.__lora_port, + "baud": self.__lora_baud, + "frequency": self.__lora_frequency, + }, + } + + def load(self) -> None: + """Loads the configuration from the specified file""" + parser = ConfigParser() + parser.read_dict(self.__create_dict()) + parser.read(self.__config_path) + + self.__map_extent_nw = ( + parser["LastCoords"].getfloat("lat1"), + parser["LastCoords"].getfloat("lon1"), + ) + self.__map_extent_se = ( + parser["LastCoords"].getfloat("lat2"), + parser["LastCoords"].getfloat("lon2"), + ) + + self.__lora_port = parser["LoRa"].get("port") + self.__lora_baud = parser["LoRa"].getint("baud") + self.__lora_frequency = parser["LoRa"].getint("frequency") + + def write(self) -> None: + """Writes the configuration to the file""" + parser = ConfigParser() + parser.read_dict(self.__create_dict()) + with open(self.__config_path, "w", encoding="ascii") as handle: + parser.write(handle) + + @property + def lora_port(self) -> str: + """LoRa port + + Returns: + str: LoRa port + """ + return self.__lora_port + + @lora_port.setter + def lora_port(self, value: Any) -> None: + if not isinstance(value, str): + raise TypeError + self.__lora_port = value + + @property + def lora_baud(self) -> int: + """LoRa baud rate + + Returns: + int: LoRa baud rate + """ + return self.__lora_baud + + @lora_baud.setter + def lora_baud(self, value: Any) -> None: + if not isinstance(value, int): + raise TypeError + if value <= 0: + raise ValueError + self.__lora_baud = value + + @property + def lora_frequency(self) -> int: + """LoRa frequency in Hz + + Returns: + int: LoRa frequency + """ + return self.__lora_frequency + + @lora_frequency.setter + def lora_frequency(self, value: Any) -> None: + if not isinstance(value, int): + raise TypeError + if value <= 0: + raise ValueError + self.__lora_frequency = value + + @property + def map_extent(self) -> Tuple[Tuple[float, float], Tuple[float, float]]: + """Map previous extent + + Returns: + Tuple[Tuple[float, float], Tuple[float, float]]: NW and SE map extents in dd.dddddd + """ + return (self.__map_extent_nw, self.__map_extent_se) + + @map_extent.setter + def map_extent(self, value: Any) -> None: + if not isinstance(value, tuple): + raise TypeError + if len(value) != 2: + raise TypeError + for coordinate in value: + if not isinstance(coordinate, tuple): + raise TypeError + if len(value) != 2: + raise TypeError + + if not isinstance(coordinate[0], float): + raise TypeError + if not -90 <= coordinate[0] <= 90: + raise ValueError + + if not isinstance(coordinate[1], float): + raise TypeError + if not -180 <= coordinate[1] <= 180: + raise ValueError + self.__map_extent_nw = value[0] + self.__map_extent_se = value[1] + + def __enter__(self) -> Configuration: + self.load() + return self + + def __exit__(self, exc, exp, exv) -> None: + self.write() + + +__config_instance: Dict[Path, Configuration] = {} + + +def get_instance(path: Path) -> Configuration: + """Retrieves the corresponding configuration instance singleton + + Args: + path (Path): Path to config path + + Returns: + Configuration: Configuration singleton + """ + if path not in __config_instance: + __config_instance[path] = Configuration(path) + return __config_instance[path] + + +def get_config_path() -> Path: + """Retrieves the application configuration path + + Returns: + Path: Path to configuration file + """ + return Path("gcsConfig.ini") diff --git a/rttDroneGCS/gui/__init__.py b/rttDroneGCS/gui/__init__.py new file mode 100644 index 0000000..6fea5c9 --- /dev/null +++ b/rttDroneGCS/gui/__init__.py @@ -0,0 +1 @@ +"""GUI for the RTT Drone GCS.""" diff --git a/rttDroneGCS/gui/dialogs.py b/rttDroneGCS/gui/dialogs.py new file mode 100644 index 0000000..01cac2b --- /dev/null +++ b/rttDroneGCS/gui/dialogs.py @@ -0,0 +1,345 @@ +"""Dialogs for the RTT Drone GCS.""" + +from __future__ import annotations + +from PyQt5.QtCore import QRegExp +from PyQt5.QtGui import QRegExpValidator +from PyQt5.QtWidgets import ( + QGridLayout, + QHBoxLayout, + QLabel, + QLineEdit, + QPushButton, + QVBoxLayout, + QWidget, + QWizard, + QWizardPage, +) + +from rttDroneGCS.config import get_config_path, get_instance + +from .popups import UserPopups + + +class BaseDialog(QWizard): + """Base dialog class.""" + + def __init__(self, parent: QWidget, title: str, page_class: QWizardPage) -> None: + """Initialize the dialog. + + Args: + ---- + parent (QWidget): The parent widget. + title (str): The title of the dialog. + page_class (QWizardPage): The page class. + + """ + super().__init__(parent) + self.parent = parent + self.setWindowTitle(title) + self.page = page_class(self) + self.addPage(self.page) + self.resize(640, 480) + self.button(QWizard.FinishButton).clicked.connect(self.submit) + + def submit(self) -> None: + """Submit the dialog.""" + msg = "Subclasses must implement submit method" + raise NotImplementedError(msg) + + +class BaseDialogPage(QWizardPage): + """Base dialog page class.""" + + def __init__(self, parent: QWizard) -> None: + """Initialize the dialog page. + + Args: + ---- + parent (QWizard): The parent wizard. + + """ + super().__init__(parent) + self._parent = parent + self.user_pops = UserPopups() + self._create_widget() + + def _create_widget(self) -> None: + msg = "Subclasses must implement _create_widget method" + raise NotImplementedError(msg) + + +class ExpertSettingsDialog(BaseDialog): + """Expert settings dialog.""" + + def __init__(self, parent: QWidget, option_vars: dict) -> None: + """Initialize the expert settings dialog. + + Args: + ---- + parent (QWidget): The parent widget. + option_vars (dict): The option variables. + + """ + super().__init__( + parent, + "Expert/Engineering Settings", + ExpertSettingsDialogPage, + ) + self.option_vars = option_vars + self.page.option_vars = option_vars + self.parent.updateGUIOptionVars(0xFF, self.option_vars) + + def submit(self) -> None: + """Submit the expert settings dialog.""" + if not self.page.validate_parameters(): + self.page.user_pops.show_warning( + "Entered information could not be validated", + "Invalid Input", + ) + return + self.parent.submitGUIOptionVars(0xFF) + + +class ExpertSettingsDialogPage(BaseDialogPage): + """Expert settings dialog page.""" + + def _create_widget(self) -> None: + exp_settings_frame = QGridLayout() + + labels = [ + "Expected Ping Width (ms)", + "Min. Width Multiplier", + "Max. Width Multiplier", + "Min. Ping SNR(dB)", + "GPS Port", + "GPS Baud Rate", + "Output Directory", + "GPS Mode", + "SYS Autostart", + ] + input_fields = [ + "DSP_pingWidth", + "DSP_pingMin", + "DSP_pingMax", + "DSP_pingSNR", + "GPS_device", + "GPS_baud", + "SYS_outputDir", + "GPS_mode", + "SYS_autostart", + ] + + for i, (label_text, field_name) in enumerate(zip(labels, input_fields)): + exp_settings_frame.addWidget(QLabel(label_text), i, 0) + self.option_vars[field_name] = QLineEdit() + exp_settings_frame.addWidget(self.option_vars[field_name], i, 1) + + btn_submit = QPushButton("Submit") + btn_submit.clicked.connect(self._parent.submit) + exp_settings_frame.addWidget(btn_submit, len(labels), 0, 1, 2) + + self.setLayout(exp_settings_frame) + + def validate_parameters(self) -> bool: + """Validate the parameters.""" + return True + + +class AddTargetDialog(BaseDialog): + """Add target dialog.""" + + def __init__( + self, + parent: QWidget, + center_frequency: int, + sampling_frequency: int, + ) -> None: + """Initialize the add target dialog. + + Args: + ---- + parent (QWidget): The parent widget. + center_frequency (int): The center frequency. + sampling_frequency (int): The sampling frequency. + + """ + super().__init__(parent, "Add Target", AddTargetDialogPage) + self.center_frequency = center_frequency + self.sampling_frequency = sampling_frequency + self.page.center_frequency = center_frequency + self.page.sampling_frequency = sampling_frequency + + def submit(self) -> None: + """Submit the add target dialog.""" + if not self._validate(): + self.user_pops.show_warning( + "You have entered an invalid target frequency. Please try again.", + "Invalid frequency", + ) + return + self.name = self.page.target_name_entry.text() + self.freq = int(self.page.target_freq_entry.text()) + + def _validate(self) -> bool: + return ( + abs(int(self.page.target_freq_entry.text()) - self.center_frequency) + <= self.sampling_frequency + ) + + +class AddTargetDialogPage(BaseDialogPage): + """Add target dialog page.""" + + def _create_widget(self) -> None: + frm_target_settings = QGridLayout() + + lbl_target_name = QLabel("Target Name:") + frm_target_settings.addWidget(lbl_target_name, 0, 0) + self.target_name_entry = QLineEdit() + frm_target_settings.addWidget(self.target_name_entry, 0, 1) + + lbl_target_freq = QLabel("Target Frequency:") + frm_target_settings.addWidget(lbl_target_freq, 1, 0) + self.target_freq_entry = QLineEdit() + frm_target_settings.addWidget(self.target_freq_entry, 1, 1) + + regex_string = QRegExp(r"^\d{1,9}$") + val = QRegExpValidator(regex_string) + self.target_freq_entry.setValidator(val) + + lbl_center_freq = QLabel(f"Center Frequency: {self.center_frequency} Hz") + frm_target_settings.addWidget(lbl_center_freq, 2, 0, 1, 2) + + lbl_sampling_freq = QLabel(f"Sampling Frequency: {self.sampling_frequency} Hz") + frm_target_settings.addWidget(lbl_sampling_freq, 3, 0, 1, 2) + + lbl_valid_range = QLabel( + f"Valid Range: {self.center_frequency - self.sampling_frequency} to " + f"{self.center_frequency + self.sampling_frequency} Hz", + ) + frm_target_settings.addWidget(lbl_valid_range, 4, 0, 1, 2) + + self.setLayout(frm_target_settings) + + +class ConnectionDialog(BaseDialog): + """Connection dialog.""" + + def __init__(self, parent: QWidget) -> None: + """Initialize the connection dialog. + + Args: + ---- + parent (QWidget): The parent widget. + + """ + super().__init__(parent, "LoRa Connection Settings", ConnectionDialogPage) + + def _submit(self) -> None: + is_valid, error_message = self.page.validate_input() + if not is_valid: + UserPopups().show_warning(error_message, "Invalid Input") + return + self.parent.config.lora_port = self.page.port_entry.text() + self.parent.config.lora_baud = int(self.page.baud_entry.text()) + self.parent.config.lora_frequency = int(self.page.freq_entry.text()) + self.parent.config.write() + + +class ConnectionDialogPage(BaseDialogPage): + """Connection dialog page.""" + + def _create_widget(self) -> None: + frm_holder = QVBoxLayout() + + for label, attr in [ + ("LoRa Port", "lora_port"), + ("LoRa Baud Rate", "lora_baud"), + ("LoRa Frequency", "lora_frequency"), + ]: + frm = QHBoxLayout() + frm.addWidget(QLabel(label)) + entry = QLineEdit() + entry.setText(str(getattr(self._parent.parent.config, attr))) + frm.addWidget(entry) + frm_holder.addLayout(frm) + setattr(self, f"{attr.split('_')[1]}_entry", entry) + + self.setLayout(frm_holder) + + def validate_input(self) -> tuple[bool, str]: + """Validate the input.""" + try: + int(self.baud_entry.text()) + int(self.freq_entry.text()) + except ValueError: + return False, "Invalid baud rate or frequency" + else: + return True, "" + + +class ConfigDialog(BaseDialog): + """Config dialog.""" + + def __init__(self, parent: QWidget) -> None: + """Initialize the config dialog. + + Args: + ---- + parent (QWidget): The parent widget. + + """ + super().__init__(parent, "Edit Configuration Settings", ConfigDialogPage) + self.config = get_instance(get_config_path()) + + def _submit(self) -> None: + try: + self.config.map_extent = ( + (float(self.page.lat_1[1].text()), float(self.page.lon_1[1].text())), + (float(self.page.lat_2[1].text()), float(self.page.lon_2[1].text())), + ) + self.config.lora_port = self.page.lora_port[1].text() + self.config.lora_baud = int(self.page.lora_baud[1].text()) + self.config.lora_frequency = int(self.page.lora_frequency[1].text()) + self.config.write() + UserPopups().show_warning("Configuration updated successfully", "Success") + if self.parent: + self.parent.close() + except ValueError as e: + UserPopups().show_warning(f"Invalid input: {e}", "Error") + except (OSError, AttributeError) as e: + UserPopups().show_warning(f"Configuration error: {e}", "Error") + + +class ConfigDialogPage(BaseDialogPage): + """Config dialog page.""" + + def _create_widget(self) -> None: + frm_holder = QVBoxLayout() + + for label, attr in [ + ("Lat 1", "map_extent[0][0]"), + ("Lon 1", "map_extent[0][1]"), + ("Lat 2", "map_extent[1][0]"), + ("Lon 2", "map_extent[1][1]"), + ("LoRa Port", "lora_port"), + ("LoRa Baud Rate", "lora_baud"), + ("LoRa Frequency", "lora_frequency"), + ]: + value = self._get_nested_attr(self._parent.config, attr) + text_box = self.user_pops.create_text_box(label, str(value)) + frm_holder.addLayout(text_box[0]) + setattr( + self, + attr.split(".")[-1].replace("[", "_").replace("]", ""), + text_box, + ) + + self.setLayout(frm_holder) + + def _get_nested_attr(self, obj: object, attr: str) -> object: + parts = attr.replace("]", "").replace("[", ".").split(".") + for part in parts: + obj = obj[int(part)] if part.isdigit() else getattr(obj, part) + return obj diff --git a/rttDroneGCS/gui/map.py b/rttDroneGCS/gui/map.py new file mode 100644 index 0000000..5727982 --- /dev/null +++ b/rttDroneGCS/gui/map.py @@ -0,0 +1,862 @@ + + + +import csv +import math +import os +from pathlib import Path +from threading import Thread +from typing import Tuple, Optional + +import requests +from PyQt5.QtCore import Qt, QDir +from PyQt5.QtGui import QColor +from PyQt5.QtWidgets import ( + QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QFileDialog, + QAction, QToolBar +) + +from qgis.core import ( + QgsRectangle, QgsPointXY, QgsGeometry, QgsFeature, QgsField, + QgsVectorLayer, QgsRasterLayer, QgsProject, QgsCoordinateTransform, + QgsCoordinateReferenceSystem, QgsSymbol, QgsRendererRange, + QgsGraduatedSymbolRenderer, QgsMarkerSymbol, QgsSvgMarkerSymbolLayer, + QgsWkbTypes, QgsProperty +) +from qgis.gui import QgsMapCanvas, QgsMapToolEmitPoint, QgsRubberBand + +from .popups import UserPopups + +class RectangleMapTool(QgsMapToolEmitPoint): + def __init__(self, canvas): + super().__init__(canvas) + self.canvas = canvas + self.rubber_band = QgsRubberBand(self.canvas, True) + self.rubber_band.setColor(QColor(0, 255, 255, 125)) + self.rubber_band.setWidth(1) + self.reset() + + def reset(self): + self.start_point = self.end_point = None + self.is_emitting_point = False + self.rubber_band.reset(True) + + def canvasPressEvent(self, e): + self.start_point = self.toMapCoordinates(e.pos()) + self.end_point = self.start_point + self.is_emitting_point = True + self.show_rect(self.start_point, self.end_point) + + def canvasReleaseEvent(self, e): + self.is_emitting_point = False + + def canvasMoveEvent(self, e): + if not self.is_emitting_point: + return + self.end_point = self.toMapCoordinates(e.pos()) + self.show_rect(self.start_point, self.end_point) + + def show_rect(self, start_point, end_point): + self.rubber_band.reset(QgsWkbTypes.PolygonGeometry) + if start_point.x() == end_point.x() or start_point.y() == end_point.y(): + return + points = [ + QgsPointXY(start_point.x(), start_point.y()), + QgsPointXY(start_point.x(), end_point.y()), + QgsPointXY(end_point.x(), end_point.y()), + QgsPointXY(end_point.x(), start_point.y()) + ] + for i, point in enumerate(points): + self.rubber_band.addPoint(point, i == len(points) - 1) + self.rubber_band.show() + + def rectangle(self): + if self.start_point is None or self.end_point is None: + return None + if self.start_point.x() == self.end_point.x() or self.start_point.y() == self.end_point.y(): + return None + return QgsRectangle(self.start_point, self.end_point) + +class PolygonMapTool(QgsMapToolEmitPoint): + def __init__(self, canvas): + super().__init__(canvas) + self.canvas = canvas + self.vertices = [] + self.rubber_band = QgsRubberBand(self.canvas, QgsWkbTypes.PolygonGeometry) + self.rubber_band.setColor(Qt.red) + self.rubber_band.setWidth(1) + self.reset() + + def reset(self): + self.start_point = self.end_point = None + self.is_emitting_point = False + self.rubber_band.reset(True) + + def canvasPressEvent(self, e): + self.start_point = self.toMapCoordinates(e.pos()) + self.end_point = self.start_point + self.is_emitting_point = True + self.add_vertex(self.start_point) + self.show_line(self.start_point, self.end_point) + self.show_polygon() + + def canvasReleaseEvent(self, e): + self.is_emitting_point = False + + def canvasMoveEvent(self, e): + if not self.is_emitting_point: + return + self.end_point = self.toMapCoordinates(e.pos()) + self.show_line(self.start_point, self.end_point) + + def add_vertex(self, point): + self.vertices.append(QgsPointXY(point)) + + def show_polygon(self): + if len(self.vertices) > 1: + self.rubber_band.reset(QgsWkbTypes.PolygonGeometry) + for i, vertex in enumerate(self.vertices): + self.rubber_band.addPoint(vertex, i == len(self.vertices) - 1) + self.rubber_band.show() + + def show_line(self, start_point, end_point): + self.rubber_band.reset(QgsWkbTypes.PolygonGeometry) + if start_point.x() == end_point.x() or start_point.y() == end_point.y(): + return + self.rubber_band.addPoint(QgsPointXY(start_point.x(), start_point.y()), True) + self.rubber_band.show() + +class VehicleData: + def __init__(self): + self.ind = 0 + self.last_loc = None + +class MapWidget(QWidget): + def __init__(self, root): + super().__init__() + self.holder = QVBoxLayout() + self.ground_truth = None + self.map_layer = None + self.vehicle = None + self.vehicle_path = None + self.precision = None + self.cones = None + self.vehicle_data = {} + self.ping_layer = None + self.ping_renderer = None + self.estimate = None + self.tool_polygon = None + self.polygon_layer = None + self.polygon_action = None + self.heat_map = None + self.ping_min = 800 + self.ping_max = 0 + self.cone_min = float('inf') + self.cone_max = float('-inf') + self.ind = 0 + self.ind_ping = 0 + self.ind_est = 0 + self.ind_cone = 0 + self.toolbar = QToolBar() + self.canvas = QgsMapCanvas() + self.canvas.setCanvasColor(Qt.white) + + self.transform_to_web = QgsCoordinateTransform( + QgsCoordinateReferenceSystem("EPSG:4326"), + QgsCoordinateReferenceSystem("EPSG:3857"), + QgsProject.instance() + ) + self.transform = QgsCoordinateTransform( + QgsCoordinateReferenceSystem("EPSG:3857"), + QgsCoordinateReferenceSystem("EPSG:4326"), + QgsProject.instance() + ) + + def set_up_heat_map(self): + file_name = QFileDialog.getOpenFileName() + if file_name[0]: + if self.heat_map is not None: + QgsProject.instance().removeMapLayer(self.heat_map) + self.heat_map = QgsRasterLayer(file_name[0], "heat_map") + self._configure_heat_map() + + def _configure_heat_map(self): + stats = self.heat_map.dataProvider().bandStatistics(1) + max_val = stats.maximumValue + fcn = QgsColorRampShader() + fcn.setColorRampType(QgsColorRampShader.Interpolated) + lst = [ + QgsColorRampShader.ColorRampItem(0, QColor(0, 0, 0)), + QgsColorRampShader.ColorRampItem(max_val, QColor(255, 255, 255)) + ] + fcn.setColorRampItemList(lst) + shader = QgsRasterShader() + shader.setRasterShaderFunction(fcn) + + renderer = QgsSingleBandPseudoColorRenderer(self.heat_map.dataProvider(), 1, shader) + self.heat_map.setRenderer(renderer) + + QgsProject.instance().addMapLayer(self.heat_map) + dest_crs = self.map_layer.crs() + raster_crs = self.heat_map.crs() + + self.heat_map.setCrs(raster_crs) + self.canvas.setDestinationCrs(dest_crs) + + self.canvas.setLayers([ + self.heat_map, self.estimate, self.ground_truth, + self.vehicle, self.ping_layer, self.vehicle_path, self.map_layer + ]) + + def plot_precision(self, coord, freq, num_pings): + data_dir = 'holder' + output_file_name = f'/{data_dir}/PRECISION_{freq/1e7:.3f}_{num_pings}_heat_map.tiff' + file_name = QDir().currentPath() + output_file_name + + if self.heat_map is not None: + QgsProject.instance().removeMapLayer(self.heat_map) + + self.heat_map = QgsRasterLayer(file_name, "heat_map") + self._configure_heat_map() + self.heat_map.renderer().setOpacity(0.7) + + self.canvas.setLayers([ + self.heat_map, self.estimate, self.ground_truth, + self.vehicle, self.ping_layer, self.vehicle_path, self.map_layer + ]) + + def adjust_canvas(self): + self.canvas.setExtent(self.map_layer.extent()) + self.canvas.setLayers([ + self.precision, self.estimate, self.ground_truth, + self.vehicle, self.ping_layer, self.cones, + self.vehicle_path, self.polygon_layer, self.map_layer + ]) + self.canvas.zoomToFullExtent() + self.canvas.freeze(True) + self.canvas.show() + self.canvas.refresh() + self.canvas.freeze(False) + self.canvas.repaint() + + def add_toolbar(self): + actions = [ + ("Zoom in", self.zoom_in), + ("Zoom out", self.zoom_out), + ("Pan", self.pan), + ("Polygon", self.polygon) + ] + + for name, func in actions: + action = QAction(name, self) + action.setCheckable(True) + action.triggered.connect(func) + self.toolbar.addAction(action) + setattr(self, f"action_{name.lower().replace(' ', '_')}", action) + + self.tool_pan = QgsMapToolPan(self.canvas) + self.tool_pan.setAction(self.action_pan) + self.tool_zoom_in = QgsMapToolZoom(self.canvas, False) + self.tool_zoom_in.setAction(self.action_zoom_in) + self.tool_zoom_out = QgsMapToolZoom(self.canvas, True) + self.tool_zoom_out.setAction(self.action_zoom_out) + self.tool_polygon = PolygonMapTool(self.canvas) + self.tool_polygon.setAction(self.action_polygon) + + def polygon(self): + self.canvas.setMapTool(self.tool_polygon) + + def zoom_in(self): + self.canvas.setMapTool(self.tool_zoom_in) + + def zoom_out(self): + self.canvas.setMapTool(self.tool_zoom_out) + + def pan(self): + self.canvas.setMapTool(self.tool_pan) + + def plot_vehicle(self, id, coord): + lat, lon = coord[0], coord[1] + point = self.transform_to_web.transform(QgsPointXY(lon, lat)) + if self.vehicle is None: + return + + vehicle_data = self.vehicle_data.get(id, VehicleData()) + self.vehicle_data[id] = vehicle_data + + if vehicle_data.ind > 0: + self._update_vehicle_path(vehicle_data, point) + self._update_vehicle_position(vehicle_data) + + vehicle_data.last_loc = point + self._add_new_vehicle_position(point) + self.ind += 1 + vehicle_data.ind = self.ind + + def _update_vehicle_path(self, vehicle_data, point): + lpr = self.vehicle_path.dataProvider() + lin = QgsGeometry.fromPolylineXY([vehicle_data.last_loc, point]) + line_feat = QgsFeature() + line_feat.setGeometry(lin) + lpr.addFeatures([line_feat]) + + def _update_vehicle_position(self, vehicle_data): + self.vehicle.startEditing() + self.vehicle.deleteFeature(vehicle_data.ind) + self.vehicle.commitChanges() + + def _add_new_vehicle_position(self, point): + vpr = self.vehicle.dataProvider() + pnt = QgsGeometry.fromPointXY(point) + f = QgsFeature() + f.setGeometry(pnt) + vpr.addFeatures([f]) + self.vehicle.updateExtents() + + def plot_cone(self, coord): + lat, lon, heading = coord[0], coord[1], coord[4] + power_arr = [2.4, 4, 5, 2.1, 3, 8, 5.9, 2, 1, 3, 5, 4] + aind = self.ind_cone % 12 + power = power_arr[aind] + + point = self.transform_to_web.transform(QgsPointXY(lon, lat)) + self.cone_min = min(self.cone_min, power) + self.cone_max = max(self.cone_max, power) + + if self.cones is None: + return + + if self.ind_cone > 4: + self._remove_old_cone() + + self._update_cone_colors() + self._add_new_cone(point, heading, power) + + def _remove_old_cone(self): + self.cones.startEditing() + self.cones.deleteFeature(self.ind_cone - 5) + self.cones.commitChanges() + + def _update_cone_colors(self): + updates = {} + opacity = 1 + for update_ind in range(self.ind_cone, max(self.ind_cone - 5, -1), -1): + feature = self.cones.getFeature(update_ind) + amp = feature.attributes()[1] + color = self.calc_color(amp, self.cone_min, self.cone_max, opacity) + height = self.calc_height(amp, self.cone_min, self.cone_max) + updates[update_ind] = {2: color, 3: height} + opacity -= opacity -= 0.2 + + self.cones.dataProvider().changeAttributeValues(updates) + + def _add_new_cone(self, point, heading, power): + cpr = self.cones.dataProvider() + feature = QgsFeature() + feature.setFields(self.cones.fields()) + feature.setGeometry(QgsGeometry.fromPointXY(point)) + feature.setAttribute(0, heading) + feature.setAttribute(1, power) + feature.setAttribute(2, self.calc_color(power, self.cone_min, self.cone_max, 1)) + feature.setAttribute(3, self.calc_height(power, self.cone_min, self.cone_max)) + feature.setAttribute(4, "bottom") + cpr.addFeatures([feature]) + self.cones.updateExtents() + self.ind_cone += 1 + + def calc_color(self, amp, min_amp, max_amp, opac): + if min_amp == max_amp: + color_ratio = 0.5 + else: + color_ratio = (amp - min_amp) / (max_amp - min_amp) + red = int(255 * color_ratio) + blue = int(255 * (1 - color_ratio)) + opacity = int(255 * opac) + return f"#{opacity:02x}{red:02x}00{blue:02x}" + + def calc_height(self, amp, min_amp, max_amp): + if min_amp == max_amp: + return 4.0 + return 3.0 * (amp - min_amp) / (max_amp - min_amp) + 1 + + def plot_ping(self, coord, power): + lat, lon = coord[0], coord[1] + point = self.transform_to_web.transform(QgsPointXY(lon, lat)) + + if self.ping_layer is None: + return + + self._update_ping_range(power) + self._add_new_ping(point, power) + + def _update_ping_range(self, power): + change = False + if power < self.ping_min: + self.ping_min = power + change = True + if power > self.ping_max: + self.ping_max = power + change = True + + if change: + self._update_ping_renderer() + + def _update_ping_renderer(self): + r = self.ping_max - self.ping_min + ranges = [0.14, 0.28, 0.42, 0.56, 0.7, 0.84] + labels = ['Blue', 'Cyan', 'Green', 'Yellow', 'Orange', 'ORed', 'Red'] + + for i, (label, upper) in enumerate(zip(labels, ranges + [1])): + lower = ranges[i-1] if i > 0 else 0 + self.ping_renderer.updateRangeLowerValue(i, self.ping_min + r * lower) + self.ping_renderer.updateRangeUpperValue(i, self.ping_min + r * upper) + + def _add_new_ping(self, point, power): + vpr = self.ping_layer.dataProvider() + feature = QgsFeature() + feature.setFields(self.ping_layer.fields()) + feature.setGeometry(QgsGeometry.fromPointXY(point)) + feature.setAttribute(0, power) + vpr.addFeatures([feature]) + self.ping_layer.updateExtents() + + def plot_estimate(self, coord, frequency): + lat, lon = coord[0], coord[1] + point = self.transform_to_web.transform(QgsPointXY(lon, lat)) + + if self.estimate is None: + return + + if self.ind_est > 0: + self._update_estimate(point) + else: + self._add_new_estimate(point) + + def _update_estimate(self, point): + self.estimate.startEditing() + self.estimate.deleteFeature(self.ind_est) + self.estimate.commitChanges() + self._add_new_estimate(point) + + def _add_new_estimate(self, point): + vpr = self.estimate.dataProvider() + feature = QgsFeature() + feature.setGeometry(QgsGeometry.fromPointXY(point)) + vpr.addFeatures([feature]) + self.estimate.updateExtents() + self.ind_est += 1 + +class MapOptions(QWidget): + def __init__(self): + super().__init__() + self.map_widget = None + self.btn_cache_map = None + self.is_web_map = False + self.lbl_dist = None + self._create_widgets() + self.created = False + self.writer = None + self.has_point = False + self.user_pops = UserPopups() + + def _create_widgets(self): + layout = QVBoxLayout() + layout.addWidget(QLabel('Map Options')) + + self.btn_set_search_area = QPushButton('Set Search Area') + self.btn_set_search_area.setEnabled(False) + layout.addWidget(self.btn_set_search_area) + + self.btn_cache_map = QPushButton('Cache Map') + self.btn_cache_map.clicked.connect(self._cache_map) + self.btn_cache_map.setEnabled(False) + layout.addWidget(self.btn_cache_map) + + self.btn_clear_map = QPushButton('Clear Map') + self.btn_clear_map.clicked.connect(self.clear) + layout.addWidget(self.btn_clear_map) + + export_layout = QVBoxLayout() + for label in ['Pings', 'Vehicle Path', 'Polygon', 'Cones']: + btn = QPushButton(label) + btn.clicked.connect(getattr(self, f'export_{label.lower().replace(" ", "_")}')) + export_layout.addWidget(btn) + + export_widget = QWidget() + export_widget.setLayout(export_layout) + layout.addWidget(export_widget) + + dist_layout = QHBoxLayout() + dist_layout.addWidget(QLabel('Distance from Actual')) + self.lbl_dist = QLabel('') + dist_layout.addWidget(self.lbl_dist) + layout.addLayout(dist_layout) + + self.setLayout(layout) + + def clear(self): + if self.map_widget is None: + return + self.map_widget.tool_polygon.rubber_band.reset(QgsWkbTypes.PolygonGeometry) + self.map_widget.tool_rect.rubber_band.reset(QgsWkbTypes.PolygonGeometry) + self.map_widget.tool_polygon.vertices.clear() + + def _cache_map(self): + if not self.is_web_map: + print("alert") + return + + if self.map_widget.tool_rect.rectangle() is None: + self.user_pops.show_warning( + "Use the rect tool to choose an area on the map to cache", + "No specified area to cache!" + ) + self.map_widget.rect() + else: + cache_thread = Thread(target=self.map_widget.cache_map) + cache_thread.start() + self.map_widget.canvas.refresh() + + def set_map(self, map_widget: 'MapWidget', is_web_map: bool): + self.is_web_map = is_web_map + self.map_widget = map_widget + self.btn_cache_map.setEnabled(is_web_map) + + def est_distance(self, coord, stale, res): + lat1, lon1 = coord[0], coord[1] + lat2, lon2 = 32.885889, -117.234028 + + if not self.has_point: + self._add_ground_truth_point(lat2, lon2) + + dist = self.distance(lat1, lat2, lon1, lon2) + self._write_results(dist, res) + self.lbl_dist.setText(f'{dist:.3f}(m.)') + + def _add_ground_truth_point(self, lat, lon): + point = self.map_widget.transform_to_web.transform(QgsPointXY(lon, lat)) + vpr = self.map_widget.ground_truth.dataProvider() + feature = QgsFeature() + feature.setGeometry(QgsGeometry.fromPointXY(point)) + vpr.addFeatures([feature]) + self.map_widget.ground_truth.updateExtents() + self.has_point = True + + def _write_results(self, dist, res): + field_names = ['Distance', 'res.x', 'residuals'] + mode = 'w' if not self.created else 'a+' + with open('results.csv', mode, newline='') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=field_names) + if not self.created: + writer.writeheader() + self.created = True + writer.writerow({ + 'Distance': str(dist), + 'res.x': str(res.x), + 'residuals': str(res.fun) + }) + + @staticmethod + def distance(lat1, lat2, lon1, lon2): + lon1, lon2, lat1, lat2 = map(math.radians, [lon1, lon2, lat1, lat2]) + dlon = lon2 - lon1 + dlat = lat2 - lat1 + a = math.sin(dlat/2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon/2)**2 + c = 2 * math.asin(math.sqrt(a)) + r = 6371 # Radius of Earth in kilometers + return c * r * 1000 + + def export_layer(self, layer_name: str, file_name: str): + if self.map_widget is None: + self.user_pops.show_warning("Load a map before exporting.") + return + + folder = str(QFileDialog.getExistingDirectory(self, "Select Directory")) + file_path = os.path.join(folder, file_name) + options = QgsVectorFileWriter.SaveVectorOptions() + options.driverName = "ESRI Shapefile" + + layer = getattr(self.map_widget, layer_name) + QgsVectorFileWriter.writeAsVectorFormatV2( + layer, file_path, QgsCoordinateTransformContext(), options + ) + + def export_ping(self): + self.export_layer('ping_layer', 'pings.shp') + + def export_vehicle_path(self): + self.export_layer('vehicle_path', 'vehicle_path.shp') + + def export_polygon(self): + if self.map_widget is None: + self.user_pops.show_warning("Load a map before exporting.") + return + + if self.map_widget.tool_polygon is None: + return + + if not self.map_widget.tool_polygon.vertices: + self.user_pops.show_warning( + "Use the polygon tool to choose an area on the map to export", + "No specified area to export!" + ) + self.map_widget.polygon() + return + + vpr = self.map_widget.polygon_layer.dataProvider() + points = self.map_widget.tool_polygon.vertices + poly_geom = QgsGeometry.fromPolygonXY([points]) + + feature = QgsFeature() + feature.setGeometry(poly_geom) + vpr.addFeatures([feature]) + self.map_widget.polygon_layer.updateExtents() + + folder = str(QFileDialog.getExistingDirectory(self, "Select Directory")) + file_path = os.path.join(folder, 'polygon.shp') + options = QgsVectorFileWriter.SaveVectorOptions() + options.driverName = "ESRI Shapefile" + + QgsVectorFileWriter.writeAsVectorFormatV2( + self.map_widget.polygon_layer, file_path, + QgsCoordinateTransformContext(), options + ) + vpr.truncate() + + def export_cone(self): + self.export_layer('cones', 'cones.shp') + +class WebMap(MapWidget): + def __init__(self, root, p1_lat, p1_lon, p2_lat, p2_lon, load_cached): + super().__init__(root) + self.load_cached = load_cached + self.add_layers() + self.adjust_canvas() + r = QgsRectangle(p1_lon, p2_lat, p2_lon, p1_lat) + rect = self.transform_to_web.transformBoundingBox(r) + self.canvas.zoomToFeatureExtent(rect) + self.add_toolbar() + self.add_rect_tool() + self.pan() + self._setup_layout(root) + + def _setup_layout(self, root): + self.holder.addWidget(self.toolbar) + self.holder.addWidget(self.canvas) + self.setLayout(self.holder) + root.addWidget(self, 0, 1, 1, 2) + self.root = root + + def add_layers(self): + layers = [ + self.set_up_estimate(), + self.set_up_vehicle_layers(), + self.set_up_ping_layer(), + self.set_up_ground_truth(), + self.set_up_cone_layer(), + self.set_up_polygon_layer() + ] + for layer in layers: + if isinstance(layer, tuple): + for l in layer: + QgsProject.instance().addMapLayer(l) + else: + QgsProject.instance().addMapLayer(layer) + + self._setup_base_layer() + + def _setup_base_layer(self): + if self.load_cached: + path = QDir().currentPath() + url_with_params = f'type=xyz&url=file:///{path}/tiles/%7Bz%7D/%7Bx%7D/%7By%7D.png' + else: + url_with_params = 'type=xyz&url=http://a.tile.openstreetmap.org/%7Bz%7D/%7Bx%7D/%7By%7D.png&zmax=19&zmin=0&crs=EPSG3857' + + self.map_layer = QgsRasterLayer(url_with_params, 'OpenStreetMap', 'wms') + + if self.map_layer.isValid(): + crs = QgsCoordinateReferenceSystem("EPSG:3857") + self.map_layer.setCrs(crs) + QgsProject.instance().addMapLayer(self.map_layer) + else: + print('invalid map_layer') + raise RuntimeError("Invalid map layer") + + def add_rect_tool(self): + self.rect_action = QAction("Rect", self) + self.rect_action.setCheckable(True) + self.rect_action.triggered.connect(self.rect) + self.toolbar.addAction(self.rect_action) + self.tool_rect = RectangleMapTool(self.canvas) + self.tool_rect.setAction(self.rect_action) + + def rect(self): + self.canvas.setMapTool(self.tool_rect) + + @staticmethod + def degree_to_tile_num(lat_deg, lon_deg, zoom): + lat_rad = math.radians(lat_deg) + n = 2.0 ** zoom + x = int((lon_deg + 180.0) / 360.0 * n) + y = int((1.0 - math.asinh(math.tan(lat_rad)) / math.pi) / 2.0 * n) + return (x, y) + + def cache_map(self): + if self.tool_rect.rectangle() is None: + return + + rect = self.tool_rect.rectangle() + r = self.transform.transformBoundingBox(rect, QgsCoordinateTransform.ForwardTransform, True) + print(f"Rectangle: {r.xMinimum()}, {r.yMinimum()}, {r.xMaximum()}, {r.yMaximum()}") + + if r is not None: + self._download_tiles(r) + + def _download_tiles(self, r): + zoom_start = 17 + tile_count = 0 + for zoom in range(zoom_start, 19): + x_min, y_min = self.degree_to_tile_num(r.yMinimum(), r.xMinimum(), zoom) + x_max, y_max = self.degree_to_tile_num(r.yMaximum(), r.xMaximum(), zoom) + print(f"Zoom: {zoom}") + print(f"{x_min}, {x_max}, {y_min}, {y_max}") + for x in range(x_min, x_max + 1): + for y in range(y_max, y_min + 1): + if tile_count < 200: + time.sleep(1) + downloaded = self.download_tile(x, y, zoom) + if downloaded: + tile_count += 1 + else: + print("Tile count exceeded, please try again in a few minutes") + return + print("Download Complete") + + def download_tile(self, x_tile, y_tile, zoom): + url = f"http://c.tile.openstreetmap.org/{zoom}/{x_tile}/{y_tile}.png" + dir_path = f"tiles/{zoom}/{x_tile}/" + download_path = f"{dir_path}{y_tile}.png" + + if not os.path.exists(dir_path): + os.makedirs(dir_path) + + if not os.path.isfile(download_path): + print(f"downloading {url}") + headers = { + 'User-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/101.0.4951.54 Safari/537.36' + } + response = requests.get(url, headers=headers) + if response.status_code == 200: + with open(download_path, 'wb') as f: + f.write(response.content) + return True + else: + print(f"Failed to download {url}") + return False + else: + print(f"skipped {url}") + return False + +class StaticMap(MapWidget): + def __init__(self, root): + super().__init__(root) + self.file_name = self._get_file_name() + self._add_layers() + self.adjust_canvas() + self.add_toolbar() + self.pan() + self._setup_layout(root) + + def _get_file_name(self): + return QFileDialog.getOpenFileName()[0] + + def _add_layers(self): + if not self.file_name: + return + + self.map_layer = QgsRasterLayer(self.file_name, "SRTM layer name") + if not self.map_layer.crs().isValid(): + raise FileNotFoundError("Invalid file, loading from web...") + print(self.map_layer.crs()) + + layers = [ + self._setup_estimate_layer(), + self._setup_vehicle_layers(), + self._setup_ping_layer(), + ] + + for layer in layers: + if isinstance(layer, tuple): + for l in layer: + QgsProject.instance().addMapLayer(l) + else: + QgsProject.instance().addMapLayer(layer) + + QgsProject.instance().addMapLayer(self.map_layer) + + def _setup_estimate_layer(self): + uri = "Point?crs=epsg:4326" + layer = QgsVectorLayer(uri, 'Estimate', "memory") + symbol = QgsMarkerSymbol.createSimple({'name': 'diamond', 'color': 'blue'}) + layer.renderer().setSymbol(symbol) + layer.setAutoRefreshInterval(500) + layer.setAutoRefreshEnabled(True) + return layer + + def _setup_vehicle_layers(self): + uri = "Point?crs=epsg:4326" + uri_line = "Linestring?crs=epsg:4326" + vehicle_layer = QgsVectorLayer(uri, 'Vehicle', "memory") + vehicle_path_layer = QgsVectorLayer(uri_line, 'vehicle_path', "memory") + + path = os.path.join(QDir.currentPath(), 'camera.svg') + symbol_svg = QgsSvgMarkerSymbolLayer(path) + symbol_svg.setSize(4) + symbol_svg.setFillColor(QColor('#0000ff')) + symbol_svg.setStrokeColor(QColor('#ff0000')) + symbol_svg.setStrokeWidth(1) + + vehicle_layer.renderer().symbol().changeSymbolLayer(0, symbol_svg) + vehicle_layer.setAutoRefreshInterval(500) + vehicle_layer.setAutoRefreshEnabled(True) + vehicle_path_layer.setAutoRefreshInterval(500) + vehicle_path_layer.setAutoRefreshEnabled(True) + + return vehicle_layer, vehicle_path_layer + + def _setup_ping_layer(self): + uri = "Point?crs=epsg:4326" + layer = QgsVectorLayer(uri, 'Pings', 'memory') + + symbols = { + 'blue': QgsSymbol.defaultSymbol(layer.geometryType()), + 'green': QgsSymbol.defaultSymbol(layer.geometryType()), + 'yellow': QgsSymbol.defaultSymbol(layer.geometryType()), + 'orange': QgsSymbol.defaultSymbol(layer.geometryType()), + 'red': QgsSymbol.defaultSymbol(layer.geometryType()) + } + + for color, symbol in symbols.items(): + symbol.setColor(QColor(color)) + + ranges = [ + QgsRendererRange(0, 20, symbols['blue'], 'Blue'), + QgsRendererRange(20, 40, symbols['green'], 'Green'), + QgsRendererRange(40, 60, symbols['yellow'], 'Yellow'), + QgsRendererRange(60, 80, symbols['orange'], 'Orange'), + QgsRendererRange(80, 100, symbols['red'], 'Red') + ] + + renderer = QgsGraduatedSymbolRenderer('Amp', ranges) + classification_method = QgsApplication.classificationMethodRegistry().method("EqualInterval") + renderer.setClassificationMethod(classification_method) + renderer.setClassAttribute('Amp') + + layer.dataProvider().addAttributes([QgsField(name='Amp', type=QVariant.Double, len=30)]) + layer.updateFields() + + layer.setRenderer(renderer) + layer.setAutoRefreshInterval(500) + layer.setAutoRefreshEnabled(True) + + return layer \ No newline at end of file diff --git a/rttDroneGCS/gui/popups.py b/rttDroneGCS/gui/popups.py new file mode 100644 index 0000000..c1ddbe3 --- /dev/null +++ b/rttDroneGCS/gui/popups.py @@ -0,0 +1,116 @@ +"""Popups for the GUI.""" + +from __future__ import annotations + +from PyQt5.QtCore import Qt, QTimer +from PyQt5.QtWidgets import ( + QButtonGroup, + QGridLayout, + QLabel, + QLineEdit, + QMessageBox, + QRadioButton, +) + + +class UserPopups: + """Creates popup boxes for user display.""" + + def create_text_box(self, name: str, text: str) -> tuple[QGridLayout, QLineEdit]: + """Create a text box. + + Args: + ---- + name (str): The name of the text box. + text (str): The text to display in the text box. + + Returns: + ------- + tuple[QGridLayout, QLineEdit]: The text box. + + """ + form = QGridLayout() + form.setColumnStretch(0, 0) + + label = QLabel(name) + form.addWidget(label) + + line = QLineEdit() + line.setText(text) + + form.addWidget(line, 0, 1) + return form, line + + def create_binary_box( + self, + name: str, + labels_list: list[str], + ) -> tuple[QGridLayout, QRadioButton]: + """Create a binary box. + + Args: + ---- + name (str): The name of the binary box. + labels_list (list[str]): The labels of the binary box. + + Returns: + ------- + tuple[QGridLayout, QRadioButton]: The binary box. + + """ + form = QGridLayout() + form.setColumnStretch(1, 0) + label = QLabel(name) + form.addWidget(label) + true_event = QRadioButton(labels_list[0]) + false_event = QRadioButton(labels_list[1]) + true_event.setChecked(True) + + button = QButtonGroup(parent=form) + button.setExclusive(True) + button.addButton(true_event) + button.addButton(false_event) + form.addWidget(true_event, 0, 1, Qt.AlignLeft) + form.addWidget(false_event, 0, 0, Qt.AlignRight) + return form, true_event + + def show_warning(self, text: str, title: str) -> None: + """Show a warning. + + Args: + ---- + text (str): The text to display in the warning. + title (str): The title of the warning. + + """ + msg = QMessageBox() + msg.setText(title) + msg.setWindowTitle("Alert") + msg.setInformativeText(text) + msg.setIcon(QMessageBox.Critical) + msg.addButton(QMessageBox.Ok) + msg.exec_() + + def show_timed_warning( + self, + text: str, + timeout: int, + title: str = "Warning", + ) -> None: + """Show a timed warning. + + Args: + ---- + text (str): The text to display in the warning. + timeout (int): The timeout of the warning. + title (str): The title of the warning. + + """ + msg = QMessageBox() + QTimer.singleShot(timeout * 1000, lambda: msg.done(0)) + msg.setText(title) + msg.setWindowTitle("Alert") + msg.setInformativeText(text) + msg.setIcon(QMessageBox.Critical) + msg.addButton(QMessageBox.Ok) + msg.exec_() diff --git a/rttDroneGCS/mav/__init__.py b/rttDroneGCS/mav/__init__.py new file mode 100644 index 0000000..3763252 --- /dev/null +++ b/rttDroneGCS/mav/__init__.py @@ -0,0 +1,11 @@ +from .mav_model import MAVModel +from .enums import Events, SDRInitStates, ExtsStates, OutputDirStates, RTTStates + +__all__ = [ + "MAVModel", + "Events", + "SDRInitStates", + "ExtsStates", + "OutputDirStates", + "RTTStates", +] diff --git a/rttDroneGCS/mav/enums.py b/rttDroneGCS/mav/enums.py new file mode 100644 index 0000000..0460e25 --- /dev/null +++ b/rttDroneGCS/mav/enums.py @@ -0,0 +1,61 @@ +"""Enums for the MAVModel class.""" + +from enum import Enum, auto + + +class Events(Enum): + """Callback Events.""" + + Heartbeat = auto() + Exception = auto() + GetFreqs = auto() + GetOptions = auto() + NoHeartbeat = auto() + NewPing = auto() + NewEstimate = auto() + UpgradeStatus = auto() + VehicleInfo = auto() + ConeInfo = auto() + + +class SDRInitStates(Enum): + """SDR Initialization States.""" + + find_devices = 0 + wait_recycle = 1 + usrp_probe = 2 + rdy = 3 + fail = 4 + + +class ExtsStates(Enum): + """GPS Initialization States.""" + + get_tty = 0 + get_msg = 1 + wait_recycle = 2 + rdy = 3 + fail = 4 + + +class OutputDirStates(Enum): + """Output Directory Initialization States.""" + + get_output_dir = 0 + check_output_dir = 1 + check_space = 2 + wait_recycle = 3 + rdy = 4 + fail = 5 + + +class RTTStates(Enum): + """System Initialization States.""" + + init = 0 + wait_init = 1 + wait_start = 2 + start = 3 + wait_end = 4 + finish = 5 + fail = 6 diff --git a/rttDroneGCS/mav/mav_model.py b/rttDroneGCS/mav/mav_model.py new file mode 100644 index 0000000..88754de --- /dev/null +++ b/rttDroneGCS/mav/mav_model.py @@ -0,0 +1,456 @@ +"""Module for the MAVModel class.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any, Callable + +import rttDroneComms.comms + +from rttDroneGCS.ping import DataManager, RTTPing + +from .enums import Events, ExtsStates, OutputDirStates, RTTStates, SDRInitStates + + +class MAVModel: + """Provides an object-oriented view of the vehicle state.""" + + BASE_OPTIONS = 0x00 + EXP_OPTIONS = 0x01 + ENG_OPTIONS = 0xFF + TGT_PARAMS = 0x100 + + CACHE_GOOD = 0 + CACHE_INVALID = 1 + CACHE_DIRTY = 2 + + def __init__(self, receiver: rttDroneComms.comms.gcsComms) -> None: + """Initialize a new MAVModel. + + Args: + ---- + receiver (rttDroneComms.comms.gcsComms): gcsComms Object + + """ + self._log = logging.getLogger("rttDroneGCS:MavModel") + self._rx = receiver + self._option_cache_dirty = { + self.BASE_OPTIONS: self.CACHE_INVALID, + self.EXP_OPTIONS: self.CACHE_INVALID, + self.ENG_OPTIONS: self.CACHE_INVALID, + self.TGT_PARAMS: self.CACHE_INVALID, + } + self.state = self._initialize_state() + self.pp_options = self._initialize_pp_options() + self.est_mgr = DataManager() + self._callbacks = {event: [] for event in Events} + self._ack_vectors = {} + self._register_rx_callbacks() + + def _initialize_state(self) -> dict: + return { + "STS_sdr_status": 0, + "STS_dir_status": 0, + "STS_gps_status": 0, + "STS_sys_status": 0, + "STS_sw_status": 0, + "UPG_state": -1, + "UPG_msg": "", + "VCL_track": {}, + "CONE_track": {}, + } + + def _initialize_pp_options(self) -> dict: + return { + "TGT_frequencies": [], + "SDR_center_freq": 0, + "SDR_sampling_freq": 0, + "SDR_gain": 0, + "DSP_ping_width": 0, + "DSP_ping_snr": 0, + "DSP_ping_max": 0, + "DSP_ping_min": 0, + "SYS_output_dir": "", + "GPS_mode": 0, + "GPS_baud": 0, + "GPS_device": "", + "SYS_autostart": False, + } + + def _register_rx_callbacks(self) -> None: + callback_mapping = { + rttDroneComms.comms.EVENTS.STATUS_HEARTBEAT: self._process_heartbeat, + rttDroneComms.comms.EVENTS.GENERAL_NO_HEARTBEAT: self._process_no_heartbeat, + rttDroneComms.comms.EVENTS.STATUS_EXCEPTION: self._handle_remote_exception, + rttDroneComms.comms.EVENTS.COMMAND_ACK: self._process_ack, + rttDroneComms.comms.EVENTS.CONFIG_FREQUENCIES: self._process_frequencies, + rttDroneComms.comms.EVENTS.CONFIG_OPTIONS: self._process_options, + rttDroneComms.comms.EVENTS.DATA_PING: self._process_ping, + rttDroneComms.comms.EVENTS.DATA_VEHICLE: self._process_vehicle, + rttDroneComms.comms.EVENTS.DATA_CONE: self._process_cone, + } + for event, callback in callback_mapping.items(): + self._rx.registerCallback(event, callback) + + async def start(self) -> None: + """Initialize the MAVModel object.""" + await self._rx.start() + self._log.info("MAVModel Started") + + async def stop(self) -> None: + """Stop the MAVModel and underlying resources.""" + await self._rx.stop() + self._log.info("MAVModel Stopped") + + def register_callback(self, event: Events, callback: Callable) -> None: + """Register a callback for the specific event. + + Args: + ---- + event (Events): Event to trigger on + callback (Callable): Callback to call + + Raises: + ------ + TypeError: If the event type is invalid + + """ + if not isinstance(event, Events): + msg = "Invalid event type" + raise TypeError(msg) + self._callbacks[event].append(callback) + + async def start_mission(self) -> None: + """Send the start mission command.""" + await self._send_command_and_wait( + 0x07, + rttDroneComms.comms.rttSTARTCommand(), + "START", + ) + + async def stop_mission(self) -> None: + """Send the stop mission command.""" + await self._send_command_and_wait( + 0x09, + rttDroneComms.comms.rttSTOPCommand(), + "STOP", + ) + + async def _send_command_and_wait( + self, + command_id: int, + command: rttDroneComms.comms.BinaryPacket, + command_name: str, + ) -> None: + event = asyncio.Event() + self._ack_vectors[command_id] = [event, 0] + await self._rx.sendPacket(command) + self._log.info(f"Sent {command_name} command") + try: + await asyncio.wait_for(event.wait(), timeout=10) + except asyncio.TimeoutError as err: + msg = f"{command_name} timed out" + raise RuntimeError(msg) from err + if not self._ack_vectors.pop(command_id)[1]: + msg = f"{command_name} NACKED" + raise RuntimeError(msg) + + async def get_frequencies(self) -> list[int]: + """Retrieve the PRX_frequencies from the payload. + + Returns + ------- + List of frequencies + + """ + if self._option_cache_dirty[self.TGT_PARAMS] == self.CACHE_INVALID: + frequency_pack_event = asyncio.Event() + self.register_callback(Events.GetFreqs, frequency_pack_event.set) + + await self._rx.sendPacket(rttDroneComms.comms.rttGETFCommand()) + self._log.info("Sent getF command") + + try: + await asyncio.wait_for( + frequency_pack_event.wait(), + timeout=10, + ) + except asyncio.TimeoutError as err: + msg = "Timeout waiting for frequencies" + raise RuntimeError(msg) from err + return self.pp_options["TGT_frequencies"] + + async def set_frequencies(self, freqs: list[int]) -> None: + """Send the command to set the specific PRX_frequencies. + + Args: + ---- + freqs (list[int]): Frequencies to set + + """ + if not isinstance(freqs, list) or not all( + isinstance(freq, int) for freq in freqs + ): + msg = "Invalid frequencies" + raise TypeError(msg) + + self._option_cache_dirty[self.TGT_PARAMS] = self.CACHE_DIRTY + await self._send_command_and_wait( + 0x03, + rttDroneComms.comms.rttSETFCommand(freqs), + "SETF", + ) + + async def add_frequency(self, frequency: int) -> None: + """Add the specified frequency to the target frequencies. + + If the specified frequency is already in TGT_frequencies, this function does + nothing. Otherwise, this function will update the TGT_frequencies on the + payload. + + Args: + ---- + frequency (int): Frequency to add + + """ + if frequency not in self.pp_options["TGT_frequencies"]: + await self.set_frequencies( + self.pp_options["TGT_frequencies"] + [frequency], + ) + + async def remove_frequency(self, frequency: int) -> None: + """Remove the specified frequency from the target frequencies. + + If the specified frequency is not in TGT_frequencies, this function raises a + RuntimeError. Otherwise, this function will update the TGT_frequencies on the + payload. + + Args: + ---- + frequency (int): Frequency to remove + + """ + if frequency not in self.pp_options["TGT_frequencies"]: + msg = "Invalid frequency" + raise RuntimeError(msg) + new_freqs = [f for f in self.pp_options["TGT_frequencies"] if f != frequency] + await self.set_frequencies(new_freqs) + + async def get_options(self, scope: int) -> dict: + """Retrieve and return the options as a dictionary from the remote. + + Scope should be set to one of MAVModel.BASE_OPTIONS, MAVModel.EXP_OPTIONS, or + MAVModel.ENG_OPTIONS. + + Args: + ---- + scope (int): Scope of options to retrieve + + Returns: + ------- + Dictionary of options + + """ + option_packet_event = asyncio.Event() + self.register_callback(Events.GetOptions, option_packet_event.set) + + await self._rx.sendPacket(rttDroneComms.comms.rttGETOPTCommand(scope)) + self._log.info("Sent GETOPT command") + + try: + await asyncio.wait_for(option_packet_event.wait(), timeout=10) + except asyncio.TimeoutError as err: + msg = "Timeout waiting for options" + raise RuntimeError(msg) from err + + accepted_keywords = [] + if scope >= self.BASE_OPTIONS: + accepted_keywords.extend(self._base_option_keywords) + if scope >= self.EXP_OPTIONS: + accepted_keywords.extend(self._exp_option_keywords) + if scope >= self.ENG_OPTIONS: + accepted_keywords.extend(self._eng_option_keywords) + + return {key: self.pp_options[key] for key in accepted_keywords} + + async def get_option(self, keyword: str) -> any | None: + """Retrieve a specific option by keyword. + + Args: + ---- + keyword: Keyword of the option to retrieve + + Returns: + ------- + The option value or None if the keyword is invalid + + """ + option_groups = [ + (self._base_option_keywords, self.BASE_OPTIONS), + (self._exp_option_keywords, self.EXP_OPTIONS), + (self._eng_option_keywords, self.ENG_OPTIONS), + ] + + for keywords, scope in option_groups: + if ( + keyword in keywords + and self._option_cache_dirty[scope] == self.CACHE_INVALID + ): + options = await self.get_options(scope) + return options.get(keyword) + + return self.pp_options.get(keyword) + + async def set_options(self, **kwargs: dict[str, Any]) -> None: + """Set the specified options on the payload. + + Args: + ---- + kwargs: Options to set by keyword + + """ + scope = self.BASE_OPTIONS + for keyword in kwargs: + if keyword in self._base_option_keywords: + scope = max(scope, self.BASE_OPTIONS) + elif keyword in self._exp_option_keywords: + scope = max(scope, self.EXP_OPTIONS) + elif keyword in self._eng_option_keywords: + scope = max(scope, self.ENG_OPTIONS) + else: + msg = f"Invalid option keyword: {keyword}" + raise KeyError(msg) + + self.pp_options.update(kwargs) + accepted_keywords = [] + if scope >= self.BASE_OPTIONS: + self._option_cache_dirty[self.BASE_OPTIONS] = self.CACHE_DIRTY + accepted_keywords.extend(self._base_option_keywords) + if scope >= self.EXP_OPTIONS: + self._option_cache_dirty[self.EXP_OPTIONS] = self.CACHE_DIRTY + accepted_keywords.extend(self._exp_option_keywords) + if scope >= self.ENG_OPTIONS: + self._option_cache_dirty[self.ENG_OPTIONS] = self.CACHE_DIRTY + accepted_keywords.extend(self._eng_option_keywords) + + await self._send_command_and_wait( + 0x05, + rttDroneComms.comms.rttSETOPTCommand( + scope, + **{key: self.pp_options[key] for key in accepted_keywords}, + ), + "SETOPT", + ) + + async def send_upgrade_packet(self, byte_stream: bytes) -> None: + """Send the upgrade packet to the payload. + + Args: + ---- + byte_stream (bytes): Byte stream to send + + """ + num_packets = -1 + if len(byte_stream) % 1000 != 0: + num_packets = len(byte_stream) // 1000 + 1 + else: + num_packets = len(byte_stream) // 1000 + for i in range(num_packets): + start_idx = i * 1000 + end_idx = start_idx + 1000 + await self._rx.sendPacket( + rttDroneComms.comms.rttUpgradePacket( + i + 1, + num_packets, + byte_stream[start_idx:end_idx], + ), + ) + + async def _process_heartbeat( + self, + packet: rttDroneComms.comms.rttHeartBeatPacket, + ) -> None: + self._log.info("Received heartbeat") + self.state.update( + { + "STS_sdr_status": SDRInitStates(packet.sdrState), + "STS_dir_status": OutputDirStates(packet.storageState), + "STS_gps_status": ExtsStates(packet.sensorState), + "STS_sys_status": RTTStates(packet.systemState), + "STS_sw_status": packet.switchState, + }, + ) + for callback in self._callbacks[Events.Heartbeat]: + callback() + + async def _process_no_heartbeat(self) -> None: + for callback in self._callbacks[Events.NoHeartbeat]: + callback() + + async def _handle_remote_exception( + self, + packet: rttDroneComms.comms.rttRemoteExceptionPacket, + ) -> None: + self._log.exception("Remote Exception: %s", packet.exception) + self._log.exception("Remote Traceback: %s", packet.traceback) + for callback in self._callbacks[Events.Exception]: + callback() + + async def _process_ack(self, packet: rttDroneComms.comms.rttACKCommand) -> None: + command_id = packet.commandID + if command_id in self._ack_vectors: + vector = self._ack_vectors[command_id] + vector[1] = packet.ack + vector[0].set() + + async def _process_frequencies( + self, + packet: rttDroneComms.comms.rttFrequenciesPacket, + ) -> None: + self._log.info("Received frequencies") + self.pp_options["TGT_frequencies"] = packet.frequencies + self._option_cache_dirty[self.TGT_PARAMS] = self.CACHE_GOOD + for callback in self._callbacks[Events.GetFreqs]: + callback() + + async def _process_options( + self, + packet: rttDroneComms.comms.rttOptionsPacket, + ) -> None: + self._log.info("Received options") + for parameter in packet.options: + self.pp_options[parameter] = packet.options[parameter] + if packet.scope >= self.BASE_OPTIONS: + self._option_cache_dirty[self.BASE_OPTIONS] = self.CACHE_GOOD + if packet.scope >= self.EXP_OPTIONS: + self._option_cache_dirty[self.EXP_OPTIONS] = self.CACHE_GOOD + if packet.scope >= self.ENG_OPTIONS: + self._option_cache_dirty[self.ENG_OPTIONS] = self.CACHE_GOOD + + for callback in self._callbacks[Events.GetOptions]: + callback() + + async def _process_ping(self, packet: rttDroneComms.comms.rttPingPacket) -> None: + ping_obj = RTTPing.from_packet(packet) + estimate = self.est_mgr.add_ping(ping_obj) + for callback in self._callbacks[Events.NewPing]: + callback() + if estimate is not None: + for callback in self._callbacks[Events.NewEstimate]: + callback() + + async def _process_vehicle( + self, + packet: rttDroneComms.comms.rttVehiclePacket, + ) -> None: + coordinate = [packet.lat, packet.lon, packet.alt, packet.hdg] + self.state["VCL_track"][packet.timestamp] = coordinate + for callback in self._callbacks[Events.VehicleInfo]: + callback() + + async def _process_cone(self, packet: rttDroneComms.comms.rttConePacket) -> None: + coordinate = [packet.lat, packet.lon, packet.alt, packet.power, packet.angle] + self.state["CONE_track"][packet.timestamp] = coordinate + for callback in self._callbacks[Events.ConeInfo]: + callback() diff --git a/rttDroneGCS/ping/__init__.py b/rttDroneGCS/ping/__init__.py new file mode 100644 index 0000000..005ea59 --- /dev/null +++ b/rttDroneGCS/ping/__init__.py @@ -0,0 +1,18 @@ +from .models import RTTCone, RTTPing +from .data_manager import DataManager +from .location_estimator import LocationEstimator +from .utils import residuals, mse, rssi_to_distance, p_d +from .io_operations import save_csv, save_tiff + +__all__ = [ + "RTTCone", + "RTTPing", + "DataManager", + "LocationEstimator", + "residuals", + "mse", + "rssi_to_distance", + "p_d", + "save_csv", + "save_tiff", +] diff --git a/rttDroneGCS/ping/data_manager.py b/rttDroneGCS/ping/data_manager.py new file mode 100644 index 0000000..48e3ef2 --- /dev/null +++ b/rttDroneGCS/ping/data_manager.py @@ -0,0 +1,158 @@ +"""Module for managing data.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import utm + +from .location_estimator import LocationEstimator + +if TYPE_CHECKING: + import numpy as np + from scipy.optimize import least_squares + + from .models import RTTPing + + +class DataManager: + """Class for managing data.""" + + def __init__(self) -> None: + """Initialize the DataManager.""" + self._estimators: dict[int, LocationEstimator] = {} + self.zone: int | None = None + self.let: str | None = None + self._vehicle_path: list[tuple[float, float, float]] = [] + + def add_ping(self, ping: RTTPing) -> tuple[np.ndarray, bool] | None: + """Add a new ping and update the estimator for the corresponding frequency. + + Args: + ---- + ping (RTTPing): The ping to add. + + Returns: + ------- + Optional[Tuple[np.ndarray, bool]]: The estimation result, if available. + + """ + if ping.freq not in self._estimators: + self._estimators[ping.freq] = LocationEstimator() + + self._estimators[ping.freq].add_ping(ping) + + if self.zone is None: + self.set_zone(ping.lat, ping.lon) + + return self._estimators[ping.freq].do_estimate() + + def add_vehicle_location(self, coord: tuple[float, float, float]) -> None: + """Add a new vehicle location to the path. + + Args: + ---- + coord (Tuple[float, float, float]): The coordinates (lat, lon, alt) of the + vehicle. + + """ + self._vehicle_path.append(coord) + + def set_zone(self, lat: float, lon: float) -> None: + """Set the UTM zone based on the given latitude and longitude. + + Args: + ---- + lat (float): Latitude + lon (float): Longitude + + """ + _, _, zone, let = utm.from_latlon(lat, lon) + self.zone = zone + self.let = let + + def get_estimate( + self, + frequency: int, + ) -> tuple[np.ndarray, bool, least_squares | None] | None: + """Get the current estimate for a specific frequency. + + Args: + ---- + frequency (int): The frequency to get the estimate for. + + Returns: + ------- + Optional[Tuple[np.ndarray, bool, Optional[least_squares]]]: The estimation + result, if available. + + """ + return self._estimators.get(frequency, LocationEstimator()).get_estimate() + + def get_frequencies(self) -> list[int]: + """Get a list of all frequencies with estimators. + + Returns + ------- + List[int]: List of frequencies. + + """ + return list(self._estimators.keys()) + + def get_pings(self, frequency: int) -> list[np.ndarray]: + """Get all pings for a specific frequency. + + Args: + ---- + frequency (int): The frequency to get pings for. + + Returns: + ------- + List[np.ndarray]: List of pings for the given frequency. + + """ + return self._estimators.get(frequency, LocationEstimator()).get_pings() + + def get_num_pings(self, frequency: int) -> int: + """Get the number of pings for a specific frequency. + + Args: + ---- + frequency (int): The frequency to get the ping count for. + + Returns: + ------- + int: Number of pings for the given frequency. + + """ + return self._estimators.get(frequency, LocationEstimator()).get_num_pings() + + def get_vehicle_path(self) -> list[tuple[float, float, float]]: + """Get the vehicle path. + + Returns + ------- + List[Tuple[float, float, float]]: List of vehicle coordinates. + + """ + return self._vehicle_path + + def get_utm_zone(self) -> tuple[int | None, str | None]: + """Get the UTM zone information. + + Returns + ------- + Tuple[Optional[int], Optional[str]]: UTM zone number and letter. + + """ + return self.zone, self.let + + def do_precisions(self, frequency: int) -> None: + """Perform precision calculations for a specific frequency. + + Args: + ---- + frequency (int): The frequency to calculate precisions for. + + """ + self._estimators.get(frequency, LocationEstimator()).do_precision() diff --git a/rttDroneGCS/ping/io_operations.py b/rttDroneGCS/ping/io_operations.py new file mode 100644 index 0000000..c44e2b1 --- /dev/null +++ b/rttDroneGCS/ping/io_operations.py @@ -0,0 +1,100 @@ +"""Module for IO operations.""" + +from __future__ import annotations + +import csv +from pathlib import Path + +import numpy as np +from osgeo import gdal, osr + + +def save_csv( + data_dir: str, + l_tx: np.ndarray, + size: int, + heat_map_area: np.ndarray, +) -> None: + """Save the heat map to a CSV file. + + Args: + ---- + data_dir (str): The directory to save the CSV file in. + l_tx (np.ndarray): The transmit location. + size (int): The size of the heat map. + heat_map_area (np.ndarray): The heat map area. + + """ + ref_x, min_y = l_tx[0] - (size / 2), l_tx[1] - (size / 2) + csv_data = [ + { + "easting": ref_x + x, + "northing": min_y + y, + "value": heat_map_area[y, x], + } + for y in range(size) + for x in range(size) + ] + + csv_path = Path(data_dir) / "query.csv" + with Path.open(csv_path, "w", newline="") as csvfile: + fieldnames = ["easting", "northing", "value"] + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(csv_data) + + +def save_tiff( # noqa: PLR0913 + data_dir: str, + freq: int, + zone_num: int, + zone: str, + l_tx: np.ndarray, + size: int, + heat_map_area: np.ndarray, +) -> None: + """Save the heat map to a TIFF file. + + Args: + ---- + data_dir (str): The directory to save the TIFF file in. + freq (int): The frequency of the heat map. + zone_num (int): The UTM zone number of the heat map. + zone (str): The UTM zone letter of the heat map. + l_tx (np.ndarray): The transmit location. + size (int): The size of the heat map. + heat_map_area (np.ndarray): The heat map area. + + """ + output_file_name = ( + Path(data_dir) / f"PRECISION_{freq / 1e6:.3f}_{len(heat_map_area)}_heatmap.tiff" + ) + driver = gdal.GetDriverByName("GTiff") + dataset = driver.Create( + output_file_name, + size, + size, + 1, + gdal.GDT_Float32, + ["COMPRESS=LZW"], + ) + + spatial_reference = osr.SpatialReference() + spatial_reference.SetUTM(zone_num, zone >= "N") + spatial_reference.SetWellKnownGeogCS("WGS84") + wkt = spatial_reference.ExportToWkt() + dataset.SetProjection(wkt) + + ref_x, ref_y = l_tx[0] - (size / 2), l_tx[1] + (size / 2) + dataset.SetGeoTransform((ref_x, 1, 0, ref_y, 0, -1)) + + band = dataset.GetRasterBand(1) + band.WriteArray(heat_map_area) + band.SetStatistics( + np.amin(heat_map_area), + np.amax(heat_map_area), + np.mean(heat_map_area), + np.std(heat_map_area), + ) + dataset.FlushCache() + dataset = None diff --git a/rttDroneGCS/ping/location_estimator.py b/rttDroneGCS/ping/location_estimator.py new file mode 100644 index 0000000..808f99e --- /dev/null +++ b/rttDroneGCS/ping/location_estimator.py @@ -0,0 +1,161 @@ +"""Module for estimating location based on RTT pings.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import numpy as np +from scipy.optimize import least_squares + +from .io_operations import save_csv, save_tiff +from .utils import p_d, rssi_to_distance + +if TYPE_CHECKING: + from .models import RTTPing + +logger = logging.getLogger(__name__) + + +class LocationEstimator: + """Estimates location based on RTT pings.""" + + def __init__(self) -> None: + """Initialize the LocationEstimator with empty pings and parameters.""" + self._pings: list[np.ndarray] = [] + self._params: np.ndarray | None = None + self._stale_estimate: bool = True + self.result: least_squares | None = None + self.last_l_tx0: float = 0 + self.last_l_tx1: float = 0 + self.index: int = 0 + + def add_ping(self, ping: RTTPing) -> None: + """Add a new RTT ping to the estimator.""" + self._pings.append(ping.to_numpy()) + + def do_estimate(self) -> tuple[np.ndarray, bool] | None: + """Perform location estimation based on collected pings.""" + if len(self._pings) < 4: # noqa: PLR2004 + return None + + pings = np.array(self._pings) + x_tx_0, y_tx_0 = np.mean(pings[:, :2], axis=0) + p_tx_0 = np.max(pings[:, 3]) + n_0 = 2 + self._params = np.array([x_tx_0, y_tx_0, p_tx_0, n_0]) + res_x = least_squares( + self._residuals, + self._params, + bounds=([0, 167000, -np.inf, 2], [833000, 10000000, np.inf, 2.1]), + ) + + if res_x.success: + self._params = res_x.x + self._stale_estimate = False + else: + self._stale_estimate = True + + self.result = res_x + return self._params, self._stale_estimate + + def d_to_prx( + self, + ping_vector: np.ndarray, + param_vector: np.ndarray, + ) -> float: + """Calculate the received power from a ping.""" + l_rx = ping_vector[:3] + l_tx = np.array([param_vector[0], param_vector[1], 0]) + p_tx, n = param_vector[2:4] + + d = max(np.linalg.norm(l_rx - l_tx), 0.01) + return p_tx - 10 * n * np.log10(d) + + def _residuals(self, param_vect: np.ndarray) -> np.ndarray: + return np.array( + [ping[3] - self.d_to_prx(ping, param_vect) for ping in self._pings], + ) + + def do_precision( + self, + data_dir: str = "holder", + freq: int = 17350000, + zone_num: int = 11, + zone: str = "S", + ) -> None: + """Perform precision estimation based on collected pings.""" + logger.info("%.3f MHz has %d pings", freq / 1e6, len(self._pings)) + + l_tx, p, n = self._params[:2], self._params[2], self._params[3] + pings = np.array(self._pings) + + distances = np.linalg.norm(pings[:, :3] - np.array([*l_tx, 0]), axis=1) + calculated_distances = rssi_to_distance(pings[:, 3], p, n) + + distance_errors = calculated_distances - distances + std_distances = np.std(distance_errors) + p_rx = pings[:, 3] + + size = 25 + heat_map_area = self._generate_heat_map( + size, + l_tx, + pings, + n, + p_rx, + p, + std_distances, + ) + + save_csv(data_dir, l_tx, size, heat_map_area) + save_tiff(data_dir, freq, zone_num, zone, l_tx, size, heat_map_area) + + def _generate_heat_map( # noqa: PLR0913 + self, + size: int, + l_tx: np.ndarray, + pings: np.ndarray, + n: float, + p_rx: np.ndarray, + p: float, + std_distances: float, + ) -> np.ndarray: + heat_map_area = np.ones((size, size)) / (size * size) + + min_y = l_tx[1] - (size / 2) + ref_x = l_tx[0] - (size / 2) + + for y in range(size): + for x in range(size): + for i in range(len(pings)): + heat_map_area[y, x] *= p_d( + np.array([x + ref_x, y + min_y, 0]), + pings[i, :3], + n, + p_rx[i], + p, + std_distances, + ) + + return heat_map_area / heat_map_area.sum() + + def get_estimate( + self, + ) -> tuple[np.ndarray, bool, least_squares | None] | None: + """Get the current estimate.""" + if self._params is None: + return None + return self._params, self._stale_estimate, self.result + + def get_pings(self) -> list[np.ndarray]: + """Get the current pings.""" + return self._pings + + def get_num_pings(self) -> int: + """Get the number of pings.""" + return len(self._pings) + + def set_pings(self, pings: list[np.ndarray]) -> None: + """Set the pings.""" + self._pings = pings diff --git a/rttDroneGCS/ping/models.py b/rttDroneGCS/ping/models.py new file mode 100644 index 0000000..81a3222 --- /dev/null +++ b/rttDroneGCS/ping/models.py @@ -0,0 +1,105 @@ +"""Models for RTT ping and cone data structures.""" + +from __future__ import annotations + +import datetime as dt +from datetime import timezone + +import numpy as np +import rttDroneComms.comms +import utm + + +class RTTCone: + """Represents an RTT cone with location, amplitude, frequency, and time data.""" + + def __init__( # noqa: PLR0913 + self, + lat: float, + lon: float, + amplitude: float, + freq: int, + alt: float, + heading: float, + time: float, + ) -> None: + """Initialize an RTTCone instance.""" + self.lat = lat + self.lon = lon + self.amplitude = amplitude + self.heading = heading + self.freq = freq + self.alt = alt + self.time = dt.datetime.fromtimestamp(time, tz=timezone.utc) + + +class RTTPing: + """Represents an RTT ping with location, power, frequency, and time data.""" + + def __init__( # noqa: PLR0913 + self, + lat: float, + lon: float, + power: float, + freq: int, + alt: float, + time: float, + ) -> None: + """Initialize an RTTPing instance.""" + self.lat = lat + self.lon = lon + self.power = power + self.freq = freq + self.alt = alt + self.time = dt.datetime.fromtimestamp(time, tz=timezone.utc) + + def to_numpy(self) -> np.ndarray: + """Convert ping data to a numpy array.""" + easting, northing, _, _ = utm.from_latlon(self.lat, self.lon) + return np.array([easting, northing, self.alt, self.power]) + + def to_dict(self) -> dict[str, int]: + """Convert ping data to a dictionary.""" + return { + "lat": int(self.lat * 1e7), + "lon": int(self.lon * 1e7), + "amp": int(self.power), + "txf": self.freq, + "alt": int(self.alt), + "time": int(self.time.timestamp() * 1e3), + } + + def to_packet(self) -> rttDroneComms.comms.rttPingPacket: + """Convert ping data to an rttPingPacket.""" + return rttDroneComms.comms.rttPingPacket( + self.lat, + self.lon, + self.alt, + self.power, + self.freq, + self.time, + ) + + @classmethod + def from_dict(cls, packet: dict[str, int]) -> RTTPing: + """Create an RTTPing instance from a dictionary.""" + return cls( + lat=float(packet["lat"]) / 1e7, + lon=float(packet["lon"]) / 1e7, + power=float(packet["amp"]), + freq=int(packet["txf"]), + alt=float(packet["alt"]), + time=float(packet["time"]) / 1e3, + ) + + @classmethod + def from_packet(cls, packet: rttDroneComms.comms.rttPingPacket) -> RTTPing: + """Create an RTTPing instance from an rttPingPacket.""" + return cls( + lat=packet.lat, + lon=packet.lon, + power=packet.txp, + freq=packet.txf, + alt=packet.alt, + time=packet.timestamp.timestamp(), + ) diff --git a/rttDroneGCS/ping/utils.py b/rttDroneGCS/ping/utils.py new file mode 100644 index 0000000..5010ad7 --- /dev/null +++ b/rttDroneGCS/ping/utils.py @@ -0,0 +1,70 @@ +"""Utility functions for RSSI-based distance calculations and related operations.""" + +import numpy as np + + +def residuals(x: np.ndarray, data: np.ndarray) -> np.ndarray: + """Calculate the error for the signal propagation model parameterized by x. + + Args: + ---- + x (np.ndarray): Vector of singal model parameters: x[0] is the transmit power, + x[1] is the path loss exponent, x[2] and x[3] are the transmit coordinates in + meters, and x[4] is the system loss constant. + data (np.ndarray): Matrix of signal data. This matrix must have shape (m, 3), + where m is the number of data samples. data[:, 0] is the vector of recieved + signal power, data[:, 1] is the vector of x-coordinates of the recieved signal + in meters, and data[:, 2] is the vector of y-coordinates of the recieved signal + in meters. + + + Returns: + ------- + np.ndarray: A vector of shape (m, ) containing the difference between the + data and estimated data using the provided signal model parameters. + + """ + p, n, tx, ty, k = x + r, dx, dy = data[:, 0], data[:, 1], data[:, 2] + d = np.linalg.norm(np.array([dx - tx, dy - ty]).transpose(), axis=1) + return p - 10 * n * np.log10(d) + k - r + + +def mse(r: float, x: np.ndarray, p: float, n: float, t: np.ndarray, k: float) -> float: # noqa: PLR0913 + """Calculate the mean squared error for the signal propagation model. + + Args: + ---- + r (float): The recieved signal power. + x (np.ndarray): A vector of shape (2,) containing the measurement location for + the recieved signal in meters. + p (float): The transmit power. + n (float): The path loss exponent. + t (np.ndarray): A vector of shape (2,) containing the transmit coordinates in + meters. + k (float): The system loss constant. + + Returns: + ------- + float: The mean squared error of this measurement. + + """ + d = np.linalg.norm(x - t) + return (r - p + 10 * n * np.log10(d) + k) ** 2 + + +def rssi_to_distance(p_rx: float, p_tx: float, n: float, alt: float = 0) -> float: + """Convert RSSI to distance, accounting for altitude if provided.""" + dist = 10 ** ((p_tx - p_rx) / (10 * n)) + if alt != 0: + dist = np.sqrt(dist**2 - alt**2) + return dist + + +def p_d( # noqa: PLR0913 + tx: np.ndarray, dx: np.ndarray, n: float, p_rx: float, p_tx: float, d_std: float, +) -> float: + """Calculate probability density for distance estimation.""" + modeled_distance = rssi_to_distance(p_rx, p_tx, n) + adjusted_distance = (np.linalg.norm(dx - tx) - modeled_distance) / d_std + return np.exp(-(adjusted_distance**2) / 2) / (np.sqrt(2 * np.pi) * d_std) diff --git a/rttDroneGCS/rtt_drone_gcs.py b/rttDroneGCS/rtt_drone_gcs.py new file mode 100644 index 0000000..4006a9d --- /dev/null +++ b/rttDroneGCS/rtt_drone_gcs.py @@ -0,0 +1,27 @@ +import logging +from .mav.mav_model import MAVModel +import rttDroneComms.comms +from .config import get_config_path, get_instance +from .gui.dialogs import ConfigDialog +from PyQt5.QtWidgets import QApplication +import sys + + +def main(): + logging.basicConfig(level=logging.INFO) + app = QApplication(sys.argv) + + config = get_instance(get_config_path()) + config_dialog = ConfigDialog(None) + if config_dialog.exec_() == ConfigDialog.Rejected: + return + + receiver = rttDroneComms.comms.gcsComms() + mav_model = MAVModel(receiver) + mav_model.start() + + sys.exit(app.exec_()) + + +if __name__ == "__main__": + main()