Skip to content
This repository was archived by the owner on Aug 5, 2025. It is now read-only.

Commit d9fa183

Browse files
authored
Merge pull request #169 from Chainlit/willy/preprocess-steps
feat: add preprocessing functions for step ingestion
2 parents 587dc96 + 9582df6 commit d9fa183

File tree

5 files changed

+184
-6
lines changed

5 files changed

+184
-6
lines changed

literalai/client.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import os
44
from contextlib import redirect_stdout
5-
from typing import Any, Dict, List, Optional, Union
5+
from typing import Any, Callable, Dict, List, Optional, Union
66

77
from traceloop.sdk import Traceloop
88
from typing_extensions import deprecated
@@ -29,6 +29,7 @@
2929
MessageStepType,
3030
Step,
3131
StepContextManager,
32+
StepDict,
3233
TrueStepType,
3334
step_decorator,
3435
)
@@ -94,15 +95,55 @@ def __init__(
9495

9596
def to_sync(self) -> "LiteralClient":
9697
if isinstance(self.api, AsyncLiteralAPI):
97-
return LiteralClient(
98+
sync_client = LiteralClient(
9899
batch_size=self.event_processor.batch_size,
99100
api_key=self.api.api_key,
100101
url=self.api.url,
101102
disabled=self.disabled,
102103
)
104+
if self.event_processor.preprocess_steps_function:
105+
sync_client.event_processor.set_preprocess_steps_function(
106+
self.event_processor.preprocess_steps_function
107+
)
108+
109+
return sync_client
103110
else:
104111
return self # type: ignore
105112

113+
def set_preprocess_steps_function(
114+
self,
115+
preprocess_steps_function: Optional[
116+
Callable[[List["StepDict"]], List["StepDict"]]
117+
],
118+
) -> None:
119+
"""
120+
Set a function that will preprocess steps before sending them to the API.
121+
This can be used for tasks like PII removal or other data transformations.
122+
123+
The preprocess function should:
124+
- Accept a list of StepDict objects as input
125+
- Return a list of modified StepDict objects
126+
- Be thread-safe as it will be called from a background thread
127+
- Handle exceptions internally to avoid disrupting the event processing
128+
129+
Example:
130+
```python
131+
def remove_pii(steps):
132+
# Process steps to remove PII data
133+
for step in steps:
134+
if step.get("content"):
135+
step["content"] = my_pii_removal_function(step["content"])
136+
return steps
137+
138+
client.set_preprocess_steps_function(remove_pii)
139+
```
140+
141+
Args:
142+
preprocess_steps_function (Callable[[List["StepDict"]], List["StepDict"]]):
143+
Function that takes a list of steps and returns a processed list
144+
"""
145+
self.event_processor.set_preprocess_steps_function(preprocess_steps_function)
146+
106147
@deprecated("Use Literal.initialize instead")
107148
def instrument_openai(self):
108149
"""

literalai/event_processor.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import threading
55
import time
66
import traceback
7-
from typing import TYPE_CHECKING, List
7+
from typing import TYPE_CHECKING, Callable, List, Optional
88

99
logger = logging.getLogger(__name__)
1010

@@ -31,7 +31,15 @@ class EventProcessor:
3131
batch: List["StepDict"]
3232
batch_timeout: float = 5.0
3333

34-
def __init__(self, api: "LiteralAPI", batch_size: int = 1, disabled: bool = False):
34+
def __init__(
35+
self,
36+
api: "LiteralAPI",
37+
batch_size: int = 1,
38+
disabled: bool = False,
39+
preprocess_steps_function: Optional[
40+
Callable[[List["StepDict"]], List["StepDict"]]
41+
] = None,
42+
):
3543
self.stop_event = threading.Event()
3644
self.batch_size = batch_size
3745
self.api = api
@@ -40,6 +48,7 @@ def __init__(self, api: "LiteralAPI", batch_size: int = 1, disabled: bool = Fals
4048
self.processing_counter = 0
4149
self.counter_lock = threading.Lock()
4250
self.last_batch_time = time.time()
51+
self.preprocess_steps_function = preprocess_steps_function
4352
self.processing_thread = threading.Thread(
4453
target=self._process_events, daemon=True
4554
)
@@ -56,6 +65,22 @@ async def a_add_events(self, event: "StepDict"):
5665
self.processing_counter += 1
5766
await to_thread(self.event_queue.put, event)
5867

68+
def set_preprocess_steps_function(
69+
self,
70+
preprocess_steps_function: Optional[
71+
Callable[[List["StepDict"]], List["StepDict"]]
72+
],
73+
):
74+
"""
75+
Set a function that will preprocess steps before sending them to the API.
76+
The function should take a list of StepDict objects and return a list of processed StepDict objects.
77+
This can be used for tasks like PII removal.
78+
79+
Args:
80+
preprocess_steps_function (Callable[[List["StepDict"]], List["StepDict"]]): The preprocessing function
81+
"""
82+
self.preprocess_steps_function = preprocess_steps_function
83+
5984
def _process_events(self):
6085
while True:
6186
batch = []
@@ -83,6 +108,24 @@ def _process_events(self):
83108

84109
def _try_process_batch(self, batch: List):
85110
try:
111+
# Apply preprocessing function if it exists
112+
if self.preprocess_steps_function is not None:
113+
try:
114+
processed_batch = self.preprocess_steps_function(batch)
115+
# Only use the processed batch if it's valid
116+
if processed_batch is not None and isinstance(
117+
processed_batch, list
118+
):
119+
batch = processed_batch
120+
else:
121+
logger.warning(
122+
"Preprocess function returned invalid result, using original batch"
123+
)
124+
except Exception as e:
125+
logger.error(f"Error in preprocess function: {str(e)}")
126+
logger.error(traceback.format_exc())
127+
# Continue with the original batch
128+
86129
return self.api.send_steps(batch)
87130
except Exception:
88131
logger.error(f"Failed to send steps: {traceback.format_exc()}")

literalai/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.1.202"
1+
__version__ = "0.1.3"

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="literalai",
5-
version="0.1.202", # update version in literalai/version.py
5+
version="0.1.3", # update version in literalai/version.py
66
description="An SDK for observability in Python applications",
77
long_description=open("README.md").read(),
88
long_description_content_type="text/markdown",

tests/e2e/test_e2e.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,3 +801,97 @@ async def test_environment(self, staging_client: LiteralClient):
801801
persisted_run = staging_client.api.get_step(run_id)
802802
assert persisted_run is not None
803803
assert persisted_run.environment == "staging"
804+
805+
@pytest.mark.timeout(5)
806+
async def test_pii_removal(
807+
self, client: LiteralClient, async_client: AsyncLiteralClient
808+
):
809+
"""Test that PII is properly removed by the preprocess function."""
810+
import re
811+
812+
# Define a PII removal function
813+
def remove_pii(steps):
814+
# Patterns for common PII
815+
email_pattern = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
816+
phone_pattern = r"\b(\+\d{1,2}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}\b"
817+
ssn_pattern = r"\b\d{3}[-\s]?\d{2}[-\s]?\d{4}\b"
818+
819+
for step in steps:
820+
# Process content field if it exists
821+
if "output" in step and step["output"]["content"]:
822+
# Replace emails with [EMAIL REDACTED]
823+
step["output"]["content"] = re.sub(
824+
email_pattern, "[EMAIL REDACTED]", step["output"]["content"]
825+
)
826+
827+
# Replace phone numbers with [PHONE REDACTED]
828+
step["output"]["content"] = re.sub(
829+
phone_pattern, "[PHONE REDACTED]", step["output"]["content"]
830+
)
831+
832+
# Replace SSNs with [SSN REDACTED]
833+
step["output"]["content"] = re.sub(
834+
ssn_pattern, "[SSN REDACTED]", step["output"]["content"]
835+
)
836+
837+
return steps
838+
839+
# Set the PII removal function on the client
840+
client.set_preprocess_steps_function(remove_pii)
841+
842+
@client.thread
843+
def thread_with_pii():
844+
thread = client.get_current_thread()
845+
846+
# User message with PII
847+
user_step = client.message(
848+
content="My email is test@example.com and my phone is (123) 456-7890. My SSN is 123-45-6789.",
849+
type="user_message",
850+
metadata={"contact_info": "Call me at 987-654-3210"},
851+
)
852+
user_step_id = user_step.id
853+
854+
# Assistant message with PII reference
855+
assistant_step = client.message(
856+
content="I'll contact you at test@example.com", type="assistant_message"
857+
)
858+
assistant_step_id = assistant_step.id
859+
860+
return thread.id, user_step_id, assistant_step_id
861+
862+
# Run the thread
863+
thread_id, user_step_id, assistant_step_id = thread_with_pii()
864+
865+
# Wait for processing to occur
866+
client.flush()
867+
868+
# Fetch the steps and verify PII was removed
869+
user_step = client.api.get_step(id=user_step_id)
870+
assistant_step = client.api.get_step(id=assistant_step_id)
871+
872+
assert user_step
873+
assert assistant_step
874+
875+
user_step_output = user_step.output["content"] # type: ignore
876+
877+
# Check user message
878+
assert "test@example.com" not in user_step_output
879+
assert "(123) 456-7890" not in user_step_output
880+
assert "123-45-6789" not in user_step_output
881+
assert "[EMAIL REDACTED]" in user_step_output
882+
assert "[PHONE REDACTED]" in user_step_output
883+
assert "[SSN REDACTED]" in user_step_output
884+
885+
assistant_step_output = assistant_step.output["content"] # type: ignore
886+
887+
# Check assistant message
888+
assert "test@example.com" not in assistant_step_output
889+
assert "[EMAIL REDACTED]" in assistant_step_output
890+
891+
# Clean up
892+
client.api.delete_step(id=user_step_id)
893+
client.api.delete_step(id=assistant_step_id)
894+
client.api.delete_thread(id=thread_id)
895+
896+
# Reset the preprocess function to avoid affecting other tests
897+
client.set_preprocess_steps_function(None)

0 commit comments

Comments
 (0)