Files changed (2) hide show
  1. pyproject.toml +1 -0
  2. tdagent/grchat.py +416 -295
pyproject.toml CHANGED
@@ -146,4 +146,5 @@ convention = "google"
146
  [tool.ruff.lint.per-file-ignores]
147
  "*/__init__.py" = ["F401"]
148
  "tdagent/cli/**/*.py" = ["D103", "T201"]
 
149
  "tests/*.py" = ["D103", "PLR2004", "S101"]
 
146
  [tool.ruff.lint.per-file-ignores]
147
  "*/__init__.py" = ["F401"]
148
  "tdagent/cli/**/*.py" = ["D103", "T201"]
149
+ "tdagent/grchat.py" = ["ANN401", "FBT001"]
150
  "tests/*.py" = ["D103", "PLR2004", "S101"]
tdagent/grchat.py CHANGED
@@ -1,5 +1,7 @@
1
  from __future__ import annotations
2
 
 
 
3
  import os
4
  from collections import OrderedDict
5
  from collections.abc import Mapping, Sequence
@@ -12,7 +14,9 @@ import botocore.exceptions
12
  import gradio as gr
13
  import gradio.themes as gr_themes
14
  from langchain_aws import ChatBedrock
 
15
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
 
16
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
17
  from langchain_mcp_adapters.client import MultiServerMCPClient
18
  from langchain_openai import AzureChatOpenAI
@@ -29,22 +33,46 @@ if TYPE_CHECKING:
29
 
30
  #### Constants ####
31
 
