Skip to content

Commit 0905307

Browse files
AI models generalization
The updated `AIGaussianCheck` and `AIGLMCheck` demonstrate that the models are compatible with any API. If the API includes complex endpoints with unsupported genes, the classifier will return `null`, and testing will fall back to random sampling. The petShop API has been updated to include such an endpoint for validation purposes. In `AIResponseClassification`, the term `probability` has been replaced by `score`, which is a more general term representing the likelihood of occurrence depending on the selected model.
1 parent 2d82dc7 commit 0905307

File tree

7 files changed

+190
-93
lines changed

7 files changed

+190
-93
lines changed

core-it/src/main/kotlin/bar/examples/it/spring/aiconstraint/numeric/AICMultiTypeApplication.kt

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ open class AICMultiTypeApplication {
3333
}
3434

3535
@GetMapping("/petInfo")
36-
open fun getString(
36+
open fun getPetInfo(
3737

3838
@RequestParam("category", required = true)
3939
@Parameter(required = true, description = "Pet category")
@@ -93,6 +93,11 @@ open class AICMultiTypeApplication {
9393

9494
@GetMapping("/ownerInfo")
9595
open fun getOwnerInfo(
96+
97+
@RequestParam("name", required = true)
98+
@Parameter(required = true, description = "Owner's name")
99+
name: String,
100+
96101
@RequestParam("id", required = true)
97102
@Parameter(required = true, description = "Owner's id")
98103
id: Int,
@@ -117,4 +122,61 @@ open class AICMultiTypeApplication {
117122

118123
}
119124

125+
@PostMapping("/petInfo")
126+
open fun createPet(
127+
@RequestParam("category", required = true)
128+
@Parameter(required = true, description = "Pet category")
129+
category: Category,
130+
131+
@RequestParam("gender", required = true)
132+
@Parameter(required = true, description = "Pet gender")
133+
gender: Gender,
134+
135+
@RequestParam("birthYear", required = true)
136+
@Parameter(
137+
required = true,
138+
description = "Pet birth year"
139+
)
140+
birthYear: Int,
141+
142+
@RequestParam("vaccinationYear", required = true)
143+
@Parameter(
144+
required = true,
145+
description = "Pet vaccination year"
146+
)
147+
vaccinationYear: Int,
148+
149+
@RequestParam("isAlive", required = true)
150+
@Parameter(required = true, description = "Pet is alive?")
151+
isAlive: Boolean,
152+
153+
@RequestParam("weight", required = true)
154+
@Parameter(
155+
required = true,
156+
description = "Pet weight"
157+
)
158+
weight: Double
159+
): ResponseEntity<String> {
160+
161+
// Validation
162+
if (birthYear <= 0) {
163+
return ResponseEntity.status(400).body("Birth year must be a positive number.")
164+
}
165+
if (weight <= 0) {
166+
return ResponseEntity.status(400).body("Weight must be a positive number.")
167+
}
168+
if (vaccinationYear <= 0) {
169+
return ResponseEntity.status(400).body("Vaccination year must be a positive number.")
170+
}
171+
if (vaccinationYear < birthYear) {
172+
return ResponseEntity.status(400).body("Vaccination year cannot be earlier than birth year.")
173+
}
174+
175+
// Response
176+
return ResponseEntity.status(200).body(
177+
"Pet created successfully: Birth Year = $birthYear, Vaccination Year = $vaccinationYear, Gender = $gender, Is Alive = $isAlive, Weight = $weight, Category = $category"
178+
)
179+
}
180+
181+
120182
}

core-it/src/test/kotlin/org/evomaster/core/problem/rest/aiconstraint/numeric/AIGLMCheck.kt

Lines changed: 50 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import org.evomaster.core.problem.enterprise.SampleType
55
import org.evomaster.core.problem.rest.IntegrationTestRestBase
66
import org.evomaster.core.problem.rest.data.RestCallAction
77
import org.evomaster.core.problem.rest.data.RestCallResult
8-
import org.evomaster.core.problem.rest.service.AIResponseClassifier
98
import org.evomaster.core.search.gene.BooleanGene
109
import org.evomaster.core.search.gene.collection.EnumGene
1110
import org.evomaster.core.search.gene.numeric.DoubleGene
@@ -14,7 +13,6 @@ import org.evomaster.core.problem.rest.builder.RestActionBuilderV3
1413
import org.evomaster.core.problem.rest.schema.RestSchema
1514
import org.evomaster.core.EMConfig
1615
import org.evomaster.core.problem.rest.classifier.GLMOnlineClassifier
17-
import org.evomaster.core.problem.rest.data.RestPath
1816
import org.evomaster.core.problem.rest.schema.OpenApiAccess
1917
import org.evomaster.core.problem.rest.service.sampler.AbstractRestSampler
2018
import org.evomaster.core.search.action.Action
@@ -89,83 +87,102 @@ class AIGLMCheck : IntegrationTestRestBase() {
8987
probRestExamples = 0.0
9088
}
9189

92-
9390
val options = RestActionBuilderV3.Options(config)
9491
val actionCluster = mutableMapOf<String, Action>()
9592
RestActionBuilderV3.addActionsFromSwagger(restSchema, actionCluster, options = options)
9693

9794
val actionList = actionCluster.values.filterIsInstance<RestCallAction>()
9895

99-
val pathToDimension = mutableMapOf<RestPath, Int>()
96+
val endpointToDimension = mutableMapOf<String, Int?>()
10097
for (action in actionList) {
101-
val path = action.path
102-
if (pathToDimension.containsKey(path)) continue
98+
val name = action.getName()
10399

104-
val dimension = action.parameters.count { p ->
100+
val hasUnsupportedGene = action.parameters.any { p ->
105101
val g = p.gene
106-
g is IntegerGene || g is DoubleGene || g is BooleanGene || g is EnumGene<*>
102+
g !is IntegerGene && g !is DoubleGene && g !is BooleanGene && g !is EnumGene<*>
103+
}
104+
105+
val dimension = if (hasUnsupportedGene) {
106+
null
107+
} else {
108+
action.parameters.count { p ->
109+
val g = p.gene
110+
g is IntegerGene || g is DoubleGene || g is BooleanGene || g is EnumGene<*>
111+
}
107112
}
108-
pathToDimension[path] = dimension
109-
}
110113

111-
val pathToClassifier = mutableMapOf<RestPath, GLMOnlineClassifier>()
112-
for ((path, dimension) in pathToDimension) {
113-
val model = GLMOnlineClassifier()
114-
model.setDimension(dimension)
115-
pathToClassifier[path] = model
114+
println("Endpoint: $name, dimension: $dimension")
115+
endpointToDimension[name] = dimension
116116
}
117117

118-
println("Classifiers initialized with their dimensions:")
119-
for ((path, expected) in pathToDimension) {
120-
val classifier = pathToClassifier[path]!!
121-
println("$path -> expected: $expected, actualDim: ${classifier.getDimension()}")
118+
/**
119+
* Initialize a classifier for each endpoint
120+
* For an endpoint containing unsupported genes, the associated classifier is null
121+
*/
122+
val endpointToClassifier = mutableMapOf<String, GLMOnlineClassifier?>()
123+
for ((name, dimension) in endpointToDimension) {
124+
if(dimension==null){
125+
endpointToClassifier[name] = null
126+
}else{
127+
val model = GLMOnlineClassifier()
128+
model.setDimension(dimension)
129+
endpointToClassifier[name] = model
130+
}
122131
}
123132

133+
for ((name, expectedDimension) in endpointToDimension) {
134+
println("Expected dimension for $name: $expectedDimension")
135+
}
124136

137+
// Execute the procedure for a period of time
125138
val random = Randomness()
126139
val sampler = injector.getInstance(AbstractRestSampler::class.java)
127140
var time = 1
128-
val timeLimit = 20
141+
val timeLimit = 200
129142
while (time <= timeLimit) {
130143
val template = random.choose(actionList)
131144
val sampledAction = template.copy() as RestCallAction
132145
sampledAction.doInitialize(random)
133146

134-
val path = sampledAction.path
135-
val dimension = pathToDimension[path] ?: error("No dimension for path: $path")
136-
val classifier = pathToClassifier[path] ?: error("Expected classifier for path: $path")
147+
val name = sampledAction.getName()
148+
val classifier = endpointToClassifier[name]
149+
val dimension = endpointToDimension[name]
137150
val geneValues = sampledAction.parameters.map { it.gene.getValueAsRawString() }
138151

139-
println("*************************************************")
152+
println("**********************************************")
140153
println("Time : $time")
141-
println("Path : $path")
154+
println("Path : $name")
155+
println("Classifier : ${if (classifier == null) "null" else "GLM"}")
156+
println("Dimension : $dimension")
142157
println("Input Genes : ${geneValues.joinToString(", ")}")
143-
println("Input dim : ${classifier.getDimension()}")
144-
println("Expected Dim : $dimension")
145158
println("Actual Genes : ${geneValues.size}")
146159

147-
// //executeRestCallAction is replaced with createIndividual to avoid override error
160+
// executeRestCallAction is replaced with createIndividual to avoid override error
148161
// val individual = createIndividual(listOf(sampledAction), SampleType.RANDOM)
149162
val individual = sampler.createIndividual(SampleType.RANDOM, listOf(sampledAction).toMutableList())
150163
val action = individual.seeMainExecutableActions()[0]
151164
val result = executeRestCallAction(action,"$baseUrlOfSut")
152-
println("Response:\n${result.getStatusCode()}")
165+
println("Response : ${result.getStatusCode()}")
153166

167+
// Skip classification for the endpoints with unsupported genes
168+
if (classifier==null){
169+
println("No classification as the classifier is null, i.e., the endpoint contains unsupported genes")
170+
continue
171+
}
154172

155173
// Update and classify
156174
classifier.updateModel(action, result)
157175
val classification = classifier.classify(action)
158176

159-
println("Probabilities: ${classification.probabilities}")
160-
require(classification.probabilities.values.all { it in 0.0..1.0 }) {
161-
"All probabilities must be in [0,1]"
177+
println("Probabilities: ${classification.scores}")
178+
require(classification.scores.values.all { it in 0.0..1.0 }) {
179+
"In GLM, scores must be in [0,1]"
162180
}
163181

164182
if (classifier != null) {
165183
val weightsAndBias = classifier.getModelParams()
166184
println("Weights and Bias = $weightsAndBias")
167185
println("**********************************************")
168-
println("**********************************************")
169186
} else {
170187
println("The classifier is not a GLMOnlineClassifier")
171188
}

core-it/src/test/kotlin/org/evomaster/core/problem/rest/aiconstraint/numeric/AIGaussianCheck.kt

Lines changed: 50 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import org.evomaster.core.problem.rest.builder.RestActionBuilderV3
1313
import org.evomaster.core.problem.rest.schema.RestSchema
1414
import org.evomaster.core.EMConfig
1515
import org.evomaster.core.problem.rest.classifier.GaussianOnlineClassifier
16-
import org.evomaster.core.problem.rest.data.RestPath
1716
import org.evomaster.core.problem.rest.schema.OpenApiAccess
1817
import org.evomaster.core.problem.rest.service.sampler.AbstractRestSampler
1918
import org.evomaster.core.search.action.Action
@@ -71,7 +70,6 @@ class AIGaussianCheck : IntegrationTestRestBase() {
7170
return result
7271
}
7372

74-
7573
fun runClassifierExample() {
7674
val schema = OpenApiAccess.getOpenAPIFromLocation("$baseUrlOfSut/v3/api-docs")
7775
val restSchema = RestSchema(schema)
@@ -84,88 +82,96 @@ class AIGaussianCheck : IntegrationTestRestBase() {
8482
probRestExamples = 0.0
8583
}
8684

87-
8885
val options = RestActionBuilderV3.Options(config)
8986
val actionCluster = mutableMapOf<String, Action>()
9087
RestActionBuilderV3.addActionsFromSwagger(restSchema, actionCluster, options = options)
9188

9289
val actionList = actionCluster.values.filterIsInstance<RestCallAction>()
9390

94-
val pathToDimension = mutableMapOf<RestPath, Int>()
91+
val endpointToDimension = mutableMapOf<String, Int?>()
9592
for (action in actionList) {
96-
val path = action.path
97-
if (pathToDimension.containsKey(path)) continue
93+
val name = action.getName()
9894

99-
val dimension = action.parameters.count { p ->
95+
val hasUnsupportedGene = action.parameters.any { p ->
10096
val g = p.gene
101-
g is IntegerGene || g is DoubleGene || g is BooleanGene || g is EnumGene<*>
97+
g !is IntegerGene && g !is DoubleGene && g !is BooleanGene && g !is EnumGene<*>
98+
}
99+
100+
val dimension = if (hasUnsupportedGene) {
101+
null
102+
} else {
103+
action.parameters.count { p ->
104+
val g = p.gene
105+
g is IntegerGene || g is DoubleGene || g is BooleanGene || g is EnumGene<*>
106+
}
102107
}
103-
pathToDimension[path] = dimension
104-
}
105108

106-
val pathToClassifier = mutableMapOf<RestPath, GaussianOnlineClassifier>()
107-
for ((path, dimension) in pathToDimension) {
108-
val model = GaussianOnlineClassifier()
109-
model.setDimension(dimension)
110-
pathToClassifier[path] = model
109+
println("Endpoint: $name, dimension: $dimension")
110+
endpointToDimension[name] = dimension
111111
}
112112

113-
println("Classifiers initialized with their dimensions:")
114-
for ((path, expected) in pathToDimension) {
115-
val classifier = pathToClassifier[path]!!
116-
println("$path -> expected: $expected, actualDim: ${classifier.getDimension()}")
113+
/**
114+
* Initialize a classifier for each endpoint
115+
* For an endpoint containing unsupported genes, the associated classifier is null
116+
*/
117+
val endpointToClassifier = mutableMapOf<String, GaussianOnlineClassifier?>()
118+
for ((name, dimension) in endpointToDimension) {
119+
if(dimension==null){
120+
endpointToClassifier[name] = null
121+
}else{
122+
val model = GaussianOnlineClassifier()
123+
model.setDimension(dimension)
124+
endpointToClassifier[name] = model
125+
}
117126
}
118127

119-
// val random = Randomness()
120-
// val sampler = injector.getInstance(AbstractRestSampler::class.java)
121-
// val template = random.choose(actionList)
122-
// val sampledAction = template.copy() as RestCallAction
123-
// sampledAction.doInitialize(random)
124-
//
125-
// // createIndividual doesn't work!
126-
// // val individual = createIndividual(listOf(sampledAction), SampleType.RANDOM)
127-
//
128-
// val individual = sampler.createIndividual(SampleType.RANDOM, listOf(sampledAction).toMutableList())
129-
// val mainAction = individual.seeMainExecutableActions()[0]
130-
// val response = executeRestCallAction(mainAction,"$baseUrlOfSut")
131-
// println("Response:\n${response.getStatusCode()}")
128+
for ((name, expectedDimension) in endpointToDimension) {
129+
println("Expected dimension for $name: $expectedDimension")
130+
}
132131

132+
// Execute the procedure for a period of time
133133
val random = Randomness()
134134
val sampler = injector.getInstance(AbstractRestSampler::class.java)
135135
var time = 1
136-
val timeLimit = 20
136+
val timeLimit = 200
137137
while (time <= timeLimit) {
138138
val template = random.choose(actionList)
139139
val sampledAction = template.copy() as RestCallAction
140140
sampledAction.doInitialize(random)
141141

142-
val path = sampledAction.path
143-
val dimension = pathToDimension[path] ?: error("No dimension for path: $path")
144-
val classifier = pathToClassifier[path] ?: error("Expected classifier for path: $path")
142+
val name = sampledAction.getName()
143+
val classifier = endpointToClassifier[name]
144+
val dimension = endpointToDimension[name]
145145
val geneValues = sampledAction.parameters.map { it.gene.getValueAsRawString() }
146146

147147
println("*************************************************")
148148
println("Time : $time")
149-
println("Path : $path")
149+
println("Path : $name")
150+
println("Classifier : ${if (classifier == null) "null" else "GAUSSIAN"}")
151+
println("Dimension : $dimension")
150152
println("Input Genes : ${geneValues.joinToString(", ")}")
151-
println("Input dim : ${classifier.getDimension()}")
152-
println("Expected Dim : $dimension")
153153
println("Actual Genes : ${geneValues.size}")
154154

155-
// //executeRestCallAction is replaced with createIndividual to avoid override error
155+
// executeRestCallAction is replaced with createIndividual to avoid override error
156156
// val individual = createIndividual(listOf(sampledAction), SampleType.RANDOM)
157157
val individual = sampler.createIndividual(SampleType.RANDOM, listOf(sampledAction).toMutableList())
158158
val action = individual.seeMainExecutableActions()[0]
159159
val result = executeRestCallAction(action,"$baseUrlOfSut")
160-
println("Response:\n${result.getStatusCode()}")
160+
println("Response : ${result.getStatusCode()}")
161+
162+
// Skip classification for the endpoints with unsupported genes
163+
if (classifier==null){
164+
println("No classification as the classifier is null, i.e., the endpoint contains unsupported genes")
165+
continue
166+
}
161167

162168
// Update and classify
163169
classifier.updateModel(action, result)
164170
val classification = classifier.classify(action)
165171

166-
println("Probabilities: ${classification.probabilities}")
167-
require(classification.probabilities.values.all { it in 0.0..1.0 }) {
168-
"All probabilities must be in [0,1]"
172+
println("Score: ${classification.scores}")
173+
require(classification.scores.values.all { it >= 0 }) {
174+
"In Gaussian scores must be positive as they are likelihoods."
169175
}
170176

171177
val d200 = classifier.getDensity200()
@@ -180,7 +186,6 @@ class AIGaussianCheck : IntegrationTestRestBase() {
180186
formatStats("Density200", d200.mean, d200.variance, d200.n)
181187
formatStats("Density400", d400.mean, d400.variance, d400.n)
182188

183-
println("----------------------------------")
184189
time++
185190
}
186191
}

0 commit comments

Comments
 (0)