diff --git a/iroh-relay/src/relay_map.rs b/iroh-relay/src/relay_map.rs index 4f8740b7735..53fb4448be0 100644 --- a/iroh-relay/src/relay_map.rs +++ b/iroh-relay/src/relay_map.rs @@ -1,6 +1,10 @@ //! based on tailscale/tailcfg/derpmap.go -use std::{collections::BTreeMap, fmt, sync::Arc}; +use std::{ + collections::BTreeMap, + fmt, + sync::{Arc, RwLock}, +}; use iroh_base::RelayUrl; use serde::{Deserialize, Serialize}; @@ -8,16 +12,34 @@ use serde::{Deserialize, Serialize}; use crate::defaults::DEFAULT_RELAY_QUIC_PORT; /// Configuration of all the relay servers that can be used. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone)] pub struct RelayMap { /// A map of the different relay IDs to the [`RelayNode`] information - nodes: Arc>>, + nodes: Arc>>>, } +impl PartialEq for RelayMap { + fn eq(&self, other: &Self) -> bool { + let this = self.nodes.read().expect("poisoned"); + let that = other.nodes.read().expect("poisoned"); + this.eq(&*that) + } +} + +impl Eq for RelayMap {} + impl RelayMap { /// Returns the sorted relay URLs. - pub fn urls(&self) -> impl Iterator { - self.nodes.keys() + pub fn urls(&self) -> T + where + T: FromIterator, + { + self.nodes + .read() + .expect("poisoned") + .keys() + .cloned() + .collect::() } /// Create an empty relay map. @@ -28,39 +50,57 @@ impl RelayMap { } /// Returns an `Iterator` over all known nodes. - pub fn nodes(&self) -> impl Iterator> { - self.nodes.values() + pub fn nodes(&self) -> T + where + T: FromIterator>, + { + self.nodes + .read() + .expect("poisoned") + .values() + .cloned() + .collect::() } /// Is this a known node? pub fn contains_node(&self, url: &RelayUrl) -> bool { - self.nodes.contains_key(url) + self.nodes.read().expect("poisoned").contains_key(url) } /// Get the given node. - pub fn get_node(&self, url: &RelayUrl) -> Option<&Arc> { - self.nodes.get(url) + pub fn get_node(&self, url: &RelayUrl) -> Option> { + self.nodes.read().expect("poisoned").get(url).cloned() } /// How many nodes are known? pub fn len(&self) -> usize { - self.nodes.len() + self.nodes.read().expect("poisoned").len() } /// Are there any nodes in this map? pub fn is_empty(&self) -> bool { - self.nodes.is_empty() + self.nodes.read().expect("poisoned").is_empty() + } + + /// Insert a new relay. + pub fn insert(&self, url: RelayUrl, node: Arc) -> Option> { + self.nodes.write().expect("poisoned").insert(url, node) + } + + /// Removes an existing relay by `RelayUrl`. + pub fn remove(&self, url: &RelayUrl) -> Option> { + self.nodes.write().expect("poisoned").remove(url) } } impl FromIterator for RelayMap { fn from_iter>(iter: T) -> Self { Self { - nodes: Arc::new( + nodes: Arc::new(RwLock::new( iter.into_iter() .map(|node| (node.url.clone(), Arc::new(node))) .collect(), - ), + )), } } } @@ -72,7 +112,9 @@ impl From for RelayMap { /// discovery ports. fn from(value: RelayUrl) -> Self { Self { - nodes: Arc::new([(value.clone(), Arc::new(value.into()))].into()), + nodes: Arc::new(RwLock::new( + [(value.clone(), Arc::new(value.into()))].into(), + )), } } } @@ -80,7 +122,7 @@ impl From for RelayMap { impl From for RelayMap { fn from(value: RelayNode) -> Self { Self { - nodes: Arc::new([(value.url.clone(), Arc::new(value))].into()), + nodes: Arc::new(RwLock::new([(value.url.clone(), Arc::new(value))].into())), } } } @@ -92,11 +134,11 @@ impl FromIterator for RelayMap { /// discovery ports. fn from_iter>(iter: T) -> Self { Self { - nodes: Arc::new( + nodes: Arc::new(RwLock::new( iter.into_iter() .map(|url| (url.clone(), Arc::new(url.into()))) .collect(), - ), + )), } } } diff --git a/iroh/src/endpoint.rs b/iroh/src/endpoint.rs index e24a13367ad..271aa865d5b 100644 --- a/iroh/src/endpoint.rs +++ b/iroh/src/endpoint.rs @@ -21,8 +21,8 @@ use std::{ }; use ed25519_dalek::{VerifyingKey, pkcs8::DecodePublicKey}; -use iroh_base::{NodeAddr, NodeId, SecretKey}; -use iroh_relay::RelayMap; +use iroh_base::{NodeAddr, NodeId, RelayUrl, SecretKey}; +use iroh_relay::{RelayMap, RelayNode}; use n0_future::time::Duration; use n0_watcher::Watcher; use nested_enum_utils::common_fields; @@ -632,6 +632,24 @@ impl Endpoint { self.msock.endpoint().set_server_config(Some(server_config)); } + /// Adds the provided configuration to the [`RelayMap`]. + /// + /// Replacing and returning any existing configuration for [`RelayUrl`]. + pub async fn insert_relay( + &self, + relay: RelayUrl, + node: Arc, + ) -> Option> { + self.msock.insert_relay(relay, node).await + } + + /// Removes the configuration from the [`RelayMap`] for the provided [`RelayUrl`]. + /// + /// Returns any existing configuration. + pub async fn remove_relay(&self, relay: &RelayUrl) -> Option> { + self.msock.remove_relay(relay).await + } + // # Methods for establishing connectivity. /// Connects to a remote [`Endpoint`]. @@ -2430,6 +2448,115 @@ mod tests { Ok(()) } + #[tokio::test] + #[traced_test] + async fn endpoint_relay_map_change() -> Result { + let (relay_map, relay_url, _guard1) = run_relay_server().await?; + let client = Endpoint::builder() + .insecure_skip_relay_cert_verify(true) + .relay_mode(RelayMode::Custom(relay_map.clone())) + .bind() + .await?; + let server = Endpoint::builder() + .insecure_skip_relay_cert_verify(true) + .relay_mode(RelayMode::Custom(relay_map)) + .alpns(vec![TEST_ALPN.to_vec()]) + .bind() + .await?; + + let task = tokio::spawn({ + let server = server.clone(); + async move { + for i in 0..2 { + println!("accept: round {i}"); + let Some(conn) = server.accept().await else { + snafu::whatever!("Expected an incoming connection"); + }; + let conn = conn.await.e()?; + let (mut send, mut recv) = conn.accept_bi().await.e()?; + let data = recv.read_to_end(1000).await.e()?; + send.write_all(&data).await.e()?; + send.finish().e()?; + conn.closed().await; + } + Ok::<_, Error>(()) + } + }); + + server.online().await; + + let mut addr = server.node_addr(); + println!("round1: {:?}", addr); + + // remove direct addrs to force relay usage + addr.direct_addresses.clear(); + + let conn = client.connect(addr, TEST_ALPN).await?; + let (mut send, mut recv) = conn.open_bi().await.e()?; + send.write_all(b"Hello, world!").await.e()?; + send.finish().e()?; + let data = recv.read_to_end(1000).await.e()?; + conn.close(0u32.into(), b"bye!"); + + assert_eq!(&data, b"Hello, world!"); + + // setup a second relay server + let (new_relay_map, new_relay_url, _guard2) = run_relay_server().await?; + let new_node = new_relay_map + .get_node(&new_relay_url) + .expect("missing node") + .clone(); + dbg!(&new_relay_map); + + let addr_watcher = server.watch_node_addr(); + + // add new new relay + assert!( + server + .insert_relay(new_relay_url.clone(), new_node.clone()) + .await + .is_none() + ); + // remove the old relay + assert!(server.remove_relay(&relay_url).await.is_some()); + + println!("------- changed ----- "); + + let mut addr = tokio::time::timeout(Duration::from_secs(10), async move { + let mut stream = addr_watcher.stream(); + while let Some(addr) = stream.next().await { + if addr.relay_url.as_ref() != Some(&relay_url) { + return addr; + } + } + panic!("failed to change relay"); + }) + .await + .e()?; + + println!("round2: {:?}", addr); + assert_eq!(addr.relay_url, Some(new_relay_url)); + + // remove direct addrs to force relay usage + addr.direct_addresses.clear(); + + let conn = client.connect(addr, TEST_ALPN).await?; + let (mut send, mut recv) = conn.open_bi().await.e()?; + send.write_all(b"Hello, world!").await.e()?; + send.finish().e()?; + let data = recv.read_to_end(1000).await.e()?; + conn.close(0u32.into(), b"bye!"); + + task.await.e()??; + + client.close().await; + server.close().await; + + assert_eq!(&data, b"Hello, world!"); + + Ok(()) + } + #[tokio::test] #[traced_test] async fn endpoint_bidi_send_recv() -> Result { diff --git a/iroh/src/magicsock.rs b/iroh/src/magicsock.rs index 25488d9c49e..699a36ce260 100644 --- a/iroh/src/magicsock.rs +++ b/iroh/src/magicsock.rs @@ -31,7 +31,7 @@ use std::{ use bytes::Bytes; use data_encoding::HEXLOWER; use iroh_base::{NodeAddr, NodeId, PublicKey, RelayUrl, SecretKey}; -use iroh_relay::RelayMap; +use iroh_relay::{RelayMap, RelayNode}; use n0_future::{ task::{self, AbortOnDropHandle}, time::{self, Duration, Instant}, @@ -204,6 +204,7 @@ pub(crate) struct MagicSock { /// The DNS resolver to be used in this magicsock. #[cfg(not(wasm_browser))] dns_resolver: DnsResolver, + relay_map: RelayMap, /// Disco disco: DiscoState, @@ -297,6 +298,28 @@ impl MagicSock { self.node_map.remote_info(node_id) } + pub(crate) async fn insert_relay( + &self, + relay: RelayUrl, + node: Arc, + ) -> Option> { + let res = self.relay_map.insert(relay, node); + self.actor_sender + .send(ActorMessage::RelayMapChange) + .await + .ok(); + res + } + + pub(crate) async fn remove_relay(&self, relay: &RelayUrl) -> Option> { + let res = self.relay_map.remove(relay); + self.actor_sender + .send(ActorMessage::RelayMapChange) + .await + .ok(); + res + } + /// Returns a [`Watcher`] for this socket's direct addresses. /// /// The [`MagicSock`] continuously monitors the direct addresses, the network addresses @@ -1205,11 +1228,12 @@ enum UpdateReason { PortmapUpdated, LinkChangeMajor, LinkChangeMinor, + RelayMapChange, } impl UpdateReason { fn is_major(self) -> bool { - matches!(self, Self::LinkChangeMajor) + matches!(self, Self::LinkChangeMajor | Self::RelayMapChange) } } @@ -1429,6 +1453,7 @@ impl Handle { node_map, ip_mapped_addrs: ip_mapped_addrs.clone(), discovery, + relay_map: relay_map.clone(), discovery_user_data: RwLock::new(discovery_user_data), direct_addrs: DiscoveredDirectAddrs::default(), net_report: Watchable::new((None, UpdateReason::None)), @@ -1780,6 +1805,7 @@ enum ActorMessage { EndpointPingExpired(usize, stun_rs::TransactionId), NetworkChange, ScheduleDirectAddrUpdate(UpdateReason, Option<(NodeId, RelayUrl)>), + RelayMapChange, #[cfg(test)] ForceNetworkChange(bool), } @@ -2034,6 +2060,10 @@ impl Actor { } } + fn handle_relay_map_change(&mut self) { + self.re_stun(UpdateReason::RelayMapChange); + } + fn re_stun(&mut self, why: UpdateReason) { let state = self.netmon_watcher.get(); self.direct_addr_update_state @@ -2068,6 +2098,9 @@ impl Actor { self.direct_addr_update_state .schedule_run(why, state.into()); } + ActorMessage::RelayMapChange => { + self.handle_relay_map_change(); + } #[cfg(test)] ActorMessage::ForceNetworkChange(is_major) => { self.handle_network_change(is_major).await; diff --git a/iroh/src/net_report.rs b/iroh/src/net_report.rs index bf734c9ab14..80b0450ff01 100644 --- a/iroh/src/net_report.rs +++ b/iroh/src/net_report.rs @@ -440,7 +440,8 @@ impl Client { let mut v6_buf = JoinSet::new(); let cancel_v6 = CancellationToken::new(); - for relay_node in self.relay_map.nodes().take(MAX_RELAYS) { + let nodes = self.relay_map.nodes::>(); + for relay_node in nodes.into_iter().take(MAX_RELAYS) { if if_state.have_v4 && needs_v4_probe { debug!(?relay_node.url, "v4 QAD probe"); let ip_mapped_addrs = self.socket_state.ip_mapped_addrs.clone(); diff --git a/iroh/src/net_report/probes.rs b/iroh/src/net_report/probes.rs index f25f7053a31..5977b3032cc 100644 --- a/iroh/src/net_report/probes.rs +++ b/iroh/src/net_report/probes.rs @@ -101,7 +101,7 @@ impl ProbePlan { pub(super) fn initial(relay_map: &RelayMap, protocols: &BTreeSet) -> Self { let mut plan = Self::default(); - for relay_node in relay_map.nodes() { + for relay_node in relay_map.nodes::>() { let mut https_probes = ProbeSet::new(Probe::Https); for attempt in 0u32..3 { @@ -194,8 +194,8 @@ mod tests { #[tokio::test] async fn test_initial_probeplan() { let (_servers, relay_map) = test_utils::relay_map(2).await; - let relay_node_1 = relay_map.nodes().next().unwrap(); - let relay_node_2 = relay_map.nodes().nth(1).unwrap(); + let relay_node_1 = &relay_map.nodes::>()[0]; + let relay_node_2 = &relay_map.nodes::>()[1]; let plan = ProbePlan::initial(&relay_map, &default_protocols()); let expected_plan: ProbePlan = [ @@ -234,8 +234,8 @@ mod tests { #[tokio::test] async fn test_initial_probeplan_some_protocols() { let (_servers, relay_map) = test_utils::relay_map(2).await; - let relay_node_1 = relay_map.nodes().next().unwrap(); - let relay_node_2 = relay_map.nodes().nth(1).unwrap(); + let relay_node_1 = &relay_map.nodes::>()[0]; + let relay_node_2 = &relay_map.nodes::>()[1]; let plan = ProbePlan::initial(&relay_map, &BTreeSet::from([Probe::Https])); let expected_plan: ProbePlan = [ diff --git a/iroh/src/net_report/reportgen.rs b/iroh/src/net_report/reportgen.rs index 34db5f9b88d..2d29c160b69 100644 --- a/iroh/src/net_report/reportgen.rs +++ b/iroh/src/net_report/reportgen.rs @@ -566,7 +566,7 @@ async fn check_captive_portal( let url = match preferred_relay { Some(url) => url, None => { - let urls: Vec<_> = dm.nodes().map(|n| n.url.clone()).collect(); + let urls: Vec<_> = dm.urls(); if urls.is_empty() { debug!("No suitable relay node for captive portal check"); return Ok(false);