32
- SYSTEM_MESSAGE = SystemMessage(
33
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  You are a security analyst assistant responsible for collecting, analyzing
35
  and disseminating actionable intelligence related to cyber threats,
36
  vulnerabilities and threat actors.
37
 
38
  When presented with potential incidents information or tickets, you should
39
- evaluate the presented evidence, decide what is missing and gather
40
- additional data using any tool at your disposal. After gathering more
41
- information you must evaluate if the incident is a threat or
42
- not and, if possible, remediation actions.
 
43
 
44
- You must always present the conducted analysis and final conclusion.
45
  Never use external means of communication, like emails or SMS, unless
46
  instructed to do so.
47
  """.strip(),
 
 
 
 
 
 
 
 
 
 
48
  )
49
 
50
 
@@ -55,6 +83,7 @@ GRADIO_ROLE_TO_LG_MESSAGE_TYPE = MappingProxyType(
55
  },
56
  )
57
 
 
58
  MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
59
  (
60
  (
@@ -90,9 +119,50 @@ MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
90
  ),
91
  )
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  #### Shared variables ####
94
 
95
  llm_agent: CompiledGraph | None = None
 
96
 
97
  #### Utility functions ####
98
 
@@ -158,6 +228,8 @@ def create_hf_llm(
158
 
159
 
160
  ## OpenAI LLM creation ##
 
 
161
  def create_openai_llm(
162
  model_id: str,
163
  token_id: str,
@@ -208,6 +280,56 @@ def create_azure_llm(
208
 
209
 
210
  #### UI functionality ####
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  async def gr_connect_to_bedrock( # noqa: PLR0913
212
  model_id: str,
213
  access_key: str,
@@ -215,11 +337,14 @@ async def gr_connect_to_bedrock( # noqa: PLR0913
215
  session_token: str,
216
  region: str,
217
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
 
 
218
  temperature: float = 0.8,
219
  max_tokens: int = 512,
220
  ) -> str:
221
  """Initialize Bedrock agent."""
222
  global llm_agent # noqa: PLW0603
 
223
  if not access_key or not secret_key:
224
  return "❌ Please provide both Access Key ID and Secret Access Key"
225
 
@@ -236,32 +361,13 @@ async def gr_connect_to_bedrock( # noqa: PLR0913
236
  if llm is None:
237
  return f"❌ Connection failed: {error}"
238
 
239
- # client = MultiServerMCPClient(
240
- # {
241
- # "toolkit": {
242
- # "url": "https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse",
243
- # "transport": "sse",
244
- # },
245
- # }
246
- # )
247
- # tools = await client.get_tools()
248
- if mcp_servers:
249
- client = MultiServerMCPClient(
250
- {
251
- server.name.replace(" ", "-"): {
252
- "url": server.value,
253
- "transport": "sse",
254
- }
255
- for server in mcp_servers
256
- },
257
- )
258
- tools = await client.get_tools()
259
- else:
260
- tools = []
261
  llm_agent = create_react_agent(
262
  model=llm,
263
- tools=tools,
264
- prompt=SYSTEM_MESSAGE,
 
 
 
265
  )
266
 
267
  return "βœ… Successfully connected to AWS Bedrock!"
@@ -271,6 +377,8 @@ async def gr_connect_to_hf(
271
  model_id: str,
272
  hf_access_token_textbox: str | None,
273
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
 
 
274
  temperature: float = 0.8,
275
  max_tokens: int = 512,
276
  ) -> str:
@@ -286,34 +394,27 @@ async def gr_connect_to_hf(
286
 
287
  if llm is None:
288
  return f"❌ Connection failed: {error}"
289
- tools = []
290
- if mcp_servers:
291
- client = MultiServerMCPClient(
292
- {
293
- server.name.replace(" ", "-"): {
294
- "url": server.value,
295
- "transport": "sse",
296
- }
297
- for server in mcp_servers
298
- },
299
- )
300
- tools = await client.get_tools()
301
 
302
  llm_agent = create_react_agent(
303
  model=llm,
304
- tools=tools,
305
- prompt=SYSTEM_MESSAGE,
 
 
 
306
  )
307
 
308
  return "βœ… Successfully connected to Hugging Face!"
309
 
310
 
311
- async def gr_connect_to_azure(
312
  model_id: str,
313
  azure_endpoint: str,
314
  api_key: str,
315
  api_version: str,
316
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
 
 
317
  temperature: float = 0.8,
318
  max_tokens: int = 512,
319
  ) -> str:
@@ -331,59 +432,47 @@ async def gr_connect_to_azure(
331
 
332
  if llm is None:
333
  return f"❌ Connection failed: {error}"
334
- tools = []
335
- if mcp_servers:
336
- client = MultiServerMCPClient(
337
- {
338
- server.name.replace(" ", "-"): {
339
- "url": server.value,
340
- "transport": "sse",
341
- }
342
- for server in mcp_servers
343
- },
344
- )
345
- tools = await client.get_tools()
346
 
347
  llm_agent = create_react_agent(
348
  model=llm,
349
- tools=tools,
350
- prompt=SYSTEM_MESSAGE,
351
  )
352
 
353
  return "βœ… Successfully connected to Azure OpenAI!"
354
 
355
 
356
- async def gr_connect_to_nebius(
357
- model_id: str,
358
- nebius_access_token_textbox: str,
359
- mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
360
- ) -> str:
361
- """Initialize Hugging Face agent."""
362
- global llm_agent # noqa: PLW0603
363
-
364
- llm, error = create_openai_llm(model_id, nebius_access_token_textbox)
365
-
366
- if llm is None:
367
- return f"❌ Connection failed: {error}"
368
- tools = []
369
- if mcp_servers:
370
- client = MultiServerMCPClient(
371
- {
372
- server.name.replace(" ", "-"): {
373
- "url": server.value,
374
- "transport": "sse",
375
- }
376
- for server in mcp_servers
377
- },
378
- )
379
- tools = await client.get_tools()
380
-
381
- llm_agent = create_react_agent(
382
- model=str(llm),
383
- tools=tools,
384
- prompt=SYSTEM_MESSAGE,
385
- )
386
- return "βœ… Successfully connected to nebius!"
387
 
388
 
389
  async def gr_chat_function( # noqa: D103
@@ -401,12 +490,17 @@ async def gr_chat_function( # noqa: D103
401
 
402
  messages.append(HumanMessage(content=message))
403
  try:
 
 
 
404
  llm_response = await llm_agent.ainvoke(
405
  {
406
  "messages": messages,
407
  },
408
  )
409
- return llm_response["messages"][-1].content
 
 
410
  except Exception as err:
411
  raise gr.Error(
412
  f"We encountered an error while invoking the model:\n{err}",
@@ -414,106 +508,26 @@ async def gr_chat_function( # noqa: D103
414
  ) from err
415
 
416
 
417
- ## UI components ##
 
 
 
418
 
 
 
 
 
 
 
 
 
 
 
419
 
420
- # Function to toggle visibility and set model IDs
421
- def toggle_model_fields(
422
- provider: str,
423
- ) -> tuple[
424
- dict[str, Any],
425
- dict[str, Any],
426
- dict[str, Any],
427
- dict[str, Any],
428
- dict[str, Any],
429
- dict[str, Any],
430
- dict[str, Any],
431
- dict[str, Any],
432
- dict[str, Any],
433
- ]: # ignore: F821
434
- """Toggle visibility of model fields based on the selected provider."""
435
- # Update model choices based on the selected provider
436
- if provider in MODEL_OPTIONS:
437
- model_choices = list(MODEL_OPTIONS[provider].keys())
438
- model_pretty = gr.update(
439
- choices=model_choices,
440
- value=model_choices[0],
441
- visible=True,
442
- interactive=True,
443
- )
444
- else:
445
- model_pretty = gr.update(choices=[], visible=False)
446
-
447
- # Visibility settings for fields specific to each provider
448
- is_aws = provider == "AWS Bedrock"
449
- is_hf = provider == "HuggingFace"
450
- is_azure = provider == "Azure OpenAI"
451
- # is_nebius = provider == "Nebius"
452
- return (
453
- model_pretty,
454
- gr.update(visible=is_aws, interactive=is_aws),
455
- gr.update(visible=is_aws, interactive=is_aws),
456
- gr.update(visible=is_aws, interactive=is_aws),
457
- gr.update(visible=is_aws, interactive=is_aws),
458
- gr.update(visible=is_hf, interactive=is_hf),
459
- gr.update(visible=is_azure, interactive=is_azure),
460
- gr.update(visible=is_azure, interactive=is_azure),
461
- gr.update(visible=is_azure, interactive=is_azure),
462
- )
463
-
464
 
465
- async def update_connection_status( # noqa: PLR0913
466
- provider: str,
467
- model_id: str,
468
- mcp_list_state: Sequence[MutableCheckBoxGroupEntry] | None,
469
- aws_access_key_textbox: str,
470
- aws_secret_key_textbox: str,
471
- aws_session_token_textbox: str,
472
- aws_region_dropdown: str,
473
- hf_token: str,
474
- azure_endpoint: str,
475
- azure_api_token: str,
476
- azure_api_version: str,
477
- temperature: float,
478
- max_tokens: int,
479
- ) -> str:
480
- """Update the connection status based on the selected provider and model."""
481
- if not provider or not model_id:
482
- return "❌ Please select a provider and model."
483
- connection = "❌ Invalid provider"
484
- if provider == "AWS Bedrock":
485
- connection = await gr_connect_to_bedrock(
486
- model_id,
487
- aws_access_key_textbox,
488
- aws_secret_key_textbox,
489
- aws_session_token_textbox,
490
- aws_region_dropdown,
491
- mcp_list_state,
492
- temperature,
493
- max_tokens,
494
- )
495
- elif provider == "HuggingFace":
496
- connection = await gr_connect_to_hf(
497
- model_id,
498
- hf_token,
499
- mcp_list_state,
500
- temperature,
501
- max_tokens,
502
- )
503
- elif provider == "Azure OpenAI":
504
- connection = await gr_connect_to_azure(
505
- model_id,
506
- azure_endpoint,
507
- azure_api_token,
508
- azure_api_version,
509
- mcp_list_state,
510
- temperature,
511
- max_tokens,
512
- )
513
- elif provider == "Nebius":
514
- connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state)
515
 
516
- return connection
517
 
518
 
519
  with (
@@ -549,65 +563,66 @@ with (
549
  value=None,
550
  label="Select Model Provider",
551
  )
552
- aws_access_key_textbox = gr.Textbox(
553
- label="AWS Access Key ID",
554
- type="password",
555
- placeholder="Enter your AWS Access Key ID",
556
- visible=False,
557
- )
558
- aws_secret_key_textbox = gr.Textbox(
559
- label="AWS Secret Access Key",
560
- type="password",
561
- placeholder="Enter your AWS Secret Access Key",
562
- visible=False,
563
- )
564
- aws_region_dropdown = gr.Dropdown(
565
- label="AWS Region",
566
- choices=[
567
- "us-east-1",
568
- "us-west-2",
569
- "eu-west-1",
570
- "eu-central-1",
571
- "ap-southeast-1",
572
- ],
573
- value="eu-west-1",
574
- visible=False,
575
- )
576
- aws_session_token_textbox = gr.Textbox(
577
- label="AWS Session Token",
578
- type="password",
579
- placeholder="Enter your AWS session token",
580
- visible=False,
581
- )
582
- hf_token = gr.Textbox(
583
- label="HuggingFace Token",
584
- type="password",
585
- placeholder="Enter your Hugging Face Access Token",
586
- visible=False,
587
- )
588
- azure_endpoint = gr.Textbox(
589
- label="Azure OpenAI Endpoint",
590
- type="text",
591
- placeholder="Enter your Azure OpenAI Endpoint",
592
- visible=False,
593
- )
594
- azure_api_token = gr.Textbox(
595
- label="Azure Access Token",
596
- type="password",
597
- placeholder="Enter your Azure OpenAI Access Token",
598
- visible=False,
599
- )
600
- azure_api_version = gr.Textbox(
601
- label="Azure OpenAI API Version",
602
- type="text",
603
- placeholder="Enter your Azure OpenAI API Version",
604
- value="2024-12-01-preview",
605
- visible=False,
606
- )
 
607
 
608
  with gr.Accordion("🧠 Model Configuration", open=True):
609
- model_display_id = gr.Dropdown(
610
- label="Select Model from the list",
611
  choices=[],
612
  visible=False,
613
  )
@@ -618,31 +633,24 @@ with (
618
  visible=False,
619
  interactive=True,
620
  )
621
- model_provider.change(
622
- toggle_model_fields,
623
- inputs=[model_provider],
624
- outputs=[
625
- model_display_id,
626
- aws_access_key_textbox,
627
- aws_secret_key_textbox,
628
- aws_session_token_textbox,
629
- aws_region_dropdown,
630
- hf_token,
631
- azure_endpoint,
632
- azure_api_token,
633
- azure_api_version,
634
- ],
635
- )
636
- model_display_id.change(
637
- lambda x, y: gr.update(
638
- value=MODEL_OPTIONS.get(y, {}).get(x),
639
- visible=True,
640
  )
641
- if x
642
- else model_id_textbox.value,
643
- inputs=[model_display_id, model_provider],
644
- outputs=[model_id_textbox],
645
- )
 
646
  # Initialize the temperature and max tokens based on model specifications
647
  temperature = gr.Slider(
648
  label="Temperature",
@@ -653,44 +661,157 @@ with (
653
  )
654
  max_tokens = gr.Slider(
655
  label="Max Tokens",
656
- minimum=64,
657
- maximum=4096,
658
- value=512,
659
  step=64,
660
  )
661
 
662
- connect_btn = gr.Button("πŸ”Œ Connect to Model", variant="primary")
663
- status_textbox = gr.Textbox(label="Connection Status", interactive=False)
664
-
665
- connect_btn.click(
666
- update_connection_status,
667
- inputs=[
668
- model_provider,
669
- model_id_textbox,
670
- mcp_list.state,
671
- aws_access_key_textbox,
672
- aws_secret_key_textbox,
673
- aws_session_token_textbox,
674
- aws_region_dropdown,
675
- hf_token,
676
- azure_endpoint,
677
- azure_api_token,
678
- azure_api_version,
679
- temperature,
680
- max_tokens,
681
- ],
682
- outputs=[status_textbox],
683
  )
 
 
 
 
 
 
 
684
 
685
  with gr.Column(scale=2):
686
  chat_interface = gr.ChatInterface(
687
  fn=gr_chat_function,
688
  type="messages",
689
  examples=[], # Add examples if needed
690
- title="πŸ‘©β€πŸ’» TDAgent",
691
- description="This is a simple agent that uses MCP tools.",
692
  )
693
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
694
 
695
  if __name__ == "__main__":
696
  gr_app.launch()
 
1
  from __future__ import annotations
2
 
3
+ import dataclasses
4
+ import enum
5
  import os
6
  from collections import OrderedDict
7
  from collections.abc import Mapping, Sequence
 
14
  import gradio as gr
15
  import gradio.themes as gr_themes
16
  from langchain_aws import ChatBedrock
17
+ from langchain_core.callbacks import BaseCallbackHandler
18
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
19
+ from langchain_core.tools import BaseTool
20
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
21
  from langchain_mcp_adapters.client import MultiServerMCPClient
22
  from langchain_openai import AzureChatOpenAI
 
33
 
34
  #### Constants ####
35
 
36
+
37
+ class AgentType(str, enum.Enum):
38
+ """TDAgent type."""
39
+
40
+ INCIDENT_HANDLER = "Incident handler"
41
+ DATA_ENRICHER = "Data enricher"
42
+
43
+ def __str__(self) -> str: # noqa: D105
44
+ return self.value
45
+
46
+
47
+ AGENT_SYSTEM_MESSAGES = OrderedDict(
48
+ (
49
+ (
50
+ AgentType.INCIDENT_HANDLER,
51
+ """
52
  You are a security analyst assistant responsible for collecting, analyzing
53
  and disseminating actionable intelligence related to cyber threats,
54
  vulnerabilities and threat actors.
55
 
56
  When presented with potential incidents information or tickets, you should
57
+ evaluate the presented evidence, gather additional data using any tool at
58
+ your disposal and take corrective actions if possible.
59
+
60
+ Afterwards, generate a cybersecurity report including: key findings, challenges,
61
+ actions taken and recommendations.
62
 
 
63
  Never use external means of communication, like emails or SMS, unless
64
  instructed to do so.
65
  """.strip(),
66
+ ),
67
+ (
68
+ AgentType.DATA_ENRICHER,
69
+ """
70
+ You are a cybersecurity incidence data enriching assistant. Analysts
71
+ will present information about security incidents and you must use
72
+ all the tools at your disposal to enrich the data as much as possible.
73
+ """.strip(),
74
+ ),
75
+ ),
76
  )
77
 
78
 
 
83
  },
84
  )
85
 
86
+
87
  MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
88
  (
89
  (
 
119
  ),
120
  )
121
 
122
+
123
+ @dataclasses.dataclass
124
+ class ToolInvocationInfo:
125
+ """Information related to a tool invocation by the LLM."""
126
+
127
+ name: str
128
+ inputs: Mapping[str, Any]
129
+
130
+
131
+ class ToolsTracerCallback(BaseCallbackHandler):
132
+ """Callback that registers tools invoked by the Agent."""
133
+
134
+ def __init__(self) -> None:
135
+ self._tools_trace: list[ToolInvocationInfo] = []
136
+
137
+ def on_tool_start( # noqa: D102
138
+ self,
139
+ serialized: dict[str, Any],
140
+ *args: Any,
141
+ inputs: dict[str, Any] | None = None,
142
+ **kwargs: Any,
143
+ ) -> Any:
144
+ self._tools_trace.append(
145
+ ToolInvocationInfo(
146
+ name=serialized.get("name", "<unknown-function-name>"),
147
+ inputs=inputs if inputs else {},
148
+ ),
149
+ )
150
+ return super().on_tool_start(serialized, *args, inputs=inputs, **kwargs)
151
+
152
+ @property
153
+ def tools_trace(self) -> Sequence[ToolInvocationInfo]:
154
+ """Tools trace information."""
155
+ return self._tools_trace
156
+
157
+ def clear(self) -> None:
158
+ """Clear tools trace."""
159
+ self._tools_trace.clear()
160
+
161
+
162
  #### Shared variables ####
163
 
164
  llm_agent: CompiledGraph | None = None
165
+ llm_tools_tracer: ToolsTracerCallback | None = None
166
 
167
  #### Utility functions ####
168
 
 
228
 
229
 
230
  ## OpenAI LLM creation ##
231
+
232
+
233
  def create_openai_llm(
234
  model_id: str,
235
  token_id: str,
 
280
 
281
 
282
  #### UI functionality ####
283
+
284
+
285
+ async def gr_fetch_mcp_tools(
286
+ mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
287
+ *,
288
+ trace_tools: bool,
289
+ ) -> list[BaseTool]:
290
+ """Fetch tools from MCP servers."""
291
+ global llm_tools_tracer # noqa: PLW0603
292
+
293
+ if mcp_servers:
294
+ client = MultiServerMCPClient(
295
+ {
296
+ server.name.replace(" ", "-"): {
297
+ "url": server.value,
298
+ "transport": "sse",
299
+ }
300
+ for server in mcp_servers
301
+ },
302
+ )
303
+ tools = await client.get_tools()
304
+ if trace_tools:
305
+ llm_tools_tracer = ToolsTracerCallback()
306
+ for tool in tools:
307
+ if tool.callbacks is None:
308
+ tool.callbacks = [llm_tools_tracer]
309
+ elif isinstance(tool.callbacks, list):
310
+ tool.callbacks.append(llm_tools_tracer)
311
+ else:
312
+ tool.callbacks.add_handler(llm_tools_tracer)
313
+ else:
314
+ llm_tools_tracer = None
315
+
316
+ return tools
317
+
318
+ return []
319
+
320
+
321
+ def gr_make_system_message(
322
+ agent_type: AgentType,
323
+ ) -> SystemMessage:
324
+ """Make agent's system message."""
325
+ try:
326
+ system_msg = AGENT_SYSTEM_MESSAGES[agent_type]
327
+ except KeyError as err:
328
+ raise gr.Error(f"Unknown agent type '{agent_type}'") from err
329
+
330
+ return SystemMessage(system_msg)
331
+
332
+
333
  async def gr_connect_to_bedrock( # noqa: PLR0913
334
  model_id: str,
335
  access_key: str,
 
337
  session_token: str,
338
  region: str,
339
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
340
+ agent_type: AgentType,
341
+ trace_tool_calls: bool,
342
  temperature: float = 0.8,
343
  max_tokens: int = 512,
344
  ) -> str:
345
  """Initialize Bedrock agent."""
346
  global llm_agent # noqa: PLW0603
347
+
348
  if not access_key or not secret_key:
349
  return "❌ Please provide both Access Key ID and Secret Access Key"
350
 
 
361
  if llm is None:
362
  return f"❌ Connection failed: {error}"
363
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  llm_agent = create_react_agent(
365
  model=llm,
366
+ tools=await gr_fetch_mcp_tools(
367
+ mcp_servers,
368
+ trace_tools=trace_tool_calls,
369
+ ),
370
+ prompt=gr_make_system_message(agent_type=agent_type),
371
  )
372
 
373
  return "βœ… Successfully connected to AWS Bedrock!"
 
377
  model_id: str,
378
  hf_access_token_textbox: str | None,
379
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
380
+ agent_type: AgentType,
381
+ trace_tool_calls: bool,
382
  temperature: float = 0.8,
383
  max_tokens: int = 512,
384
  ) -> str:
 
