Skip to content

Commit 05c5df5

Browse files
author
Stainless Bot
committed
feat: feat: allow specification of context column name when using tracers
1 parent 82cf45a commit 05c5df5

File tree

1 file changed

+43
-2
lines changed

1 file changed

+43
-2
lines changed

src/openlayer/lib/tracing/tracer.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
_current_step = contextvars.ContextVar("current_step")
2525
_current_trace = contextvars.ContextVar("current_trace")
26+
_rag_context = contextvars.ContextVar("rag_context")
2627

2728

2829
def get_current_trace() -> Optional[traces.Trace]:
@@ -35,6 +36,11 @@ def get_current_step() -> Optional[steps.Step]:
3536
return _current_step.get(None)
3637

3738

39+
def get_rag_context() -> Optional[Dict[str, Any]]:
40+
"""Returns the current context."""
41+
return _rag_context.get(None)
42+
43+
3844
@contextmanager
3945
def create_step(
4046
name: str,
@@ -57,6 +63,7 @@ def create_step(
5763
logger.debug("Starting a new trace...")
5864
current_trace = traces.Trace()
5965
_current_trace.set(current_trace) # Set the current trace in context
66+
_rag_context.set(None) # Reset the context
6067
current_trace.add_step(new_step)
6168
else:
6269
logger.debug("Adding step %s to parent step %s", name, parent_step.name)
@@ -91,6 +98,9 @@ def create_step(
9198
)
9299
)
93100

101+
if "context" in trace_data:
102+
config.update({"context_column_name": "context"})
103+
94104
if isinstance(new_step, steps.ChatCompletionStep):
95105
config.update(
96106
{
@@ -121,7 +131,7 @@ def add_chat_completion_step_to_trace(**kwargs) -> None:
121131

122132

123133
# ----------------------------- Tracing decorator ---------------------------- #
124-
def trace(*step_args, inference_pipeline_id: Optional[str] = None, **step_kwargs):
134+
def trace(*step_args, inference_pipeline_id: Optional[str] = None, context_kwarg: Optional[str] = None, **step_kwargs):
125135
"""Decorator to trace a function.
126136
127137
Examples
@@ -182,6 +192,12 @@ def wrapper(*func_args, **func_kwargs):
182192
inputs.pop("self", None)
183193
inputs.pop("cls", None)
184194

195+
if context_kwarg:
196+
if context_kwarg in inputs:
197+
log_context(inputs.get(context_kwarg))
198+
else:
199+
logger.warning("Context kwarg `%s` not found in inputs of the current function.", context_kwarg)
200+
185201
step.log(
186202
inputs=inputs,
187203
output=output,
@@ -198,7 +214,9 @@ def wrapper(*func_args, **func_kwargs):
198214
return decorator
199215

200216

201-
def trace_async(*step_args, inference_pipeline_id: Optional[str] = None, **step_kwargs):
217+
def trace_async(
218+
*step_args, inference_pipeline_id: Optional[str] = None, context_kwarg: Optional[str] = None, **step_kwargs
219+
):
202220
"""Decorator to trace a function.
203221
204222
Examples
@@ -259,6 +277,12 @@ async def wrapper(*func_args, **func_kwargs):
259277
inputs.pop("self", None)
260278
inputs.pop("cls", None)
261279

280+
if context_kwarg:
281+
if context_kwarg in inputs:
282+
log_context(inputs.get(context_kwarg))
283+
else:
284+
logger.warning("Context kwarg `%s` not found in inputs of the current function.", context_kwarg)
285+
262286
step.log(
263287
inputs=inputs,
264288
output=output,
@@ -292,6 +316,19 @@ def run_async_func(coroutine: Awaitable[Any]) -> Any:
292316
return result
293317

294318

319+
def log_context(context: List[str]) -> None:
320+
"""Logs context information to the current step of the trace.
321+
322+
The `context` parameter should be a list of strings representing the
323+
context chunks retrieved by the context retriever."""
324+
current_step = get_current_step()
325+
if current_step:
326+
_rag_context.set(context)
327+
current_step.log(metadata={"context": context})
328+
else:
329+
logger.warning("No current step found to log context.")
330+
331+
295332
# --------------------- Helper post-processing functions --------------------- #
296333
def post_process_trace(
297334
trace_obj: traces.Trace,
@@ -323,4 +360,8 @@ def post_process_trace(
323360
if input_variables:
324361
trace_data.update(input_variables)
325362

363+
context = get_rag_context()
364+
if context:
365+
trace_data["context"] = context
366+
326367
return trace_data, input_variable_names

0 commit comments

Comments
 (0)