-
-
Notifications
You must be signed in to change notification settings - Fork 8
Description
User Story
As a researcher, I want to use the SEAL algorithm to train a model that can improve its own learning process by leveraging unlabeled data within a few-shot task, so that I can achieve higher accuracy than traditional meta-learning methods.
Dependencies
- [Episodic Data Abstractions for Meta-Learning (N-way K-shot) #290: Episodic Data Abstractions]: This issue depends on the creation of an
EpisodicDataLoaderthat can provide tasks where the query set is significantly larger than the support set and can be treated as unlabeled.
Phase 1: SEALTrainer Implementation
Goal: Create the main trainer class and implement the full SEAL algorithm, which combines self-supervision, active learning, and meta-learning.
AC 1.1: Scaffolding SEALTrainer<T> (3 points)
Requirement: Create the main class structure for the trainer.
- Create a new file:
src/Models/Meta/SEALTrainer.cs. - Define a
public class SEALTrainer<T>. - The constructor must accept:
IModel<T> metaModel: The initial model to be meta-trained.IOptimizer<T> metaOptimizer: The optimizer for the outer loop meta-updates.ILossFunction<T> supervisedLoss: The loss function for the final supervised fine-tuning.ISelfSupervisedLoss<T> selfSupervisedLoss: A loss for the pre-training phase (e.g., a loss for rotation prediction).int selfSupervisedSteps: Number of pre-training steps on the query set.int supervisedSteps: Number of fine-tuning steps on the support/pseudo-labeled set.int activeLearningSelectionCount: The number of query examples to select and pseudo-label.
AC 1.2: Implement the Train Method (13 points)
Requirement: Implement the main training method containing the full, multi-stage SEAL algorithm.
- Define a public method:
public void Train(EpisodicDataLoader<T> dataLoader, int metaIterations). - Outer Loop: Implement the meta-training loop:
for (int i = 0; i < metaIterations; i++) - Inside the Outer Loop, perform the following steps in sequence:
- 1. Task Sampling: Get a new task from the
dataLoader. Treat theQuerySetas unlabeled for the initial phases. - 2. Clone Model: Create a deep copy of the
metaModel's weights (taskModel). - 3. Self-Supervised Pre-training (on Query Set):
- Create an inner-loop optimizer for the
taskModel. - Loop for
selfSupervisedSteps:- Create a self-supervised batch from the
QuerySetdata. For example, for images, create a new tensor of randomly rotated images and a corresponding tensor of rotation labels (e.g., 0=0°, 1=90°, 2=180°, 3=270°). - Perform a training step on the
taskModelusing this self-supervised batch and theselfSupervisedLoss.
- Create a self-supervised batch from the
- Create an inner-loop optimizer for the
- 4. Active Learning & Pseudo-Labeling (on Query Set):
- Use the
taskModel(which has now been pre-trained) to make predictions on the original, non-augmentedQuerySetdata. - Based on an acquisition function (e.g., highest prediction probability, i.e., confidence), select the top
activeLearningSelectionCountexamples from theQuerySet. - Create a new pseudo-labeled dataset by pairing these selected examples with their predicted labels from the model.
- Use the
- 5. Supervised Fine-tuning:
- Combine the original, small, human-labeled
SupportSetwith the new pseudo-labeled dataset. - Loop for
supervisedSteps:- Perform a standard training step on the
taskModelusing the combined labeled data and thesupervisedLoss.
- Perform a standard training step on the
- Combine the original, small, human-labeled
- 6. Meta-Update (Reptile-style):
- Calculate the difference between the initial
metaModelweights and the finaltaskModelweights after all training phases. - Use the
metaOptimizerto update themetaModel's weights, moving them a fraction of the way towards thetaskModel's weights.
- Calculate the difference between the initial
- 1. Task Sampling: Get a new task from the
Phase 2: Validation and Testing
Goal: Verify that the SEAL implementation is correct and improves model performance.
AC 2.1: Unit / Smoke Test (3 points)
Requirement: Create a test to ensure the algorithm runs end-to-end without errors.
- Create a new test file:
tests/UnitTests/Meta/SEALTrainerTests.cs. - Create a smoke test that initializes the
SEALTrainerwith mock components. - Run the
Trainmethod for a single meta-iteration (metaIterations = 1). - Assert that the method completes without throwing any exceptions and that the
metaModel's weights have been updated (i.e., they are not NaN and not equal to the initial weights).
AC 2.2: Integration Test (8 points)
Requirement: Create an integration test on a synthetic problem to prove meta-learning is occurring.
- Synthetic Data: Create a synthetic few-shot image classification task (e.g., classifying rotated MNIST digits) where the support set is very small (e.g., 1-shot) and the query set is large and unlabeled.
- Test Setup:
- Instantiate a simple CNN as the
metaModel. - Instantiate the
SEALTrainerwith a rotation prediction loss for the self-supervised phase.
- Instantiate a simple CNN as the
- Test Logic:
- Evaluate the initial
metaModelon a set of unseen test tasks and store the baseline accuracy. - Run the
SEALTrainer.Train()method for a significant number of meta-iterations. - Evaluate the trained
metaModelon the same set of unseen test tasks. - Assert that the accuracy after meta-training is significantly higher than the baseline accuracy.
- Evaluate the initial
Definition of Done
- All checklist items are complete.
- The
SEALTrainercorrectly implements the self-supervised, active learning, fine-tuning, and meta-update steps. - The integration test demonstrates that SEAL can successfully meta-learn a solution to a relevant few-shot problem.
- All new code meets the project's >= 90% test coverage requirement.
⚠️ CRITICAL ARCHITECTURAL REQUIREMENTS
Before implementing this user story, you MUST review:
- 📋 Full Requirements:
.github/USER_STORY_ARCHITECTURAL_REQUIREMENTS.md - 📐 Project Rules:
.github/PROJECT_RULES.md
Mandatory Implementation Checklist
1. INumericOperations Usage (CRITICAL)
- Include
protected static readonly INumericOperations<T> NumOps = MathHelper.GetNumericOperations<T>();in base class - NEVER hardcode
double,float, or specific numeric types - use genericT - NEVER use
default(T)- useNumOps.Zeroinstead - Use
NumOps.Zero,NumOps.One,NumOps.FromDouble()for values - Use
NumOps.Add(),NumOps.Multiply(), etc. for arithmetic - Use
NumOps.LessThan(),NumOps.GreaterThan(), etc. for comparisons
2. Inheritance Pattern (REQUIRED)
- Create
I{FeatureName}.csinsrc/Interfaces/(root level, NOT subfolders) - Create
{FeatureName}Base.csinsrc/{FeatureArea}/inheriting from interface - Create concrete classes inheriting from Base class (NOT directly from interface)
3. PredictionModelBuilder Integration (REQUIRED)
- Add private field:
private I{FeatureName}<T>? _{featureName};toPredictionModelBuilder.cs - Add Configure method taking ONLY interface (no parameters):
public IPredictionModelBuilder<T, TInput, TOutput> Configure{FeatureName}(I{FeatureName}<T> {featureName}) { _{featureName} = {featureName}; return this; }
- Use feature in
Build()with default:var {featureName} = _{featureName} ?? new Default{FeatureName}<T>(); - Verify feature is ACTUALLY USED in execution flow
4. Beginner-Friendly Defaults (REQUIRED)
- Constructor parameters with defaults from research/industry standards
- Document WHY each default was chosen (cite papers/standards)
- Validate parameters and throw
ArgumentExceptionfor invalid values
5. Property Initialization (CRITICAL)
- NEVER use
default!operator - String properties:
= string.Empty; - Collections:
= new List<T>();or= new Vector<T>(0); - Numeric properties: appropriate default or
NumOps.Zero
6. Class Organization (REQUIRED)
- One class/enum/interface per file
- ALL interfaces in
src/Interfaces/(root level) - Namespace mirrors folder structure (e.g.,
src/Regularization/→namespace AiDotNet.Regularization)
7. Documentation (REQUIRED)
- XML documentation for all public members
-
<b>For Beginners:</b>sections with analogies and examples - Document all
<param>,<returns>,<exception>tags - Explain default value choices
8. Testing (REQUIRED)
- Minimum 80% code coverage
- Test with multiple numeric types (double, float)
- Test default values are applied correctly
- Test edge cases and exceptions
- Integration tests for PredictionModelBuilder usage
See full details: .github/USER_STORY_ARCHITECTURAL_REQUIREMENTS.md