Skip to content

Commit

Permalink
fix: address review feedback and add test for read/write functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
ChetanXpro committed Dec 29, 2024
1 parent 79c1bd3 commit c9a5295
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 57 deletions.
9 changes: 4 additions & 5 deletions crates/sshx-server/src/web/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use bytes::Bytes;
use futures_util::SinkExt;
use sshx_core::proto::{server_update::ServerMessage, NewShell, TerminalInput, TerminalSize};
use sshx_core::Sid;
use subtle::ConstantTimeEq;
use subtle::{Choice, ConstantTimeEq};
use tokio::sync::mpsc;
use tokio_stream::StreamExt;
use tracing::{error, info_span, warn, Instrument};
Expand Down Expand Up @@ -98,9 +98,8 @@ async fn handle_socket(socket: &mut WebSocket, session: Arc<Session>) -> Result<

let (user_guard, _) = match recv(socket).await? {
Some(WsClient::Authenticate(bytes, write_password_bytes)) => {
// `ct_eq` returns a `Choice`, and `unwrap_u8()` converts it to 1 (equal) or 0
// (not equal).
if bytes.ct_eq(metadata.encrypted_zeros.as_ref()).unwrap_u8() != 1 {
// Constant-time comparison of bytes, converting Choice to bool
if !<Choice as Into<bool>>::into(bytes.ct_eq(metadata.encrypted_zeros.as_ref())) {
send(socket, WsServer::InvalidAuth()).await?;
return Ok(());
}
Expand All @@ -112,7 +111,7 @@ async fn handle_socket(socket: &mut WebSocket, session: Arc<Session>) -> Result<
// Both password provided and stored, validate they match using constant-time
// comparison.
(Some(provided_password), Some(stored_password)) => {
if provided_password.ct_eq(stored_password).unwrap_u8() != 1 {
if !<Choice as Into<bool>>::into(provided_password.ct_eq(stored_password)) {
send(socket, WsServer::InvalidAuth()).await?;
return Ok(());
}
Expand Down
8 changes: 6 additions & 2 deletions crates/sshx-server/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ impl Drop for TestServer {
pub struct ClientSocket {
inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
encrypt: Encrypt,
write_encrypt: Option<Encrypt>,

pub user_id: Uid,
pub users: BTreeMap<Uid, WsUser>,
Expand All @@ -93,13 +94,14 @@ pub struct ClientSocket {

impl ClientSocket {
/// Connect to a WebSocket endpoint.
pub async fn connect(uri: &str, key: &str) -> Result<Self> {
pub async fn connect(uri: &str, key: &str, write_password: Option<&str>) -> Result<Self> {
let (stream, resp) = tokio_tungstenite::connect_async(uri).await?;
ensure!(resp.status() == StatusCode::SWITCHING_PROTOCOLS);

let mut this = Self {
inner: stream,
encrypt: Encrypt::new(key),
write_encrypt: write_password.map(Encrypt::new),
user_id: Uid(0),
users: BTreeMap::new(),
shells: BTreeMap::new(),
Expand All @@ -113,7 +115,9 @@ impl ClientSocket {

async fn authenticate(&mut self) {
let encrypted_zeros = self.encrypt.zeros().into();
self.send(WsClient::Authenticate(encrypted_zeros, None))
let write_zeros = self.write_encrypt.as_ref().map(|e| e.zeros().into());

self.send(WsClient::Authenticate(encrypted_zeros, write_zeros))
.await;
}

Expand Down
4 changes: 2 additions & 2 deletions crates/sshx-server/tests/snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async fn test_basic_restore() -> Result<()> {
let key = controller.encryption_key().to_owned();
tokio::spawn(async move { controller.run().await });

let mut s = ClientSocket::connect(&server.ws_endpoint(&name), &key).await?;
let mut s = ClientSocket::connect(&server.ws_endpoint(&name), &key, None).await?;
s.flush().await;
assert_eq!(s.user_id, Uid(1));

Expand All @@ -47,7 +47,7 @@ async fn test_basic_restore() -> Result<()> {
.state()
.insert(&name, Arc::new(Session::restore(&data)?));

let mut s = ClientSocket::connect(&server.ws_endpoint(&name), &key).await?;
let mut s = ClientSocket::connect(&server.ws_endpoint(&name), &key, None).await?;
s.send(WsClient::Subscribe(Sid(1), 0)).await;
s.flush().await;

Expand Down
79 changes: 68 additions & 11 deletions crates/sshx-server/tests/with_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ async fn test_ws_missing() -> Result<()> {
let server = TestServer::new().await;

let bad_endpoint = format!("ws://{}/not/an/endpoint", server.local_addr());
assert!(ClientSocket::connect(&bad_endpoint, "").await.is_err());
assert!(ClientSocket::connect(&bad_endpoint, "", None)
.await
.is_err());

let mut s = ClientSocket::connect(&server.ws_endpoint("foobar"), "").await?;
let mut s = ClientSocket::connect(&server.ws_endpoint("foobar"), "", None).await?;
s.expect_close(4404).await;

Ok(())
Expand All @@ -74,7 +76,7 @@ async fn test_ws_basic() -> Result<()> {
let key = controller.encryption_key().to_owned();
tokio::spawn(async move { controller.run().await });

let mut s = ClientSocket::connect(&server.ws_endpoint(&name), &key).await?;
let mut s = ClientSocket::connect(&server.ws_endpoint(&name), &key, None).await?;
s.flush().await;
assert_eq!(s.user_id, Uid(1));

Expand Down Expand Up @@ -106,7 +108,7 @@ async fn test_ws_resize() -> Result<()> {
let key = controller.encryption_key().to_owned();
tokio::spawn(async move { controller.run().await });

let mut s = ClientSocket::connect(&server.ws_endpoint(&name), &key).await?;
let mut s = ClientSocket::connect(&server.ws_endpoint(&name), &key, None).await?;

s.send(WsClient::Move(Sid(1), None)).await; // error: does not exist yet!
s.flush().await;
Expand Down Expand Up @@ -151,16 +153,16 @@ async fn test_users_join() -> Result<()> {
tokio::spawn(async move { controller.run().await });

let endpoint = server.ws_endpoint(&name);
let mut s1 = ClientSocket::connect(&endpoint, &key).await?;
let mut s1 = ClientSocket::connect(&endpoint, &key, None).await?;
s1.flush().await;
assert_eq!(s1.users.len(), 1);

let mut s2 = ClientSocket::connect(&endpoint, &key).await?;
let mut s2 = ClientSocket::connect(&endpoint, &key, None).await?;
s2.flush().await;
assert_eq!(s2.users.len(), 2);

drop(s2);
let mut s3 = ClientSocket::connect(&endpoint, &key).await?;
let mut s3 = ClientSocket::connect(&endpoint, &key, None).await?;
s3.flush().await;
assert_eq!(s3.users.len(), 2);

Expand All @@ -180,7 +182,7 @@ async fn test_users_metadata() -> Result<()> {
tokio::spawn(async move { controller.run().await });

let endpoint = server.ws_endpoint(&name);
let mut s = ClientSocket::connect(&endpoint, &key).await?;
let mut s = ClientSocket::connect(&endpoint, &key, None).await?;
s.flush().await;
assert_eq!(s.users.len(), 1);
assert_eq!(s.users.get(&s.user_id).unwrap().cursor, None);
Expand All @@ -205,8 +207,8 @@ async fn test_chat_messages() -> Result<()> {
tokio::spawn(async move { controller.run().await });

let endpoint = server.ws_endpoint(&name);
let mut s1 = ClientSocket::connect(&endpoint, &key).await?;
let mut s2 = ClientSocket::connect(&endpoint, &key).await?;
let mut s1 = ClientSocket::connect(&endpoint, &key, None).await?;
let mut s2 = ClientSocket::connect(&endpoint, &key, None).await?;

s1.send(WsClient::SetName("billy".into())).await;
s1.send(WsClient::Chat("hello there!".into())).await;
Expand All @@ -219,10 +221,65 @@ async fn test_chat_messages() -> Result<()> {
(s1.user_id, "billy".into(), "hello there!".into())
);

let mut s3 = ClientSocket::connect(&endpoint, &key).await?;
let mut s3 = ClientSocket::connect(&endpoint, &key, None).await?;
s3.flush().await;
assert_eq!(s1.messages.len(), 1);
assert_eq!(s3.messages.len(), 0);

Ok(())
}

#[tokio::test]
async fn test_read_write_permissions() -> Result<()> {
let server = TestServer::new().await;

// create controller with read-only mode enabled
let mut controller = Controller::new(&server.endpoint(), "", Runner::Echo, true).await?;
let name = controller.name().to_owned();
let key = controller.encryption_key().to_owned();
let write_url = controller
.write_url()
.expect("Should have write URL when enable_readers is true")
.clone();

tokio::spawn(async move { controller.run().await });

let write_password = write_url
.split(',')
.nth(1)
.expect("Write URL should contain password");

// connect with write access
let mut writer =
ClientSocket::connect(&server.ws_endpoint(&name), &key, Some(write_password)).await?;
writer.flush().await;

// test write permissions
writer.send(WsClient::Create(0, 0)).await;
writer.flush().await;
assert_eq!(
writer.shells.len(),
1,
"Writer should be able to create a shell"
);
assert!(writer.errors.is_empty(), "Writer should not receive errors");

// connect with read-only access
let mut reader = ClientSocket::connect(&server.ws_endpoint(&name), &key, None).await?;
reader.flush().await;

// test read-only restrictions
reader.send(WsClient::Create(0, 0)).await;
reader.flush().await;
assert!(
!reader.errors.is_empty(),
"Reader should receive an error when attempting to create shell"
);
assert_eq!(
reader.shells.len(),
1,
"Reader should still see the existing shell"
);

Ok(())
}
20 changes: 12 additions & 8 deletions crates/sshx/src/controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,27 @@ impl Controller {
) -> Result<Self> {
debug!(%origin, "connecting to server");
let encryption_key = rand_alphanumeric(14); // 83.3 bits of entropy
let write_password = rand_alphanumeric(14); // 83.3 bits of entropy

let kdf_task = {
let encryption_key = encryption_key.clone();
task::spawn_blocking(move || Encrypt::new(&encryption_key))
};

let kdf_write_password_task = {
let write_password = write_password.clone();
task::spawn_blocking(move || Encrypt::new(&write_password))
let (write_password, kdf_write_password_task) = if enable_readers {
let write_password = rand_alphanumeric(14); // 83.3 bits of entropy
let task = {
let write_password = write_password.clone();
task::spawn_blocking(move || Encrypt::new(&write_password))
};
(Some(write_password), Some(task))
} else {
(None, None)
};

let mut client = Self::connect(origin).await?;
let encrypt = kdf_task.await?;
let encrypt_write_password = kdf_write_password_task.await?;
let encrypted_write_zeros = if enable_readers {
Some(encrypt_write_password.zeros().into())
let encrypted_write_zeros = if let Some(task) = kdf_write_password_task {
Some(task.await?.zeros().into())
} else {
None
};
Expand All @@ -86,7 +90,7 @@ impl Controller {
let mut resp = client.open(req).await?.into_inner();
resp.url = resp.url + "#" + &encryption_key;

let write_url = if enable_readers {
let write_url = if let (true, Some(write_password)) = (enable_readers, write_password) {
Some(resp.url.clone() + "," + &write_password)
} else {
None
Expand Down
10 changes: 5 additions & 5 deletions crates/sshx/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ fn print_greeting(shell: &str, controller: &Controller) {
r#"
{sshx} {version}
{arr} Read-only link: {link_v}
{arr} Writable link: {link_e}
{arr} Shell: {shell_v}
"#,
{arr} Read-only link: {link_v}
{arr} Writable link: {link_e}
{arr} Shell: {shell_v}
"#,
sshx = Green.bold().paint("sshx"),
version = Green.paint(&version_str),
arr = Green.paint("➜"),
Expand All @@ -63,7 +63,7 @@ fn print_greeting(shell: &str, controller: &Controller) {
{arr} Link: {link_v}
{arr} Shell: {shell_v}
"#,
"#,
sshx = Green.bold().paint("sshx"),
version = Green.paint(&version_str),
arr = Green.paint("➜"),
Expand Down
29 changes: 5 additions & 24 deletions src/lib/Session.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import { TouchZoom, INITIAL_ZOOM } from "./action/touchZoom";
import { arrangeNewTerminal } from "./arrange";
import { settings } from "./settings";
import { EyeIcon } from "svelte-feather-icons";
export let id: string;
Expand Down Expand Up @@ -93,7 +94,6 @@
}
let encrypt: Encrypt;
let write_password_encrypt: Encrypt | null = null;
let srocket: Srocket<WsServer, WsClient> | null = null;
let connected = false;
Expand Down Expand Up @@ -135,12 +135,9 @@
encrypt = await Encrypt.new(key);
const encryptedZeros = await encrypt.zeros();
write_password_encrypt = writePassword
? await Encrypt.new(writePassword)
: null;
const writeEncryptedZeros = write_password_encrypt
? await write_password_encrypt.zeros()
: null;
const writeEncryptedZeros = writePassword
? await (await Encrypt.new(writePassword)).zeros()
: null;
srocket = new Srocket<WsServer, WsClient>(`/api/s/${id}`, {
onMessage(message) {
Expand Down Expand Up @@ -473,23 +470,7 @@
<div
class="bg-yellow-900 text-yellow-200 px-1 py-0.5 rounded ml-3 inline-flex items-center gap-1"
>
<svg
class="w-3.5 h-3.5"
xmlns="http://www.w3.org/2000/svg"
width="24"
height="24"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
><path d="M1 12s4-8 11-8 11 8 11 8-4 8-11 8-11-8-11-8z" /><circle
cx="12"
cy="12"
r="3"
/></svg
>
<EyeIcon size="14" />
<span class="text-xs">Read-only</span>
</div>
{/if}
Expand Down

0 comments on commit c9a5295

Please sign in to comment.