23
23
24
24
_current_step = contextvars .ContextVar ("current_step" )
25
25
_current_trace = contextvars .ContextVar ("current_trace" )
26
+ _rag_context = contextvars .ContextVar ("rag_context" )
26
27
27
28
28
29
def get_current_trace () -> Optional [traces .Trace ]:
@@ -35,6 +36,11 @@ def get_current_step() -> Optional[steps.Step]:
35
36
return _current_step .get (None )
36
37
37
38
39
+ def get_rag_context () -> Optional [Dict [str , Any ]]:
40
+ """Returns the current context."""
41
+ return _rag_context .get (None )
42
+
43
+
38
44
@contextmanager
39
45
def create_step (
40
46
name : str ,
@@ -57,6 +63,7 @@ def create_step(
57
63
logger .debug ("Starting a new trace..." )
58
64
current_trace = traces .Trace ()
59
65
_current_trace .set (current_trace ) # Set the current trace in context
66
+ _rag_context .set (None ) # Reset the context
60
67
current_trace .add_step (new_step )
61
68
else :
62
69
logger .debug ("Adding step %s to parent step %s" , name , parent_step .name )
@@ -91,6 +98,9 @@ def create_step(
91
98
)
92
99
)
93
100
101
+ if "context" in trace_data :
102
+ config .update ({"context_column_name" : "context" })
103
+
94
104
if isinstance (new_step , steps .ChatCompletionStep ):
95
105
config .update (
96
106
{
@@ -121,7 +131,7 @@ def add_chat_completion_step_to_trace(**kwargs) -> None:
121
131
122
132
123
133
# ----------------------------- Tracing decorator ---------------------------- #
124
- def trace (* step_args , inference_pipeline_id : Optional [str ] = None , ** step_kwargs ):
134
+ def trace (* step_args , inference_pipeline_id : Optional [str ] = None , context_kwarg : Optional [ str ] = None , ** step_kwargs ):
125
135
"""Decorator to trace a function.
126
136
127
137
Examples
@@ -182,6 +192,12 @@ def wrapper(*func_args, **func_kwargs):
182
192
inputs .pop ("self" , None )
183
193
inputs .pop ("cls" , None )
184
194
195
+ if context_kwarg :
196
+ if context_kwarg in inputs :
197
+ log_context (inputs .get (context_kwarg ))
198
+ else :
199
+ logger .warning ("Context kwarg `%s` not found in inputs of the current function." , context_kwarg )
200
+
185
201
step .log (
186
202
inputs = inputs ,
187
203
output = output ,
@@ -198,7 +214,9 @@ def wrapper(*func_args, **func_kwargs):
198
214
return decorator
199
215
200
216
201
- def trace_async (* step_args , inference_pipeline_id : Optional [str ] = None , ** step_kwargs ):
217
+ def trace_async (
218
+ * step_args , inference_pipeline_id : Optional [str ] = None , context_kwarg : Optional [str ] = None , ** step_kwargs
219
+ ):
202
220
"""Decorator to trace a function.
203
221
204
222
Examples
@@ -259,6 +277,12 @@ async def wrapper(*func_args, **func_kwargs):
259
277
inputs .pop ("self" , None )
260
278
inputs .pop ("cls" , None )
261
279
280
+ if context_kwarg :
281
+ if context_kwarg in inputs :
282
+ log_context (inputs .get (context_kwarg ))
283
+ else :
284
+ logger .warning ("Context kwarg `%s` not found in inputs of the current function." , context_kwarg )
285
+
262
286
step .log (
263
287
inputs = inputs ,
264
288
output = output ,
@@ -292,6 +316,19 @@ def run_async_func(coroutine: Awaitable[Any]) -> Any:
292
316
return result
293
317
294
318
319
+ def log_context (context : List [str ]) -> None :
320
+ """Logs context information to the current step of the trace.
321
+
322
+ The `context` parameter should be a list of strings representing the
323
+ context chunks retrieved by the context retriever."""
324
+ current_step = get_current_step ()
325
+ if current_step :
326
+ _rag_context .set (context )
327
+ current_step .log (metadata = {"context" : context })
328
+ else :
329
+ logger .warning ("No current step found to log context." )
330
+
331
+
295
332
# --------------------- Helper post-processing functions --------------------- #
296
333
def post_process_trace (
297
334
trace_obj : traces .Trace ,
@@ -323,4 +360,8 @@ def post_process_trace(
323
360
if input_variables :
324
361
trace_data .update (input_variables )
325
362
363
+ context = get_rag_context ()
364
+ if context :
365
+ trace_data ["context" ] = context
366
+
326
367
return trace_data , input_variable_names
0 commit comments