-
Notifications
You must be signed in to change notification settings - Fork 3
V2/most confident indices #105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: wit-files-v2
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,194 +1,98 @@ | ||||||
| use std::{cmp::Ordering, convert::TryInto, fmt::Display}; | ||||||
|
|
||||||
| use crate::proc_block_v1::*; | ||||||
|
|
||||||
| use hotg_rune_proc_blocks::{ | ||||||
| common, | ||||||
| ndarray::ArrayView1, | ||||||
| runtime_v1::{self, *}, | ||||||
| BufferExt, SliceExt, ValueType, | ||||||
| use hotg_rune_proc_blocks::guest::{ | ||||||
| Argument, ArgumentMetadata, ArgumentType, CreateError, Dimensions, | ||||||
| ElementType, InvalidInput, Metadata, ProcBlock, RunError, Tensor, | ||||||
| TensorConstraint, TensorConstraints, TensorMetadata, parse | ||||||
| }; | ||||||
|
|
||||||
| wit_bindgen_rust::export!("../wit-files/rune/proc-block-v1.wit"); | ||||||
|
|
||||||
| struct ProcBlockV1; | ||||||
|
|
||||||
| impl proc_block_v1::ProcBlockV1 for ProcBlockV1 { | ||||||
| fn register_metadata() { | ||||||
| let metadata = | ||||||
| Metadata::new("Most Confident Indices", env!("CARGO_PKG_VERSION")); | ||||||
| metadata.set_description( | ||||||
| "Given some confidence values, create a tensor containing the indices of the top N highest confidences.", | ||||||
| ); | ||||||
| metadata.set_repository(env!("CARGO_PKG_REPOSITORY")); | ||||||
| metadata.set_homepage(env!("CARGO_PKG_HOMEPAGE")); | ||||||
| metadata.add_tag("classify"); | ||||||
|
|
||||||
| let count = ArgumentMetadata::new("count"); | ||||||
| count.set_description("The number of indices to return."); | ||||||
| count.set_default_value("1"); | ||||||
| let hint = | ||||||
| runtime_v1::supported_argument_type(ArgumentType::UnsignedInteger); | ||||||
| count.add_hint(&hint); | ||||||
| metadata.add_argument(&count); | ||||||
| use std::{cmp::Ordering, convert::TryInto}; | ||||||
|
|
||||||
| let element_type = ArgumentMetadata::new(common::element_type::NAME); | ||||||
| element_type.set_description(common::element_type::DESCRIPTION); | ||||||
| let hint = runtime_v1::interpret_as_string_in_enum( | ||||||
| common::element_type::NUMERIC, | ||||||
| ); | ||||||
| element_type.add_hint(&hint); | ||||||
| metadata.add_argument(&element_type); | ||||||
| use hotg_rune_proc_blocks::ndarray::ArrayView1; | ||||||
|
|
||||||
| let input = TensorMetadata::new("confidences"); | ||||||
| input.set_description("A 1D tensor of numeric confidence values."); | ||||||
| let hint = supported_shapes( | ||||||
| &[ | ||||||
| ElementType::U8, | ||||||
| ElementType::I8, | ||||||
| ElementType::U16, | ||||||
| ElementType::I16, | ||||||
| ElementType::U32, | ||||||
| ElementType::I32, | ||||||
| ElementType::F32, | ||||||
| ElementType::U64, | ||||||
| ElementType::I64, | ||||||
| ElementType::F64, | ||||||
| ], | ||||||
| DimensionsParam::Dynamic, | ||||||
| ); | ||||||
| input.add_hint(&hint); | ||||||
| metadata.add_input(&input); | ||||||
|
|
||||||
| let output = TensorMetadata::new("indices"); | ||||||
| output | ||||||
| .set_description("The indices, in order of descending confidence."); | ||||||
| let hint = | ||||||
| supported_shapes(&[ElementType::U32], DimensionsParam::Fixed(&[0])); | ||||||
| output.add_hint(&hint); | ||||||
| metadata.add_output(&output); | ||||||
|
|
||||||
| register_node(&metadata); | ||||||
| } | ||||||
| hotg_rune_proc_blocks::export_proc_block! { | ||||||
| metadata: metadata, | ||||||
| proc_block: MostConfidentIndices, | ||||||
| } | ||||||
|
|
||||||
| fn graph(node_id: String) -> Result<(), GraphError> { | ||||||
| let ctx = GraphContext::for_node(&node_id) | ||||||
| .ok_or(GraphError::MissingContext)?; | ||||||
| fn metadata() -> Metadata { | ||||||
| Metadata::new("Most Confident Indices", env!("CARGO_PKG_VERSION")) | ||||||
| .with_description( | ||||||
| "Given some confidence values, create a tensor containing the indices of the top N highest confidences.", | ||||||
| ) | ||||||
| .with_repository(env!("CARGO_PKG_REPOSITORY")) | ||||||
| .with_homepage(env!("CARGO_PKG_HOMEPAGE")) | ||||||
| .with_tag("classify") | ||||||
| .with_tag("confidence") | ||||||
| .with_argument(ArgumentMetadata::new("count") | ||||||
| .with_default_value("1") | ||||||
| .with_description("The number of indices to return") | ||||||
| .with_hint(ArgumentType::Float)) | ||||||
| .with_input(TensorMetadata::new("confidences").with_description("A 1D tensor of numeric confidence values.")) | ||||||
| .with_output( | ||||||
| TensorMetadata::new("indices") | ||||||
| .with_description("The indices, in order of descending confidence."), | ||||||
| ) | ||||||
| } | ||||||
|
|
||||||
| let element_type = get_element_type(|n| ctx.get_argument(n)) | ||||||
| .map_err(GraphError::InvalidArgument)?; | ||||||
| let count = get_count(|n| ctx.get_argument(n)) | ||||||
| .map_err(GraphError::InvalidArgument)?; | ||||||
| struct MostConfidentIndices { | ||||||
| count: u32, | ||||||
| } | ||||||
|
|
||||||
| ctx.add_input_tensor( | ||||||
| "confidences", | ||||||
| element_type, | ||||||
| DimensionsParam::Dynamic, | ||||||
| ); | ||||||
| ctx.add_output_tensor( | ||||||
| "indices", | ||||||
| element_type, | ||||||
| DimensionsParam::Fixed(&[count]), | ||||||
| ); | ||||||
| impl ProcBlock for MostConfidentIndices { | ||||||
| fn tensor_constraints(&self) -> TensorConstraints { | ||||||
|
|
||||||
| let count = parse::required_arg(&vec![Argument{ name: "count".to_string(), value: self.count.to_string() }], "count").unwrap(); | ||||||
|
|
||||||
| Ok(()) | ||||||
| TensorConstraints { | ||||||
| inputs: vec![TensorConstraint::numeric( | ||||||
| "confidences", | ||||||
| Dimensions::Dynamic, | ||||||
| )], | ||||||
| outputs: vec![TensorConstraint::numeric("indices", vec![count])], | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| fn kernel(node_id: String) -> Result<(), KernelError> { | ||||||
| let ctx = KernelContext::for_node(&node_id) | ||||||
| .ok_or(KernelError::MissingContext)?; | ||||||
|
|
||||||
| let count = get_count(|n| ctx.get_argument(n)) | ||||||
| .map_err(KernelError::InvalidArgument)?; | ||||||
|
|
||||||
| let TensorResult { | ||||||
| element_type, | ||||||
| dimensions, | ||||||
| buffer, | ||||||
| } = ctx.get_input_tensor("input").ok_or_else(|| { | ||||||
| KernelError::InvalidInput(InvalidInput { | ||||||
| name: "indices".to_string(), | ||||||
| reason: BadInputReason::NotFound, | ||||||
| }) | ||||||
| })?; | ||||||
|
|
||||||
| let count = count as usize; | ||||||
|
|
||||||
| let indices = match element_type { | ||||||
| ElementType::U8 => preprocess_buffer::<u8>(&buffer, &dimensions) | ||||||
| .and_then(|t| most_confident_indices(t, count))?, | ||||||
| ElementType::I8 => preprocess_buffer::<i8>(&buffer, &dimensions) | ||||||
| .and_then(|t| most_confident_indices(t, count))?, | ||||||
| ElementType::U16 => preprocess_buffer::<u16>(&buffer, &dimensions) | ||||||
| .and_then(|t| most_confident_indices(t, count))?, | ||||||
| ElementType::I16 => preprocess_buffer::<i16>(&buffer, &dimensions) | ||||||
| .and_then(|t| most_confident_indices(t, count))?, | ||||||
| ElementType::U32 => preprocess_buffer::<u32>(&buffer, &dimensions) | ||||||
| .and_then(|t| most_confident_indices(t, count))?, | ||||||
| ElementType::I32 => preprocess_buffer::<i32>(&buffer, &dimensions) | ||||||
| .and_then(|t| most_confident_indices(t, count))?, | ||||||
| ElementType::F32 => preprocess_buffer::<f32>(&buffer, &dimensions) | ||||||
| .and_then(|t| most_confident_indices(t, count))?, | ||||||
| ElementType::U64 => preprocess_buffer::<u64>(&buffer, &dimensions) | ||||||
| .and_then(|t| most_confident_indices(t, count))?, | ||||||
| ElementType::I64 => preprocess_buffer::<i64>(&buffer, &dimensions) | ||||||
| .and_then(|t| most_confident_indices(t, count))?, | ||||||
| ElementType::F64 => preprocess_buffer::<f64>(&buffer, &dimensions) | ||||||
| .and_then(|t| most_confident_indices(t, count))?, | ||||||
| ElementType::Utf8 => { | ||||||
| unreachable!("Already checked by get_element_type()") | ||||||
| fn run(&self, inputs: Vec<Tensor>) -> Result<Vec<Tensor>, RunError> { | ||||||
| let tensor = Tensor::get_named(&inputs, "confidences")?; | ||||||
|
|
||||||
| let count = parse::required_arg(&vec![Argument{ name: "count".to_string(), value: self.count.to_string() }], "count").unwrap(); | ||||||
|
|
||||||
| let indices = match tensor.element_type { | ||||||
| ElementType::U8 => most_confident_indices(tensor.view_1d::<u8>()?, count)?, | ||||||
| ElementType::I8 => most_confident_indices(tensor.view_1d::<i8>()?, count)?, | ||||||
| ElementType::U16 => most_confident_indices(tensor.view_1d::<u16>()?, count)?, | ||||||
| ElementType::I16 =>most_confident_indices(tensor.view_1d::<i16>()?, count)?, | ||||||
| ElementType::U32 => most_confident_indices(tensor.view_1d::<u32>()?, count)?, | ||||||
| ElementType::I32 => most_confident_indices(tensor.view_1d::<i32>()?, count)?, | ||||||
| ElementType::F32 => most_confident_indices(tensor.view_1d::<f32>()?, count)?, | ||||||
| ElementType::U64 => most_confident_indices(tensor.view_1d::<u64>()?, count)?, | ||||||
| ElementType::I64 => most_confident_indices(tensor.view_1d::<i64>()?, count)?, | ||||||
| ElementType::F64 => most_confident_indices(tensor.view_1d::<f64>()?, count)?, | ||||||
| _ => { | ||||||
| return Err(InvalidInput::incompatible_element_type( | ||||||
| "confidences", | ||||||
| ) | ||||||
| .into()); | ||||||
| }, | ||||||
| }; | ||||||
|
|
||||||
| ctx.set_output_tensor( | ||||||
| "indices", | ||||||
| TensorParam { | ||||||
| dimensions: &dimensions, | ||||||
| element_type: ElementType::U32, | ||||||
| buffer: indices.as_bytes(), | ||||||
| }, | ||||||
| ); | ||||||
|
|
||||||
| Ok(()) | ||||||
| Ok(vec![Tensor::new_1d("indices", &indices.to_vec())]) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
Footnotes
|
||||||
| } | ||||||
| } | ||||||
|
|
||||||
| fn preprocess_buffer<'buf, T>( | ||||||
| buffer: &'buf [u8], | ||||||
| dimensions: &[u32], | ||||||
| ) -> Result<ArrayView1<'buf, T>, KernelError> | ||||||
| where | ||||||
| T: ValueType, | ||||||
| { | ||||||
| buffer | ||||||
| .view::<T>(dimensions) | ||||||
| .and_then(|t| t.into_dimensionality()) | ||||||
| .map_err(|e| { | ||||||
| KernelError::InvalidInput(InvalidInput::invalid_value( | ||||||
| "confidences", | ||||||
| e, | ||||||
| )) | ||||||
| }) | ||||||
| } | ||||||
|
|
||||||
| fn most_confident_indices<T>( | ||||||
| tensor: ArrayView1<T>, | ||||||
| count: usize, | ||||||
| ) -> Result<Vec<u32>, KernelError> | ||||||
| ) -> Result<Vec<u32>, RunError> | ||||||
| where | ||||||
| T: PartialOrd + Copy, | ||||||
| { | ||||||
| if count > tensor.len() { | ||||||
| return Err(KernelError::InvalidArgument( | ||||||
| InvalidArgument::invalid_value( | ||||||
| "count", | ||||||
| format!( | ||||||
| "Requesting {} indices from a tensor with only {} elements", | ||||||
| count, | ||||||
| tensor.len() | ||||||
| ), | ||||||
| ), | ||||||
| )); | ||||||
| return Err(RunError::other(format!( | ||||||
| "Requesting {} indices from a tensor with only {} elements", | ||||||
| count, | ||||||
| tensor.len() | ||||||
| ))); | ||||||
| } | ||||||
|
|
||||||
| let mut indices_and_confidence: Vec<_> = | ||||||
|
|
@@ -204,59 +108,15 @@ where | |||||
| .collect()) | ||||||
| } | ||||||
|
|
||||||
| fn get_count( | ||||||
| get_argument: impl FnOnce(&str) -> Option<String>, | ||||||
| ) -> Result<u32, InvalidArgument> { | ||||||
| get_argument("count") | ||||||
| .ok_or_else(|| InvalidArgument::not_found("count"))? | ||||||
| .parse::<u32>() | ||||||
| .map_err(|e| InvalidArgument::invalid_value("count", e)) | ||||||
| } | ||||||
|
|
||||||
| fn get_element_type( | ||||||
| get_argument: impl FnOnce(&str) -> Option<String>, | ||||||
| ) -> Result<ElementType, InvalidArgument> { | ||||||
| match get_argument("element_type").as_deref() { | ||||||
| Some("u8") => Ok(ElementType::U8), | ||||||
| Some("i8") => Ok(ElementType::I8), | ||||||
| Some("u16") => Ok(ElementType::U16), | ||||||
| Some("i16") => Ok(ElementType::I16), | ||||||
| Some("u32") => Ok(ElementType::U32), | ||||||
| Some("i32") => Ok(ElementType::I32), | ||||||
| Some("f32") => Ok(ElementType::F32), | ||||||
| Some("u64") => Ok(ElementType::U64), | ||||||
| Some("i64") => Ok(ElementType::I64), | ||||||
| Some("f64") => Ok(ElementType::F64), | ||||||
| Some(other) => Err(InvalidArgument::invalid_value( | ||||||
| "element_type", | ||||||
| format!("Unsupported element type: {}", other), | ||||||
| )), | ||||||
| None => Err(InvalidArgument::not_found("element_type")), | ||||||
| } | ||||||
| } | ||||||
| impl TryFrom<Vec<Argument>> for MostConfidentIndices { | ||||||
| type Error = CreateError; | ||||||
|
|
||||||
| impl InvalidArgument { | ||||||
| fn not_found(name: impl Into<String>) -> Self { | ||||||
| InvalidArgument { | ||||||
| name: name.into(), | ||||||
| reason: BadArgumentReason::NotFound, | ||||||
| } | ||||||
| } | ||||||
| fn try_from(args: Vec<Argument>) -> Result<Self, Self::Error> { | ||||||
| let count = | ||||||
| hotg_rune_proc_blocks::guest::parse::optional_arg(&args, "count")? | ||||||
| .unwrap_or(1); | ||||||
|
|
||||||
| fn invalid_value(name: impl Into<String>, reason: impl Display) -> Self { | ||||||
| InvalidArgument { | ||||||
| name: name.into(), | ||||||
| reason: BadArgumentReason::InvalidValue(reason.to_string()), | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| impl InvalidInput { | ||||||
| fn invalid_value(name: impl Into<String>, reason: impl Display) -> Self { | ||||||
| InvalidInput { | ||||||
| name: name.into(), | ||||||
| reason: BadInputReason::InvalidValue(reason.to_string()), | ||||||
| } | ||||||
| Ok(MostConfidentIndices { count }) | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
|
|
@@ -266,30 +126,24 @@ mod tests { | |||||
| use hotg_rune_proc_blocks::ndarray; | ||||||
|
|
||||||
| #[test] | ||||||
| fn only_works_with_1d() { | ||||||
| let buffer = [1, 2, 3, 4, 5, 6]; | ||||||
| fn get_top_3_values_from_2_d() { | ||||||
| let elements = ndarray::arr2(&[[0.0, 0.5, 10.0, 3.5, -200.0]]); | ||||||
|
|
||||||
| let error = preprocess_buffer::<u8>(&buffer, &[2, 3]).unwrap_err(); | ||||||
| let tensor = Tensor::new("confidences", &elements); | ||||||
|
|
||||||
|
|
||||||
| assert!(matches!(error, KernelError::InvalidInput(_))); | ||||||
| } | ||||||
| let got = most_confident_indices(tensor.view_1d::<f64>().unwrap(), 3).unwrap(); | ||||||
|
|
||||||
| #[test] | ||||||
| fn tensors_equivalent_to_1d_are_okay_too() { | ||||||
| let buffer = [1, 2, 3, 4, 5, 6]; | ||||||
|
|
||||||
| let error = preprocess_buffer::<u8>(&buffer, &[1, 6, 1]).unwrap_err(); | ||||||
|
|
||||||
| assert!(matches!(error, KernelError::InvalidInput(_))); | ||||||
| assert_eq!(got, &[2, 3, 1]); | ||||||
| } | ||||||
|
|
||||||
| #[test] | ||||||
| fn count_must_be_less_than_input_size() { | ||||||
| let input = ndarray::arr1(&[1_u32, 2, 3]); | ||||||
|
|
||||||
| let error = most_confident_indices(input.view(), 42).unwrap_err(); | ||||||
| most_confident_indices(input.view(), 42).unwrap_err(); | ||||||
|
|
||||||
| assert!(matches!(error, KernelError::InvalidArgument(_))); | ||||||
| // assert!(matches!(error, RunError)); | ||||||
| } | ||||||
|
|
||||||
| #[test] | ||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line should be deleted because we get
countwhen first using theTryFrom<Vec<Argument>>to construct ourMostConfidentIndicesobject.