Skip to content

Commit d8577b2

Browse files
authored
Merge pull request rllm-org#106 from evanmiller-anthropic/docvqa
DocVQA implementation
2 parents 7c3622e + f377e86 commit d8577b2

File tree

7 files changed

+381
-1
lines changed

7 files changed

+381
-1
lines changed

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,14 @@ The questions were generated by GPT-4 based on the "Computer Systems Security: P
300300
inspect eval inspect_evals/boolq
301301
```
302302

303+
- ### [DocVQA: A Dataset for VQA on Document Images](src/inspect_evals/docvqa)
304+
DocVQA is a Visual Question Answering benchmark that consists of 50,000 questions covering 12,000+ document images. This implementation solves and scores the "validation" split.
305+
<sub><sup>Contributed by: [@evanmiller-anthropic](https://github.com/evanmiller-anthropic)</sub></sup>
306+
307+
```bash
308+
inspect eval inspect_evals/docvqa
309+
```
310+
303311
- ### [DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs](src/inspect_evals/drop)
304312
Evaluates reading comprehension where models must resolve references in a question, perhaps to multiple input positions, and perform discrete operations over them (such as addition, counting, or sorting).
305313
<sub><sup>Contributed by: [@xeon27](https://github.com/xeon27)</sub></sup>
@@ -443,4 +451,4 @@ The questions were generated by GPT-4 based on the "Computer Systems Security: P
443451
inspect eval inspect_evals/agie_lsat_lr
444452
```
445453

446-
<!-- /Eval Listing: Automatically Generated -->
454+
<!-- /Eval Listing: Automatically Generated -->

src/inspect_evals/_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
cybermetric_10000,
2727
)
2828
from .cyberseceval_2 import interpreter_abuse, prompt_injection, vulnerability_exploit
29+
from .docvqa import docvqa
2930
from .drop import drop
3031
from .ds1000 import ds1000
3132
from .gaia import gaia, gaia_level1, gaia_level2, gaia_level3

