Skip to content

Commit 18d7e47

Browse files
committed
[Rust] Add GET and POST helper functions for DownloadInstance
1 parent f7464b7 commit 18d7e47

File tree

1 file changed

+126
-0
lines changed

1 file changed

+126
-0
lines changed

rust/src/download/instance.rs

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@ use crate::headless::is_shutdown_requested;
33
use crate::rc::{Ref, RefCountable};
44
use crate::string::{strings_to_string_list, BnString, IntoCStr};
55
use binaryninjacore_sys::*;
6+
use std::cell::RefCell;
67
use std::collections::HashMap;
78
use std::ffi::{c_void, CStr};
89
use std::mem::{ManuallyDrop, MaybeUninit};
910
use std::os::raw::c_char;
1011
use std::ptr::null_mut;
12+
use std::rc::Rc;
1113
use std::slice;
1214

1315
pub trait CustomDownloadInstance: Sized {
@@ -83,6 +85,36 @@ pub struct DownloadResponse {
8385
pub headers: HashMap<String, String>,
8486
}
8587

88+
pub struct OwnedDownloadResponse {
89+
pub data: Vec<u8>,
90+
pub status_code: u16,
91+
pub headers: HashMap<String, String>,
92+
}
93+
94+
impl OwnedDownloadResponse {
95+
/// Attempt to parse the response body as UTF-8.
96+
pub fn text(&self) -> Result<String, std::string::FromUtf8Error> {
97+
String::from_utf8(self.data.clone())
98+
}
99+
100+
/// Attempt to deserialize the response body as JSON into T.
101+
pub fn json<T: serde::de::DeserializeOwned>(&self) -> Result<T, serde_json::Error> {
102+
serde_json::from_slice(&self.data)
103+
}
104+
105+
/// Convenience to get a header value by case-insensitive name.
106+
pub fn header(&self, name: &str) -> Option<&str> {
107+
self.headers
108+
.get(&name.to_ascii_lowercase())
109+
.map(|s| s.as_str())
110+
}
111+
112+
/// True if the status code is in the 2xx range.
113+
pub fn is_success(&self) -> bool {
114+
(200..300).contains(&self.status_code)
115+
}
116+
}
117+
86118
/// A reader for a [`DownloadInstance`].
87119
pub struct DownloadInstanceReader {
88120
pub instance: Ref<DownloadInstance>,
@@ -199,6 +231,100 @@ impl DownloadInstance {
199231
unsafe { BNNotifyProgressForDownloadInstance(self.handle, progress, total) }
200232
}
201233

234+
pub fn get<I>(&mut self, url: &str, headers: I) -> Result<OwnedDownloadResponse, String>
235+
where
236+
I: IntoIterator<Item = (String, String)>,
237+
{
238+
let buf: Rc<RefCell<Vec<u8>>> = Rc::new(RefCell::new(Vec::new()));
239+
let buf_closure = Rc::clone(&buf);
240+
let callbacks = DownloadInstanceInputOutputCallbacks {
241+
read: None,
242+
write: Some(Box::new(move |data: &[u8]| {
243+
buf_closure.borrow_mut().extend_from_slice(data);
244+
data.len()
245+
})),
246+
progress: Some(Box::new(|_, _| true)),
247+
};
248+
249+
let resp = self.perform_custom_request("GET", url, headers, &callbacks)?;
250+
drop(callbacks);
251+
let out = Rc::try_unwrap(buf).map_err(|_| "Buffer held with strong reference")?;
252+
Ok(OwnedDownloadResponse {
253+
data: out.into_inner(),
254+
status_code: resp.status_code,
255+
headers: resp.headers,
256+
})
257+
}
258+
259+
pub fn post<I>(
260+
&mut self,
261+
url: &str,
262+
headers: I,
263+
body: Vec<u8>,
264+
) -> Result<OwnedDownloadResponse, String>
265+
where
266+
I: IntoIterator<Item = (String, String)>,
267+
{
268+
let resp_buf: Rc<RefCell<Vec<u8>>> = Rc::new(RefCell::new(Vec::new()));
269+
let resp_buf_closure = Rc::clone(&resp_buf);
270+
// Request body position tracker captured by the read closure
271+
let mut pos = 0usize;
272+
let total = body.len();
273+
let callbacks = DownloadInstanceInputOutputCallbacks {
274+
// Supply request body to the core
275+
read: Some(Box::new(move |dst: &mut [u8]| -> Option<usize> {
276+
if pos >= total {
277+
return Some(0);
278+
}
279+
let remaining = total - pos;
280+
let to_copy = remaining.min(dst.len());
281+
dst[..to_copy].copy_from_slice(&body[pos..pos + to_copy]);
282+
pos += to_copy;
283+
Some(to_copy)
284+
})),
285+
// Collect response body
286+
write: Some(Box::new(move |data: &[u8]| {
287+
resp_buf_closure.borrow_mut().extend_from_slice(data);
288+
data.len()
289+
})),
290+
progress: Some(Box::new(|_, _| true)),
291+
};
292+
293+
let resp = self.perform_custom_request("POST", url, headers, &callbacks)?;
294+
drop(callbacks);
295+
if !(200..300).contains(&(resp.status_code as i32)) {
296+
return Err(format!("HTTP error: {}", resp.status_code));
297+
}
298+
299+
let out = Rc::try_unwrap(resp_buf).map_err(|_| "Buffer held with strong reference")?;
300+
Ok(OwnedDownloadResponse {
301+
data: out.into_inner(),
302+
status_code: resp.status_code,
303+
headers: resp.headers,
304+
})
305+
}
306+
307+
pub fn post_json<I, T>(
308+
&mut self,
309+
url: &str,
310+
headers: I,
311+
body: &T,
312+
) -> Result<OwnedDownloadResponse, String>
313+
where
314+
I: IntoIterator<Item = (String, String)>,
315+
T: serde::Serialize,
316+
{
317+
let mut headers: Vec<(String, String)> = headers.into_iter().collect();
318+
if !headers
319+
.iter()
320+
.any(|(k, _)| k.eq_ignore_ascii_case("content-type"))
321+
{
322+
headers.push(("content-type".into(), "application/json".into()));
323+
}
324+
let bytes = serde_json::to_vec(body).map_err(|e| e.to_string())?;
325+
self.post(url, headers, bytes)
326+
}
327+
202328
pub fn perform_request(
203329
&mut self,
204330
url: &str,

0 commit comments

Comments
 (0)