Skip to content

Commit e3c2fc6

Browse files
committed
fix Session init race condition
1 parent f6db8b0 commit e3c2fc6

21 files changed

+86
-87
lines changed

rust/src/headless.rs

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ use crate::{
1818
};
1919
use std::io;
2020
use std::path::{Path, PathBuf};
21-
use std::sync::atomic::AtomicUsize;
22-
use std::sync::atomic::Ordering::SeqCst;
2321
use thiserror::Error;
2422

2523
use crate::enterprise::release_license;
@@ -35,7 +33,7 @@ use std::time::Duration;
3533
static MAIN_THREAD_HANDLE: Mutex<Option<JoinHandle<()>>> = Mutex::new(None);
3634

3735
/// Used to prevent shutting down Binary Ninja if there are other [`Session`]'s.
38-
static SESSION_COUNT: AtomicUsize = AtomicUsize::new(0);
36+
static SESSION_COUNT: Mutex<usize> = Mutex::new(0);
3937

4038
#[derive(Error, Debug)]
4139
pub enum InitializationError {
@@ -47,6 +45,8 @@ pub enum InitializationError {
4745
InvalidLicense,
4846
#[error("no license could located, please see `binaryninja::set_license` for details")]
4947
NoLicenseFound,
48+
#[error("unable to apply options to an previously created `binaryninja::headless::Session`")]
49+
SessionAlreadyInitialized,
5050
}
5151

5252
/// Loads plugins, core architecture, platform, etc.
@@ -284,39 +284,56 @@ pub fn license_location() -> Option<LicenseLocation> {
284284
}
285285

286286
/// Wrapper for [`init`] and [`shutdown`]. Instantiating this at the top of your script will initialize everything correctly and then clean itself up at exit as well.
287-
pub struct Session {}
287+
pub struct Session {
288+
/// lock that don't allow the user to create a session directly
289+
_lock: std::marker::PhantomData<()>,
290+
}
288291

289292
impl Session {
290293
/// Get a registered [`Session`] for use.
291294
///
292295
/// This is required so that we can keep track of the [`SESSION_COUNT`].
293-
fn registered_session() -> Self {
294-
let _previous_count = SESSION_COUNT.fetch_add(1, SeqCst);
295-
Self {}
296+
fn register_session(
297+
options: Option<InitializationOptions>,
298+
) -> Result<Self, InitializationError> {
299+
// if we were able to locate a license, continue with initialization.
300+
if license_location().is_none() {
301+
// otherwise you must call [Self::new_with_license].
302+
return Err(InitializationError::NoLicenseFound);
303+
}
304+
305+
// This is required so that we call init only once
306+
let mut session_count = SESSION_COUNT.lock().unwrap();
307+
match (*session_count, options) {
308+
// no session, just create one
309+
(0, options) => init_with_opts(options.unwrap_or_default())?,
310+
// session already created, can't apply options
311+
(1.., Some(_)) => return Err(InitializationError::SessionAlreadyInitialized),
312+
// NOTE if the existing session was created with options,
313+
// returning the current session may not be exacly what the
314+
// user expects.
315+
(1.., None) => {}
316+
}
317+
*session_count += 1;
318+
Ok(Self {
319+
_lock: std::marker::PhantomData,
320+
})
296321
}
297322

298323
/// Before calling new you must make sure that the license is retrievable, otherwise the core won't be able to initialize.
299324
///
300325
/// If you cannot otherwise provide a license via `BN_LICENSE_FILE` environment variable or the Binary Ninja user directory
301326
/// you can call [`Session::new_with_opts`] instead of this function.
302327
pub fn new() -> Result<Self, InitializationError> {
303-
if license_location().is_some() {
304-
// We were able to locate a license, continue with initialization.
305-
init()?;
306-
Ok(Self::registered_session())
307-
} else {
308-
// There was no license that could be automatically retrieved, you must call [Self::new_with_license].
309-
Err(InitializationError::NoLicenseFound)
310-
}
328+
Self::register_session(None)
311329
}
312330

313331
/// Initialize with options, the same rules apply as [`Session::new`], see [`InitializationOptions::default`] for the regular options passed.
314332
///
315333
/// This differs from [`Session::new`] in that it does not check to see if there is a license that the core
316334
/// can discover by itself, therefor it is expected that you know where your license is when calling this directly.
317335
pub fn new_with_opts(options: InitializationOptions) -> Result<Self, InitializationError> {
318-
init_with_opts(options)?;
319-
Ok(Self::registered_session())
336+
Self::register_session(Some(options))
320337
}
321338

322339
/// ```no_run
@@ -410,10 +427,13 @@ impl Session {
410427

411428
impl Drop for Session {
412429
fn drop(&mut self) {
413-
let previous_count = SESSION_COUNT.fetch_sub(1, SeqCst);
414-
if previous_count == 1 {
430+
let mut session_count = SESSION_COUNT.lock().unwrap();
431+
match *session_count {
432+
0 => unreachable!(),
415433
// We were the last session, therefor we can safely shut down.
416-
shutdown();
434+
1 => shutdown(),
435+
2.. => {}
417436
}
437+
*session_count -= 1;
418438
}
419439
}

rust/tests/background_task.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@ use binaryninja::headless::Session;
33
use rstest::*;
44

55
#[fixture]
6-
#[once]
76
fn session() -> Session {
87
Session::new().expect("Failed to initialize session")
98
}
109

1110
#[rstest]
12-
fn test_background_task_registered(_session: &Session) {
11+
fn test_background_task_registered(_session: Session) {
1312
let task_progress = "test registered";
1413
let task = BackgroundTask::new(task_progress, false);
1514
BackgroundTask::running_tasks()
@@ -25,7 +24,7 @@ fn test_background_task_registered(_session: &Session) {
2524
}
2625

2726
#[rstest]
28-
fn test_background_task_cancellable(_session: &Session) {
27+
fn test_background_task_cancellable(_session: Session) {
2928
let task_progress = "test cancellable";
3029
let task = BackgroundTask::new(task_progress, false);
3130
BackgroundTask::running_tasks()
@@ -38,7 +37,7 @@ fn test_background_task_cancellable(_session: &Session) {
3837
}
3938

4039
#[rstest]
41-
fn test_background_task_progress(_session: &Session) {
40+
fn test_background_task_progress(_session: Session) {
4241
let task = BackgroundTask::new("test progress", false);
4342
let first_progress = task.progress_text().to_string();
4443
assert_eq!(first_progress, "test progress");

rust/tests/binary_reader.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@ use std::io::{Read, Seek, SeekFrom};
66
use std::path::PathBuf;
77

88
#[fixture]
9-
#[once]
109
fn session() -> Session {
1110
Session::new().expect("Failed to initialize session")
1211
}
1312

1413
#[rstest]
15-
fn test_binary_reader_seek(_session: &Session) {
14+
fn test_binary_reader_seek(_session: Session) {
1615
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
1716
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");
1817
let mut reader = BinaryReader::new(&view);
@@ -50,7 +49,7 @@ fn test_binary_reader_seek(_session: &Session) {
5049
}
5150

5251
#[rstest]
53-
fn test_binary_reader_read(_session: &Session) {
52+
fn test_binary_reader_read(_session: Session) {
5453
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
5554
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");
5655
let mut reader = BinaryReader::new(&view);

rust/tests/binary_view.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@ use rstest::*;
66
use std::path::PathBuf;
77

88
#[fixture]
9-
#[once]
109
fn session() -> Session {
1110
Session::new().expect("Failed to initialize session")
1211
}
1312

1413
#[rstest]
15-
fn test_binary_loading(_session: &Session) {
14+
fn test_binary_loading(_session: Session) {
1615
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
1716
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");
1817
assert!(view.has_initial_analysis(), "No initial analysis");
@@ -22,7 +21,7 @@ fn test_binary_loading(_session: &Session) {
2221
}
2322

2423
#[rstest]
25-
fn test_binary_saving(_session: &Session) {
24+
fn test_binary_saving(_session: Session) {
2625
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
2726
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");
2827
// Verify the contents before we modify.
@@ -45,7 +44,7 @@ fn test_binary_saving(_session: &Session) {
4544
}
4645

4746
#[rstest]
48-
fn test_binary_saving_database(_session: &Session) {
47+
fn test_binary_saving_database(_session: Session) {
4948
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
5049
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");
5150
// Update a symbol to verify modification

rust/tests/binary_writer.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@ use std::io::{Read, Seek, SeekFrom, Write};
77
use std::path::PathBuf;
88

99
#[fixture]
10-
#[once]
1110
fn session() -> Session {
1211
Session::new().expect("Failed to initialize session")
1312
}
1413

1514
#[rstest]
16-
fn test_binary_writer_seek(_session: &Session) {
15+
fn test_binary_writer_seek(_session: Session) {
1716
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
1817
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");
1918
let mut writer = BinaryWriter::new(&view);
@@ -51,7 +50,7 @@ fn test_binary_writer_seek(_session: &Session) {
5150
}
5251

5352
#[rstest]
54-
fn test_binary_writer_write(_session: &Session) {
53+
fn test_binary_writer_write(_session: Session) {
5554
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
5655
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");
5756
let mut reader = BinaryReader::new(&view);

rust/tests/collaboration.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ use serial_test::serial;
99
use std::path::PathBuf;
1010

1111
#[fixture]
12-
#[once]
1312
fn session() -> Session {
1413
Session::new().expect("Failed to initialize session")
1514
}
@@ -58,7 +57,7 @@ fn temp_project_scope<T: Fn(&RemoteProject)>(remote: &Remote, project_name: &str
5857

5958
#[rstest]
6059
#[serial]
61-
fn test_connection(_session: &Session) {
60+
fn test_connection(_session: Session) {
6261
if !has_collaboration_support() {
6362
eprintln!("No collaboration support, skipping test...");
6463
return;
@@ -74,7 +73,7 @@ fn test_connection(_session: &Session) {
7473

7574
#[rstest]
7675
#[serial]
77-
fn test_project_creation(_session: &Session) {
76+
fn test_project_creation(_session: Session) {
7877
if !has_collaboration_support() {
7978
eprintln!("No collaboration support, skipping test...");
8079
return;
@@ -155,7 +154,7 @@ fn test_project_creation(_session: &Session) {
155154

156155
#[rstest]
157156
#[serial]
158-
fn test_project_sync(_session: &Session) {
157+
fn test_project_sync(_session: Session) {
159158
if !has_collaboration_support() {
160159
eprintln!("No collaboration support, skipping test...");
161160
return;

rust/tests/component.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@ use rstest::*;
55
use std::path::PathBuf;
66

77
#[fixture]
8-
#[once]
98
fn session() -> Session {
109
Session::new().expect("Failed to initialize session")
1110
}
1211

1312
#[rstest]
14-
fn test_component_creation(_session: &Session) {
13+
fn test_component_creation(_session: Session) {
1514
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
1615
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");
1716
let component = ComponentBuilder::new(view.clone()).name("test").finalize();

rust/tests/demangler.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@ use binaryninja::types::{QualifiedName, Type};
99
use rstest::*;
1010

1111
#[fixture]
12-
#[once]
1312
fn session() -> Session {
1413
Session::new().expect("Failed to initialize session")
1514
}
1615

1716
#[rstest]
18-
fn test_demangler_simple(_session: &Session) {
17+
fn test_demangler_simple(_session: Session) {
1918
let placeholder_arch = CoreArchitecture::by_name("x86").expect("x86 exists");
2019
// Example LLVM-style mangled name
2120
let llvm_mangled = "_Z3fooi"; // "foo(int)" in LLVM mangling
@@ -46,7 +45,7 @@ fn test_demangler_simple(_session: &Session) {
4645
}
4746

4847
#[rstest]
49-
fn test_custom_demangler(_session: &Session) {
48+
fn test_custom_demangler(_session: Session) {
5049
struct TestDemangler;
5150

5251
impl CustomDemangler for TestDemangler {

rust/tests/high_level_il.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@ use rstest::*;
55
use std::path::PathBuf;
66

77
#[fixture]
8-
#[once]
98
fn session() -> Session {
109
Session::new().expect("Failed to initialize session")
1110
}
1211

1312
#[rstest]
14-
fn test_hlil_info(_session: &Session) {
13+
fn test_hlil_info(_session: Session) {
1514
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
1615
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");
1716

rust/tests/low_level_il.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@ use rstest::*;
1212
use std::path::PathBuf;
1313

1414
#[fixture]
15-
#[once]
1615
fn session() -> Session {
1716
Session::new().expect("Failed to initialize session")
1817
}
1918

2019
#[rstest]
21-
fn test_llil_info(_session: &Session) {
20+
fn test_llil_info(_session: Session) {
2221
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
2322
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");
2423

@@ -170,7 +169,7 @@ fn test_llil_info(_session: &Session) {
170169
}
171170

172171
#[rstest]
173-
fn test_llil_visitor(_session: &Session) {
172+
fn test_llil_visitor(_session: Session) {
174173
let out_dir = env!("OUT_DIR").parse::<PathBuf>().unwrap();
175174
let view = binaryninja::load(out_dir.join("atox.obj")).expect("Failed to create view");
176175
let platform = view.default_platform().unwrap();

0 commit comments

Comments
 (0)