src/inspect_evals/docvqa/README.md

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# DocVQA: A Dataset for VQA on Document Images
2+
3+
[DocVQA](https://arxiv.org/abs/2007.00398) is a Visual Question Answering benchmark that consists of 50,000 questions covering 12,000+ document images. This implementation solves and scores the "validation" split.
4+
5+
<!-- Contributors: Automatically Generated -->
6+
Contributed by [@evanmiller-anthropic](https://github.com/evanmiller-anthropic)
7+
<!-- /Contributors: Automatically Generated -->
8+
9+
10+
<!-- Usage: Automatically Generated -->
11+
## Usage
12+
13+
First, install the `inspect_ai` and `inspect_evals` Python packages with:
14+
15+
```bash
16+
pip install inspect_ai
17+
pip install git+https://github.com/UKGovernmentBEIS/inspect_evals
18+
```
19+
20+
Then, evaluate against one or more models with:
21+
22+
```bash
23+
inspect eval inspect_evals/docvqa --model openai/gpt-4o
24+
```
25+
26+
After running evaluations, you can view their logs using the `inspect view` command:
27+
28+
```bash
29+
inspect view
30+
```
31+
32+
If you don't want to specify the `--model` each time you run an evaluation, create a `.env` configuration file in your working directory that defines the `INSPECT_EVAL_MODEL` environment variable along with your API key. For example:
33+
34+
```bash
35+
INSPECT_EVAL_MODEL=anthropic/claude-3-5-sonnet-20240620
36+
ANTHROPIC_API_KEY=<anthropic-api-key>
37+
```
38+
<!-- /Usage: Automatically Generated -->
39+
40+
<!-- Options: Automatically Generated -->
41+
## Options
42+
43+
You can control a variety of options from the command line. For example:
44+
45+
```bash
46+
inspect eval inspect_evals/docvqa --limit 10
47+
inspect eval inspect_evals/docvqa --max-connections 10
48+
inspect eval inspect_evals/docvqa --temperature 0.5
49+
```
50+
51+
See `inspect eval --help` for all available options.
52+
<!-- /Options: Automatically Generated -->
53+
54+
## Dataset
55+
56+
The DocVQA dataset contains a "validation" split and a "test" split. To prevent leakage into training data, the authors of DocVQA have chosen to hold back the answers to the "test" split. Scoring on the "test" split requires coordinating with the DocVQA authors.
57+
58+
Each split contains several questions about each image. Here is an example image:
59+
60+
![Diabetes in Australia](https://rrc.cvc.uab.es/files/DocVQA_exT3_2_Infographics.png)
61+
62+
And associated example questions:
63+
* How many females are affected by diabetes?
64+
* What percentage of cases can be prevented?
65+
* What could lead to blindness or stroke diabetes?
66+
67+
The model is tasked to answer each question by referring to the image. The prompts are based on OpenAI's [simple-evals](https://github.com/openai/simple-evals/blob/294cb1f/drop_eval.py#L261C13-L283C91).
68+
69+
## Scoring
70+
71+
DocVQA computes the Average Normalized Levenstein Similarity:
72+
73+
[Average Normalized Levenstein Similarity definition](https://user-images.githubusercontent.com/48327001/195277520-b1ef2be2-c4d7-417b-91ec-5fda8aa6db06.png)

src/inspect_evals/docvqa/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .docvqa import docvqa
2+
3+
__all__ = ["docvqa"]

src/inspect_evals/docvqa/docvqa.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import re
2+
from io import BytesIO
3+
from pathlib import Path
4+
from typing import Any
5+
6+
from inspect_ai import Task, task
7+
from inspect_ai.dataset import Sample, hf_dataset
8+
from inspect_ai.model import ChatMessage, ChatMessageUser, ContentImage, ContentText
9+
from inspect_ai.scorer import (
10+
INCORRECT,
11+
AnswerPattern,
12+
Score,
13+
Scorer,
14+
Target,
15+
accuracy,
16+
scorer,
17+
stderr,
18+
)
19+
from inspect_ai.solver import (
20+
Generate,
21+
Solver,
22+
TaskState,
23+
solver,
24+
)
25+
from PIL import Image
26+
from platformdirs import user_cache_dir
27+
28+
FREEFORM_TEMPLATE = r"""
29+
Answer the following question. The entire content of your response should be of the following format: 'ANSWER: $ANSWER' (without quotes) where $ANSWER is your answer.
30+
31+
{question}
32+
"""
33+
34+
IMAGE_BASE_DIR = Path(user_cache_dir("inspect_evals")) / "docvqa_images"
35+
36+
37+
def _levenshtein_distance(str1: str, str2: str) -> int:
38+
"""Computes a Levenshtein distance, same as Levenshtein.distance in the python-Levenshtein package."""
39+
# Create a matrix of size (len(str1) + 1) x (len(str2) + 1)
40+
matrix = [[0 for j in range(len(str2) + 1)] for i in range(len(str1) + 1)]
41+
42+
# Initialize the first row and column
43+
for i in range(len(str1) + 1):
44+
matrix[i][0] = i
45+
for j in range(len(str2) + 1):
46+
matrix[0][j] = j
47+
48+
# Fill in the rest of the matrix
49+
for i in range(1, len(str1) + 1):
50+
for j in range(1, len(str2) + 1):
51+
matrix[i][j] = min(
52+
matrix[i - 1][j] + 1, # deletion
53+
matrix[i][j - 1] + 1, # insertion
54+
matrix[i - 1][j - 1] + int(str1[i - 1] != str2[j - 1]), # substitution
55+
)
56+
57+
return matrix[len(str1)][len(str2)]
58+
59+
60+
def _best_normalized_levenshtein_similiarity(
61+
completion: str, ground_truths: list[str], threshold: float
62+
) -> float:
63+
"""
64+
Compute a best normalized Levenshtein similiarity, an input into the Average Normalized Levenshtein Similiarity (ANLS)
65+
66+
The Average Normalized Levenshtein Similarity (ANLS) is defined in equation (1) of
67+
https://arxiv.org/pdf/1907.00490.pdf
68+
69+
Note that the "average" is computed by the accuracy metric -- not here. This function computes
70+
the term inside the summation of equation (1).
71+
"""
72+
best_score = 0.0
73+
for ground_truth in ground_truths:
74+
if len(ground_truth) == 0 and len(completion) == 0:
75+
best_score = 1
76+
break
77+
levenshtein_distance = _levenshtein_distance(
78+
completion.lower(), ground_truth.lower()
79+
)
80+
normed_levenshtein_distance = levenshtein_distance / max(
81+
len(completion), len(ground_truth)
82+
)
83+
if normed_levenshtein_distance < threshold:
84+
score = 1.0 - normed_levenshtein_distance
85+
else:
86+
score = 0.0
87+
if score > best_score:
88+
best_score = score
89+
return best_score
90+
91+
92+
@task
93+
def docvqa() -> Task:
94+
dataset = hf_dataset(
95+
path="lmms-lab/DocVQA",
96+
name="DocVQA",
97+
split="validation", # "answers" in the "test" split are held back by the authors
98+
sample_fields=record_to_sample,
99+
trust=True,
100+
shuffle=True,
101+
)
102+
103+
return Task(
104+
dataset=dataset,
105+
solver=[docvqa_solver()],
106+
scorer=docvqa_scorer(),
107+
)
108+
109+
110+
@scorer(metrics=[accuracy(), stderr()])
111+
def docvqa_scorer() -> Scorer:
112+
async def normalized_levenshtein_similiarity_score(
113+
state: TaskState, target: Target
114+
) -> Score:
115+
threshold = 0.5
116+
ground_truths = target.target
117+
match = re.search(
118+
AnswerPattern.LINE,
119+
state.output.completion,
120+
re.IGNORECASE,
121+
)
122+
if match:
123+
completion = match.groups()[0]
124+
return Score(
125+
value=_best_normalized_levenshtein_similiarity(
126+
completion, ground_truths, threshold
127+
),
128+
answer=completion,
129+
)
130+
131+
else:
132+
# didn't find the scoring pattern
133+
return Score(
134+
value=INCORRECT,
135+
explanation="Scoring pattern not matched in output: "
136+
+ f"{state.output.completion}",
137+
)
138+
139+
return normalized_levenshtein_similiarity_score
140+
141+
142+
@solver
143+
def docvqa_solver() -> Solver:
144+
async def solve(state: TaskState, generate: Generate) -> TaskState:
145+
state.user_prompt.text = FREEFORM_TEMPLATE.format(
146+
question=state.user_prompt.text
147+
)
148+
return await generate(state)
149+
150+
return solve
151+
152+
153+
def record_to_sample(record: dict[str, Any]) -> Sample:
154+
# extract image
155+
image_path = Path(IMAGE_BASE_DIR / record["image"]["path"])
156+
157+
image_bytes = record["image"]["bytes"]
158+
assert is_image_png(image_bytes)
159+
160+
if not image_path.exists():
161+
print(f"Extracting {image_path.name}")
162+
# ensure parent
163+
image_path.parent.mkdir(exist_ok=True, parents=True)
164+
# reduce the image size
165+
img = Image.open(BytesIO(image_bytes))
166+
img.thumbnail((1024, 1024))
167+
# save preserving format
168+
img.save(image_path, format=img.format)
169+
170+
message: list[ChatMessage] = [
171+
ChatMessageUser(
172+
content=[
173+
ContentText(text=record["question"]),
174+
ContentImage(image=image_path.as_posix()),
175+
]
176+
)
177+
]
178+
179+
return Sample(
180+
input=message,
181+
target=record["answers"],
182+
id=record["questionId"],
183+
metadata={"document_id": record["docId"]},
184+
)
185+
186+
187+
def is_image_png(image_bytes: bytes) -> bool:
188+
return image_bytes[:8] == b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"

tests/docvqa/test_docvqa.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from inspect_evals.docvqa.docvqa import (
2+
_best_normalized_levenshtein_similiarity as best,
3+
)
4+
from inspect_evals.docvqa.docvqa import (
5+
_levenshtein_distance as levenshtein,
6+
)
7+
8+
9+
def test_levenshtein():
10+
# Basic test cases
11+
assert levenshtein("", "") == 0 # Empty strings
12+
assert levenshtein("a", "a") == 0 # Same single char
13+
assert levenshtein("abc", "abc") == 0 # Same string
14+
15+
# Single operations
16+
assert levenshtein("a", "") == 1 # Single deletion
17+
assert levenshtein("", "a") == 1 # Single insertion
18+
assert levenshtein("a", "b") == 1 # Single substitution
19+
20+
# Multiple operations
21+
assert levenshtein("kitten", "sitting") == 3 # Classic example
22+
assert levenshtein("sunday", "saturday") == 3 # Real words
23+
24+
25+
def test_best_normalized_levenshtein_distance():
26+
def best_norm_lev_sim(completion, ground_truths, threshold=2.0):
27+
return round(best(completion, ground_truths, threshold), 3)
28+
29+
# Basic cases
30+
assert best_norm_lev_sim("", [""]) == 1.0 # Empty strings
31+
assert best_norm_lev_sim("a", ["a"]) == 1.0 # Single char match
32+
assert best_norm_lev_sim("", ["a"]) == 0.0 # Empty vs char
33+
assert best_norm_lev_sim("a", ["b"]) == 0.0 # Different chars
34+
35+
# Multiple correct answers
36+
assert (
37+
best_norm_lev_sim("color", ["color", "colour"]) == 1.0
38+
) # Exact match with variants
39+
40+
assert (
41+
best_norm_lev_sim("theatre", ["theater", "theatre"]) == 1.0
42+
) # Regional spellings
43+
44+
# Partial matches with multiple answers
45+
assert best_norm_lev_sim("thetre", ["theater", "theatre"]) == round(
46+
1 - 1 / 7, 3
47+
) # One deletion
48+
49+
# Case insensitivity
50+
assert best_norm_lev_sim("HELLO", ["hello", "hola"]) == 1.0 # All case differences
51+
52+
# Length differences
53+
assert best_norm_lev_sim("hi", ["hello", "hey"]) == round(
54+
1 - 2 / 3, 3
55+
) # Short vs longer options
56+
57+
assert best_norm_lev_sim("hi", ["hello", "hey"], 0.5) == 0.0 # Test threshold
58+
59+
assert best_norm_lev_sim("hi", ["hello", "hey"], 0.75) == round(
60+
1 - 2 / 3, 3
61+
) # Test threshold
62+
63+
# Numeric and special characters
64+
assert (
65+
best_norm_lev_sim("2nd floor", ["second floor", "2nd floor", "floor 2"]) == 1.0
66+
) # Number representations
67+
68+
# Common abbreviations
69+
assert (
70+
best_norm_lev_sim("dept", ["department", "dept.", "dept"]) == 1.0
71+
) # Abbreviation matches
72+
73+
# Multiple errors
74+
assert best_norm_lev_sim(
75+
"californa", ["california", "calif", "ca"]
76+
) > best_norm_lev_sim(
77+
"calfrnia", ["california", "calif", "ca"]
78+
) # Better partial match
79+
80+
# Spaces and formatting
81+
assert (
82+
best_norm_lev_sim("new york", ["newyork", "new york", "ny"]) == 1.0
83+
) # Space variations
84+
85+
# Unicode and special characters
86+
assert best_norm_lev_sim("café", ["cafe", "café", "caffè"]) == 1.0 # Accent marks
87+
88+
# Long string comparisons
89+
assert (
90+
best_norm_lev_sim(
91+
"mississipi river", ["mississippi river", "river mississippi"]
92+
)
93+
> 0.9
94+
) # Minor spelling error
95+
96+
# Completely different strings
97+
assert best_norm_lev_sim("kiwi", ["banana", "orange"]) == 0.0 # No similarity

0 commit comments

Comments
 (0)