Skip to content

Commit

Permalink
fix decimal vs hex equivalence
Browse files Browse the repository at this point in the history
  • Loading branch information
sslivkoff committed Jan 29, 2024
1 parent bdcb442 commit e6cb41b
Show file tree
Hide file tree
Showing 8 changed files with 280 additions and 39 deletions.
2 changes: 1 addition & 1 deletion SPECIFICATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ Requirements:
- All keys of `RpcConfig` and `Endpoint` are required. No additional keys must be present, except within `global_metadata`, `profile_metadata`, and `endpoint_metadata`.
- Every endpoint name specified in `RpcConfig.default_endpoint` and in `RpcConfig.network_defaults` must exist in `RpcConfig.endpoints`.
- These key-value structures can be easily represented in JSON and in most common programming languages.
- EVM `chain_id`'s must be represented using either a decimal string or a hex string. Strings are used because chain id's can be 256 bits and most programming languages do not have native 256 bit integer types. For readability, decimal should be used for small chain id values and hex should be used for values that use the entire 256 bits.
- EVM `chain_id`'s must be represented using either a decimal string or a `0x`-prefixed hex string. Strings are used because chain id's can be 256 bits and most programming languages do not have native 256 bit integer types. For readability, decimal should be used for small chain id values and hex should be used for values that use the entire 256 bits.
- Names of endpoints, networks, and profiles should be composed of characters that are either alphanumeric, dashes, underscores, or periods. Names should be at least 1 character long.

