Skip to content

Commit 3379dc4

Browse files
committed
Merge branch 'main' of github.com:abetlen/llama_cpp_python into main
2 parents 9522284 + 628e3fb commit 3379dc4

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

llama_cpp/llama.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ def __getitem__(self, key: Sequence[int]) -> "LlamaState":
141141
if _key is None:
142142
raise KeyError("Key not found")
143143
value: "LlamaState" = self.cache.pop(_key) # type: ignore
144-
self.cache.push(_key, side="front") # type: ignore
144+
# NOTE: This puts an integer as key in cache, which breaks,
145+
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
146+
# self.cache.push(_key, side="front") # type: ignore
145147
return value
146148

147149
def __contains__(self, key: Sequence[int]) -> bool:
@@ -168,7 +170,7 @@ def __init__(
168170
eval_logits: Deque[List[float]],
169171
input_ids: npt.NDArray[np.intc],
170172
scores: npt.NDArray[np.single],
171-
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
173+
llama_state: bytes,
172174
llama_state_size: int,
173175
):
174176
self.eval_tokens = eval_tokens
@@ -1512,7 +1514,7 @@ def save_state(self) -> LlamaState:
15121514
eval_logits=self.eval_logits.copy(),
15131515
scores=self._scores.copy(),
15141516
input_ids=self._input_ids.copy(),
1515-
llama_state=llama_state_compact,
1517+
llama_state=bytes(llama_state_compact),
15161518
llama_state_size=n_bytes,
15171519
)
15181520

@@ -1523,7 +1525,10 @@ def load_state(self, state: LlamaState) -> None:
15231525
self._scores = state.scores.copy()
15241526
self._input_ids = state.input_ids.copy()
15251527
state_size = state.llama_state_size
1526-
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
1528+
LLamaStateArrayType = (llama_cpp.c_uint8 * state_size)
1529+
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
1530+
1531+
if llama_cpp.llama_set_state_data(self.ctx, llama_state) != state_size:
15271532
raise RuntimeError("Failed to set llama state data")
15281533

15291534
def n_ctx(self) -> int:

0 commit comments

Comments
 (0)