394
 
395
  if llm is None:
396
  return f"❌ Connection failed: {error}"
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
  llm_agent = create_react_agent(
399
  model=llm,
400
+ tools=await gr_fetch_mcp_tools(
401
+ mcp_servers,
402
+ trace_tools=trace_tool_calls,
403
+ ),
404
+ prompt=gr_make_system_message(agent_type=agent_type),
405
  )
406
 
407
  return "βœ… Successfully connected to Hugging Face!"
408
 
409
 
410
+ async def gr_connect_to_azure( # noqa: PLR0913
411
  model_id: str,
412
  azure_endpoint: str,
413
  api_key: str,
414
  api_version: str,
415
  mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
416
+ agent_type: AgentType,
417
+ trace_tool_calls: bool,
418
  temperature: float = 0.8,
419
  max_tokens: int = 512,
420
  ) -> str:
 
432
 
433
  if llm is None:
434
  return f"❌ Connection failed: {error}"
 
 
 
 
 
 
 
 
 
 
 
 
435
 
436
  llm_agent = create_react_agent(
437
  model=llm,
438
+ tools=await gr_fetch_mcp_tools(mcp_servers, trace_tools=trace_tool_calls),
439
+ prompt=gr_make_system_message(agent_type=agent_type),
440
  )
441
 
