-
-
Notifications
You must be signed in to change notification settings - Fork 7
Description
User Story
As a machine learning practitioner, I want access to a broader suite of model interpretability techniques, including SHAP values and permutation feature importance, so that I can better understand how my models make predictions, identify key influencing factors, and build trust in complex black-box models.
Problem: Missing Key Model Interpretability Techniques
The src/Interpretability module is strong in fairness and bias detection, and includes LIME and Partial Dependence Plots. However, several other fundamental and widely-used interpretability techniques, crucial for understanding complex models, are currently missing.
Phase 1: Implement SHAP (SHapley Additive exPlanations)
Goal: Provide a robust, game-theory based method for explaining individual predictions.
AC 1.1: Create SHAPExplainer.cs (18 points)
Requirement: Implement a SHAP explainer for model-agnostic explanations.
- File:
src/Interpretability/SHAPExplainer.cs - Class:
public class SHAPExplainer<T> - Constructor: Takes an
IModel<T>(black-box model) and aMatrix<T>(background dataset). - Method:
public Vector<T> Explain(Vector<T> instance): Returns SHAP values for a single instance. - Logic: Implement the core SHAP algorithm, potentially using approximations for efficiency.
AC 1.2: Unit Tests for SHAPExplainer (10 points)
Requirement: Verify the correctness of the SHAP explainer.
- File:
tests/UnitTests/Interpretability/SHAPExplainerTests.cs - Test Cases: Test with a simple model and dataset, ensuring SHAP values sum to the difference between the instance's prediction and the baseline prediction.
Phase 2: Implement Permutation Feature Importance
Goal: Provide a model-agnostic method to assess the importance of features.
AC 2.1: Create PermutationFeatureImportance.cs (13 points)
Requirement: Implement Permutation Feature Importance.
- File:
src/Interpretability/PermutationFeatureImportance.cs - Class:
public class PermutationFeatureImportance<T> - Constructor: Takes an
IModel<T>(model) and anIMetric<T>(evaluation metric). - Method:
public Dictionary<int, double> Calculate(Matrix<T> X, Vector<T> y): Returns a dictionary of feature index to importance score. - Logic: Measures the decrease in model score when a feature's values are randomly shuffled.
AC 2.2: Unit Tests for PermutationFeatureImportance (8 points)
Requirement: Verify the correctness of Permutation Feature Importance.
- File:
tests/UnitTests/Interpretability/PermutationFeatureImportanceTests.cs - Test Cases: Test with a simple model and dataset, ensuring more important features have higher scores.
Phase 3: Implement Global Surrogate Models
Goal: Provide a method to explain a complex model's global behavior using a simpler, interpretable model.
AC 3.1: Create GlobalSurrogateExplainer.cs (13 points)
Requirement: Implement a Global Surrogate Model explainer.
- File:
src/Interpretability/GlobalSurrogateExplainer.cs - Class:
public class GlobalSurrogateExplainer<T> - Constructor: Takes an
IModel<T>(black-box model) and anIModel<T>(interpretable surrogate model, e.g., Decision Tree). - Method:
public void Fit(Matrix<T> X): Trains the surrogate model on predictions of the black-box model. - Method:
public IModel<T> GetSurrogateModel(): Returns the trained surrogate model.
AC 3.2: Unit Tests for GlobalSurrogateExplainer (8 points)
Requirement: Verify the correctness of the Global Surrogate Explainer.
- File:
tests/UnitTests/Interpretability/GlobalSurrogateExplainerTests.cs - Test Cases: Test with a complex model and a simple surrogate, ensuring the surrogate can approximate the complex model's behavior.
Definition of Done
- All checklist items are complete.
-
SHAPExplainer,PermutationFeatureImportance, andGlobalSurrogateExplainerare implemented and unit-tested. - All new tests pass.
⚠️ CRITICAL ARCHITECTURAL REQUIREMENTS
Before implementing this user story, you MUST review:
- 📋 Full Requirements:
.github/USER_STORY_ARCHITECTURAL_REQUIREMENTS.md - 📐 Project Rules:
.github/PROJECT_RULES.md
Mandatory Implementation Checklist
1. INumericOperations Usage (CRITICAL)
- Include
protected static readonly INumericOperations<T> NumOps = MathHelper.GetNumericOperations<T>();in base class - NEVER hardcode
double,float, or specific numeric types - use genericT - NEVER use
default(T)- useNumOps.Zeroinstead - Use
NumOps.Zero,NumOps.One,NumOps.FromDouble()for values - Use
NumOps.Add(),NumOps.Multiply(), etc. for arithmetic - Use
NumOps.LessThan(),NumOps.GreaterThan(), etc. for comparisons
2. Inheritance Pattern (REQUIRED)
- Create
I{FeatureName}.csinsrc/Interfaces/(root level, NOT subfolders) - Create
{FeatureName}Base.csinsrc/{FeatureArea}/inheriting from interface - Create concrete classes inheriting from Base class (NOT directly from interface)
3. PredictionModelBuilder Integration (REQUIRED)
- Add private field:
private I{FeatureName}<T>? _{featureName};toPredictionModelBuilder.cs - Add Configure method taking ONLY interface (no parameters):
public IPredictionModelBuilder<T, TInput, TOutput> Configure{FeatureName}(I{FeatureName}<T> {featureName}) { _{featureName} = {featureName}; return this; }
- Use feature in
Build()with default:var {featureName} = _{featureName} ?? new Default{FeatureName}<T>(); - Verify feature is ACTUALLY USED in execution flow
4. Beginner-Friendly Defaults (REQUIRED)
- Constructor parameters with defaults from research/industry standards
- Document WHY each default was chosen (cite papers/standards)
- Validate parameters and throw
ArgumentExceptionfor invalid values
5. Property Initialization (CRITICAL)
- NEVER use
default!operator - String properties:
= string.Empty; - Collections:
= new List<T>();or= new Vector<T>(0); - Numeric properties: appropriate default or
NumOps.Zero
6. Class Organization (REQUIRED)
- One class/enum/interface per file
- ALL interfaces in
src/Interfaces/(root level) - Namespace mirrors folder structure (e.g.,
src/Regularization/→namespace AiDotNet.Regularization)
7. Documentation (REQUIRED)
- XML documentation for all public members
-
<b>For Beginners:</b>sections with analogies and examples - Document all
<param>,<returns>,<exception>tags - Explain default value choices
8. Testing (REQUIRED)
- Minimum 80% code coverage
- Test with multiple numeric types (double, float)
- Test default values are applied correctly
- Test edge cases and exceptions
- Integration tests for PredictionModelBuilder usage
See full details: .github/USER_STORY_ARCHITECTURAL_REQUIREMENTS.md