diff --git a/.cargo/config.toml b/.cargo/config.toml index 20711f1b9..13c84ae1c 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -3,4 +3,4 @@ target-dir = "target" rustflags = ["--cfg", "tokio_unstable"] [env] -RUST_LOG = { value = "clash=trace" } \ No newline at end of file +RUST_LOG = { value = "clash=trace" } diff --git a/Cargo.lock b/Cargo.lock index 2201850b1..ebdb557f1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -251,7 +251,7 @@ dependencies = [ "derive-deftly", "derive_builder_fork_arti", "derive_more", - "educe", + "educe 0.4.23", "fs-mistrust", "futures", "hostname-validator", @@ -1148,6 +1148,7 @@ dependencies = [ "aes", "aes-gcm", "anyhow", + "arc-swap", "arti-client", "async-recursion", "async-trait", @@ -1171,6 +1172,8 @@ dependencies = [ "digest", "dirs 6.0.0", "downcast-rs", + "educe 0.6.0", + "enum_dispatch", "env_logger", "erased-serde", "filetime", @@ -1989,12 +1992,24 @@ version = "0.4.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0f0042ff8246a363dbe77d2ceedb073339e85a804b9a47636c6e016a9a32c05f" dependencies = [ - "enum-ordinalize", + "enum-ordinalize 3.1.15", "proc-macro2", "quote", "syn 1.0.109", ] +[[package]] +name = "educe" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d7bc049e1bd8cdeb31b68bbd586a9464ecf9f3944af3958a7a9d0f8b9799417" +dependencies = [ + "enum-ordinalize 4.3.0", + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "either" version = "1.13.0" @@ -2065,6 +2080,26 @@ dependencies = [ "syn 2.0.96", ] +[[package]] +name = "enum-ordinalize" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea0dcfa4e54eeb516fe454635a95753ddd39acda650ce703031c6973e315dd5" +dependencies = [ + "enum-ordinalize-derive", +] + +[[package]] +name = "enum-ordinalize-derive" +version = "4.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "enum_dispatch" version = "0.3.13" @@ -5148,7 +5183,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78c9c2fb898b8b41e90b84234baf8075a7f30cf120101e42afe34acbf4c50ac8" dependencies = [ "derive_more", - "educe", + "educe 0.4.23", "either", "fluid-let", "thiserror 2.0.11", @@ -6193,7 +6228,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ebea3f15b119f60ac4de81fe2e57ed281040bf1ce6e520b38c913fcbeb02b68" dependencies = [ "derive-deftly", - "educe", + "educe 0.4.23", "futures", "oneshot-fused-workaround", "pin-project", @@ -6230,7 +6265,7 @@ dependencies = [ "bytes", "derive-deftly", "digest", - "educe", + "educe 0.4.23", "getrandom 0.2.15", "safelog", "thiserror 2.0.11", @@ -6251,7 +6286,7 @@ dependencies = [ "caret", "derive-deftly", "derive_more", - "educe", + "educe 0.4.23", "paste", "rand 0.8.5", "smallvec", @@ -6294,7 +6329,7 @@ dependencies = [ "caret", "derive_builder_fork_arti", "derive_more", - "educe", + "educe 0.4.23", "futures", "oneshot-fused-workaround", "postage", @@ -6345,7 +6380,7 @@ dependencies = [ "derive_more", "downcast-rs", "dyn-clone", - "educe", + "educe 0.4.23", "futures", "humantime-serde", "itertools 0.14.0", @@ -6389,7 +6424,7 @@ dependencies = [ "cfg-if", "derive-deftly", "derive_builder_fork_arti", - "educe", + "educe 0.4.23", "either", "figment", "fs-mistrust", @@ -6479,7 +6514,7 @@ dependencies = [ "derive_builder_fork_arti", "derive_more", "digest", - "educe", + "educe 0.4.23", "event-listener", "fs-mistrust", "fslock", @@ -6561,7 +6596,7 @@ dependencies = [ "derive_builder_fork_arti", "derive_more", "dyn-clone", - "educe", + "educe 0.4.23", "futures", "humantime", "humantime-serde", @@ -6600,7 +6635,7 @@ dependencies = [ "async-trait", "derive-deftly", "derive_more", - "educe", + "educe 0.4.23", "either", "futures", "itertools 0.14.0", @@ -6762,7 +6797,7 @@ dependencies = [ "derive_more", "digest", "ed25519-dalek", - "educe", + "educe 0.4.23", "getrandom 0.2.15", "hex", "rand_core 0.6.4", @@ -6806,7 +6841,7 @@ dependencies = [ "derive-deftly", "derive_more", "dyn-clone", - "educe", + "educe 0.4.23", "futures", "itertools 0.14.0", "paste", @@ -6871,7 +6906,7 @@ dependencies = [ "derive_builder_fork_arti", "derive_more", "digest", - "educe", + "educe 0.4.23", "hex", "humantime", "itertools 0.14.0", @@ -6946,7 +6981,7 @@ dependencies = [ "derive_builder_fork_arti", "derive_more", "digest", - "educe", + "educe 0.4.23", "futures", "hkdf", "hmac", @@ -7019,7 +7054,7 @@ dependencies = [ "coarsetime", "derive_more", "dyn-clone", - "educe", + "educe 0.4.23", "futures", "futures-rustls", "libc", @@ -7046,7 +7081,7 @@ dependencies = [ "async-trait", "derive-deftly", "derive_more", - "educe", + "educe 0.4.23", "futures", "humantime", "itertools 0.14.0", @@ -7073,7 +7108,7 @@ dependencies = [ "amplify", "caret", "derive-deftly", - "educe", + "educe 0.4.23", "safelog", "subtle", "thiserror 2.0.11", diff --git a/clash/tests/data/config/listeners/tunnel.yaml b/clash/tests/data/config/listeners/tunnel.yaml new file mode 100644 index 000000000..7f1b1ace9 --- /dev/null +++ b/clash/tests/data/config/listeners/tunnel.yaml @@ -0,0 +1,10 @@ +port: 8080 +socks-port: 8081 +log-level: trace +listeners: +- name: tunnel-in + type: tunnel + port: 15201 + listen: 127.0.0.1 + network: [tcp, udp] + target: 127.0.0.1:5201 \ No newline at end of file diff --git a/clash_lib/Cargo.toml b/clash_lib/Cargo.toml index 4336c02f4..948ac1b26 100644 --- a/clash_lib/Cargo.toml +++ b/clash_lib/Cargo.toml @@ -61,6 +61,7 @@ uuid = { version = "1", features = ["v4", "fast-rng", "macro-diagnostics", "serd foreign-types-shared = "0.3" network-interface = { version = "2", git = "https://github.com/Watfaq/network-interface.git", rev = "db9f1bcd2faa5122f694283bd300bf42d166d2f3" } # temp workaround for https://github.com/cross-rs/cross/issues/1627 base64 = "0.22" +arc-swap = "1" httparse = "1" h2 = "0.4" @@ -100,6 +101,11 @@ serde_yaml = "0.9" serde_json = "1" erased-serde = "0.4" + +# Macro Magic +educe = "0.6" +enum_dispatch = "0.3" + # DNS watfaq-dns = { version = "0.1.1" } hickory-client = "0.25.0-alpha.2" diff --git a/clash_lib/src/app/api/handlers/config.rs b/clash_lib/src/app/api/handlers/config.rs index cdd8ceb4b..be2e5d586 100644 --- a/clash_lib/src/app/api/handlers/config.rs +++ b/clash_lib/src/app/api/handlers/config.rs @@ -18,21 +18,21 @@ use crate::{ api::AppState, dispatcher, dns::ThreadSafeDNSResolver, - inbound::manager::{Ports, ThreadSafeInboundManager}, + inbound::manager::{InboundManager, Ports}, }, config::{def, internal::config::BindAddress}, }; #[derive(Clone)] struct ConfigState { - inbound_manager: ThreadSafeInboundManager, + inbound_manager: Arc, dispatcher: Arc, global_state: Arc>, dns_resolver: ThreadSafeDNSResolver, } pub fn routes( - inbound_manager: ThreadSafeInboundManager, + inbound_manager: Arc, dispatcher: Arc, global_state: Arc>, dns_resolver: ThreadSafeDNSResolver, @@ -51,12 +51,11 @@ pub fn routes( } async fn get_configs(State(state): State) -> impl IntoResponse { - let inbound_manager = state.inbound_manager.lock().await; let run_mode = state.dispatcher.get_mode().await; let global_state = state.global_state.lock().await; let dns_resolver = state.dns_resolver; - let ports = inbound_manager.get_ports(); + let ports = state.inbound_manager.get_ports().await; axum::response::Json(PatchConfigRequest { port: ports.port, @@ -64,18 +63,12 @@ async fn get_configs(State(state): State) -> impl IntoResponse { redir_port: ports.redir_port, tproxy_port: ports.tproxy_port, mixed_port: ports.mixed_port, - bind_address: Some(inbound_manager.get_bind_address().to_string()), + bind_address: Some(state.inbound_manager.get_bind_address().0.to_string()), mode: Some(run_mode), log_level: Some(global_state.log_level), ipv6: Some(dns_resolver.ipv6()), - allow_lan: Some(match inbound_manager.get_bind_address() { - BindAddress::Any => true, - BindAddress::One(one) => match one { - crate::proxy::utils::Interface::IpAddr(ip) => !ip.is_loopback(), - crate::proxy::utils::Interface::Name(iface) => iface != "lo", - }, - }), + allow_lan: Some(state.inbound_manager.get_bind_address().0.is_unspecified()), }) } @@ -187,12 +180,13 @@ async fn patch_configs( ); } - let mut inbound_manager = state.inbound_manager.lock().await; - + let inbound_manager = state.inbound_manager.clone(); + let mut need_restart = false; if let Some(bind_address) = payload.bind_address.clone() { match bind_address.parse::() { Ok(bind_address) => { - inbound_manager.set_bind_address(bind_address); + inbound_manager.set_bind_address(bind_address).await; + need_restart = true; } Err(_) => { return ( @@ -208,7 +202,7 @@ async fn patch_configs( if payload.rebuild_listeners() { // TODO: maybe buggy - let current_ports = inbound_manager.get_ports(); + let current_ports = inbound_manager.get_ports().await; let ports = Ports { port: payload.port.or(current_ports.port), @@ -217,14 +211,11 @@ async fn patch_configs( tproxy_port: payload.tproxy_port.or(current_ports.tproxy_port), mixed_port: payload.mixed_port.or(current_ports.mixed_port), }; - - inbound_manager.rebuild_listeners(ports); - - global_state.inbound_listener_handle.abort(); - - let r = inbound_manager.get_runner().unwrap(); - - global_state.inbound_listener_handle = tokio::spawn(r); + inbound_manager.change_ports(ports).await; + need_restart = true; + } + if need_restart { + inbound_manager.restart().await; } if let Some(mode) = payload.mode { diff --git a/clash_lib/src/app/api/mod.rs b/clash_lib/src/app/api/mod.rs index 259a2d753..53e2429d3 100644 --- a/clash_lib/src/app/api/mod.rs +++ b/clash_lib/src/app/api/mod.rs @@ -19,9 +19,12 @@ use tracing::{error, info}; use crate::{GlobalState, Runner, config::internal::config::Controller}; use super::{ - dispatcher, dispatcher::StatisticsManager, dns::ThreadSafeDNSResolver, - inbound::manager::ThreadSafeInboundManager, logging::LogEvent, - outbound::manager::ThreadSafeOutboundManager, profile::ThreadSafeCacheFile, + dispatcher::{self, StatisticsManager}, + dns::ThreadSafeDNSResolver, + inbound::manager::InboundManager, + logging::LogEvent, + outbound::manager::ThreadSafeOutboundManager, + profile::ThreadSafeCacheFile, router::ThreadSafeRouter, }; @@ -37,7 +40,7 @@ pub struct AppState { pub fn get_api_runner( controller_cfg: Controller, log_source: Sender, - inbound_manager: ThreadSafeInboundManager, + inbound_manager: Arc, dispatcher: Arc, global_state: Arc>, dns_resolver: ThreadSafeDNSResolver, diff --git a/clash_lib/src/app/dispatcher/dispatcher_impl.rs b/clash_lib/src/app/dispatcher/dispatcher_impl.rs index 2981de331..09e098d46 100644 --- a/clash_lib/src/app/dispatcher/dispatcher_impl.rs +++ b/clash_lib/src/app/dispatcher/dispatcher_impl.rs @@ -238,6 +238,7 @@ impl Dispatcher { /// Dispatch a UDP packet to outbound handler /// returns the close sender #[instrument] + #[must_use] pub async fn dispatch_datagram( &self, sess: Session, @@ -458,10 +459,13 @@ impl Dispatcher { let (close_sender, close_receiver) = tokio::sync::oneshot::channel::(); tokio::spawn(async move { - let _ = close_receiver.await; - trace!("UDP close signal for {} received", s); - t1.abort(); - t2.abort(); + if (close_receiver.await).is_ok() { + trace!("UDP close signal for {} received", s); + t1.abort(); + t2.abort(); + } else { + error!("UDP close signal dropped!"); + } }); return close_sender; diff --git a/clash_lib/src/app/inbound/manager.rs b/clash_lib/src/app/inbound/manager.rs index 63281c5d2..45621a6cd 100644 --- a/clash_lib/src/app/inbound/manager.rs +++ b/clash_lib/src/app/inbound/manager.rs @@ -1,27 +1,22 @@ +use arc_swap::ArcSwap; use serde::{Deserialize, Serialize}; -use tokio::sync::Mutex; +use tokio::{ + sync::{RwLock, oneshot}, + task::{JoinHandle, JoinSet}, +}; +use tracing::error; use crate::{ - Error, Runner, + Result, app::{ - dispatcher::Dispatcher, - inbound::network_listener::{ListenerType, NetworkInboundListener}, + dispatcher::Dispatcher, inbound::network_listener::NetworkInboundHandler, }, - common::{auth::ThreadSafeAuthenticator, errors::new_io_error}, - config::internal::config::{BindAddress, Inbound}, + common::auth::ThreadSafeAuthenticator, + config::internal::{config::BindAddress, listener::InboundOpts}, }; use std::{collections::HashMap, sync::Arc}; -pub struct InboundManager { - network_listeners: HashMap, - dispatcher: Arc, - bind_address: BindAddress, - authenticator: ThreadSafeAuthenticator, -} - -pub type ThreadSafeInboundManager = Arc>; - -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct Ports { pub port: Option, #[serde(rename = "socks-port")] @@ -34,153 +29,217 @@ pub struct Ports { pub mixed_port: Option, } +type TaskHandle = RwLock>, oneshot::Sender<()>)>>; + +pub struct InboundManager { + dispatcher: Arc, + bind_address: ArcSwap, + authenticator: ThreadSafeAuthenticator, + + inbounds_opt: RwLock>, + inbounds_handler: RwLock>, + + task_handle: TaskHandle, +} + impl InboundManager { - pub fn new( - inbound: Inbound, + pub async fn new( + bind_address: BindAddress, + _authentication: Vec, // TODO dispatcher: Arc, authenticator: ThreadSafeAuthenticator, - ) -> Result { - let network_listeners = HashMap::new(); - - let mut s = Self { - network_listeners, + inbounds_opt: HashMap, + ) -> Result { + let s = Self { + inbounds_handler: HashMap::with_capacity(3).into(), dispatcher, - bind_address: inbound.bind_address, + bind_address: ArcSwap::new(bind_address.into()), authenticator, + inbounds_opt: inbounds_opt.into(), + task_handle: RwLock::new(None), }; - - let ports = Ports { - port: inbound.port, - socks_port: inbound.socks_port, - redir_port: inbound.redir_port, - tproxy_port: inbound.tproxy_port, - mixed_port: inbound.mixed_port, - }; - - s.rebuild_listeners(ports); + s.build_handlers().await; Ok(s) } - pub fn get_runner(&self) -> Result { - let mut runners = Vec::new(); - for r in self.network_listeners.values() { - runners.append(&mut r.listen()?); + pub async fn start(self: &Arc) { + let mut guard = self.task_handle.write().await; + if let Some((handle, signal)) = guard.take() { + _ = signal.send(()); + handle.abort(); } - Ok(Box::pin(async move { - let mut errors = Vec::new(); - let _ = futures::future::join_all(runners) - .await - .into_iter() - .filter_map(|r| r.map_err(|e| errors.push(e)).ok()) - .collect::>(); - if errors.is_empty() { - Ok(()) - } else { - Err(new_io_error(format!( - "failed to start inbound listeners: {:?}", - errors - )) - .into()) - } - })) + let v = self.clone(); + let (signal_tx, signal_rx) = oneshot::channel(); + let handle = tokio::spawn(async move { v.build_task(signal_rx).await }); + *guard = Some((handle, signal_tx)); } - /// API handlers below - pub fn get_bind_address(&self) -> &BindAddress { - &self.bind_address + // FIXME: This is not working if + // 1. Inner nested spawned tasks. + // 2. spawn_blocking + pub async fn shutdown(&self) { + if let Some((handle, signal)) = self.task_handle.write().await.take() { + _ = signal.send(()); + handle.abort(); + } } - pub fn set_bind_address(&mut self, bind_address: BindAddress) { - self.bind_address = bind_address; + pub async fn restart(self: &Arc) { + self.build_handlers().await; + self.start().await; } - pub fn get_ports(&self) -> Ports { - let mut ports = Ports { - port: None, - socks_port: None, - redir_port: None, - tproxy_port: None, - mixed_port: None, - }; - self.network_listeners - .values() - .for_each(|x| match x.listener_type { - ListenerType::Http => { - ports.port = Some(x.port); - } - ListenerType::Socks5 => { - ports.socks_port = Some(x.port); - } - ListenerType::Mixed => { - ports.mixed_port = Some(x.port); + // Build `inbounds_handler` tasks + async fn build_task( + self: &Arc, + mut signal: oneshot::Receiver<()>, + ) -> Result<()> { + let mut runners = JoinSet::new(); + for (_, handler) in self.inbounds_handler.read().await.iter() { + handler.listen(&mut runners)?; + } + loop { + tokio::select! { + Some(result) = runners.join_next() => { + match result { + Ok(Err(e)) => error!("failed to start inbound listeners: {e:?}"), + Err(e) => { + if let Ok(reason) = e.try_into_panic() { + std::panic::resume_unwind(reason); + } + } + _ => {} + } } - ListenerType::Tproxy => { - ports.tproxy_port = Some(x.port); + _ = &mut signal => { + runners.shutdown().await; + break; } - }); - - ports + } + } + Ok(()) } - pub fn rebuild_listeners(&mut self, ports: Ports) { - let mut network_listeners = HashMap::new(); - if let Some(http_port) = ports.port { + // Sync `inbounds_handler` with `inbounds_opt` + async fn build_handlers(&self) { + let mut network_listeners = HashMap::with_capacity(3); + let guard = self.inbounds_opt.read().await; + for (name, inbound) in guard.iter() { network_listeners.insert( - ListenerType::Http, - NetworkInboundListener { - name: "HTTP".to_string(), - bind_addr: self.bind_address.clone(), - port: http_port, - listener_type: ListenerType::Http, + name.clone(), + NetworkInboundHandler { + name: name.to_string(), dispatcher: self.dispatcher.clone(), authenticator: self.authenticator.clone(), + listener: inbound.clone(), // TODO use Arc }, ); } - if let Some(socks_port) = ports.socks_port { - network_listeners.insert( - ListenerType::Socks5, - NetworkInboundListener { - name: "SOCKS5".to_string(), - bind_addr: self.bind_address.clone(), - port: socks_port, - listener_type: ListenerType::Socks5, - dispatcher: self.dispatcher.clone(), - authenticator: self.authenticator.clone(), - }, - ); - } + *self.inbounds_handler.write().await = network_listeners; + } - if let Some(mixed_port) = ports.mixed_port { - network_listeners.insert( - ListenerType::Mixed, - NetworkInboundListener { - name: "Mixed".to_string(), - bind_addr: self.bind_address.clone(), - port: mixed_port, - listener_type: ListenerType::Mixed, - dispatcher: self.dispatcher.clone(), - authenticator: self.authenticator.clone(), - }, - ); + // RESTFUL API handlers below + pub async fn get_ports(&self) -> Ports { + let mut ports = Ports::default(); + let guard = self.inbounds_opt.read().await; + for (_, opts) in guard.iter() { + match &opts { + InboundOpts::Http { + common_opts, + inherited, + } => { + if *inherited { + ports.port = Some(common_opts.port) + } + } + InboundOpts::Socks { + common_opts, + inherited, + .. + } => { + if *inherited { + ports.socks_port = Some(common_opts.port) + } + } + InboundOpts::Mixed { + common_opts, + inherited, + .. + } => { + if *inherited { + ports.mixed_port = Some(common_opts.port) + } + } + InboundOpts::TProxy { + common_opts, + inherited, + .. + } => { + if *inherited { + ports.tproxy_port = Some(common_opts.port) + } + } + InboundOpts::Redir { + common_opts, + inherited, + } => { + if *inherited { + ports.redir_port = Some(common_opts.port) + } + } + _ => {} + } } + ports + } - if let Some(tproxy_port) = ports.tproxy_port { - network_listeners.insert( - ListenerType::Tproxy, - NetworkInboundListener { - name: "TProxy".to_string(), - bind_addr: self.bind_address.clone(), - port: tproxy_port, - listener_type: ListenerType::Tproxy, - dispatcher: self.dispatcher.clone(), - authenticator: self.authenticator.clone(), - }, - ); + pub fn get_bind_address(&self) -> BindAddress { + **self.bind_address.load() + } + + pub async fn set_bind_address(&self, bind_address: BindAddress) { + self.bind_address.store(Arc::new(bind_address)); + let mut guard = self.inbounds_opt.write().await; + for (_, opts) in guard.iter_mut() { + if opts.inherited() { + opts.common_opts_mut().listen = bind_address + } } + } - self.network_listeners = network_listeners; + pub async fn change_ports(&self, ports: Ports) { + let mut guard = self.inbounds_opt.write().await; + for (_, opts) in guard.iter_mut() { + match &opts { + InboundOpts::Http { inherited, .. } => { + if *inherited && let Some(port) = ports.port { + *opts.port_mut() = port + } + } + InboundOpts::Socks { inherited, .. } => { + if *inherited && let Some(port) = ports.socks_port { + *opts.port_mut() = port + } + } + InboundOpts::Mixed { inherited, .. } => { + if *inherited && let Some(port) = ports.mixed_port { + *opts.port_mut() = port + } + } + InboundOpts::TProxy { inherited, .. } => { + if *inherited && let Some(port) = ports.tproxy_port { + *opts.port_mut() = port + } + } + InboundOpts::Redir { inherited, .. } => { + if *inherited && let Some(port) = ports.redir_port { + *opts.port_mut() = port + } + } + _ => {} + } + } } } diff --git a/clash_lib/src/app/inbound/network_listener.rs b/clash_lib/src/app/inbound/network_listener.rs index 08da1c8cf..9f5903e60 100644 --- a/clash_lib/src/app/inbound/network_listener.rs +++ b/clash_lib/src/app/inbound/network_listener.rs @@ -1,178 +1,128 @@ use crate::{ - common::auth::ThreadSafeAuthenticator, config::internal::config::BindAddress, + common::auth::ThreadSafeAuthenticator, + config::listener::InboundOpts, + proxy::{ + http::HttpInbound, + inbound::{InboudHandler, InboundHandlerTrait as _}, + mixed::MixedInbound, + socks::SocksInbound, + tunnel::TunnelInbound, + }, }; -use crate::proxy::{AnyInboundListener, http, mixed, socks}; - #[cfg(target_os = "linux")] -use crate::proxy::tproxy; +use crate::proxy::tproxy::TproxyInbound; -use crate::{Dispatcher, Error, Runner, proxy::utils::Interface}; -use futures::FutureExt; -use network_interface::{Addr, NetworkInterfaceConfig}; +use crate::Dispatcher; +use tokio::task::JoinSet; use tracing::{info, warn}; -use std::{ - net::{IpAddr, Ipv4Addr}, - sync::Arc, -}; - -#[derive(Eq, PartialEq, Hash, Clone, Debug)] -pub enum ListenerType { - Http, - Socks5, - Mixed, - Tproxy, -} +use std::sync::Arc; -pub struct NetworkInboundListener { +pub struct NetworkInboundHandler { pub name: String, - pub bind_addr: BindAddress, - pub port: u16, - pub listener_type: ListenerType, + pub listener: InboundOpts, pub dispatcher: Arc, pub authenticator: ThreadSafeAuthenticator, } -impl NetworkInboundListener { - pub fn listen(&self) -> Result, Error> { - let mut runners = Vec::::new(); - - match &self.bind_addr { - BindAddress::Any => { - #[cfg(target_os = "ios")] - { - let all_ifaces = network_interface::NetworkInterface::show() - .expect("list interfaces"); - - for iface in all_ifaces.into_iter() { - let ip = iface - .addr - .into_iter() - .filter_map(|x| match x { - Addr::V4(v4) => { - if v4.ip.is_unspecified() - || v4.ip.is_link_local() - || v4.ip.is_multicast() - { - None - } else { - Some(v4.ip) - } - } - Addr::V6(_) => None, - }) - .next(); - - if !ip.is_some() { - continue; - } - - self.build_and_insert_listener(&mut runners, ip.unwrap()); - } - } - #[cfg(not(target_os = "ios"))] - { - let ip = "0.0.0.0".parse().expect("must parse"); - self.build_and_insert_listener(&mut runners, ip); - } - } - BindAddress::One(iface) => match iface { - Interface::IpAddr(ip) => match ip { - IpAddr::V4(ip) => { - self.build_and_insert_listener(&mut runners, *ip) - } - IpAddr::V6(_) => unreachable!("unsupported listening v6"), - }, - Interface::Name(iface) => { - let ip = network_interface::NetworkInterface::show() - .expect("list interfaces") - .into_iter() - .filter(|x| &x.name == iface) - .flat_map(|x| x.addr) - .map(|x| match x { - Addr::V4(v4) => v4.ip, - Addr::V6(_) => unreachable!(), - }) - .find(|x| { - !x.is_unspecified() - && !x.is_link_local() - && !x.is_multicast() - }) - .expect("no valid ip"); - - self.build_and_insert_listener(&mut runners, ip); - } - }, - }; - - Ok(runners) +impl NetworkInboundHandler { + pub fn listen( + &self, + set: &mut JoinSet>, + ) -> crate::Result<()> { + self.build_and_insert_listener(set) } - fn build_and_insert_listener(&self, runners: &mut Vec, ip: Ipv4Addr) { - let listener: AnyInboundListener = match self.listener_type { - ListenerType::Http => Arc::new(http::Listener::new( - (ip, self.port).into(), + fn build_and_insert_listener( + &self, + set: &mut JoinSet>, + ) -> crate::Result<()> { + let handler: InboudHandler = match &self.listener { + InboundOpts::Http { common_opts, .. } => HttpInbound::new( + (common_opts.listen.0, common_opts.port).into(), self.dispatcher.clone(), self.authenticator.clone(), - )), - ListenerType::Socks5 => Arc::new(socks::Listener::new( - (ip, self.port).into(), + ) + .into(), + InboundOpts::Socks { common_opts, .. } => SocksInbound::new( + (common_opts.listen.0, common_opts.port).into(), self.dispatcher.clone(), self.authenticator.clone(), - )), - ListenerType::Mixed => Arc::new(mixed::Listener::new( - (ip, self.port).into(), + ) + .into(), + InboundOpts::Mixed { common_opts, .. } => MixedInbound::new( + (common_opts.listen.0, common_opts.port).into(), self.dispatcher.clone(), self.authenticator.clone(), - )), - ListenerType::Tproxy => { + ) + .into(), + #[allow(unused)] + InboundOpts::TProxy { common_opts, .. } => { #[cfg(target_os = "linux")] { - Arc::new(tproxy::Listener::new( - (ip, self.port).into(), + TproxyInbound::new( + (common_opts.listen.0, common_opts.port).into(), self.dispatcher.clone(), - )) + ) + .into() } + #[cfg(not(target_os = "linux"))] { warn!("tproxy is not supported on this platform"); - return; + return Ok(()); } } + InboundOpts::Redir { .. } => unimplemented!(), + InboundOpts::Tunnel { + common_opts, + network, + target, + } => TunnelInbound::new( + (common_opts.listen.0, common_opts.port).into(), + self.dispatcher.clone(), + network.clone(), + target.clone(), + )? + .into(), }; + let handler = Arc::new(handler); + if handler.handle_tcp() { + info!( + "{} TCP listening at: {}:{}", + self.name, + self.listener.common_opts().listen.0, + self.listener.common_opts().port + ); - if listener.handle_tcp() { - let listener_type = self.listener_type.clone(); - info!("{} TCP listening at: {}:{}", self.name, ip, self.port); + let tcp_listener = handler.clone(); - let tcp_listener = listener.clone(); - runners.push( - async move { - tcp_listener.listen_tcp().await.map_err(|e| { - warn!( - "handler of {:?} tcp listen failed: {}", - listener_type, e - ); - e.into() - }) - } - .boxed(), - ); + let name = self.name.clone(); + set.spawn(async move { + tcp_listener.listen_tcp().await.map_err(|e| { + warn!("handler {} tcp listen failed: {e}", name); + e.into() + }) + }); } - if listener.handle_udp() { - info!("{} UDP listening at: {}:{}", self.name, ip, self.port); - let udp_listener = listener.clone(); - runners.push( - async move { - udp_listener.listen_udp().await.map_err(|e| { - warn!("handler udp listen failed: {}", e); - e.into() - }) - } - .boxed(), + if handler.handle_udp() { + info!( + "{} UDP listening at: {}:{}", + self.name, + self.listener.common_opts().listen.0, + self.listener.common_opts().port ); + let udp_listener = handler.clone(); + let name = self.name.clone(); + set.spawn(async move { + udp_listener.listen_udp().await.map_err(|e| { + warn!("handler {} udp listen failed: {e}", name); + e.into() + }) + }); } + Ok(()) } } diff --git a/clash_lib/src/config/def.rs b/clash_lib/src/config/def.rs index d2012797c..be96a7e74 100644 --- a/clash_lib/src/config/def.rs +++ b/clash_lib/src/config/def.rs @@ -1,9 +1,12 @@ use crate::Error; use std::{collections::HashMap, fmt::Display, path::PathBuf, str::FromStr}; -use serde::{Deserialize, Serialize}; +use educe::Educe; +use serde::{Deserialize, Deserializer, Serialize}; use serde_yaml::Value; +use super::config::BindAddress; + fn default_tun_address() -> String { "198.18.0.1/32".to_string() } @@ -92,26 +95,6 @@ impl Display for LogLevel { } } -#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] -#[serde(untagged)] -pub enum Port { - Str(String), - Num(u16), -} - -impl TryInto for Port { - type Error = Error; - - fn try_into(self) -> Result { - match self { - Port::Num(x) => Ok(x), - Port::Str(x) => x - .parse() - .map_err(|_| Error::InvalidConfig(format!("invalid port: {}", x))), - } - } -} - /// Example /// ```yaml /// --- @@ -275,10 +258,12 @@ impl TryInto for Port { /// - MATCH, DIRECT /// ... /// ``` -#[derive(Serialize, Deserialize)] +#[derive(Deserialize, Educe)] #[serde(rename_all = "kebab-case", default)] +#[educe(Default)] pub struct Config { /// The HTTP proxy port + #[serde(alias = "http_port")] pub port: Option, /// The SOCKS5 proxy port pub socks_port: Option, @@ -296,8 +281,8 @@ pub struct Config { /// HTTP and SOCKS5 proxy authentication pub authentication: Vec, /// Allow connections to the local-end server from other LAN IP addresses - #[deprecated = "dont use. see `bind_address`"] - pub allow_lan: bool, + /// Deprecated see `bind_address` + pub allow_lan: Option, /// The address that the inbound listens on /// # Note /// - setting this to `*` will listen on all interfaces, which is @@ -305,7 +290,7 @@ pub struct Config { /// - setting this to non local IP will enable `allow_lan` automatically /// - and if you don't want `allow_lan` to be enabled, you should set this /// to `localhost` or `127.1` - pub bind_address: String, + pub bind_address: BindAddress, /// Clash router working mode /// Either `rule`, `global` or `direct` pub mode: RunMode, @@ -318,26 +303,32 @@ pub struct Config { pub profile: Profile, /// Proxy settings #[serde(rename = "proxies")] - pub proxy: Vec>, + pub proxy: Option>>, #[serde(rename = "proxy-groups")] /// Proxy group settings - pub proxy_group: Vec>, + pub proxy_group: Option>>, #[serde(rename = "rules")] /// Rule settings - pub rule: Vec, + pub rule: Option>, /// Hosts pub hosts: HashMap, /// Country database path relative to the $CWD + #[educe(Default = "Country.mmdb")] pub mmdb: String, /// Country database download url + // TODO not compatiable with clash-meta + #[educe(Default = Some("https://github.com/Loyalsoldier/geoip/releases/download/202307271745/Country.mmdb".into()))] pub mmdb_download_url: Option, - /// Optional ASN database path relative to the $CWD + /// Optional ASN database path relative to the working dir + #[educe(Default = "Country-asn.mmdb")] pub asn_mmdb: String, /// Optional ASN database download url pub asn_mmdb_download_url: Option, /// Geosite database path relative to the $CWD + #[educe(Default = "geosite.dat")] pub geosite: String, /// Geosite database download url + #[educe(Default = Some("https://github.com/Loyalsoldier/v2ray-rules-dat/releases/download/202406182210/geosite.dat".into()))] pub geosite_download_url: Option, // these options has default vals, @@ -345,7 +336,6 @@ pub struct Config { /// whether your network environment supports IPv6 /// this will affect the DNS server response to AAAA questions /// default is `false` - #[serde(default = "Default::default")] pub ipv6: bool, /// external controller address pub external_controller: Option, @@ -379,6 +369,9 @@ pub struct Config { /// device-id: "dev://utun1989" /// ``` pub tun: Option, + + #[serde(rename = "listeners")] + pub listener: Option>>, } impl TryFrom for Config { @@ -395,72 +388,26 @@ impl FromStr for Config { type Err = Error; fn from_str(s: &str) -> Result { - let mut val: Value = serde_yaml::from_str(s).map_err(|x| { + let mut val: Value = serde_yaml::from_str(s).map_err(|e| { Error::InvalidConfig(format!( - "couldn't not parse config content {}: {}", - s, x + "couldn't not parse config content {s}: {e}" )) })?; - val.apply_merge().map_err(|x| { + val.apply_merge().map_err(|e| { Error::InvalidConfig(format!( - "failed to process anchors in config content {}: {}", - s, x + "failed to process anchors in config content {s}: {e}" )) })?; - serde_yaml::from_value(val).map_err(|x| { + serde_yaml::from_value(val).map_err(|e| { Error::InvalidConfig(format!( - "counldn't not parse config content {}: {}", - s, x + "counldn't not parse config content {s}: {e}" )) }) } } -impl Default for Config { - fn default() -> Self { - #[allow(deprecated)] - Self { - port: Default::default(), - socks_port: Default::default(), - redir_port: Default::default(), - tproxy_port: Default::default(), - mixed_port: Default::default(), - authentication: Default::default(), - allow_lan: Default::default(), - bind_address: String::from("*"), - mode: Default::default(), - log_level: Default::default(), - ipv6: Default::default(), - external_controller: Default::default(), - external_ui: Default::default(), - secret: Default::default(), - interface: Default::default(), - routing_mask: Default::default(), - proxy_provider: Default::default(), - rule_provider: Default::default(), - hosts: Default::default(), - dns: Default::default(), - experimental: Default::default(), - profile: Default::default(), - proxy: Default::default(), - proxy_group: Default::default(), - rule: Default::default(), - mmdb: "Country.mmdb".to_string(), - mmdb_download_url: Some( - "https://github.com/Loyalsoldier/geoip/releases/download/202307271745/Country.mmdb" - .to_owned(), - ), - asn_mmdb: "Country-asn.mmdb".to_string(), - asn_mmdb_download_url: None, // can be downloaded from the same release but let's not make it default - geosite: "geosite.dat".to_string(), - geosite_download_url: Some("https://github.com/Loyalsoldier/v2ray-rules-dat/releases/download/202406182210/geosite.dat".to_owned()), - tun: Default::default(), - } - } -} - #[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] #[serde(untagged)] pub enum DNSListen { @@ -489,8 +436,9 @@ pub enum DNSListen { /// ca-key: dns.key /// ``` -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Educe)] #[serde(rename_all = "kebab-case", default)] +#[educe(Default)] pub struct DNS { /// When disabled, system DNS config will be used /// All other DNS related options will only be used when this is enabled @@ -498,6 +446,7 @@ pub struct DNS { /// When false, response to AAAA questions will be empty pub ipv6: bool, /// Whether to `Config::hosts` as when resolving hostnames + #[educe(Default = true)] pub user_hosts: bool, /// DNS servers pub nameserver: Vec, @@ -511,37 +460,20 @@ pub struct DNS { /// Whether to use fake IP addresses pub enhanced_mode: DNSMode, /// Fake IP addresses pool CIDR + #[educe(Default = "198.18.0.1/16")] pub fake_ip_range: String, /// Fake IP addresses filter pub fake_ip_filter: Vec, /// Default nameservers, used to resolve DoH hostnames + #[educe(Default = vec![ + String::from("114.114.114.114"), + String::from("8.8.8.8")] + )] pub default_nameserver: Vec, /// Lookup domains via specific nameservers pub nameserver_policy: HashMap, } -impl Default for DNS { - fn default() -> Self { - Self { - enable: Default::default(), - ipv6: Default::default(), - user_hosts: true, - nameserver: Default::default(), - fallback: Default::default(), - fallback_filter: Default::default(), - listen: Default::default(), - enhanced_mode: Default::default(), - fake_ip_range: String::from("198.18.0.1/16"), - fake_ip_filter: Default::default(), - default_nameserver: vec![ - String::from("114.114.114.114"), - String::from("8.8.8.8"), - ], - nameserver_policy: Default::default(), - } - } -} - #[derive(Serialize, Deserialize, Default, Clone, Debug)] #[serde(rename_all = "kebab-case")] pub enum DNSMode { @@ -551,29 +483,23 @@ pub enum DNSMode { RedirHost, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, Educe)] #[serde(default)] +#[educe(Default)] pub struct FallbackFilter { #[serde(rename = "geoip")] + #[educe(Default = true)] pub geo_ip: bool, + #[serde(rename = "geoip-code")] + #[educe(Default = "CN")] pub geo_ip_code: String, + #[serde(rename = "ipcidr")] pub ip_cidr: Vec, pub domain: Vec, } -impl Default for FallbackFilter { - fn default() -> Self { - Self { - geo_ip: true, - geo_ip_code: String::from("CN"), - ip_cidr: Default::default(), - domain: Default::default(), - } - } -} - #[derive(Serialize, Deserialize, Default)] #[serde(rename_all = "kebab-case")] pub struct Experimental { @@ -600,6 +526,46 @@ impl Default for Profile { } } +#[derive(PartialEq, Debug, Clone, Serialize, Copy)] +pub struct Port(pub u16); + +impl From for u16 { + fn from(val: Port) -> Self { + val.0 + } +} + +impl<'de> Deserialize<'de> for Port { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[derive(Deserialize)] + #[serde(untagged)] + enum StrOrNum { + Str(String), + Num(u64), + Other, + } + + let value = StrOrNum::deserialize(deserializer)?; + + match value { + StrOrNum::Num(num) => u16::try_from(num) + .map(Port) + .map_err(|_| serde::de::Error::custom("Port number out of range")), + + StrOrNum::Str(s) => { + s.parse::().map(Port).map_err(serde::de::Error::custom) + } + + StrOrNum::Other => { + Err(serde::de::Error::custom("Invalid type for port")) + } + } + } +} + #[cfg(test)] mod tests { use serde_yaml::Value; @@ -614,7 +580,7 @@ mod tests { port: 9090 "#; let c = cfg.parse::().expect("should parse"); - assert_eq!(c.port, Some(Port::Num(9090))); + assert_eq!(c.port, Some(Port(9090))); } #[test] @@ -623,7 +589,7 @@ mod tests { port: "9090" "#; let c = cfg.parse::().expect("should parse"); - assert_eq!(c.port, Some(Port::Str("9090".to_string()))); + assert_eq!(c.port, Some(Port(9090))); } #[test] @@ -1096,19 +1062,28 @@ rules: let des: Config = serde_yaml::from_str(example_cfg).expect("should parse yaml"); - assert_eq!(des.port.expect("invalid port"), Port::Num(7890)); + assert_eq!(des.port.expect("invalid port"), Port(7890)); assert_eq!(des.dns.fallback_filter.geo_ip_code, String::from("CN")); - assert_eq!(des.proxy.len(), 14); - assert_eq!(des.proxy[2].get("name").unwrap().as_str(), Some("ss3")); + assert_eq!(des.proxy.as_ref().map(|v| v.len()).unwrap_or(0), 14); + assert_eq!( + des.proxy + .as_ref() + .map(|v| v[2].get("name").unwrap().as_str()) + .unwrap(), + Some("ss3") + ); assert_eq!( - des.proxy[2] - .get("plugin-opts") - .unwrap() - .as_mapping() - .unwrap() - .get(Value::String("mode".into())) - .unwrap() - .as_str(), + des.proxy + .as_ref() + .map(|v| v[2] + .get("plugin-opts") + .unwrap() + .as_mapping() + .unwrap() + .get(Value::String("mode".into())) + .unwrap() + .as_str()) + .unwrap(), Some("websocket") ); } diff --git a/clash_lib/src/config/internal/config.rs b/clash_lib/src/config/internal/config.rs index 6679c96af..c0e0f1f30 100644 --- a/clash_lib/src/config/internal/config.rs +++ b/clash_lib/src/config/internal/config.rs @@ -1,10 +1,12 @@ use std::collections::HashMap; -use std::{fmt::Display, net::IpAddr, str::FromStr}; +use std::{ + net::{IpAddr, Ipv4Addr}, + str::FromStr, +}; use ipnet::IpNet; -use serde::{Deserialize, Serialize, de::value::MapDeserializer}; -use serde_yaml::Value; +use serde::{Deserialize, Serialize}; use crate::{ Error, @@ -12,17 +14,12 @@ use crate::{ common::auth, config::{ def::{self, LogLevel, RunMode}, - internal::{ - proxy::{OutboundProxy, PROXY_DIRECT, PROXY_REJECT}, - rule::RuleType, - }, + internal::{proxy::OutboundProxy, rule::RuleType}, }, proxy::utils::Interface, }; -use super::proxy::{ - OutboundProxyProtocol, OutboundProxyProviderDef, map_serde_error, -}; +use super::{listener::InboundOpts, proxy::OutboundProxyProviderDef}; pub struct Config { pub general: General, @@ -38,10 +35,11 @@ pub struct Config { pub proxies: HashMap, pub proxy_groups: HashMap, pub proxy_providers: HashMap, + pub listeners: HashMap, } impl Config { - fn validate(self) -> Result { + pub fn validate(self) -> Result { for r in self.rules.iter() { if !self.proxies.contains_key(r.target()) && !self.proxy_groups.contains_key(r.target()) @@ -56,252 +54,9 @@ impl Config { } } -impl TryFrom for Config { - type Error = crate::Error; - - fn try_from(c: def::Config) -> Result { - let mut proxy_names = - vec![String::from(PROXY_DIRECT), String::from(PROXY_REJECT)]; - #[allow(deprecated)] - Self { - general: General { - inbound: Inbound { - port: c.port.clone().map(|x| x.try_into()).transpose()?, - socks_port: c - .socks_port - .clone() - .map(|x| x.try_into()) - .transpose()?, - redir_port: c - .redir_port - .clone() - .map(|x| x.try_into()) - .transpose()?, - tproxy_port: c - .tproxy_port - .clone() - .map(|x| x.try_into()) - .transpose()?, - mixed_port: c - .mixed_port - .clone() - .map(|x| x.try_into()) - .transpose()?, - authentication: c.authentication.clone(), - bind_address: c.bind_address.parse()?, - }, - controller: Controller { - external_controller: c.external_controller.clone(), - external_ui: c.external_ui.clone(), - secret: c.secret.clone(), - }, - mode: c.mode, - log_level: c.log_level, - ipv6: c.ipv6, - interface: c.interface.as_ref().map(|iface| { - if let Ok(addr) = iface.parse::() { - Interface::IpAddr(addr) - } else { - Interface::Name(iface.to_string()) - } - }), - routing_mask: c.routing_mask, - mmdb: c.mmdb.to_owned(), - mmdb_download_url: c.mmdb_download_url.to_owned(), - asn_mmdb: c.asn_mmdb.to_owned(), - asn_mmdb_download_url: c.asn_mmdb_download_url.to_owned(), - geosite: c.geosite.to_owned(), - geosite_download_url: c.geosite_download_url.to_owned(), - }, - dns: (&c).try_into()?, - experimental: c.experimental, - tun: match c.tun { - Some(t) => TunConfig { - enable: t.enable, - device_id: t.device_id, - route_all: t.route_all, - routes: t - .routes - .map(|r| { - r.into_iter() - .map(|x| x.parse()) - .collect::, _>>() - }) - .transpose() - .map_err(|x| { - Error::InvalidConfig(format!("parse tun routes: {}", x)) - })? - .unwrap_or_default(), - gateway: t.gateway.parse().map_err(|x| { - Error::InvalidConfig(format!("parse tun gateway: {}", x)) - })?, - mtu: t.mtu, - so_mark: t.so_mark, - route_table: t.route_table, - dns_hijack: match t.dns_hijack { - def::DnsHijack::Switch(b) => b, - def::DnsHijack::List(_) => true, - }, - }, - None => TunConfig::default(), - }, - profile: Profile { - store_selected: c.profile.store_selected, - }, - rules: c - .rule - .into_iter() - .map(|x| { - x.parse::() - .map_err(|x| Error::InvalidConfig(x.to_string())) - }) - .collect::, _>>()?, - rule_providers: c - .rule_provider - .map(|m| { - m.into_iter() - .try_fold(HashMap::new(), |mut rv, (name, mut body)| { - body.insert( - "name".to_owned(), - serde_yaml::Value::String(name.clone()), - ); - let provider = RuleProviderDef::try_from(body) - .map_err(|x| { - Error::InvalidConfig(format!( - "invalid rule provider {}: {}", - name, x - )) - })?; - rv.insert(name, provider); - Ok::< - HashMap, - Error, - >(rv) - }) - .expect("proxy provider parse error") - }) - .unwrap_or_default(), - users: c - .authentication - .into_iter() - .map(|u| { - let mut parts = u.splitn(2, ':'); - let username = parts.next().unwrap().to_string(); - let password = parts.next().unwrap_or("").to_string(); - auth::User::new(username, password) - }) - .collect(), - proxies: c.proxy.into_iter().try_fold( - HashMap::from([ - ( - String::from(PROXY_DIRECT), - OutboundProxy::ProxyServer(OutboundProxyProtocol::Direct), - ), - ( - String::from(PROXY_REJECT), - OutboundProxy::ProxyServer(OutboundProxyProtocol::Reject), - ), - ]), - |mut rv, x| { - let proxy = OutboundProxy::ProxyServer( - OutboundProxyProtocol::try_from(x)?, - ); - let name = proxy.name(); - if rv.contains_key(name.as_str()) { - return Err(Error::InvalidConfig(format!( - "duplicated proxy name: {}", - name, - ))); - } - proxy_names.push(name.clone()); - rv.insert(name, proxy); - Ok(rv) - }, - )?, - proxy_groups: c.proxy_group.into_iter().try_fold( - HashMap::::new(), - |mut rv, mapping| { - let group = OutboundProxy::ProxyGroup( - mapping.clone().try_into().map_err( - |x: Error| match mapping.get("name") { - Some(name) => Error::InvalidConfig(format!( - "proxy group: {}: {}", - name.as_str() - .expect("proxy group name must be string"), - x - )), - _ => Error::InvalidConfig( - "proxy group name missing".to_string(), - ), - }, - )?, - ); - proxy_names.push(group.name()); - rv.insert(group.name().to_string(), group); - Ok::, Error>(rv) - }, - )?, - // https://stackoverflow.com/a/62001313/1109167 - proxy_names, - proxy_providers: c - .proxy_provider - .map(|m| { - m.into_iter() - .try_fold(HashMap::new(), |mut rv, (name, mut body)| { - body.insert( - "name".to_owned(), - serde_yaml::Value::String(name.clone()), - ); - let provider = OutboundProxyProviderDef::try_from(body) - .map_err(|x| { - Error::InvalidConfig(format!( - "invalid proxy provider {}: {}", - name, x - )) - })?; - rv.insert(name, provider); - Ok::< - HashMap< - std::string::String, - OutboundProxyProviderDef, - >, - Error, - >(rv) - }) - .expect("proxy provider parse error") - }) - .unwrap_or_default(), - } - .validate() - } -} - -#[cfg(test)] -mod tests { - use crate::def; - - use super::Config; - - #[test] - fn from_def_config() { - let cfg = r#" - port: 9090 - mixed-port: "9091" - "#; - let c = cfg.parse::().expect("should parse"); - assert_eq!(c.port.clone().map(|x| x.try_into().unwrap()), Some(9090)); - assert_eq!( - c.mixed_port.clone().map(|x| x.try_into().unwrap()), - Some(9091) - ); - let cc: Config = c.try_into().expect("should into"); - assert_eq!(cc.general.inbound.port, Some(9090)); - assert_eq!(cc.general.inbound.mixed_port, Some(9091)); - } -} - pub struct General { - pub inbound: Inbound, + pub authentication: Vec, + pub bind_address: BindAddress, pub(crate) controller: Controller, pub mode: RunMode, pub log_level: LogLevel, @@ -336,55 +91,64 @@ pub struct TunConfig { pub dns_hijack: bool, } -#[derive(Clone, Default)] -pub enum BindAddress { - #[default] - Any, - One(Interface), +#[derive(Serialize, Clone, Debug, Copy, PartialEq)] +#[serde(transparent)] +pub struct BindAddress(pub IpAddr); +impl BindAddress { + pub fn all() -> Self { + Self(IpAddr::V4(Ipv4Addr::UNSPECIFIED)) + } + + pub fn local() -> Self { + Self(IpAddr::V4(Ipv4Addr::LOCALHOST)) + } +} +impl Default for BindAddress { + fn default() -> Self { + Self::local() + } } -impl Display for BindAddress { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - BindAddress::Any => write!(f, "*"), - BindAddress::One(one) => match one { - Interface::IpAddr(ip) => write!(f, "{}", ip), - Interface::Name(name) => write!(f, "{}", name), - }, +impl<'de> Deserialize<'de> for BindAddress { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let str = String::deserialize(deserializer)?; + match str.as_str() { + "*" => Ok(Self(IpAddr::V4(Ipv4Addr::UNSPECIFIED))), + "localhost" => Ok(Self(IpAddr::from([127, 0, 0, 1]))), + _ => { + if let Ok(ip) = str.parse::() { + Ok(Self(ip)) + } else { + Err(serde::de::Error::custom(format!( + "Invalid BindAddress value {str}" + ))) + } + } } } } impl FromStr for BindAddress { - type Err = Error; + type Err = anyhow::Error; - fn from_str(s: &str) -> Result { - match s { - "*" => Ok(Self::Any), - "localhost" => { - Ok(Self::One(Interface::IpAddr(IpAddr::from([127, 0, 0, 1])))) - } + fn from_str(str: &str) -> Result { + match str { + "*" => Ok(Self(IpAddr::V4(Ipv4Addr::UNSPECIFIED))), + "localhost" => Ok(Self(IpAddr::from([127, 0, 0, 1]))), _ => { - if let Ok(ip) = s.parse::() { - Ok(BindAddress::One(Interface::IpAddr(ip))) + if let Ok(ip) = str.parse::() { + Ok(Self(ip)) } else { - Ok(BindAddress::One(Interface::Name(s.to_string()))) + Err(anyhow!("Invalid BindAddress value {str}")) } } } } } -pub struct Inbound { - pub port: Option, - pub socks_port: Option, - pub redir_port: Option, - pub tproxy_port: Option, - pub mixed_port: Option, - pub authentication: Vec, - pub bind_address: BindAddress, -} - #[derive(Serialize, Deserialize, Default, Clone)] pub struct Controller { pub external_controller: Option, @@ -415,18 +179,42 @@ pub struct FileRuleProvider { pub behavior: RuleSetBehavior, } -impl TryFrom> for RuleProviderDef { - type Error = crate::Error; - - fn try_from(mapping: HashMap) -> Result { - let name = mapping - .get("name") - .and_then(|x| x.as_str()) - .ok_or(Error::InvalidConfig( - "rule provider name is required".to_owned(), - ))? - .to_owned(); - RuleProviderDef::deserialize(MapDeserializer::new(mapping.into_iter())) - .map_err(map_serde_error(name)) +#[cfg(test)] +mod tests { + use crate::config::{def, internal::convert::convert, listener::InboundOpts}; + #[test] + fn from_def_config() { + let cfg = r#" + port: 9090 + mixed-port: "9091" + "#; + let c = cfg.parse::().expect("should parse"); + assert_eq!(c.port.clone().map(|x| x.try_into().unwrap()), Some(9090)); + assert_eq!( + c.mixed_port.clone().map(|x| x.try_into().unwrap()), + Some(9091) + ); + let cc = convert(c).expect("should convert"); + + assert!( + cc.listeners + .iter() + .find(|(_, listener)| match listener { + InboundOpts::Http { common_opts, .. } => + common_opts.port == 9090, + _ => false, + }) + .is_some() + ); + assert!( + cc.listeners + .iter() + .find(|(_, listener)| match listener { + InboundOpts::Mixed { common_opts, .. } => + common_opts.port == 9091, + _ => false, + }) + .is_some() + ); } } diff --git a/clash_lib/src/config/internal/convert/general.rs b/clash_lib/src/config/internal/convert/general.rs new file mode 100644 index 000000000..04d2c7985 --- /dev/null +++ b/clash_lib/src/config/internal/convert/general.rs @@ -0,0 +1,38 @@ +use std::net::IpAddr; + +use crate::{ + config::{ + config::{Controller, General}, + def, + }, + proxy::utils::Interface, +}; + +pub(super) fn convert(c: &def::Config) -> Result { + Ok(General { + authentication: c.authentication.clone(), + controller: Controller { + external_controller: c.external_controller.clone(), + external_ui: c.external_ui.clone(), + secret: c.secret.clone(), + }, + mode: c.mode, + log_level: c.log_level, + ipv6: c.ipv6, + interface: c.interface.as_ref().map(|iface| { + if let Ok(addr) = iface.parse::() { + Interface::IpAddr(addr) + } else { + Interface::Name(iface.to_string()) + } + }), + routing_mask: c.routing_mask, + mmdb: c.mmdb.to_owned(), + mmdb_download_url: c.mmdb_download_url.to_owned(), + asn_mmdb: c.asn_mmdb.to_owned(), + asn_mmdb_download_url: c.asn_mmdb_download_url.to_owned(), + geosite: c.geosite.to_owned(), + geosite_download_url: c.geosite_download_url.to_owned(), + bind_address: c.bind_address, + }) +} diff --git a/clash_lib/src/config/internal/convert/listener.rs b/clash_lib/src/config/internal/convert/listener.rs new file mode 100644 index 000000000..bd341d559 --- /dev/null +++ b/clash_lib/src/config/internal/convert/listener.rs @@ -0,0 +1,180 @@ +use std::collections::HashMap; + +use serde::{Deserialize as _, de::value::MapDeserializer}; +use serde_yaml::Value; +use tracing::warn; + +use crate::{ + Error, + config::{ + def::{self, Port}, + listener::{CommonInboundOpts, InboundOpts}, + proxy::map_serde_error, + }, +}; + +pub(super) fn convert( + raw: Option>>, + c: &def::Config, +) -> Result, crate::Error> { + let http_port = c.port; + let socks_port = c.socks_port; + let mixed_port = c.mixed_port; + let tpoxy_port = c.tproxy_port; + let redir_port = c.redir_port; + let bind_address = c.bind_address; + let mut inbounds = raw.unwrap_or_default().into_iter().try_fold( + HashMap::with_capacity(3), + |mut accum, raw| { + let inbound = InboundOpts::try_from(raw)?; + let mut duplicate = false; + match &inbound { + InboundOpts::Http { common_opts, .. } => { + if let Some(Port(port)) = http_port { + if port == common_opts.port + && c.bind_address == common_opts.listen + { + duplicate = true; + } + } + } + InboundOpts::Socks { common_opts, .. } => { + if let Some(Port(port)) = socks_port { + if port == common_opts.port + && c.bind_address == common_opts.listen + { + duplicate = true; + } + } + } + InboundOpts::Mixed { common_opts, .. } => { + if let Some(Port(port)) = mixed_port { + if port == common_opts.port + && c.bind_address == common_opts.listen + { + duplicate = true; + } + } + } + InboundOpts::TProxy { common_opts, .. } => { + if let Some(Port(port)) = mixed_port { + if port == common_opts.port + && c.bind_address == common_opts.listen + { + duplicate = true; + } + } + } + InboundOpts::Redir { common_opts, .. } => { + if let Some(Port(port)) = mixed_port { + if port == common_opts.port + && c.bind_address == common_opts.listen + { + duplicate = true; + } + } + } + _ => {} + } + if !duplicate { + accum.insert(inbound.common_opts().name.clone(), inbound); + } else { + // TODO improve log, more info + warn!("duplicate inbound detected"); + } + Ok::, Error>(accum) + }, + )?; + // Add top-level proxies to inbounds if they are not overrided + if let Some(Port(http_port)) = http_port { + inbounds.insert( + "HTTP-IN".into(), + InboundOpts::Http { + common_opts: CommonInboundOpts { + name: "HTTP-IN".into(), + listen: bind_address, + port: http_port, + ..Default::default() + }, + inherited: true, + }, + ); + } + if let Some(Port(socks_port)) = socks_port { + inbounds.insert( + "SOCKS-IN".into(), + InboundOpts::Socks { + common_opts: CommonInboundOpts { + name: "SOCKS-IN".into(), + listen: bind_address, + port: socks_port, + ..Default::default() + }, + udp: true, + inherited: true, + }, + ); + } + if let Some(Port(mixed_port)) = mixed_port { + inbounds.insert( + "MIXED-IN".into(), + InboundOpts::Mixed { + common_opts: CommonInboundOpts { + name: "MIXED-IN".into(), + listen: bind_address, + port: mixed_port, + ..Default::default() + }, + udp: true, + inherited: true, + }, + ); + } + if let Some(Port(redir_port)) = redir_port { + inbounds.insert( + "REDIR-IN".into(), + InboundOpts::Redir { + common_opts: CommonInboundOpts { + name: "REDIR-IN".into(), + listen: bind_address, + port: redir_port, + ..Default::default() + }, + inherited: true, + }, + ); + } + if let Some(Port(tproxy_port)) = tpoxy_port { + inbounds.insert( + "TPROXY-IN".into(), + InboundOpts::TProxy { + common_opts: CommonInboundOpts { + name: "TPROXY-IN".into(), + listen: bind_address, + port: tproxy_port, + ..Default::default() + }, + udp: true, + inherited: true, + }, + ); + } + Ok(inbounds) +} + +impl TryFrom> for InboundOpts { + type Error = crate::Error; + + fn try_from(mapping: HashMap) -> Result { + let name = mapping + .get("name") + .and_then(|x| x.as_str()) + .and_then(|v| if v.is_empty() { None } else { Some(v) }) + .ok_or(Error::InvalidConfig( + "missing field `name` in inbound listener".to_owned(), + ))? + .to_owned(); + InboundOpts::deserialize(MapDeserializer::new(mapping.into_iter())) + .map_err(map_serde_error(name)) + } +} diff --git a/clash_lib/src/config/internal/convert/mod.rs b/clash_lib/src/config/internal/convert/mod.rs new file mode 100644 index 000000000..4e6cd8061 --- /dev/null +++ b/clash_lib/src/config/internal/convert/mod.rs @@ -0,0 +1,151 @@ +use std::collections::HashMap; + +use serde::{Deserialize, de::value::MapDeserializer}; +use serde_yaml::Value; + +use crate::{ + Error, + common::auth, + config::{ + def, + internal::{ + proxy::{OutboundProxy, PROXY_DIRECT, PROXY_REJECT}, + rule::RuleType, + }, + }, +}; + +mod general; +mod listener; +mod proxy_group; +mod rule_provider; +mod tun; + +use super::{ + config::{self, BindAddress, Profile}, + proxy::{ + OutboundGroupProtocol, OutboundProxyProtocol, OutboundProxyProviderDef, + map_serde_error, + }, +}; + +impl TryFrom for config::Config { + type Error = crate::Error; + + fn try_from(value: def::Config) -> Result { + convert(value) + } +} + +pub(super) fn convert(mut c: def::Config) -> Result { + let mut proxy_names = + vec![String::from(PROXY_DIRECT), String::from(PROXY_REJECT)]; + + if c.allow_lan.unwrap_or(false) { + c.bind_address = BindAddress::all() + } else { + c.bind_address = BindAddress::local() + } + + config::Config { + general: general::convert(&c)?, + dns: (&c).try_into()?, + experimental: c.experimental.take(), + tun: tun::convert(c.tun.take())?, + profile: Profile { + store_selected: c.profile.store_selected, + }, + rules: c + .rule + .take() + .unwrap_or_default() + .into_iter() + .map(|x| { + x.parse::() + .map_err(|x| Error::InvalidConfig(x.to_string())) + }) + .collect::, _>>()?, + rule_providers: rule_provider::convert(c.rule_provider.take()), + users: c + .authentication + .clone() + .into_iter() + .map(|u| { + let mut parts = u.splitn(2, ':'); + let username = parts.next().unwrap().to_string(); + let password = parts.next().unwrap_or("").to_string(); + auth::User::new(username, password) + }) + .collect(), + proxies: c.proxy.take().unwrap_or_default().into_iter().try_fold( + HashMap::from([ + ( + String::from(PROXY_DIRECT), + OutboundProxy::ProxyServer(OutboundProxyProtocol::Direct), + ), + ( + String::from(PROXY_REJECT), + OutboundProxy::ProxyServer(OutboundProxyProtocol::Reject), + ), + ]), + |mut rv, x| { + let proxy = + OutboundProxy::ProxyServer(OutboundProxyProtocol::try_from(x)?); + let name = proxy.name(); + if rv.contains_key(name.as_str()) { + return Err(Error::InvalidConfig(format!( + "duplicated proxy name: {name}" + ))); + } + proxy_names.push(name.clone()); + rv.insert(name, proxy); + Ok(rv) + }, + )?, + proxy_groups: proxy_group::concert(c.proxy_group.take(), &mut proxy_names)?, + proxy_names, + proxy_providers: c + .proxy_provider + .take() + .map(|m| { + m.into_iter() + .try_fold(HashMap::new(), |mut rv, (name, mut body)| { + body.insert( + "name".to_owned(), + serde_yaml::Value::String(name.clone()), + ); + let provider = OutboundProxyProviderDef::try_from(body) + .map_err(|x| { + Error::InvalidConfig(format!( + "invalid proxy provider {name}: {x}" + )) + })?; + rv.insert(name, provider); + Ok::< + HashMap, + Error, + >(rv) + }) + .expect("proxy provider parse error") + }) + .unwrap_or_default(), + listeners: listener::convert(c.listener.take(), &c)?, + } + .validate() +} + +impl TryFrom> for OutboundGroupProtocol { + type Error = Error; + + fn try_from(mapping: HashMap) -> Result { + let name = mapping + .get("name") + .and_then(|x| x.as_str()) + .ok_or(Error::InvalidConfig( + "missing field `name` in outbound proxy grouop".to_owned(), + ))? + .to_owned(); + OutboundGroupProtocol::deserialize(MapDeserializer::new(mapping.into_iter())) + .map_err(map_serde_error(name)) + } +} diff --git a/clash_lib/src/config/internal/convert/proxy_group.rs b/clash_lib/src/config/internal/convert/proxy_group.rs new file mode 100644 index 000000000..95ac5b676 --- /dev/null +++ b/clash_lib/src/config/internal/convert/proxy_group.rs @@ -0,0 +1,28 @@ +use std::collections::HashMap; + +use serde_yaml::Value; + +use crate::{Error, config::proxy::OutboundProxy}; + +pub fn concert( + before: Option>>, + proxy_names: &mut Vec, +) -> Result, crate::Error> { + before.unwrap_or_default().into_iter().try_fold( + HashMap::::new(), + |mut rv, mapping| { + let group = OutboundProxy::ProxyGroup( + mapping.clone().try_into().map_err(|x| { + if let Some(name) = mapping.get("name") { + Error::InvalidConfig(format!("proxy group: {name:#?}: {x}")) + } else { + Error::InvalidConfig("proxy group name missing".to_string()) + } + })?, + ); + proxy_names.push(group.name()); + rv.insert(group.name().to_string(), group); + Ok::, Error>(rv) + }, + ) +} diff --git a/clash_lib/src/config/internal/convert/rule_provider.rs b/clash_lib/src/config/internal/convert/rule_provider.rs new file mode 100644 index 000000000..61369770a --- /dev/null +++ b/clash_lib/src/config/internal/convert/rule_provider.rs @@ -0,0 +1,49 @@ +use std::collections::HashMap; + +use serde::{Deserialize as _, de::value::MapDeserializer}; +use serde_yaml::Value; + +use crate::{ + Error, + config::{config::RuleProviderDef, proxy::map_serde_error}, +}; + +pub(super) fn convert( + before: Option>>, +) -> HashMap { + before + .map(|m| { + m.into_iter() + .try_fold(HashMap::new(), |mut rv, (name, mut body)| { + body.insert( + "name".to_owned(), + serde_yaml::Value::String(name.clone()), + ); + let provider = RuleProviderDef::try_from(body).map_err(|x| { + Error::InvalidConfig(format!( + "invalid rule provider {name}: {x}" + )) + })?; + rv.insert(name, provider); + Ok::, Error>(rv) + }) + .expect("proxy provider parse error") + }) + .unwrap_or_default() +} + +impl TryFrom> for RuleProviderDef { + type Error = crate::Error; + + fn try_from(mapping: HashMap) -> Result { + let name = mapping + .get("name") + .and_then(|x| x.as_str()) + .ok_or(Error::InvalidConfig( + "rule provider name is required".to_owned(), + ))? + .to_owned(); + RuleProviderDef::deserialize(MapDeserializer::new(mapping.into_iter())) + .map_err(map_serde_error(name)) + } +} diff --git a/clash_lib/src/config/internal/convert/tun.rs b/clash_lib/src/config/internal/convert/tun.rs new file mode 100644 index 000000000..06363ac2e --- /dev/null +++ b/clash_lib/src/config/internal/convert/tun.rs @@ -0,0 +1,39 @@ +use crate::{ + Error, + config::{config, def}, +}; + +pub fn convert( + before: Option, +) -> Result { + match before { + Some(t) => Ok(config::TunConfig { + enable: t.enable, + device_id: t.device_id, + route_all: t.route_all, + routes: t + .routes + .map(|r| { + r.into_iter() + .map(|x| x.parse()) + .collect::, _>>() + }) + .transpose() + .map_err(|x| { + Error::InvalidConfig(format!("parse tun routes: {}", x)) + })? + .unwrap_or_default(), + gateway: t.gateway.parse().map_err(|x| { + Error::InvalidConfig(format!("parse tun gateway: {}", x)) + })?, + mtu: t.mtu, + so_mark: t.so_mark, + route_table: t.route_table, + dns_hijack: match t.dns_hijack { + def::DnsHijack::Switch(b) => b, + def::DnsHijack::List(_) => true, + }, + }), + None => Ok(config::TunConfig::default()), + } +} diff --git a/clash_lib/src/config/internal/listener.rs b/clash_lib/src/config/internal/listener.rs new file mode 100644 index 000000000..df54aa889 --- /dev/null +++ b/clash_lib/src/config/internal/listener.rs @@ -0,0 +1,100 @@ +use educe::Educe; +use serde::{Deserialize, Serialize}; + +use super::config::BindAddress; + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(tag = "type")] +#[serde(rename_all = "kebab-case")] +pub enum InboundOpts { + Http { + #[serde(flatten)] + common_opts: CommonInboundOpts, + inherited: bool, // TODO users + }, + Socks { + #[serde(flatten)] + common_opts: CommonInboundOpts, + udp: bool, + inherited: bool, // TODO users + }, + Mixed { + #[serde(flatten)] + common_opts: CommonInboundOpts, + udp: bool, // TODO users + inherited: bool, + }, + TProxy { + #[serde(flatten)] + common_opts: CommonInboundOpts, + udp: bool, + inherited: bool, + }, + Redir { + #[serde(flatten)] + common_opts: CommonInboundOpts, + inherited: bool, + }, + Tunnel { + #[serde(flatten)] + common_opts: CommonInboundOpts, + network: Vec, + target: String, + }, +} + +impl InboundOpts { + pub fn common_opts(&self) -> &CommonInboundOpts { + match self { + InboundOpts::Http { common_opts, .. } => common_opts, + InboundOpts::Socks { common_opts, .. } => common_opts, + InboundOpts::Mixed { common_opts, .. } => common_opts, + InboundOpts::TProxy { common_opts, .. } => common_opts, + InboundOpts::Tunnel { common_opts, .. } => common_opts, + InboundOpts::Redir { common_opts, .. } => common_opts, + } + } + + pub fn common_opts_mut(&mut self) -> &mut CommonInboundOpts { + match self { + InboundOpts::Http { common_opts, .. } => common_opts, + InboundOpts::Socks { common_opts, .. } => common_opts, + InboundOpts::Mixed { common_opts, .. } => common_opts, + InboundOpts::TProxy { common_opts, .. } => common_opts, + InboundOpts::Tunnel { common_opts, .. } => common_opts, + InboundOpts::Redir { common_opts, .. } => common_opts, + } + } + + pub fn inherited(&self) -> bool { + match self { + InboundOpts::Http { inherited, .. } => *inherited, + InboundOpts::Socks { inherited, .. } => *inherited, + InboundOpts::Mixed { inherited, .. } => *inherited, + InboundOpts::TProxy { inherited, .. } => *inherited, + InboundOpts::Tunnel { .. } => false, + InboundOpts::Redir { inherited, .. } => *inherited, + } + } + + pub fn port(&self) -> u16 { + self.common_opts().port + } + + pub fn port_mut(&mut self) -> &mut u16 { + &mut self.common_opts_mut().port + } +} + +#[derive(Serialize, Deserialize, Debug, Educe, Clone)] +#[educe(Default)] +#[serde(rename_all = "kebab-case")] +pub struct CommonInboundOpts { + pub name: String, + pub listen: BindAddress, + #[educe(Default = 0)] + pub port: u16, + // TODO opts down below is unimplemented + pub rule: Option, + pub proxy: Option, +} diff --git a/clash_lib/src/config/internal/mod.rs b/clash_lib/src/config/internal/mod.rs index 0f081977b..0fcee88e8 100644 --- a/clash_lib/src/config/internal/mod.rs +++ b/clash_lib/src/config/internal/mod.rs @@ -1,5 +1,8 @@ pub mod config; +pub mod listener; pub mod proxy; pub mod rule; pub use config::Config as InternalConfig; + +mod convert; diff --git a/clash_lib/src/config/internal/proxy.rs b/clash_lib/src/config/internal/proxy.rs index be44871c6..c988a182e 100644 --- a/clash_lib/src/config/internal/proxy.rs +++ b/clash_lib/src/config/internal/proxy.rs @@ -363,22 +363,6 @@ impl OutboundGroupProtocol { } } -impl TryFrom> for OutboundGroupProtocol { - type Error = Error; - - fn try_from(mapping: HashMap) -> Result { - let name = mapping - .get("name") - .and_then(|x| x.as_str()) - .ok_or(Error::InvalidConfig( - "missing field `name` in outbound proxy grouop".to_owned(), - ))? - .to_owned(); - OutboundGroupProtocol::deserialize(MapDeserializer::new(mapping.into_iter())) - .map_err(map_serde_error(name)) - } -} - impl Display for OutboundGroupProtocol { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { diff --git a/clash_lib/src/config/mod.rs b/clash_lib/src/config/mod.rs index 6bb77fea1..900b52f4e 100644 --- a/clash_lib/src/config/mod.rs +++ b/clash_lib/src/config/mod.rs @@ -2,4 +2,4 @@ pub mod def; pub mod internal; mod utils; pub use def::DNSListen; -pub use internal::InternalConfig as RuntimeConfig; +pub use internal::{InternalConfig as RuntimeConfig, *}; diff --git a/clash_lib/src/lib.rs b/clash_lib/src/lib.rs index 041c69675..080dc33da 100644 --- a/clash_lib/src/lib.rs +++ b/clash_lib/src/lib.rs @@ -1,6 +1,7 @@ #![feature(ip)] #![feature(sync_unsafe_cell)] #![feature(unbounded_shifts)] +#![feature(let_chains)] #[macro_use] extern crate anyhow; @@ -67,9 +68,11 @@ pub enum Error { Crypto(String), #[error("operation error: {0}")] Operation(String), + #[error(transparent)] + Other(#[from] anyhow::Error), } - -pub type Runner = futures::future::BoxFuture<'static, Result<(), Error>>; +pub type Result = std::result::Result; +pub type Runner = futures::future::BoxFuture<'static, Result<()>>; pub struct Options { pub config: Config, @@ -92,7 +95,7 @@ pub enum Config { } impl Config { - pub fn try_parse(self) -> Result { + pub fn try_parse(self) -> Result { match self { Config::Def(c) => c.try_into(), Config::Internal(c) => Ok(c), @@ -106,12 +109,10 @@ impl Config { pub struct GlobalState { log_level: LogLevel, - // must be Some otherwise we'll refuse to start - inbound_listener_handle: JoinHandle>, - tunnel_listener_handle: Option>>, - api_listener_handle: Option>>, - dns_listener_handle: Option>>, + tunnel_listener_handle: Option>>, + api_listener_handle: Option>>, + dns_listener_handle: Option>>, reload_tx: mpsc::Sender<(Config, oneshot::Sender<()>)>, cwd: String, } @@ -122,7 +123,7 @@ pub struct RuntimeController { static RUNTIME_CONTROLLER: OnceCell = OnceCell::new(); -pub fn start_scaffold(opts: Options) -> Result<(), Error> { +pub fn start_scaffold(opts: Options) -> Result<()> { let rt = match opts.rt.as_ref().unwrap_or(&TokioRuntime::MultiThread) { TokioRuntime::MultiThread => tokio::runtime::Builder::new_multi_thread() .enable_all() @@ -174,7 +175,7 @@ pub async fn start( config: InternalConfig, cwd: String, log_tx: broadcast::Sender, -) -> Result<(), Error> { +) -> Result<()> { let (shutdown_tx, mut shutdown_rx) = mpsc::channel(1); let _ = RUNTIME_CONTROLLER.get_or_init(|| RuntimeController { shutdown_tx }); @@ -190,8 +191,8 @@ pub async fn start( let components = create_components(cwd.clone(), config).await?; - let inbound_runner = components.inbound_manager.lock().await.get_runner()?; - let inbound_listener_handle = tokio::spawn(inbound_runner); + let inbound_manager = components.inbound_manager.clone(); + inbound_manager.start().await; let tun_runner_handle = components.tun_runner.map(tokio::spawn); let dns_listener_handle = components.dns_listener.map(tokio::spawn); @@ -200,7 +201,6 @@ pub async fn start( let global_state = Arc::new(Mutex::new(GlobalState { log_level, - inbound_listener_handle, tunnel_listener_handle: tun_runner_handle, dns_listener_handle, reload_tx, @@ -237,7 +237,9 @@ pub async fn start( })); tasks.push(Box::pin(async move { - let _ = tokio::signal::ctrl_c().await; + tokio::signal::ctrl_c() + .await + .expect("failed to listen for ^C event"); Ok(()) })); @@ -259,8 +261,9 @@ pub async fn start( done.send(()).unwrap(); debug!("stopping listeners"); + inbound_manager.shutdown().await; let mut g = global_state.lock().await; - g.inbound_listener_handle.abort(); + if let Some(h) = g.tunnel_listener_handle.take() { h.abort(); } @@ -271,13 +274,9 @@ pub async fn start( h.abort(); } + let inbound_manager = new_components.inbound_manager.clone(); debug!("reloading inbound listener"); - let inbound_listener_handle = new_components - .inbound_manager - .lock() - .await - .get_runner() - .map(tokio::spawn)?; + inbound_manager.restart().await; debug!("reloading tun runner"); let tun_runner_handle = new_components.tun_runner.map(tokio::spawn); @@ -301,7 +300,6 @@ pub async fn start( ) .map(tokio::spawn); - g.inbound_listener_handle = inbound_listener_handle; g.tunnel_listener_handle = tun_runner_handle; g.dns_listener_handle = dns_listener_handle; g.api_listener_handle = api_listener_handle; @@ -322,7 +320,7 @@ struct RuntimeComponents { router: Arc, dispatcher: Arc, statistics_manager: Arc, - inbound_manager: Arc>, + inbound_manager: Arc, tun_runner: Option, dns_listener: Option, @@ -331,7 +329,7 @@ struct RuntimeComponents { async fn create_components( cwd: PathBuf, config: InternalConfig, -) -> Result { +) -> Result { let system_resolver = Arc::new( SystemResolver::new(config.dns.ipv6) .map_err(|x| Error::DNSError(x.to_string()))?, @@ -442,11 +440,16 @@ async fn create_components( let authenticator = Arc::new(auth::PlainAuthenticator::new(config.users)); debug!("initializing inbound manager"); - let inbound_manager = Arc::new(Mutex::new(InboundManager::new( - config.general.inbound, - dispatcher.clone(), - authenticator, - )?)); + let inbound_manager = Arc::new( + InboundManager::new( + config.general.bind_address, + config.general.authentication, + dispatcher.clone(), + authenticator, + config.listeners, + ) + .await?, + ); debug!("initializing tun runner"); let tun_runner = diff --git a/clash_lib/src/proxy/http/inbound/mod.rs b/clash_lib/src/proxy/http/inbound/mod.rs index c5a5245cb..ba2e7f585 100644 --- a/clash_lib/src/proxy/http/inbound/mod.rs +++ b/clash_lib/src/proxy/http/inbound/mod.rs @@ -5,30 +5,29 @@ mod proxy; use crate::{ Dispatcher, common::auth::ThreadSafeAuthenticator, - proxy::{InboundListener, utils::apply_tcp_options}, + proxy::{inbound::InboundHandlerTrait, utils::apply_tcp_options}, }; -use async_trait::async_trait; pub use proxy::handle as handle_http; -use std::{io, net::SocketAddr, sync::Arc}; +use std::{net::SocketAddr, sync::Arc}; use tokio::net::TcpListener; use tracing::warn; #[derive(Clone)] -pub struct Listener { +pub struct HttpInbound { addr: SocketAddr, dispatcher: Arc, authenticator: ThreadSafeAuthenticator, } -impl Drop for Listener { +impl Drop for HttpInbound { fn drop(&mut self) { warn!("HTTP inbound listener on {} stopped", self.addr); } } -impl Listener { +impl HttpInbound { pub fn new( addr: SocketAddr, dispatcher: Arc, @@ -42,8 +41,7 @@ impl Listener { } } -#[async_trait] -impl InboundListener for Listener { +impl InboundHandlerTrait for HttpInbound { fn handle_tcp(&self) -> bool { true } @@ -52,7 +50,7 @@ impl InboundListener for Listener { false } - async fn listen_tcp(&self) -> std::io::Result<()> { + async fn listen_tcp(&self) -> anyhow::Result<()> { let listener = TcpListener::bind(self.addr).await?; loop { @@ -69,7 +67,7 @@ impl InboundListener for Listener { } } - async fn listen_udp(&self) -> std::io::Result<()> { - Err(io::Error::new(io::ErrorKind::Other, "unsupported")) + async fn listen_udp(&self) -> anyhow::Result<()> { + Err(anyhow!("unsupported")) } } diff --git a/clash_lib/src/proxy/http/mod.rs b/clash_lib/src/proxy/http/mod.rs index 393b265e8..8d0125f4f 100644 --- a/clash_lib/src/proxy/http/mod.rs +++ b/clash_lib/src/proxy/http/mod.rs @@ -1,3 +1,3 @@ mod inbound; -pub use inbound::{Listener, handle_http}; +pub use inbound::{HttpInbound, handle_http}; diff --git a/clash_lib/src/proxy/inbound.rs b/clash_lib/src/proxy/inbound.rs new file mode 100644 index 000000000..811a28177 --- /dev/null +++ b/clash_lib/src/proxy/inbound.rs @@ -0,0 +1,26 @@ +use enum_dispatch::enum_dispatch; + +use super::{ + http::HttpInbound, mixed::MixedInbound, socks::SocksInbound, + tunnel::TunnelInbound, +}; + +#[enum_dispatch(InboudHandler)] +pub trait InboundHandlerTrait { + /// support tcp or not + fn handle_tcp(&self) -> bool; + /// support udp or not + fn handle_udp(&self) -> bool; + async fn listen_tcp(&self) -> anyhow::Result<()>; + async fn listen_udp(&self) -> anyhow::Result<()>; +} + +#[enum_dispatch] +pub enum InboudHandler { + Http(HttpInbound), + Socks(SocksInbound), + Mixed(MixedInbound), + #[cfg(target_os = "linux")] + TProxy(super::tproxy::TproxyInbound), + Tunnel(TunnelInbound), +} diff --git a/clash_lib/src/proxy/mixed/mod.rs b/clash_lib/src/proxy/mixed/mod.rs index c9e8b4d8d..c9b0630ea 100644 --- a/clash_lib/src/proxy/mixed/mod.rs +++ b/clash_lib/src/proxy/mixed/mod.rs @@ -1,30 +1,29 @@ use crate::{ Dispatcher, common::auth::ThreadSafeAuthenticator, - proxy::InboundListener, session::{Network, Session}, }; -use async_trait::async_trait; + use std::{net::SocketAddr, sync::Arc}; use tokio::net::TcpListener; use tracing::warn; -use super::{http, socks, utils::apply_tcp_options}; +use super::{http, inbound::InboundHandlerTrait, socks, utils::apply_tcp_options}; -pub struct Listener { +pub struct MixedInbound { addr: SocketAddr, dispatcher: Arc, authenticator: ThreadSafeAuthenticator, } -impl Drop for Listener { +impl Drop for MixedInbound { fn drop(&mut self) { warn!("MixedPort inbound listener on {} stopped", self.addr); } } -impl Listener { +impl MixedInbound { pub fn new( addr: SocketAddr, dispatcher: Arc, @@ -38,8 +37,7 @@ impl Listener { } } -#[async_trait] -impl InboundListener for Listener { +impl InboundHandlerTrait for MixedInbound { fn handle_tcp(&self) -> bool { true } @@ -48,7 +46,7 @@ impl InboundListener for Listener { false } - async fn listen_tcp(&self) -> std::io::Result<()> { + async fn listen_tcp(&self) -> anyhow::Result<()> { let listener = TcpListener::bind(self.addr).await?; loop { @@ -99,7 +97,7 @@ impl InboundListener for Listener { } } - async fn listen_udp(&self) -> std::io::Result<()> { - unreachable!("don't listen to me :)") + async fn listen_udp(&self) -> anyhow::Result<()> { + Err(anyhow!("UDP is not supported")) } } diff --git a/clash_lib/src/proxy/mod.rs b/clash_lib/src/proxy/mod.rs index 6729bc351..c18b54c04 100644 --- a/clash_lib/src/proxy/mod.rs +++ b/clash_lib/src/proxy/mod.rs @@ -55,8 +55,10 @@ pub mod group; pub use group::{fallback, loadbalance, relay, selector, urltest}; mod common; +pub mod inbound; mod options; mod transport; +pub mod tunnel; pub use options::HandlerCommonOptions; @@ -111,18 +113,6 @@ impl OutboundDatagram for T where pub type AnyOutboundDatagram = Box>; -#[async_trait] -pub trait InboundListener: Send + Sync + Unpin { - /// support tcp or not - fn handle_tcp(&self) -> bool; - /// support udp or not - fn handle_udp(&self) -> bool; - async fn listen_tcp(&self) -> io::Result<()>; - async fn listen_udp(&self) -> io::Result<()>; -} - -pub type AnyInboundListener = Arc; - #[derive(Serialize, Deserialize)] pub enum OutboundType { Shadowsocks, diff --git a/clash_lib/src/proxy/socks/inbound/mod.rs b/clash_lib/src/proxy/socks/inbound/mod.rs index a7759a3c6..3cbcb8016 100644 --- a/clash_lib/src/proxy/socks/inbound/mod.rs +++ b/clash_lib/src/proxy/socks/inbound/mod.rs @@ -4,10 +4,10 @@ mod stream; use crate::{ Dispatcher, common::auth::ThreadSafeAuthenticator, - proxy::{InboundListener, utils::apply_tcp_options}, + proxy::{inbound::InboundHandlerTrait, utils::apply_tcp_options}, session::{Network, Session, Type}, }; -use async_trait::async_trait; + use std::{net::SocketAddr, sync::Arc}; pub use stream::handle_tcp; use tokio::net::TcpListener; @@ -15,19 +15,19 @@ use tracing::warn; pub use datagram::Socks5UDPCodec; -pub struct Listener { +pub struct SocksInbound { addr: SocketAddr, dispatcher: Arc, authenticator: ThreadSafeAuthenticator, } -impl Drop for Listener { +impl Drop for SocksInbound { fn drop(&mut self) { warn!("SOCKS5 inbound listener on {} stopped", self.addr); } } -impl Listener { +impl SocksInbound { pub fn new( addr: SocketAddr, dispatcher: Arc, @@ -41,8 +41,7 @@ impl Listener { } } -#[async_trait] -impl InboundListener for Listener { +impl InboundHandlerTrait for SocksInbound { fn handle_tcp(&self) -> bool { true } @@ -51,7 +50,7 @@ impl InboundListener for Listener { false } - async fn listen_tcp(&self) -> std::io::Result<()> { + async fn listen_tcp(&self) -> anyhow::Result<()> { let listener = TcpListener::bind(self.addr).await?; loop { @@ -76,7 +75,8 @@ impl InboundListener for Listener { } } - async fn listen_udp(&self) -> std::io::Result<()> { - unreachable!("don't listen to me :)") + async fn listen_udp(&self) -> anyhow::Result<()> { + // TODO + Err(anyhow!("UDP is not supported")) } } diff --git a/clash_lib/src/proxy/socks/mod.rs b/clash_lib/src/proxy/socks/mod.rs index 7fa55f94c..6d44a408a 100644 --- a/clash_lib/src/proxy/socks/mod.rs +++ b/clash_lib/src/proxy/socks/mod.rs @@ -2,6 +2,6 @@ mod inbound; mod outbound; mod socks5; -pub use inbound::{Listener, Socks5UDPCodec, handle_tcp}; +pub use inbound::{Socks5UDPCodec, SocksInbound, handle_tcp}; pub use outbound::{Handler, HandlerOptions}; pub use socks5::SOCKS5_VERSION; diff --git a/clash_lib/src/proxy/tproxy/mod.rs b/clash_lib/src/proxy/tproxy/mod.rs index b96e0d0ee..5c8707054 100644 --- a/clash_lib/src/proxy/tproxy/mod.rs +++ b/clash_lib/src/proxy/tproxy/mod.rs @@ -1,10 +1,10 @@ -use super::tun::TunDatagram; +use super::{inbound::InboundHandlerTrait, tun::TunDatagram}; use crate::{ app::dispatcher::Dispatcher, - proxy::{InboundListener, datagram::UdpPacket, utils::apply_tcp_options}, + proxy::{datagram::UdpPacket, utils::apply_tcp_options}, session::{Network, Session, Type}, }; -use async_trait::async_trait; + use socket2::{Domain, Socket}; use std::{ net::SocketAddr, @@ -14,25 +14,24 @@ use std::{ use tokio::net::TcpListener; use tracing::{trace, warn}; -pub struct Listener { +pub struct TproxyInbound { addr: SocketAddr, dispather: Arc, } -impl Drop for Listener { +impl Drop for TproxyInbound { fn drop(&mut self) { warn!("Tproxy inbound listener on {} stopped", self.addr); } } -impl Listener { +impl TproxyInbound { pub fn new(addr: SocketAddr, dispather: Arc) -> Self { Self { addr, dispather } } } -#[async_trait] -impl InboundListener for Listener { +impl InboundHandlerTrait for TproxyInbound { fn handle_tcp(&self) -> bool { true } @@ -41,7 +40,7 @@ impl InboundListener for Listener { true } - async fn listen_tcp(&self) -> std::io::Result<()> { + async fn listen_tcp(&self) -> anyhow::Result<()> { let socket = Socket::new(socket2::Domain::IPV4, socket2::Type::STREAM, None)?; socket.set_ip_transparent(true)?; @@ -76,7 +75,7 @@ impl InboundListener for Listener { } } - async fn listen_udp(&self) -> std::io::Result<()> { + async fn listen_udp(&self) -> anyhow::Result<()> { let socket = Socket::new(Domain::IPV4, socket2::Type::DGRAM, None)?; socket.set_ip_transparent(true)?; socket.set_nonblocking(true)?; @@ -104,7 +103,7 @@ impl InboundListener for Listener { async fn handle_inbound_datagram( socket: Arc, dispatcher: Arc, -) -> std::io::Result<()> { +) -> anyhow::Result<()> { // dispatcher <-> tproxy communications let (l_tx, mut l_rx) = tokio::sync::mpsc::channel(32); diff --git a/clash_lib/src/proxy/tunnel/mod.rs b/clash_lib/src/proxy/tunnel/mod.rs new file mode 100644 index 000000000..95b702a4b --- /dev/null +++ b/clash_lib/src/proxy/tunnel/mod.rs @@ -0,0 +1,250 @@ +use std::{ + io, + net::SocketAddr, + ops::DerefMut, + pin::Pin, + str::FromStr, + sync::Arc, + task::{Context, Poll}, +}; + +use crate::{ + app::dispatcher::Dispatcher, + common::errors::new_io_error, + session::{Network, Session, SocksAddr, Type}, +}; +use futures::{Sink, Stream}; +use tokio::{ + io::ReadBuf, + net::{TcpListener, UdpSocket}, +}; +use tracing::{info, warn}; + +use super::{ + datagram::UdpPacket, inbound::InboundHandlerTrait, utils::apply_tcp_options, +}; + +#[derive(Clone)] +pub struct TunnelInbound { + listen: SocketAddr, + dispatcher: Arc, + network: Vec, + target: SocksAddr, +} + +impl Drop for TunnelInbound { + fn drop(&mut self) { + warn!("HTTP inbound listener on {} stopped", self.listen); + } +} + +impl TunnelInbound { + pub fn new( + addr: SocketAddr, + dispatcher: Arc, + network: Vec, + target: String, + ) -> anyhow::Result { + Ok(Self { + listen: addr, + dispatcher, + network, + target: SocksAddr::from_str(&target)?, + }) + } +} + +impl InboundHandlerTrait for TunnelInbound { + fn handle_tcp(&self) -> bool { + true + } + + fn handle_udp(&self) -> bool { + true + } + + async fn listen_tcp(&self) -> anyhow::Result<()> { + if !self.network.contains(&"tcp".to_string()) { + return Ok(()); + } + info!( + "[Tunnel-TCP] listening on {}, remote: {}", + self.listen, self.target + ); + let listener = TcpListener::bind(self.listen).await?; + + loop { + let (socket, src_addr) = listener.accept().await?; + + let stream = apply_tcp_options(socket)?; + + let dispatcher = self.dispatcher.clone(); + let sess = Session { + network: Network::Tcp, + typ: Type::Tunnel, + source: src_addr, + destination: self.target.clone(), + ..Default::default() + }; + + tokio::spawn(async move { + dispatcher.dispatch_stream(sess, Box::new(stream)).await; + }); + } + } + + async fn listen_udp(&self) -> anyhow::Result<()> { + if !self.network.contains(&"udp".to_string()) { + return Ok(()); + } + info!( + "[Tunnel-UDP] listening on {}, remote: {}", + self.listen, self.target + ); + let socket = UdpSocket::bind(self.listen).await?; + let sess = Session { + network: Network::Udp, + typ: Type::Tunnel, + destination: self.target.clone(), + ..Default::default() + }; + let inbound = UdpSession::new(socket, self.target.clone()); + + _ = self + .dispatcher + .dispatch_datagram(sess, Box::new(inbound)) + .await; + Ok(()) + } +} + +#[derive(Debug)] +struct UdpSession { + pub socket: UdpSocket, + pub dst_addr: SocksAddr, + pub read_buf: Vec, + pub send_buf: Option<(Vec, SocketAddr)>, +} + +impl UdpSession { + fn new(socket: UdpSocket, dst_addr: SocksAddr) -> Self { + Self { + socket, + dst_addr, + read_buf: Vec::with_capacity(65507), + send_buf: None, + } + } +} + +impl Sink for UdpSession { + type Error = io::Error; + + fn poll_ready( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + let this = self.deref_mut(); + // "Back pressure" mechanism, new data is allowed to be written only when the + // buffer is empty + match this.send_buf { + Some(_) => Poll::Pending, + None => Poll::Ready(Ok(())), + } + } + + fn start_send( + mut self: Pin<&mut Self>, + item: UdpPacket, + ) -> Result<(), Self::Error> { + let this = self.deref_mut(); + let socket = &this.socket; + let dst_addr = match item.dst_addr { + SocksAddr::Ip(socket_addr) => socket_addr, + SocksAddr::Domain(..) => { + return Err(new_io_error( + "UdpPacket dst_src MUSTBE IpAddr instead of Domain", + )); + } + }; + + // Try to send immediately, if blocked, enter the buffer and wait for + // poll_flush to process + match socket.try_send_to(&item.data, dst_addr) { + Ok(_) => Ok(()), + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + this.send_buf = Some((item.data, dst_addr)); + Ok(()) + } + Err(e) => Err(e), + } + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let this = self.deref_mut(); + let socket = &this.socket; + let send_buf = &this.send_buf; + if let Some((data, dst_addr)) = send_buf { + return match socket.try_send_to(data, *dst_addr) { + Ok(_) => { + this.send_buf.take(); + Poll::Ready(Ok(())) + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + // Register Waker to wake up when the socket is writable + socket.poll_send_ready(cx) + } + Err(e) => Poll::Ready(Err(e)), + }; + } + // No data needs flush + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } +} + +impl Stream for UdpSession { + type Item = UdpPacket; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let this = self.deref_mut(); + let socket = &this.socket; + this.read_buf.resize(this.read_buf.capacity(), 0); + let mut buf = ReadBuf::new(&mut this.read_buf); + dbg!(buf.initialized().len()); + buf.clear(); + match socket.poll_recv_from(cx, &mut buf) { + Poll::Ready(Ok(src_addr)) => { + let data = buf.filled().to_vec(); + let dst_addr = this.dst_addr.clone(); + let src_addr = SocksAddr::from(src_addr); + Poll::Ready(Some(UdpPacket { + data, + src_addr, + dst_addr, + })) + } + Poll::Ready(Err(e)) => { + if e.kind() == io::ErrorKind::WouldBlock { + Poll::Pending + } else { + // FIXME + Poll::Ready(None) + } + } + Poll::Pending => Poll::Pending, + } + } +} diff --git a/clash_lib/src/proxy/utils/mod.rs b/clash_lib/src/proxy/utils/mod.rs index ddd50756a..71b59679e 100644 --- a/clash_lib/src/proxy/utils/mod.rs +++ b/clash_lib/src/proxy/utils/mod.rs @@ -122,7 +122,7 @@ pub fn get_outbound_interface() -> Option { all_outbounds.into_iter().next() } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub enum Interface { IpAddr(IpAddr), Name(String), diff --git a/clash_lib/src/session.rs b/clash_lib/src/session.rs index ddce0385c..20ffde58c 100644 --- a/clash_lib/src/session.rs +++ b/clash_lib/src/session.rs @@ -3,6 +3,7 @@ use std::{ fmt::{Debug, Display, Formatter}, io, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + str::FromStr, }; use crate::proxy::utils::Interface; @@ -18,6 +19,40 @@ pub enum SocksAddr { Domain(String, u16), } +impl FromStr for SocksAddr { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + let mut s = s.to_string(); + if !s.contains(':') { + s = format!("{s}:80"); + } + match SocketAddr::from_str(&s) { + Ok(v) => Ok(Self::Ip(v)), + Err(_) => { + let tokens: Vec<_> = s.split(':').collect(); + if tokens.len() == 2 { + let port: u16 = tokens.get(1).unwrap().parse()?; + Ok(Self::Domain(tokens.first().unwrap().to_string(), port)) + } else { + Err(anyhow!("SocksAddr parse error, value: {s}")) + } + } + } + } +} +#[test] +fn test_from_str() { + assert_eq!( + SocksAddr::from_str("127.0.0.1").unwrap(), + SocksAddr::Ip(SocketAddr::V4("127.0.0.1:80".parse().unwrap())) + ); + assert!(SocksAddr::from_str("127.0.0.1:80").is_ok()); + assert!(SocksAddr::from_str("hosta.com").is_ok()); + assert!(SocksAddr::from_str("hosta.com:443").is_ok()); + assert!(SocksAddr::from_str("hosta.:com:443").is_err()); +} + impl Default for SocksAddr { fn default() -> Self { Self::Ip(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0)) @@ -57,6 +92,7 @@ impl SocksAddr { )) } + // TODO move to vmess pub fn write_buf(&self, buf: &mut T) { match self { Self::Ip(addr) => match addr { @@ -374,7 +410,7 @@ pub enum Type { Tun, #[cfg(target_os = "linux")] Tproxy, - + Tunnel, Ignore, }