442
  return "βœ… Successfully connected to Azure OpenAI!"
443
 
444
 
445
+ # async def gr_connect_to_nebius(
446
+ # model_id: str,
447
+ # nebius_access_token_textbox: str,
448
+ # mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
449
+ # ) -> str:
450
+ # """Initialize Hugging Face agent."""
451
+ # global llm_agent
452
+
453
+ # llm, error = create_openai_llm(model_id, nebius_access_token_textbox)
454
+
455
+ # if llm is None:
456
+ # return f"❌ Connection failed: {error}"
457
+ # tools = []
458
+ # if mcp_servers:
459
+ # client = MultiServerMCPClient(
460
+ # {
461
+ # server.name.replace(" ", "-"): {
462
+ # "url": server.value,
463
+ # "transport": "sse",
464
+ # }
465
+ # for server in mcp_servers
466
+ # },
467
+ # )
468
+ # tools = await client.get_tools()
469
+
470
+ # llm_agent = create_react_agent(
471
+ # model=str(llm),
472
+ # tools=tools,
473
+ # prompt=SYSTEM_MESSAGE,
474
+ # )
475
+ # return "βœ… Successfully connected to nebius!"
476
 
477
 
478
  async def gr_chat_function( # noqa: D103
 
490
 
491
  messages.append(HumanMessage(content=message))
492
  try:
493
+ if llm_tools_tracer is not None:
494
+ llm_tools_tracer.clear()
495
+
496
  llm_response = await llm_agent.ainvoke(
497
  {
498
  "messages": messages,
499
  },
500
  )
501
+ return _add_tools_trace_to_message(
502
+ llm_response["messages"][-1].content,
503
+ )
504
  except Exception as err:
505
  raise gr.Error(
506
  f"We encountered an error while invoking the model:\n{err}",
 
508
  ) from err
509
 
510
 
511
+ def _add_tools_trace_to_message(message: str) -> str:
512
+ if not llm_tools_tracer or not llm_tools_tracer.tools_trace:
513
+ return message
514
+ import json
515
 
516
+ traces = []
517
+ for index, tool_info in enumerate(llm_tools_tracer.tools_trace):
518
+ trace_msg = f" {index}. {tool_info.name}"
519
+ if tool_info.inputs:
520
+ trace_msg += "\n"
521
+ trace_msg += " * Arguments:\n"
522
+ trace_msg += " ```json\n"
523
+ trace_msg += f" {json.dumps(tool_info.inputs, indent=4)}\n"
524
+ trace_msg += " ```\n"
525
+ traces.append(trace_msg)
526
 
527
+ return f"{message}\n\n# Tools Trace\n\n" + "\n".join(traces)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
 
530
+ ## UI components ##
531
 
532
 
533
  with (
 
563
  value=None,
564
  label="Select Model Provider",
565
  )
566
+
567
+ ## Amazon Bedrock Configuration ##
568
+ with gr.Group(visible=False) as aws_bedrock_conf_group:
569
+ aws_access_key_textbox = gr.Textbox(
570
+ label="AWS Access Key ID",
571
+ type="password",
572
+ placeholder="Enter your AWS Access Key ID",
573
+ )
574
+ aws_secret_key_textbox = gr.Textbox(
575
+ label="AWS Secret Access Key",
576
+ type="password",
577
+ placeholder="Enter your AWS Secret Access Key",
578
+ )
579
+ aws_region_dropdown = gr.Dropdown(
580
+ label="AWS Region",
581
+ choices=[
582
+ "us-east-1",
583
+ "us-west-2",
584
+ "eu-west-1",
585
+ "eu-central-1",
586
+ "ap-southeast-1",
587
+ ],
588
+ value="eu-west-1",
589
+ )
590
+ aws_session_token_textbox = gr.Textbox(
591
+ label="AWS Session Token",
592
+ type="password",
593
+ placeholder="Enter your AWS session token",
594
+ )
595
+
596
+ ## Huggingface Configuration ##
597
+ with gr.Group(visible=False) as hf_conf_group:
598
+ hf_token = gr.Textbox(
599
+ label="HuggingFace Token",
600
+ type="password",
601
+ placeholder="Enter your Hugging Face Access Token",
602
+ )
603
+
604
+ ## Azure Configuration ##
605
+ with gr.Group(visible=False) as azure_conf_group:
606
+ azure_endpoint = gr.Textbox(
607
+ label="Azure OpenAI Endpoint",
608
+ type="text",
609
+ placeholder="Enter your Azure OpenAI Endpoint",
610
+ )
611
+ azure_api_token = gr.Textbox(
612
+ label="Azure Access Token",
613
+ type="password",
614
+ placeholder="Enter your Azure OpenAI Access Token",
615
+ )
616
+ azure_api_version = gr.Textbox(
617
+ label="Azure OpenAI API Version",
618
+ type="text",
619
+ placeholder="Enter your Azure OpenAI API Version",
620
+ value="2024-12-01-preview",
621
+ )
622
 
623
  with gr.Accordion("🧠 Model Configuration", open=True):
624
+ model_id_dropdown = gr.Dropdown(
625
+ label="Select known model id or type your own below",
626
  choices=[],
627
  visible=False,
628
  )
 
633
  visible=False,
634
  interactive=True,
635
  )
