diff --git a/cht/agent/base.py b/cht/agent/base.py index 7335123..17be1f0 100644 --- a/cht/agent/base.py +++ b/cht/agent/base.py @@ -34,6 +34,7 @@ class SessionContext: mentioned_frames: list[FrameRef] = field(default_factory=list) transcript_segments: list[TranscriptRef] = field(default_factory=list) mentioned_transcripts: list[TranscriptRef] = field(default_factory=list) + history: list[tuple[str, str]] = field(default_factory=list) # [(role, text), ...] class AgentProvider(ABC): diff --git a/cht/agent/claude_sdk_provider.py b/cht/agent/claude_sdk_provider.py index 6f8deac..ea5b42a 100644 --- a/cht/agent/claude_sdk_provider.py +++ b/cht/agent/claude_sdk_provider.py @@ -62,6 +62,12 @@ def _build_prompt(message: str, context: SessionContext) -> str: tm2, ts2 = divmod(int(t.end), 60) lines.append(f" {t.id} [{tm1:02d}:{ts1:02d}-{tm2:02d}:{ts2:02d}] {t.text}") + if context.history: + lines.append("\nConversation history:") + for role, text in context.history: + prefix = "User" if role == "user" else "Assistant" + lines.append(f" {prefix}: {text}") + lines.append(f"\nUser message: {message}") return "\n".join(lines) diff --git a/cht/agent/openai_compat_provider.py b/cht/agent/openai_compat_provider.py index 20a045a..14eff50 100644 --- a/cht/agent/openai_compat_provider.py +++ b/cht/agent/openai_compat_provider.py @@ -118,12 +118,14 @@ class OpenAICompatProvider(AgentProvider): except Exception as e: log.warning("Could not encode frame %s: %s", frame.id, e) + messages = [{"role": "system", "content": SYSTEM_PROMPT}] + for role, text in context.history: + messages.append({"role": role, "content": text}) + messages.append({"role": "user", "content": content}) + stream = client.chat.completions.create( model=self._model, - messages=[ - {"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": content}, - ], + messages=messages, stream=True, ) for chunk in stream: diff --git a/cht/agent/runner.py b/cht/agent/runner.py index e644f7e..e947be1 100644 --- a/cht/agent/runner.py +++ b/cht/agent/runner.py @@ -146,6 +146,8 @@ class AgentRunner: def __init__(self): self._provider: AgentProvider | None = None + self._history: list[tuple[str, str]] = [] # (role, text) + self.include_history = False # toggled by UI def _get_provider(self) -> AgentProvider: if self._provider is None: @@ -178,6 +180,9 @@ class AgentRunner: def model(self, value: str): self._get_provider().model = value + def clear_history(self): + self._history.clear() + def send( self, message: str, @@ -205,9 +210,14 @@ class AgentRunner: mentioned_frames=mentioned_frames, transcript_segments=transcript, mentioned_transcripts=mentioned_transcripts, + history=list(self._history) if self.include_history else [], ) + self._history.append(("user", message)) + response_chunks = [] for chunk in provider.stream(message, context): + response_chunks.append(chunk) on_chunk(chunk) + self._history.append(("assistant", "".join(response_chunks))) on_done(None) except Exception as e: log.error("Agent error: %s", e) diff --git a/cht/window.py b/cht/window.py index 16c9c0b..5180d92 100644 --- a/cht/window.py +++ b/cht/window.py @@ -350,6 +350,7 @@ class ChtWindow(Adw.ApplicationWindow): self._waveform_engine.reset() self._waveform_widget.set_peaks(None, 0.05) self._transcriber.reset() + self._agent.clear_history() self._transcript_order.clear() self._transcript_rows.clear() self._transcript_texts.clear() @@ -539,7 +540,7 @@ class ChtWindow(Adw.ApplicationWindow): clear_btn = Gtk.Button(label="Clear") clear_btn.add_css_class("flat") - clear_btn.connect("clicked", lambda b: self._agent_output_view.get_buffer().set_text("")) + clear_btn.connect("clicked", self._on_clear_agent_output) header.append(clear_btn) box.append(header) @@ -600,6 +601,11 @@ class ChtWindow(Adw.ApplicationWindow): self._lang_dropdown.connect("notify::selected", self._on_lang_changed) actions_box.append(self._lang_dropdown) + self._history_toggle = Gtk.CheckButton(label="Chat") + self._history_toggle.set_tooltip_text("Include conversation history in prompts") + self._history_toggle.connect("toggled", lambda b: setattr(self._agent, "include_history", b.get_active())) + actions_box.append(self._history_toggle) + outer.append(actions_box) # Text entry + send @@ -775,7 +781,7 @@ class ChtWindow(Adw.ApplicationWindow): self._send_message(msg) return True elif keyval == Gdk.KEY_Delete: - self._agent_output_view.get_buffer().set_text("") + self._on_clear_agent_output(None) return True return False @@ -989,6 +995,9 @@ class ChtWindow(Adw.ApplicationWindow): buf.apply_tag(tag, buf.get_iter_at_mark(mark), it) buf.delete_mark(mark) + def _on_clear_agent_output(self, _button): + self._agent_output_view.get_buffer().set_text("") + def _on_lang_changed(self, dropdown, _pspec): idx = dropdown.get_selected() lang_names = list(LANGUAGES.keys())