##### Metadata
Expand Down
17 changes: 12 additions & 5 deletions python/mesc/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,18 @@ def get_endpoint_by_network(
raise ValueError('chain_id must be a str or int')
chain_id = str(chain_id)
network_defaults = config['network_defaults']
default_name = network_defaults.get(chain_id)
default_name = network_utils.get_by_chain_id(network_defaults, chain_id)

# get profile default for network
if profile and profile in config['profiles']:
if profile is not None and profile in config['profiles']:
if not config['profiles'][profile]['use_mesc']:
return None
name = config['profiles'][profile]['network_defaults'].get(
chain_id, default_name
name = network_utils.get_by_chain_id(
config['profiles'][profile]['network_defaults'],
chain_id,
)
if name is None:
name = default_name
else:
name = default_name

Expand Down Expand Up @@ -143,8 +146,12 @@ def find_endpoints(
if chain_id is not None:
if isinstance(chain_id, int):
chain_id = str(chain_id)
chain_id = network_utils.chain_id_to_standard_hex(chain_id)
endpoints = [
endpoint for endpoint in endpoints if endpoint['chain_id'] == chain_id
endpoint
for endpoint in endpoints
if endpoint['chain_id'] is not None
and network_utils.chain_id_to_standard_hex(endpoint['chain_id']) == chain_id
]

# check name_contains
Expand Down
29 changes: 29 additions & 0 deletions python/mesc/network_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import typing
from .types import RpcConfig
from . import network_names

Expand Down Expand Up @@ -31,3 +32,31 @@ def network_name_to_chain_id(
return chain_id
else:
return None


def chain_id_to_standard_hex(chain_id: str) -> str | None:
if chain_id.startswith('0x'):
if len(chain_id) > 2:
as_hex = chain_id
else:
try:
as_hex = hex(int(chain_id))
except ValueError:
return None

return '0x' + as_hex[2:].lstrip('0')


T = typing.TypeVar('T')


def get_by_chain_id(mapping: typing.Mapping[str, T], chain_id: str) -> T | None:
if chain_id in mapping:
return mapping[chain_id]

standard_mapping = {chain_id_to_standard_hex(k): v for k, v in mapping.items()}
return standard_mapping.get(chain_id_to_standard_hex(chain_id))


def chain_ids_equal(lhs: str, rhs: str) -> bool:
return chain_id_to_standard_hex(lhs) == chain_id_to_standard_hex(rhs)
32 changes: 29 additions & 3 deletions python/mesc/validation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing_extensions import Any
from typing import Sequence

from .exceptions import InvalidConfig
from .types import rpc_config_types, endpoint_types, profile_types
Expand Down Expand Up @@ -100,7 +101,9 @@ def validate(config: Any) -> None:

# default endpoints of each network actually use that specified network
for chain_id, endpoint_name in config['network_defaults'].items():
if chain_id != config['endpoints'][endpoint_name]['chain_id']:
if not network_utils.chain_ids_equal(
chain_id, config['endpoints'][endpoint_name]['chain_id']
):
raise InvalidConfig(
'Endpoint is set as the default endpoint of network '
+ chain_id
Expand All @@ -109,7 +112,9 @@ def validate(config: Any) -> None:
)
for profile_name, profile in config['profiles'].items():
for chain_id, endpoint_name in profile['network_defaults'].items():
if chain_id != config['endpoints'][endpoint_name]['chain_id']:
if not network_utils.chain_ids_equal(
chain_id, config['endpoints'][endpoint_name]['chain_id']
):
raise InvalidConfig(
'Endpoint is set as the default endpoint of network '
+ chain_id
Expand Down Expand Up @@ -165,7 +170,28 @@ def validate(config: Any) -> None:
)

# no duplicate default network entries using decimal vs hex
pass
ensure_no_chain_id_collisions(
list(config['network_defaults'].keys()), 'network defaults'
)
for profile_name, profile in config['profiles'].items():
ensure_no_chain_id_collisions(
list(config['network_defaults'].keys()), 'profile ' + profile_name
)


def ensure_no_chain_id_collisions(chain_ids: Sequence[str], name: str) -> None:
hex_numbers = set()
for chain_id in chain_ids:
as_hex = network_utils.chain_id_to_standard_hex(chain_id)
if as_hex in hex_numbers:
raise Exception(
'chain_id collision, '
+ str(name)
+ ' has multiple decimal/hex values for chain_id: '
+ str(chain_id)
)
else:
hex_numbers.add(as_hex)


def _check_type(
Expand Down
2 changes: 1 addition & 1 deletion rust/crates/mesc/src/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub fn get_default_endpoint(profile: Option<&str>) -> Result<Option<Endpoint>, M
}

/// get endpoint by network
pub fn get_endpoint_by_network<T: TryIntoChainId>(
pub fn get_endpoint_by_network<T: TryIntoChainId + std::fmt::Debug + std::clone::Clone>(
chain_id: T,
profile: Option<&str>,
) -> Result<Option<Endpoint>, MescError> {
Expand Down
43 changes: 31 additions & 12 deletions rust/crates/mesc/src/query.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
directory,
types::{Endpoint, MescError, RpcConfig},
MultiEndpointQuery, TryIntoChainId,
ChainId, MultiEndpointQuery, TryIntoChainId,
};
use std::collections::HashMap;

Expand All @@ -10,14 +10,14 @@ pub fn get_default_endpoint(
config: &RpcConfig,
profile: Option<&str>,
) -> Result<Option<Endpoint>, MescError> {
// if using a profile, check if that profile has a default endpoint for chain_id
// if using a profile, check if that profile has a default endpoint
if let Some(profile) = profile {
if let Some(profile_data) = config.profiles.get(profile) {
if !profile_data.use_mesc {
return Ok(None)
return Ok(None);
}
if let Some(endpoint_name) = profile_data.default_endpoint.as_deref() {
return get_endpoint_by_name(config, endpoint_name)
return get_endpoint_by_name(config, endpoint_name);
}
}
};
Expand All @@ -29,7 +29,7 @@ pub fn get_default_endpoint(
}

/// get endpoint by network
pub fn get_endpoint_by_network<T: TryIntoChainId>(
pub fn get_endpoint_by_network<T: TryIntoChainId + std::fmt::Debug + std::clone::Clone>(
config: &RpcConfig,
chain_id: T,
profile: Option<&str>,
Expand All @@ -40,21 +40,40 @@ pub fn get_endpoint_by_network<T: TryIntoChainId>(
if let Some(profile) = profile {
if let Some(profile_data) = config.profiles.get(profile) {
if !profile_data.use_mesc {
return Ok(None)
return Ok(None);
}
if let Some(endpoint_name) = profile_data.network_defaults.get(&chain_id) {
return get_endpoint_by_name(config, endpoint_name)
return get_endpoint_by_name(config, endpoint_name);
}
}
};

// check if base configuration has a default endpoint for that chain_id
match config.network_defaults.get(&chain_id) {
Some(name) => get_endpoint_by_name(config, name),
match get_by_chain_id(&config.network_defaults, chain_id)? {
Some(name) => get_endpoint_by_name(config, name.as_str()),
None => Ok(None),
}
}

fn get_by_chain_id<T: TryIntoChainId, S: std::fmt::Debug + Clone>(
mapping: &HashMap<ChainId, S>,
chain_id: T,
) -> Result<Option<S>, MescError> {
let chain_id = chain_id.try_into_chain_id()?;
if let Some(value) = mapping.get(&chain_id) {
Ok(Some(value.clone()))
} else {
let standard_chain_id = chain_id.to_hex_256()?;
let results: Result<HashMap<String, S>, _> = mapping
.iter()
.map(|(k, v)| k.to_hex_256().map(|hex| (hex, v.clone())))
.collect::<Result<Vec<_>, _>>() // Collect into a Result<Vec<(String, S)>, Error>
.map(|pairs| pairs.into_iter().collect::<HashMap<_, _>>());
let standard_mapping = results?;
Ok(standard_mapping.get(&standard_chain_id).cloned())
}
}

/// get endpoint by name
pub fn get_endpoint_by_name(config: &RpcConfig, name: &str) -> Result<Option<Endpoint>, MescError> {
if let Some(endpoint) = config.endpoints.get(name) {
Expand All @@ -73,7 +92,7 @@ pub fn get_endpoint_by_query(
if let Some(profile) = profile {
if let Some(profile_data) = config.profiles.get(profile) {
if !profile_data.use_mesc {
return Ok(None)
return Ok(None);
}
}
}
Expand Down Expand Up @@ -108,7 +127,7 @@ pub fn find_endpoints(
let mut candidates: Vec<Endpoint> = config.endpoints.clone().into_values().collect();

if let Some(chain_id) = query.chain_id {
candidates.retain(|endpoint| endpoint.chain_id.as_ref() == Some(&chain_id))
candidates.retain(|endpoint| endpoint.chain_id.as_ref() == Some(&chain_id));
}

if let Some(name) = query.name_contains {
Expand All @@ -133,7 +152,7 @@ pub fn get_global_metadata(
if let Some(profile) = profile {
if let Some(profile_data) = config.profiles.get(profile) {
if !profile_data.use_mesc {
return Ok(HashMap::new())
return Ok(HashMap::new());
}
metadata.extend(profile_data.profile_metadata.clone())
}
Expand Down
54 changes: 46 additions & 8 deletions rust/crates/mesc/src/types/chain_ids.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize};

/// ChainId is a string representation of an integer chain id
/// - TryFrom conversions allow specifying as String, &str, uint, or binary data
#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq, Hash)]
#[derive(Serialize, Deserialize, Debug, Clone, Eq)]
pub struct ChainId(String);

impl ChainId {
Expand All @@ -23,12 +23,14 @@ impl ChainId {
/// convert to hex representation, zero-padded to 256 bits
pub fn to_hex_256(&self) -> Result<String, MescError> {
let ChainId(chain_id) = self;
if chain_id.starts_with("0x") {
Ok(chain_id.clone())
if let Some(stripped) = chain_id.strip_prefix("0x") {
Ok(format!("0x{:0>64}", stripped))
} else {
match chain_id.parse::<u64>() {
Ok(number) => Ok(format!("0x{:016x}", number)),
Err(_) => Err(MescError::IntegrityError("bad chain_id".to_string())),
match chain_id.parse::<u128>() {
Ok(number) => Ok(format!("0x{:064x}", number)),
Err(_) => {
Err(MescError::InvalidChainId("cannot convert chain_id to hex".to_string()))
}
}
}
}
Expand All @@ -40,12 +42,42 @@ impl ChainId {
}
}

impl std::hash::Hash for ChainId {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
match self.to_hex_256() {
Ok(as_hex) => {
as_hex.hash(state);
}
_ => {
let ChainId(contents) = self;
contents.hash(state);
}
}
}
}

impl PartialOrd for ChainId {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl PartialEq for ChainId {
fn eq(&self, other: &Self) -> bool {
let self_string: String = match self.to_hex() {
Ok(s) => s[2..].to_string(),
Err(_) => return self == other,
};
let other_string = match other.to_hex() {
Ok(s) => s[2..].to_string(),
Err(_) => return self == other,
};
let self_str = format!("{:0>79}", self_string);
let other_str = format!("{:0>79}", other_string);
self_str.eq(&other_str)
}
}

impl Ord for ChainId {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
let self_string: String = match self.to_hex() {
Expand Down Expand Up @@ -96,7 +128,10 @@ impl TryIntoChainId for ChainId {

impl TryIntoChainId for String {
fn try_into_chain_id(self) -> Result<ChainId, MescError> {
if !self.is_empty() && self.chars().all(|c| c.is_ascii_digit()) {
if !self.is_empty() &&
(self.chars().all(|c| c.is_ascii_digit()) ||
(self.starts_with("0x") && self[2..].chars().all(|c| c.is_ascii_hexdigit())))
{
Ok(ChainId(self))
} else {
Err(MescError::InvalidChainId(self))
Expand All @@ -106,7 +141,10 @@ impl TryIntoChainId for String {

impl TryIntoChainId for &str {
fn try_into_chain_id(self) -> Result<ChainId, MescError> {
if self.chars().all(|c| c.is_ascii_digit()) {
if !self.is_empty() &&
(self.chars().all(|c| c.is_ascii_digit()) ||
(self.starts_with("0x") && self[2..].chars().all(|c| c.is_ascii_hexdigit())))
{
Ok(ChainId(self.to_string()))
} else {
Err(MescError::InvalidChainId(self.to_string()))
Expand Down
Loading

0 comments on commit e6cb41b

Please sign in to comment.