This commit is contained in:
2026-04-03 01:24:37 -03:00
parent cae9312db1
commit db3b94a6a1
5 changed files with 34 additions and 6 deletions

View File

@@ -34,6 +34,7 @@ class SessionContext:
mentioned_frames: list[FrameRef] = field(default_factory=list) mentioned_frames: list[FrameRef] = field(default_factory=list)
transcript_segments: list[TranscriptRef] = field(default_factory=list) transcript_segments: list[TranscriptRef] = field(default_factory=list)
mentioned_transcripts: list[TranscriptRef] = field(default_factory=list) mentioned_transcripts: list[TranscriptRef] = field(default_factory=list)
history: list[tuple[str, str]] = field(default_factory=list) # [(role, text), ...]
class AgentProvider(ABC): class AgentProvider(ABC):

View File

@@ -62,6 +62,12 @@ def _build_prompt(message: str, context: SessionContext) -> str:
tm2, ts2 = divmod(int(t.end), 60) tm2, ts2 = divmod(int(t.end), 60)
lines.append(f" {t.id} [{tm1:02d}:{ts1:02d}-{tm2:02d}:{ts2:02d}] {t.text}") lines.append(f" {t.id} [{tm1:02d}:{ts1:02d}-{tm2:02d}:{ts2:02d}] {t.text}")
if context.history:
lines.append("\nConversation history:")
for role, text in context.history:
prefix = "User" if role == "user" else "Assistant"
lines.append(f" {prefix}: {text}")
lines.append(f"\nUser message: {message}") lines.append(f"\nUser message: {message}")
return "\n".join(lines) return "\n".join(lines)

View File

@@ -118,12 +118,14 @@ class OpenAICompatProvider(AgentProvider):
except Exception as e: except Exception as e:
log.warning("Could not encode frame %s: %s", frame.id, e) log.warning("Could not encode frame %s: %s", frame.id, e)
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
for role, text in context.history:
messages.append({"role": role, "content": text})
messages.append({"role": "user", "content": content})
stream = client.chat.completions.create( stream = client.chat.completions.create(
model=self._model, model=self._model,
messages=[ messages=messages,
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": content},
],
stream=True, stream=True,
) )
for chunk in stream: for chunk in stream:

View File

@@ -146,6 +146,8 @@ class AgentRunner:
def __init__(self): def __init__(self):
self._provider: AgentProvider | None = None self._provider: AgentProvider | None = None
self._history: list[tuple[str, str]] = [] # (role, text)
self.include_history = False # toggled by UI
def _get_provider(self) -> AgentProvider: def _get_provider(self) -> AgentProvider:
if self._provider is None: if self._provider is None:
@@ -178,6 +180,9 @@ class AgentRunner:
def model(self, value: str): def model(self, value: str):
self._get_provider().model = value self._get_provider().model = value
def clear_history(self):
self._history.clear()
def send( def send(
self, self,
message: str, message: str,
@@ -205,9 +210,14 @@ class AgentRunner:
mentioned_frames=mentioned_frames, mentioned_frames=mentioned_frames,
transcript_segments=transcript, transcript_segments=transcript,
mentioned_transcripts=mentioned_transcripts, mentioned_transcripts=mentioned_transcripts,
history=list(self._history) if self.include_history else [],
) )
self._history.append(("user", message))
response_chunks = []
for chunk in provider.stream(message, context): for chunk in provider.stream(message, context):
response_chunks.append(chunk)
on_chunk(chunk) on_chunk(chunk)
self._history.append(("assistant", "".join(response_chunks)))
on_done(None) on_done(None)
except Exception as e: except Exception as e:
log.error("Agent error: %s", e) log.error("Agent error: %s", e)

View File

@@ -350,6 +350,7 @@ class ChtWindow(Adw.ApplicationWindow):
self._waveform_engine.reset() self._waveform_engine.reset()
self._waveform_widget.set_peaks(None, 0.05) self._waveform_widget.set_peaks(None, 0.05)
self._transcriber.reset() self._transcriber.reset()
self._agent.clear_history()
self._transcript_order.clear() self._transcript_order.clear()
self._transcript_rows.clear() self._transcript_rows.clear()
self._transcript_texts.clear() self._transcript_texts.clear()
@@ -539,7 +540,7 @@ class ChtWindow(Adw.ApplicationWindow):
clear_btn = Gtk.Button(label="Clear") clear_btn = Gtk.Button(label="Clear")
clear_btn.add_css_class("flat") clear_btn.add_css_class("flat")
clear_btn.connect("clicked", lambda b: self._agent_output_view.get_buffer().set_text("")) clear_btn.connect("clicked", self._on_clear_agent_output)
header.append(clear_btn) header.append(clear_btn)
box.append(header) box.append(header)
@@ -600,6 +601,11 @@ class ChtWindow(Adw.ApplicationWindow):
self._lang_dropdown.connect("notify::selected", self._on_lang_changed) self._lang_dropdown.connect("notify::selected", self._on_lang_changed)
actions_box.append(self._lang_dropdown) actions_box.append(self._lang_dropdown)
self._history_toggle = Gtk.CheckButton(label="Chat")
self._history_toggle.set_tooltip_text("Include conversation history in prompts")
self._history_toggle.connect("toggled", lambda b: setattr(self._agent, "include_history", b.get_active()))
actions_box.append(self._history_toggle)
outer.append(actions_box) outer.append(actions_box)
# Text entry + send # Text entry + send
@@ -775,7 +781,7 @@ class ChtWindow(Adw.ApplicationWindow):
self._send_message(msg) self._send_message(msg)
return True return True
elif keyval == Gdk.KEY_Delete: elif keyval == Gdk.KEY_Delete:
self._agent_output_view.get_buffer().set_text("") self._on_clear_agent_output(None)
return True return True
return False return False
@@ -989,6 +995,9 @@ class ChtWindow(Adw.ApplicationWindow):
buf.apply_tag(tag, buf.get_iter_at_mark(mark), it) buf.apply_tag(tag, buf.get_iter_at_mark(mark), it)
buf.delete_mark(mark) buf.delete_mark(mark)
def _on_clear_agent_output(self, _button):
self._agent_output_view.get_buffer().set_text("")
def _on_lang_changed(self, dropdown, _pspec): def _on_lang_changed(self, dropdown, _pspec):
idx = dropdown.get_selected() idx = dropdown.get_selected()
lang_names = list(LANGUAGES.keys()) lang_names = list(LANGUAGES.keys())