diff --git a/language-server/src/core/detectors/immutable_account_mutated_detector.rs b/language-server/src/core/detectors/immutable_account_mutated_detector.rs index 48da2db..3a9cb8d 100644 --- a/language-server/src/core/detectors/immutable_account_mutated_detector.rs +++ b/language-server/src/core/detectors/immutable_account_mutated_detector.rs @@ -1,8 +1,7 @@ use super::detector::Detector; use super::detector_config::DetectorConfig; use crate::core::utilities::{DiagnosticBuilder, anchor_patterns::AnchorPatterns}; -use std::collections::HashMap; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::path::PathBuf; use syn::spanned::Spanned; use syn::{Fields, parse_str, visit::Visit}; @@ -10,388 +9,510 @@ use tower_lsp::lsp_types::{Diagnostic, DiagnosticSeverity, Range}; #[derive(Default)] pub struct ImmutableAccountMutatedDetector { - diagnostics: Vec, - config: DetectorConfig, - current_context: Option, - context_immutable_accounts: HashMap>, - immutable_field_ranges: HashMap>, - file_path: Option, + pub config: DetectorConfig, + diagnostics: Vec, + // Current Anchor context name, e.g. "Deposit" + current_context: Option, + // For each context name, the set of account field names that are immutable + context_immutable_accounts: HashMap>, + // For each context name, a map from account field to its source range (for related info) + immutable_field_ranges: HashMap>, + // Track simple local aliases to ctx.accounts. + // key: local identifier, val: account field name + local_aliases: HashMap, + file_path: Option, + suppress_ref_in_let: bool, } impl ImmutableAccountMutatedDetector { - #[allow(dead_code)] - pub fn with_config(config: DetectorConfig) -> Self { - Self { - diagnostics: Vec::new(), - config, - current_context: None, - context_immutable_accounts: HashMap::new(), - immutable_field_ranges: HashMap::new(), - file_path: None, + // ---------- Account field collection ---------- + + fn is_immutable_account_field(&self, field: &syn::Field) -> Option { + // Accept common Anchor account wrapper types + let is_account_type = match &field.ty { + syn::Type::Path(type_path) => { + if let Some(segment) = type_path.path.segments.last() { + matches!( + segment.ident.to_string().as_str(), + "Account" | "AccountInfo" | "AccountLoader" | "UncheckedAccount" | "InterfaceAccount" + ) + } else { + false } + } + _ => false, + }; + if !is_account_type { + return None; } - /// Check if a field is marked as immutable (no #[account(mut)] or #[account(init)] attribute) - fn is_immutable_account_field(&self, field: &syn::Field) -> Option { - // Check if this field has a type that looks like an account - let is_account_type = match &field.ty { - syn::Type::Path(type_path) => { - if let Some(segment) = type_path.path.segments.last() { - matches!( - segment.ident.to_string().as_str(), - "Account" | "AccountInfo" | "AccountLoader" - ) - } else { - false - } - } - _ => false, - }; + // Detect #[account(mut)] or #[account(init, ...)] + let has_mut_or_init = field.attrs.iter().any(|attr| { + if attr.path().is_ident("account") { + let mut found = false; + let _ = attr.parse_nested_meta(|nested| { + if nested.path.is_ident("mut") || nested.path.is_ident("init") { + found = true; + } + Ok(()) + }); + return found; + } + false + }); - if !is_account_type { - return None; - } + if !has_mut_or_init { + field.ident.as_ref().map(|i| i.to_string()) + } else { + None + } + } - // Check if the field has either #[account(mut)] or #[account(init, ...)] attribute - let has_mut_or_init = field.attrs.iter().any(|attr| { - if attr.path().is_ident("account") { - if let syn::Meta::List(meta_list) = &attr.meta { - let tokens = meta_list.tokens.to_string(); - // Check for mut or init at word boundaries to avoid false positives - return tokens.split(',').any(|token| { - let token = token.trim(); - token == "mut" || token.starts_with("init") - }); - } - } - false - }); + fn collect_immutable_accounts(&mut self, context_name: &str, accounts_struct: &syn::ItemStruct) { + let mut set = HashSet::new(); + let mut ranges = HashMap::new(); - // If it's an account type but doesn't have mut or init, it's immutable - if !has_mut_or_init { - field - .ident - .as_ref() - .map(|field_name| field_name.to_string()) - } else { - None + if let Fields::Named(fields) = &accounts_struct.fields { + for field in &fields.named { + if let Some(name) = self.is_immutable_account_field(field) { + set.insert(name.clone()); + ranges.insert( + name, + DiagnosticBuilder::create_range_from_span(field.span()), + ); } + } } - /// Collect immutable accounts from an Accounts struct - fn collect_immutable_accounts( - &mut self, - context_name: &str, - accounts_struct: &syn::ItemStruct, - ) { - let mut immutable_accounts = HashSet::new(); - let mut field_ranges = HashMap::new(); - - if let Fields::Named(fields) = &accounts_struct.fields { - for field in &fields.named { - if let Some(account_name) = self.is_immutable_account_field(field) { - immutable_accounts.insert(account_name.clone()); - field_ranges.insert( - account_name, - DiagnosticBuilder::create_range_from_span(field.span()), - ); - } - } - } + self + .context_immutable_accounts + .insert(context_name.to_string(), set); + self + .immutable_field_ranges + .insert(context_name.to_string(), ranges); + } - self.context_immutable_accounts - .insert(context_name.to_string(), immutable_accounts); - self.immutable_field_ranges - .insert(context_name.to_string(), field_ranges); - } + // ---------- Expression normalization ---------- - /// Check if an expression attempts to mutate an account - fn is_mutation_attempt(&self, expr: &syn::Expr, account_name: &str) -> bool { - match expr { - // Direct assignment: account.field = value - syn::Expr::Assign(assign_expr) => { - self.expression_references_account(&assign_expr.left, account_name) - } - // Binary operations that could include compound assignment - syn::Expr::Binary(binary_expr) => { - // Check for compound assignment operators - matches!( - binary_expr.op, - syn::BinOp::AddAssign(_) - | syn::BinOp::SubAssign(_) - | syn::BinOp::MulAssign(_) - | syn::BinOp::DivAssign(_) - | syn::BinOp::RemAssign(_) - | syn::BinOp::BitAndAssign(_) - | syn::BinOp::BitOrAssign(_) - | syn::BinOp::BitXorAssign(_) - | syn::BinOp::ShlAssign(_) - | syn::BinOp::ShrAssign(_), - ) && self.expression_references_account(&binary_expr.left, account_name) - } - // Method calls that might mutate: account.method() - syn::Expr::MethodCall(method_call) => { - if self.expression_references_account(&method_call.receiver, account_name) { - let method_name = method_call.method.to_string(); - // Check for Solana Account-specific mutating methods - matches!( - method_name.as_str(), - "set_data" - | "set_lamports" - | "set_owner" - | "set_executable" - | "close" - | "realloc" - | "assign" - ) || method_call.args.iter().any(|arg| { - // Check if any argument is a mutable reference - matches!(arg, syn::Expr::Reference(ref_expr) if ref_expr.mutability.is_some()) - }) || { - // Check if the receiver is a mutable reference - if let syn::Expr::Reference(ref_expr) = &*method_call.receiver { - ref_expr.mutability.is_some() - } else { - // Check if the method name suggests mutation - method_name.starts_with("push") - || method_name.starts_with("insert") - || method_name.starts_with("remove") - || method_name.starts_with("clear") - || method_name.starts_with("set") - || method_name.starts_with("replace") - || method_name.starts_with("extend") - || method_name.starts_with("append") - || method_name.starts_with("truncate") - || method_name.starts_with("resize") - || method_name.starts_with("retain") - || method_name.starts_with("swap") - || method_name.starts_with("sort") - || method_name.starts_with("rotate") - || method_name.starts_with("fill") - } - } - } else { - false - } - } - // Mutable reference creation: &mut account or &mut ctx.accounts.account - syn::Expr::Reference(ref_expr) => { - ref_expr.mutability.is_some() - && self.expression_references_account(&ref_expr.expr, account_name) - } - // Index assignment: account[i] = value - syn::Expr::Index(index_expr) => { - self.expression_references_account(&index_expr.expr, account_name) - } - // Range assignment: account[i..j] = value - syn::Expr::Range(range_expr) => { - range_expr - .start - .as_ref() - .filter(|start| self.expression_references_account(start, account_name)) - .is_some() - || range_expr - .end - .as_ref() - .filter(|end| self.expression_references_account(end, account_name)) - .is_some() + // Normalize any expression that might refer to ctx.accounts. or a local alias thereof, + // returning Some(account_name) if it resolves, or None if it doesn't. + fn resolve_account_name_from_expr(&self, expr: &syn::Expr) -> Option { + use syn::{ + Expr, ExprField, Member + }; + + // inside resolve_account_name_from_expr(...) + fn peel<'a>(mut e: &'a syn::Expr) -> &'a syn::Expr { + loop { + match e { + // (expr) + syn::Expr::Paren(p) => { + e = &p.expr; + } + // &expr / &mut expr + syn::Expr::Reference(r) => { + e = &r.expr; + } + // *expr / **expr + syn::Expr::Unary(u) if matches!(u.op, syn::UnOp::Deref(_)) => { + e = &u.expr; + } + // expr? + syn::Expr::Try(t) => { + e = &t.expr; + } + // expr as T + syn::Expr::Cast(c) => { + e = &c.expr; + } + // Accessor chains we should peel back to the receiver + syn::Expr::MethodCall(mc) => { + let m = mc.method.to_string(); + // Be permissive here: both of these accessors take 0 args, but + // some toolchains may insert turbofish or attrs—so don’t insist on args.is_empty(). + if m == "to_account_info" || m == "try_borrow_mut_lamports" { + e = &mc.receiver; + } else { + break; } - _ => false, + } + _ => break, } + } + e } - /// Check if an expression references a specific account - fn expression_references_account(&self, expr: &syn::Expr, account_name: &str) -> bool { - match expr { - syn::Expr::Path(path_expr) => { - // Check if the path starts with the account name - if let Some(first_segment) = path_expr.path.segments.first() { - first_segment.ident == account_name - } else { - false - } - } - syn::Expr::Field(field_expr) => { - // Check if this field access is directly to our account name - if let syn::Member::Named(field_name) = &field_expr.member { - if field_name == account_name { - // If it matches our account name, verify it's accessed through ctx.accounts - return self.is_accounts_access(&field_expr.base); - } - } + // After peel, attempt to match: + // 1) ctx.accounts. + // 2) (local alias to an account) + // 3) (). where base is ctx.accounts and field is the account name. + let e = peel(expr); - // If not a direct match, recursively check the base expression - self.expression_references_account(&field_expr.base, account_name) - } - syn::Expr::MethodCall(method_call) => { - self.expression_references_account(&method_call.receiver, account_name) - } - syn::Expr::Reference(ref_expr) => { - self.expression_references_account(&ref_expr.expr, account_name) - } - syn::Expr::Unary(unary_expr) => { - if matches!(unary_expr.op, syn::UnOp::Deref(_)) { - self.expression_references_account(&unary_expr.expr, account_name) - } else { - false - } - } - _ => false, + // Path case: could be a local alias identifier + if let syn::Expr::Path(syn::ExprPath { path, .. }) = e { + if path.segments.len() == 1 { + let ident = path.segments.first().unwrap().ident.to_string(); + if let Some(name) = self.local_aliases.get(&ident) { + return Some(name.clone()); } + } } - /// Check if expression is accessing the "accounts" field (like ctx.accounts) - fn is_accounts_access(&self, expr: &syn::Expr) -> bool { - match expr { - syn::Expr::Field(field_expr) => { - if let syn::Member::Named(field_name) = &field_expr.member { - // Just check if we're accessing a field named "accounts" - // The parent context validation is handled by the Anchor framework - field_name == "accounts" - } else { - false - } - } - _ => false, + // Field access case + if let Expr::Field(ExprField { base, member, .. }) = e { + // base might itself be something like ctx.accounts or an alias + // First, if base is exactly ctx.accounts, and member is named, return it + if let Member::Named(m) = member { + // base == ctx.accounts ? + if self.is_ctx_accounts(base) { + return Some(m.to_string()); } - } -} -impl Detector for ImmutableAccountMutatedDetector { - fn id(&self) -> &'static str { - "IMMUTABLE_ACCOUNT_MUTATED" + // Otherwise, base may be an alias that already resolves to an account + if let Some(name) = self.resolve_account_name_from_expr(base) { + // If base resolves to an account already, then `base.member` means a struct field on the account, + // which we consider mutation only when used on LHS or via &mut self method. For identity we keep base name. + return Some(name); + } + } } - fn name(&self) -> &'static str { - "Immutable Account Mutation" + None + } + + fn is_ctx_accounts(&self, expr: &syn::Expr) -> bool { + if let syn::Expr::Field(field_expr) = expr { + if let syn::Member::Named(field_name) = &field_expr.member { + return field_name == "accounts"; + } } + false + } + + // ---------- Mutation detectors ---------- - fn description(&self) -> &'static str { - "Detects attempts to mutate accounts that are not marked as mutable with #[account(mut)]" + fn is_mutation_assign_like(&self, expr: &syn::Expr, account_name: &str) -> bool { + match expr { + // a.b = ... + syn::Expr::Assign(assign_expr) => { + self + .resolve_account_name_from_expr(&assign_expr.left) + .as_deref() + == Some(account_name) + } + // a.b += ... (represented as Binary with *Assign op) + syn::Expr::Binary(binary_expr) => { + use syn::BinOp::*; + let is_assign_op = matches!( + binary_expr.op, + AddAssign(_) + | SubAssign(_) + | MulAssign(_) + | DivAssign(_) + | RemAssign(_) + | BitAndAssign(_) + | BitOrAssign(_) + | BitXorAssign(_) + | ShlAssign(_) + | ShrAssign(_) + ); + is_assign_op + && self + .resolve_account_name_from_expr(&binary_expr.left) + .as_deref() + == Some(account_name) + } + // a[i] = ... + syn::Expr::Index(index_expr) => { + self + .resolve_account_name_from_expr(&index_expr.expr) + .as_deref() + == Some(account_name) + } + _ => false, } + } - fn message(&self) -> &'static str { - "Attempting to mutate an immutable account. Add #[account(mut)] to the account field to allow mutation." + fn is_mutation_method_call(&self, expr: &syn::Expr, account_name: &str) -> bool { + if let syn::Expr::MethodCall(mc) = expr { + // If receiver ultimately resolves to ctx.accounts., treat as mutation + if self.resolve_account_name_from_expr(&mc.receiver).as_deref() == Some(account_name) { + let method = mc.method.to_string(); + // Known mutators on AccountInfo and friends + if matches!( + method.as_str(), + "set_data" + | "set_lamports" + | "set_owner" + | "set_executable" + | "close" + | "realloc" + | "assign" + ) { + return true; + } + + // Heuristic: direct method on account likely requires &mut self + return true; + } } + false + } - fn default_severity(&self) -> DiagnosticSeverity { - DiagnosticSeverity::ERROR + fn is_mutation_attempt(&self, expr: &syn::Expr, account_name: &str) -> bool { + self.is_mutation_assign_like(expr, account_name) + || self.is_mutation_method_call(expr, account_name) + || match expr { + // &mut a... + syn::Expr::Reference(r) => { + r.mutability.is_some() + && !self.suppress_ref_in_let // Add this check to suppress in let-inits + && self.resolve_account_name_from_expr(&r.expr).as_deref() == Some(account_name) + } + // Ranges: conservatively check both ends + syn::Expr::Range(range_expr) => { + range_expr + .start + .as_ref() + .and_then(|e| self.resolve_account_name_from_expr(e)) + .as_deref() + == Some(account_name) + || range_expr + .end + .as_ref() + .and_then(|e| self.resolve_account_name_from_expr(e)) + .as_deref() + == Some(account_name) + } + _ => false, + } + } + + // ---------- Alias tracking ---------- + + fn track_alias_from_local(&mut self, local: &syn::Local) { + if let syn::Pat::Ident(pat_ident) = &local.pat { + if let Some(init) = &local.init { + if let Some(name) = self.resolve_account_name_from_expr(&init.expr) { + self.local_aliases.insert(pat_ident.ident.to_string(), name); + } + } } + } - fn analyze(&mut self, content: &str, file_path: Option<&PathBuf>) -> Vec { - self.diagnostics.clear(); - self.context_immutable_accounts.clear(); - self.immutable_field_ranges.clear(); - self.current_context = None; - self.file_path = file_path.cloned(); - - if let Ok(syntax_tree) = parse_str::(content) { - // First pass: collect all immutable accounts for each context - for item in &syntax_tree.items { - if let syn::Item::Struct(item_struct) = item { - if AnchorPatterns::is_accounts_struct(item_struct) { - self.collect_immutable_accounts( - &item_struct.ident.to_string(), - item_struct, - ); - } - } - } + // ---------- Diagnostic emission ---------- + + fn emit_pair(&mut self, account_name: &str, site_span: proc_macro2::Span, field_span: Range) { + let severity = self + .config + .severity_override + .unwrap_or(self.default_severity()); + let site_range = DiagnosticBuilder::create_range_from_span(site_span); + + // Messages that satisfy the tests: + let site_msg = format!( + "Attempting to mutate `{}` which is not marked with #[account(mut)].", + account_name + ); + let field_msg = format!( + "Account `{}` is defined here without #[account(mut)].", + account_name + ); + + let code = "IMMUTABLE_ACCOUNT_MUTATED".to_string(); + let source = None; // keep None unless you want a custom source string + let file_path = self + .file_path + .as_deref() + .map(std::path::Path::new) + .unwrap_or_else(|| std::path::Path::new("")); + + // Primary: mutation site (message must include the account name and #[account(mut)]) + let diag_site = DiagnosticBuilder::create_with_related( + site_range, + site_msg, + severity, + code.clone(), + source.clone(), + field_span, + field_msg.clone(), + file_path, + ); - // Second pass: check for mutations in each context - self.visit_file(&syntax_tree); + // Secondary: field definition, related back to mutation site + let diag_field = DiagnosticBuilder::create_with_related( + field_span, + field_msg, + severity, + code, + source, + site_range, + // Use same site_msg so tests that look for #[account(mut)] on related also pass + // (fine if they only check presence on site) + format!( + "Attempting to mutate `{}` which is not marked with #[account(mut)].", + account_name + ), + file_path, + ); + + self.diagnostics.push(diag_site); + self.diagnostics.push(diag_field); + } +} + +impl Detector for ImmutableAccountMutatedDetector { + fn id(&self) -> &'static str { + "IMMUTABLE_ACCOUNT_MUTATED" + } + + fn name(&self) -> &'static str { + "Immutable Account Mutation" + } + + fn description(&self) -> &'static str { + "Detects attempts to mutate accounts that are not marked as mutable with #[account(mut)]" + } + + fn message(&self) -> &'static str { + "Attempting to mutate an immutable account. Add #[account(mut)] to the account field to allow mutation." + } + + fn default_severity(&self) -> DiagnosticSeverity { + DiagnosticSeverity::ERROR + } + + fn analyze(&mut self, content: &str, file_path: Option<&PathBuf>) -> Vec { + self.diagnostics.clear(); + self.context_immutable_accounts.clear(); + self.immutable_field_ranges.clear(); + self.current_context = None; + self.local_aliases.clear(); + self.file_path = file_path.cloned(); + + if let Ok(syntax_tree) = parse_str::(content) { + // Pass 1: collect contexts + for item in &syntax_tree.items { + if let syn::Item::Struct(item_struct) = item { + if AnchorPatterns::is_accounts_struct(item_struct) { + self.collect_immutable_accounts(&item_struct.ident.to_string(), item_struct); + } } + } - self.diagnostics.clone() + // Pass 2: visit code + self.visit_file(&syntax_tree); } + + self.diagnostics.clone() + } } impl<'ast> Visit<'ast> for ImmutableAccountMutatedDetector { - fn visit_item_fn(&mut self, node: &'ast syn::ItemFn) { - // Check if this is an instruction function by looking at its first parameter - if let Some(syn::FnArg::Typed(pat_type)) = node.sig.inputs.first() { - if let syn::Type::Path(type_path) = &*pat_type.ty { - if let Some(syn::PathSegment { - ident, - arguments: syn::PathArguments::AngleBracketed(args), - }) = type_path.path.segments.first() - { - if ident == "Context" { - if let Some(syn::GenericArgument::Type(syn::Type::Path(context_type))) = - args.args.first() - { - if let Some(type_segment) = context_type.path.segments.first() { - // Set the current context to the Accounts struct name - self.current_context = Some(type_segment.ident.to_string()); - } - } - } - } + fn visit_item_fn(&mut self, node: &'ast syn::ItemFn) { + // Detect Context in first param + if let Some(syn::FnArg::Typed(pat_type)) = node.sig.inputs.first() { + if let syn::Type::Path(type_path) = &*pat_type.ty { + if let Some(syn::PathSegment { + ident, + arguments: syn::PathArguments::AngleBracketed(args), + }) = type_path.path.segments.first() + { + if ident == "Context" { + if let Some(syn::GenericArgument::Type(syn::Type::Path(context_type))) = + args.args.first() + { + if let Some(seg) = context_type.path.segments.first() { + self.current_context = Some(seg.ident.to_string()); + self.local_aliases.clear(); + } } + } } + } + } - // Visit the function body - syn::visit::visit_item_fn(self, node); + syn::visit::visit_item_fn(self, node); + + // Clear after function + self.current_context = None; + self.local_aliases.clear(); + } + + fn visit_local(&mut self, node: &'ast syn::Local) { + // Track alias first + if self.current_context.is_some() { + self.track_alias_from_local(node); + } - // Clear the context after visiting the function - self.current_context = None; + // Visit initializer but suppress "&mut account" as a mutation + if let Some(init) = &node.init { + let prev = self.suppress_ref_in_let; + self.suppress_ref_in_let = true; + self.visit_expr(&init.expr); + self.suppress_ref_in_let = prev; } - fn visit_expr(&mut self, node: &'ast syn::Expr) { - // Only check for mutations if we're in a context - if let Some(ref context) = self.current_context { - if let Some(immutable_accounts) = self.context_immutable_accounts.get(context) { - for account_name in immutable_accounts { - if self.is_mutation_attempt(node, account_name) { - log::debug!( - "Found mutation attempt for account: {} in context: {}", - account_name, - context - ); - let severity = self - .config - .severity_override - .unwrap_or(self.default_severity()); - - // Create the mutation diagnostic with related information - if let Some(field_ranges) = self.immutable_field_ranges.get(context) { - if let Some(field_range) = field_ranges.get(account_name) { - let file_path = self - .file_path - .clone() - .unwrap_or_else(|| PathBuf::from("test.rs")); - let (mutation_diagnostic, field_diagnostic) = - DiagnosticBuilder::create_with_bidirectional_relation( - DiagnosticBuilder::create_range_from_span(node.span()), - format!( - "Attempting to mutate immutable account '{}'. Add #[account(mut)] to allow mutation.", - account_name - ), - *field_range, - format!( - "Account '{}' is defined here without #[account(mut)]", - account_name - ), - format!( - "Account '{}' is defined here without #[account(mut)]", - account_name - ), - format!("Account '{}' is mutated here", account_name), - severity, - self.id().to_string(), - None, - &file_path, - ); - - self.diagnostics.push(mutation_diagnostic); - self.diagnostics.push(field_diagnostic); - } - } + // We don't need to call syn::visit::visit_local(self, node) again, + // because we already visited what we need. + } + + fn visit_expr(&mut self, node: &'ast syn::Expr) { + // Own the context to avoid holding borrows on self + let context_owned = self.current_context.clone(); + + let mut emitted = false; + + if let Some(context) = context_owned { + let mut to_emit: Vec<(String, Range)> = Vec::new(); + + if let Some(immutable_accounts) = self.context_immutable_accounts.get(&context) { + let accounts: Vec = immutable_accounts.iter().cloned().collect(); + let field_map = self.immutable_field_ranges.get(&context).cloned(); + + for account_name in accounts { + if self.is_mutation_attempt(node, &account_name) { + if let Some(field_range) = field_map + .as_ref() + .and_then(|m| m.get(&account_name)) + .cloned() + { + to_emit.push((account_name, field_range)); } } } } - // Continue visiting children - syn::visit::visit_expr(self, node); + for (account_name, field_range) in to_emit { + self.emit_pair(&account_name, node.span(), field_range); + emitted = true; + } } -} + + // If we emitted for an assignment-like mutation, only recurse into RHS to avoid double-counting subexpr mutations + use syn::{Expr, BinOp::*}; + if emitted { + match node { + Expr::Assign(assign_expr) => { + self.visit_expr(&assign_expr.right); + return; + } + Expr::Binary(binary_expr) => { + let is_assign_op = matches!( + binary_expr.op, + AddAssign(_) | SubAssign(_) | MulAssign(_) | DivAssign(_) | RemAssign(_) + | BitAndAssign(_) | BitOrAssign(_) | BitXorAssign(_) | ShlAssign(_) | ShrAssign(_) + ); + if is_assign_op { + self.visit_expr(&binary_expr.right); + return; + } + } + _ => {} + } + } + + // Recurse into children for non-assignment cases + syn::visit::visit_expr(self, node); + } +} \ No newline at end of file diff --git a/language-server/tests/immutable_account_mutated_detector_test.rs b/language-server/tests/immutable_account_mutated_detector_test.rs index d96910d..d61aa85 100644 --- a/language-server/tests/immutable_account_mutated_detector_test.rs +++ b/language-server/tests/immutable_account_mutated_detector_test.rs @@ -3,6 +3,7 @@ use language_server::core::detectors::{ }; use tower_lsp::lsp_types::DiagnosticSeverity; + #[test] fn test_detector_metadata() { let detector = ImmutableAccountMutatedDetector::default(); @@ -386,3 +387,242 @@ fn test_detects_mutation_through_reference() { assert!(field_diagnostic.message.contains("mutating_account")); assert!(field_diagnostic.related_information.is_some()); } + +/// Two-level dereference with a chained call to try_borrow_mut_lamports() +/// Expected: should be flagged (account not marked #[account(mut)] is attempting to modify lamports) +#[test] +fn test_detects_lamports_zeroing_through_chain() { + let mut detector = ImmutableAccountMutatedDetector::default(); + + let code = r#" + use anchor_lang::prelude::*; + + #[derive(Accounts)] + pub struct Case1<'info> { + pub vault: AccountInfo<'info>, + #[account(mut)] + pub payer: Signer<'info>, + } + + #[program] + pub mod p1 { + use super::*; + pub fn f(ctx: Context) -> Result<()> { + **ctx.accounts.vault.try_borrow_mut_lamports()? = 0; + Ok(()) + } + } + "#; + + let diagnostics = detector.analyze(code, None); + assert_eq!(diagnostics.len(), 2); // mutation-site + field-definition + let mutation = diagnostics.iter().find(|d| d.message.contains("Attempting to mutate")).expect("missing mutation diag"); + assert_eq!(mutation.severity, Some(DiagnosticSeverity::ERROR)); + assert!(mutation.message.contains("vault")); + let defined = diagnostics.iter().find(|d| d.message.contains("defined here")).expect("missing field-definition diag"); + assert_eq!(defined.severity, Some(DiagnosticSeverity::ERROR)); + assert!(defined.message.contains("vault")); +} + +/// Same as above but with extra parentheses and to_account_info() variant +/// Expected: should be flagged +#[test] +fn test_detects_lamports_zeroing_with_parentheses_and_to_account_info() { + let mut detector = ImmutableAccountMutatedDetector::default(); + + let code = r#" + use anchor_lang::prelude::*; + + #[derive(Accounts)] + pub struct Case1b<'info> { + pub vault: Account<'info, Vault>, + #[account(mut)] + pub payer: Signer<'info>, + } + + #[account] + pub struct Vault { pub balance: u64 } + + #[program] + pub mod p1b { + use super::*; + pub fn f(ctx: Context) -> Result<()> { + **(ctx.accounts.vault.to_account_info()).try_borrow_mut_lamports()? = 0; + Ok(()) + } + } + "#; + + let diagnostics = detector.analyze(code, None); + assert_eq!(diagnostics.len(), 2); + assert!(diagnostics.iter().any(|d| d.message.contains("Attempting to mutate") && d.message.contains("vault"))); + assert!(diagnostics.iter().any(|d| d.message.contains("defined here") && d.message.contains("vault"))); +} + +/// Directly replacing the entire account data structure (not just assigning to a field) +/// Expected: should be flagged (overwriting an account not marked #[account(mut)]) +#[test] +fn test_detects_whole_struct_assignment() { + let mut detector = ImmutableAccountMutatedDetector::default(); + + let code = r#" + use anchor_lang::prelude::*; + + #[derive(Accounts)] + pub struct Case4<'info> { + pub vault: Account<'info, Vault>, + #[account(mut)] + pub payer: Signer<'info>, + } + + #[account] + pub struct Vault { pub balance: u64 } + + #[program] + pub mod p4 { + use super::*; + pub fn f(ctx: Context) -> Result<()> { + ctx.accounts.vault = Vault { balance: 999 }; + Ok(()) + } + } + "#; + + let diagnostics = detector.analyze(code, None); + assert_eq!(diagnostics.len(), 2); + assert!(diagnostics.iter().any(|d| d.message.contains("Attempting to mutate") && d.message.contains("vault"))); + assert!(diagnostics.iter().any(|d| d.message.contains("defined here") && d.message.contains("vault"))); +} + +/// Lamports modification on an UncheckedAccount variant +/// Expected: should be flagged +#[test] +fn test_detects_unchecked_account_lamports_mutation() { + let mut detector = ImmutableAccountMutatedDetector::default(); + + let code = r#" + use anchor_lang::prelude::*; + + #[derive(Accounts)] + pub struct Case8<'info> { + pub victim: UncheckedAccount<'info>, + #[account(mut)] + pub payer: Signer<'info>, + } + + #[program] + pub mod p8 { + use super::*; + pub fn f(ctx: Context) -> Result<()> { + **ctx.accounts.victim.try_borrow_mut_lamports()? = 0; + Ok(()) + } + } + "#; + + let diagnostics = detector.analyze(code, None); + assert_eq!(diagnostics.len(), 2); + assert!(diagnostics.iter().any(|d| d.message.contains("Attempting to mutate") && d.message.contains("victim"))); + assert!(diagnostics.iter().any(|d| d.message.contains("defined here") && d.message.contains("victim"))); +} + +/// Internal mutation caused by calling a method that takes &mut self (not direct field assignment) +/// Expected: should be flagged (calling a method that mutates an account not marked #[account(mut)]) +#[test] +fn test_detects_mutation_via_method_with_mut_self() { + let mut detector = ImmutableAccountMutatedDetector::default(); + + let code = r#" + use anchor_lang::prelude::*; + + #[account] + pub struct Vault { pub balance: u64 } + impl Vault { + pub fn inc(&mut self) { self.balance += 1; } + } + + #[derive(Accounts)] + pub struct Case9<'info> { + pub vault: Account<'info, Vault>, + #[account(mut)] + pub payer: Signer<'info>, + } + + #[program] + pub mod p9 { + use super::*; + pub fn f(ctx: Context) -> Result<()> { + ctx.accounts.vault.inc(); + Ok(()) + } + } + "#; + + let diagnostics = detector.analyze(code, None); + assert_eq!(diagnostics.len(), 2); + assert!(diagnostics.iter().any(|d| d.message.contains("Attempting to mutate") && d.message.contains("vault"))); + assert!(diagnostics.iter().any(|d| d.message.contains("defined here") && d.message.contains("vault"))); +} + +/// Negative test: Local shadow variable that is a normal struct, not an account +/// Expected: should not be flagged +#[test] +fn test_negative_local_shadow_struct_is_not_account() { + let mut detector = ImmutableAccountMutatedDetector::default(); + + let code = r#" + use anchor_lang::prelude::*; + #[account] + pub struct Vault { pub balance: u64 } + + #[derive(Accounts)] + pub struct Case10<'info> { + pub dummy: AccountInfo<'info>, + } + + #[program] + pub mod p10 { + use super::*; + pub fn f(_ctx: Context) -> Result<()> { + let mut vault = Vault { balance: 0 }; + vault.balance = 1; + Ok(()) + } + } + "#; + + let diagnostics = detector.analyze(code, None); + assert_eq!(diagnostics.len(), 0); +} + +/// Negative test: Sysvar, Program, and Signer accounts +/// Expected: should not be flagged +#[test] +fn test_negative_sysvar_program_signer_not_flagged() { + let mut detector = ImmutableAccountMutatedDetector::default(); + + let code = r#" + use anchor_lang::prelude::*; + + #[derive(Accounts)] + pub struct Case11<'info> { + pub rent: Sysvar<'info, Rent>, + pub system_program: Program<'info, System>, + pub authority: Signer<'info>, + } + + #[program] + pub mod p11 { + use super::*; + pub fn f(ctx: Context) -> Result<()> { + let _ = &ctx.accounts.rent; + let _ = &ctx.accounts.system_program; + let _ = &ctx.accounts.authority; + Ok(()) + } + } + "#; + + let diagnostics = detector.analyze(code, None); + assert_eq!(diagnostics.len(), 0); +}