636
+
637
+ # Agent configuration options
638
+ with gr.Group():
639
+ agent_system_message_radio = gr.Radio(
640
+ choices=list(AGENT_SYSTEM_MESSAGES.keys()),
641
+ value=next(iter(AGENT_SYSTEM_MESSAGES.keys())),
642
+ label="Agent type",
643
+ info=(
644
+ "Changes the system message to pre-condition the agent"
645
+ " to act in a desired way."
646
+ ),
 
 
 
 
 
 
 
 
647
  )
648
+ agent_trace_tools_checkbox = gr.Checkbox(
649
+ value=False,
650
+ label="Trace tool calls",
651
+ info="Add the invoked tools trace at the end of the message",
652
+ )
653
+
654
  # Initialize the temperature and max tokens based on model specifications
655
  temperature = gr.Slider(
656
  label="Temperature",
 
661
  )
662
  max_tokens = gr.Slider(
663
  label="Max Tokens",
664
+ minimum=128,
665
+ maximum=8192,
666
+ value=2048,
667
  step=64,
668
  )
669
 
670
+ connect_aws_bedrock_btn = gr.Button(
671
+ "πŸ”Œ Connect to Bedrock",
672
+ variant="primary",
673
+ visible=False,
674
+ )
675
+ connect_hf_btn = gr.Button(
676
+ "πŸ”Œ Connect to Huggingface πŸ€—",
677
+ variant="primary",
678
+ visible=False,
 
 
 
 
 
 
 
 
 
 
 
 
679
  )
