Skip to content

Commit

Permalink
Remove global API endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
pinkisemils committed Dec 18, 2024
1 parent 785bdfb commit 2701374
Show file tree
Hide file tree
Showing 22 changed files with 283 additions and 153 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import java.io.File
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import net.mullvad.mullvadvpn.lib.endpoint.ApiEndpointOverride

const val PROBLEM_REPORT_LOGS_FILE = "problem_report.txt"

Expand Down Expand Up @@ -39,7 +40,10 @@ class MullvadProblemReport(context: Context, val dispatcher: CoroutineDispatcher
collectReport(logDirectory.absolutePath, logsPath.absolutePath)
}

suspend fun sendReport(userReport: UserReport): SendProblemReportResult {
suspend fun sendReport(
userReport: UserReport,
apiEndpointOverride: ApiEndpointOverride?,
): SendProblemReportResult {
// If report is not collected then, collect it, if it fails then return error
if (!logsExists() && !collectLogs()) {
return SendProblemReportResult.Error.CollectLog
Expand All @@ -52,6 +56,7 @@ class MullvadProblemReport(context: Context, val dispatcher: CoroutineDispatcher
userReport.description,
logsPath.absolutePath,
cacheDirectory.absolutePath,
apiEndpointOverride
)
}

Expand Down Expand Up @@ -89,5 +94,6 @@ class MullvadProblemReport(context: Context, val dispatcher: CoroutineDispatcher
userMessage: String,
reportPath: String,
cacheDirectory: String,
apiEndpointOverride: ApiEndpointOverride?,
): Boolean
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import net.mullvad.mullvadvpn.applist.ApplicationsProvider
import net.mullvad.mullvadvpn.compose.state.RelayListType
import net.mullvad.mullvadvpn.constant.IS_PLAY_BUILD
import net.mullvad.mullvadvpn.dataproxy.MullvadProblemReport
import net.mullvad.mullvadvpn.lib.endpoint.ApiEndpointOverride
import net.mullvad.mullvadvpn.lib.payment.PaymentProvider
import net.mullvad.mullvadvpn.lib.shared.VoucherRepository
import net.mullvad.mullvadvpn.receiver.BootCompletedReceiver
Expand All @@ -29,6 +30,7 @@ import net.mullvad.mullvadvpn.repository.SettingsRepository
import net.mullvad.mullvadvpn.repository.SplashCompleteRepository
import net.mullvad.mullvadvpn.repository.SplitTunnelingRepository
import net.mullvad.mullvadvpn.repository.WireguardConstraintsRepository
import net.mullvad.mullvadvpn.service.DaemonConfig
import net.mullvad.mullvadvpn.ui.MainActivity
import net.mullvad.mullvadvpn.ui.serviceconnection.AppVersionInfoRepository
import net.mullvad.mullvadvpn.ui.serviceconnection.ServiceConnectionManager
Expand Down Expand Up @@ -223,7 +225,7 @@ val uiModule = module {
viewModel { VoucherDialogViewModel(get()) }
viewModel { VpnSettingsViewModel(get(), get(), get(), get(), get()) }
viewModel { WelcomeViewModel(get(), get(), get(), get(), isPlayBuild = IS_PLAY_BUILD) }
viewModel { ReportProblemViewModel(get(), get()) }
viewModel { ReportProblemViewModel(get(), get(), get<DaemonConfig>().apiEndpointOverride) }
viewModel { ViewLogsViewModel(get()) }
viewModel { OutOfTimeViewModel(get(), get(), get(), get(), get(), isPlayBuild = IS_PLAY_BUILD) }
viewModel { PaymentViewModel(get()) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import net.mullvad.mullvadvpn.constant.MINIMUM_LOADING_TIME_MILLIS
import net.mullvad.mullvadvpn.dataproxy.MullvadProblemReport
import net.mullvad.mullvadvpn.dataproxy.SendProblemReportResult
import net.mullvad.mullvadvpn.dataproxy.UserReport
import net.mullvad.mullvadvpn.lib.endpoint.ApiEndpointOverride
import net.mullvad.mullvadvpn.repository.ProblemReportRepository

data class ReportProblemUiState(
Expand All @@ -38,6 +39,7 @@ sealed interface ReportProblemSideEffect {
class ReportProblemViewModel(
private val mullvadProblemReporter: MullvadProblemReport,
private val problemReportRepository: ProblemReportRepository,
private val apiEndpointOverride: ApiEndpointOverride?,
) : ViewModel() {

private val sendingState: MutableStateFlow<SendingReportUiState?> = MutableStateFlow(null)
Expand Down Expand Up @@ -66,7 +68,10 @@ class ReportProblemViewModel(

// Ensure we show loading for at least MINIMUM_LOADING_TIME_MILLIS
val deferredResult = async {
mullvadProblemReporter.sendReport(UserReport(nullableEmail, description))
mullvadProblemReporter.sendReport(
UserReport(nullableEmail, description),
apiEndpointOverride,
)
}
delay(MINIMUM_LOADING_TIME_MILLIS)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import kotlinx.parcelize.Parcelize
data class ApiEndpointOverride(
val hostname: String,
val port: Int = CUSTOM_ENDPOINT_HTTPS_PORT,
val disableAddressCache: Boolean = true,
val disableTls: Boolean = false,
val forceDirectConnection: Boolean = true,
) : Parcelable {
Expand Down
9 changes: 1 addition & 8 deletions mullvad-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ tokio = { workspace = true, features = ["macros", "time", "rt-multi-thread", "ne
tokio-rustls = { version = "0.26.0", features = ["logging", "tls12", "ring"], default-features = false}
tokio-socks = "0.5.1"
rustls-pemfile = "2.1.3"
uuid = { version = "1.4.1", features = ["v4"] }

mullvad-encrypted-dns-proxy = { path = "../mullvad-encrypted-dns-proxy" }
mullvad-fs = { path = "../mullvad-fs" }
Expand All @@ -50,14 +51,6 @@ httpmock = "0.7.0-rc1"
[build-dependencies]
cbindgen = { version = "0.24.3", default-features = false }

[target.'cfg(target_os = "ios")'.dependencies]
uuid = { version = "1.4.1", features = ["v4"] }

[lib]
crate-type = [ "rlib", "staticlib" ]
bench = false

[[test]]
name = "ffi"
# required-features = [ "api-override" ]
features = [ "api-override" ]
29 changes: 14 additions & 15 deletions mullvad-api/src/address_cache.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
//! This module keeps track of the last known good API IP address and reads and stores it on disk.
use super::API;
use crate::DnsResolver;
use crate::{ApiEndpoint, DnsResolver};
use async_trait::async_trait;
use std::{io, net::SocketAddr, path::Path, sync::Arc};
use tokio::{
Expand Down Expand Up @@ -38,42 +37,42 @@ impl DnsResolver for AddressCache {

#[derive(Clone)]
pub struct AddressCache {
hostname: String,
inner: Arc<Mutex<AddressCacheInner>>,
write_path: Option<Arc<Path>>,
}

impl AddressCache {
/// Initialize cache using the hardcoded address, and write changes to `write_path`.
pub fn new(write_path: Option<Box<Path>>) -> Self {
Self::new_inner(API.address(), write_path)
}

pub fn with_static_addr(address: SocketAddr) -> Self {
Self::new_inner(address, None)
pub fn new(endpoint: &ApiEndpoint, write_path: Option<Box<Path>>) -> Self {
Self::new_inner(endpoint.address(), endpoint.host().to_owned(), write_path)
}

/// Initialize cache using `read_path`, and write changes to `write_path`.
pub async fn from_file(read_path: &Path, write_path: Option<Box<Path>>) -> Result<Self, Error> {
pub async fn from_file(
read_path: &Path,
write_path: Option<Box<Path>>,
hostname: String,
) -> Result<Self, Error> {
log::debug!("Loading API addresses from {}", read_path.display());
Ok(Self::new_inner(
read_address_file(read_path).await?,
write_path,
))
let address = read_address_file(read_path).await?;
Ok(Self::new_inner(address, hostname, write_path))
}

fn new_inner(address: SocketAddr, write_path: Option<Box<Path>>) -> Self {
fn new_inner(address: SocketAddr, hostname: String, write_path: Option<Box<Path>>) -> Self {
let cache = AddressCacheInner::from_address(address);
log::debug!("Using API address: {}", cache.address);

Self {
inner: Arc::new(Mutex::new(cache)),
write_path: write_path.map(Arc::from),
hostname,
}
}

/// Returns the address if the hostname equals `API.host`. Otherwise, returns `None`.
async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> {
if hostname.eq_ignore_ascii_case(API.host()) {
if hostname.eq_ignore_ascii_case(&self.hostname) {
Some(self.get_address().await)
} else {
None
Expand Down
3 changes: 1 addition & 2 deletions mullvad-api/src/bin/relay_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ use talpid_types::ErrorExt;

#[tokio::main]
async fn main() {
let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
.expect("Failed to load runtime");
let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current());

let relay_list_request =
RelayListProxy::new(runtime.mullvad_rest_handle(ApiConnectionMode::Direct.into_provider()))
Expand Down
8 changes: 8 additions & 0 deletions mullvad-api/src/ffi/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub enum MullvadApiErrorKind {

/// MullvadApiErrorKind contains a description and an error kind. If the error kind is
/// `MullvadApiErrorKind` is NoError, the pointer will be nil.
#[derive(Debug)]
#[repr(C)]
pub struct MullvadApiError {
description: *mut libc::c_char,
Expand Down Expand Up @@ -47,6 +48,13 @@ impl MullvadApiError {
}
}

pub fn unwrap(&self) {
if !matches!(self.kind, MullvadApiErrorKind::NoError) {
let desc = unsafe { std::ffi::CStr::from_ptr(self.description) };
panic!("API ERROR - {:?} - {}", self.kind, desc.to_str().unwrap());
}
}

pub fn drop(self) {
if self.description.is_null() {
return;
Expand Down
91 changes: 79 additions & 12 deletions mullvad-api/src/ffi/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(not(target_os = "android"))]
use std::{
ffi::{CStr, CString},
net::SocketAddr,
Expand All @@ -6,8 +7,9 @@ use std::{
};

use crate::{
proxy::ApiConnectionMode,
rest::{self, MullvadRestHandle},
AccountsProxy, DevicesProxy,
AccountsProxy, ApiEndpoint, DevicesProxy,
};

mod device;
Expand Down Expand Up @@ -48,13 +50,13 @@ impl MullvadApiClient {
struct FfiClient {
tokio_runtime: tokio::runtime::Runtime,
api_runtime: crate::Runtime,
api_hostname: String,
}

impl FfiClient {
unsafe fn new(
api_address_ptr: *const libc::c_char,
hostname: *const libc::c_char,
#[cfg(any(feature = "api-override", test))] disable_tls: bool,
) -> Result<Self, MullvadApiError> {
// SAFETY: addr_str must be a valid pointer to a null-terminated string.
let addr_str = unsafe { string_from_raw_ptr(api_address_ptr)? };
Expand All @@ -68,12 +70,15 @@ impl FfiClient {
)
})?;

// The call site guarantees that
// api_hostname and api_address will never change after the first call to new.
std::env::set_var(crate::env::API_HOST_VAR, &api_hostname);
std::env::set_var(crate::env::API_ADDR_VAR, &addr_str);
std::env::set_var(crate::env::API_FORCE_DIRECT_VAR, "0");
std::env::set_var(crate::env::DISABLE_TLS_VAR, "0");
let endpoint = ApiEndpoint {
host: Some(api_hostname.clone()),
address: Some(api_address),
#[cfg(feature = "api-override")]
force_direct: false,
#[cfg(any(feature = "api-override", test))]
disable_tls,
};

let mut runtime_builder = tokio::runtime::Builder::new_multi_thread();

runtime_builder.worker_threads(2).enable_all();
Expand All @@ -84,13 +89,12 @@ impl FfiClient {
// It is imperative that the REST runtime is created within an async context, otherwise
// ApiAvailability panics.
let api_runtime = tokio_runtime.block_on(async {
crate::Runtime::with_static_addr(tokio_runtime.handle().clone(), api_address)
crate::Runtime::with_custom_endpoint(tokio_runtime.handle().clone(), &endpoint)
});

let context = FfiClient {
tokio_runtime,
api_runtime,
api_hostname,
};

Ok(context)
Expand Down Expand Up @@ -204,7 +208,7 @@ impl FfiClient {
fn rest_handle(&self) -> MullvadRestHandle {
self.tokio_handle().block_on(async {
self.api_runtime
.static_mullvad_rest_handle(self.api_hostname.clone())
.mullvad_rest_handle(ApiConnectionMode::Direct.into_provider())
})
}

Expand Down Expand Up @@ -239,8 +243,16 @@ pub unsafe extern "C" fn mullvad_api_client_initialize(
client_ptr: *mut MullvadApiClient,
api_address_ptr: *const libc::c_char,
hostname: *const libc::c_char,
#[cfg(any(feature = "api-override", test))] disable_tls: bool,
) -> MullvadApiError {
match unsafe { FfiClient::new(api_address_ptr, hostname) } {
match unsafe {
FfiClient::new(
api_address_ptr,
hostname,
#[cfg(any(feature = "api-override", test))]
disable_tls,
)
} {
Ok(client) => {
unsafe {
std::ptr::write(client_ptr, MullvadApiClient::new(client));
Expand Down Expand Up @@ -443,3 +455,58 @@ unsafe fn string_from_raw_ptr(ptr: *const libc::c_char) -> Result<String, Mullva
})?
.to_owned())
}

#[cfg(test)]
mod test {
use httpmock::prelude::*;
use std::{mem::MaybeUninit, net::Ipv4Addr};

use super::*;
const STAGING_HOSTNAME: &[u8] = b"api-app.stagemole.eu\0";

#[test]
fn test_initialization() {
let _ = create_client(&SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 1));
}

fn create_client(addr: &SocketAddr) -> MullvadApiClient {
let mut client = MaybeUninit::<MullvadApiClient>::uninit();
let cstr_address = CString::new(addr.to_string()).unwrap();
let _client = unsafe {
mullvad_api_client_initialize(
client.as_mut_ptr(),
cstr_address.as_ptr().cast(),
STAGING_HOSTNAME.as_ptr().cast(),
true,
)
.unwrap();
};
unsafe { client.assume_init() }
}

#[test]
fn test_create_delete_account() {
let server = test_server();
let client = create_client(server.address());

let mut account_buf = vec![0 as libc::c_char; 100];
unsafe { mullvad_api_create_account(client, (&account_buf.as_mut_ptr()).cast()).unwrap() };
}

fn test_server() -> MockServer {
let server = MockServer::start();
let expected_create_account_response = br#"{"id":"085df870-0fc2-47cb-9e8c-cb43c1bdaac0","expiry":"2024-12-11T12:56:32+00:00","max_ports":0,"can_add_ports":false,"max_devices":5,"can_add_devices":true,"number":"6705749539195318"}"#;
let _mock = server.mock(|when, then| {
when.path("/".to_string() + crate::ACCOUNTS_URL_PREFIX + "/accounts")
.method(POST);
then.status(201)
.body(expected_create_account_response)
.header("content-type", "application/json")
.header(
"content-length",
&format!("{}", expected_create_account_response.len()),
);
});
server
}
}
Loading

0 comments on commit 2701374

Please sign in to comment.