680
+ connect_azure_btn = gr.Button(
681
+ "πŸ”Œ Connect to Azure",
682
+ variant="primary",
683
+ visible=False,
684
+ )
685
+
686
+ status_textbox = gr.Textbox(label="Connection Status", interactive=False)
687
 
688
  with gr.Column(scale=2):
689
  chat_interface = gr.ChatInterface(
690
  fn=gr_chat_function,
691
  type="messages",
692
  examples=[], # Add examples if needed
693
+ title="πŸ‘©β€πŸ’» TDAgent πŸ‘¨β€πŸ’»",
694
+ description="A simple threat analyst agent with MCP tools.",
695
  )
696
 
697
+ ## UI Events ##
698
+
699
+ def _toggle_model_choices_ui(
700
+ provider: str,
701
+ ) -> dict[str, Any]:
702
+ if provider in MODEL_OPTIONS:
703
+ model_choices = list(MODEL_OPTIONS[provider].keys())
704
+ return gr.update(
705
+ choices=model_choices,
706
+ value=model_choices[0],
707
+ visible=True,
708
+ interactive=True,
709
+ )
710
+
711
+ return gr.update(choices=[], visible=False)
712
+
713
+ def _toggle_model_aws_bedrock_conf_ui(
714
+ provider: str,
715
+ ) -> tuple[dict[str, Any], ...]:
716
+ is_aws = provider == "AWS Bedrock"
717
+ return gr.update(visible=is_aws), gr.update(visible=is_aws)
718
+
719
+ def _toggle_model_hf_conf_ui(
720
+ provider: str,
721
+ ) -> tuple[dict[str, Any], ...]:
722
+ is_hf = provider == "HuggingFace"
723
+ return gr.update(visible=is_hf), gr.update(visible=is_hf)
724
+
725
+ def _toggle_model_azure_conf_ui(
726
+ provider: str,
727
+ ) -> tuple[dict[str, Any], ...]:
728
+ is_azure = provider == "Azure OpenAI"
729
+ return gr.update(visible=is_azure), gr.update(visible=is_azure)
730
+
731
+ ## Connect Event Listeners ##
732
+
733
+ model_provider.change(
734
+ _toggle_model_choices_ui,
735
+ inputs=[model_provider],
736
+ outputs=[model_id_dropdown],
737
+ )
738
+ model_provider.change(
739
+ _toggle_model_aws_bedrock_conf_ui,
740
+ inputs=[model_provider],
741
+ outputs=[aws_bedrock_conf_group, connect_aws_bedrock_btn],
742
+ )
743
+ model_provider.change(
744
+ _toggle_model_hf_conf_ui,
745
+ inputs=[model_provider],
746
+ outputs=[hf_conf_group, connect_hf_btn],
747
+ )
748
+ model_provider.change(
749
+ _toggle_model_azure_conf_ui,
750
+ inputs=[model_provider],
751
+ outputs=[azure_conf_group, connect_azure_btn],
752
+ )
753
+
754
+ connect_aws_bedrock_btn.click(
755
+ gr_connect_to_bedrock,
756
+ inputs=[
757
+ model_id_textbox,
758
+ aws_access_key_textbox,
759
+ aws_secret_key_textbox,
760
+ aws_session_token_textbox,
761
+ aws_region_dropdown,
762
+ mcp_list.state,
763
+ agent_system_message_radio,
764
+ agent_trace_tools_checkbox,
765
+ temperature,
766
+ max_tokens,
767
+ ],
768
+ outputs=[status_textbox],
769
+ )
770
+
771
+ connect_hf_btn.click(
772
+ gr_connect_to_hf,
773
+ inputs=[
774
+ model_id_textbox,
775
+ hf_token,
776
+ mcp_list.state,
777
+ agent_system_message_radio,
778
+ agent_trace_tools_checkbox,
779
+ temperature,
780
+ max_tokens,
781
+ ],
782
+ outputs=[status_textbox],
783
+ )
784
+
785
+ connect_azure_btn.click(
786
+ gr_connect_to_azure,
787
+ inputs=[
788
+ model_id_textbox,
789
+ azure_endpoint,
790
+ azure_api_token,
791
+ azure_api_version,
792
+ mcp_list.state,
793
+ agent_system_message_radio,
794
+ agent_trace_tools_checkbox,
795
+ temperature,
796
+ max_tokens,
797
+ ],
798
+ outputs=[status_textbox],
799
+ )
800
+
801
+ model_id_dropdown.change(
802
+ lambda x, y: (
803
+ gr.update(
804
+ value=MODEL_OPTIONS.get(y, {}).get(x),
805
+ visible=True,
806
+ )
807
+ if x
808
+ else model_id_textbox.value
809
+ ),
810
+ inputs=[model_id_dropdown, model_provider],
811
+ outputs=[model_id_textbox],
812
+ )
813
+
814
+ ## Entry Point ##
815
 
816
  if __name__ == "__main__":
817
  gr_app.launch()