diff --git a/.tekton/on-cm-runner.yaml b/.tekton/on-cm-runner.yaml
index 27d0f667..6630f8f2 100644
--- a/.tekton/on-cm-runner.yaml
+++ b/.tekton/on-cm-runner.yaml
@@ -269,7 +269,7 @@ spec:
resources:
requests:
memory: "1Gi"
- cpu: "500m"
+ cpu: "300m"
readinessProbe:
tcpSocket:
port: 9092
diff --git a/Makefile b/Makefile
index 74b2a19c..cf733572 100644
--- a/Makefile
+++ b/Makefile
@@ -71,8 +71,8 @@ test: test-unit test-llm-metrics ## Run all tests.
@echo "All tests have been run."
test-unit: ## Run unit tests.
- @echo "Running unit tests in $(SRC_DIR)..."
- @python -m pytest $(SRC_DIR) $(PYTEST_OPTS)
+ @echo "Running unit tests in $(SRC_DIR) and tests/..."
+ @python -m pytest $(SRC_DIR) tests/ $(PYTEST_OPTS)
test-llm-metrics: ## Run LLM metrics tests.
@echo "Running LLM metrics tests..."
diff --git a/src/exploit_iq_commons/utils/dep_tree.py b/src/exploit_iq_commons/utils/dep_tree.py
index 68a42657..e3c9c531 100644
--- a/src/exploit_iq_commons/utils/dep_tree.py
+++ b/src/exploit_iq_commons/utils/dep_tree.py
@@ -142,6 +142,10 @@ class DependencyTreeBuilder(ABC):
supported ecosystem.
"""
+ # Directory name where dependency source files are stored (e.g. "vendor", "node_modules").
+ # Each subclass sets this to its ecosystem's convention.
+ DEP_SOURCE_DIR: str = ""
+
@abstractmethod
# Build a sort of "upside down" tree - a dict containing mapping of each
# package to a list of all consuming packages
@@ -157,6 +161,8 @@ def install_dependencies(self, manifest_path: Path):
class CCppDependencyTreeBuilder(DependencyTreeBuilder):
+ DEP_SOURCE_DIR = C_DEP_LIBS_NAME
+
# Pre-compiled regex patterns (optimization: compile once, use many times)
INCLUDE_COMBINED_RE = re.compile(
r'#include\s*([<"])([^>"]+)[>"]'
@@ -181,7 +187,7 @@ def __init__(self):
"bench", "benchmark", "demo", "sample"
]
self.C_STANDARD_LIB = "glibc"
- self.RPM_LIBS_DIR = C_DEP_LIBS_NAME
+ self.RPM_LIBS_DIR = self.DEP_SOURCE_DIR
self.output_json_path = None
self.ccp_dep_tree = None
@@ -788,6 +794,7 @@ def find_project_name(self, root_dir="."):
class GoDependencyTreeBuilder(DependencyTreeBuilder):
+ DEP_SOURCE_DIR = "vendor"
def install_dependencies(self, manifest_path: Path):
self.download_go_mod_vendor(manifest_path)
@@ -897,6 +904,7 @@ def extract_package_name(self, package_name: str) -> str:
return package_name
class JavaDependencyTreeBuilder(DependencyTreeBuilder):
+ DEP_SOURCE_DIR = "dependencies-sources"
def __init__(self, query: str):
self._query = query
@@ -911,7 +919,7 @@ def __check_file_exists(self, dir_path: str | Path, filename: str) -> bool:
def install_dependencies(self, manifest_path: Path):
mvn_command = "mvn"
settings_path = os.getenv('JAVA_MAVEN_DEFAULT_SETTINGS_FILE_PATH','../../../../kustomize/base/settings.xml')
- source_path = "dependencies-sources"
+ source_path = self.DEP_SOURCE_DIR
if self.__check_file_exists(manifest_path, "mvnw"):
mvn_command = "./mvnw"
@@ -1168,6 +1176,7 @@ def looks_like_version(v: str) -> bool:
return depth, coord
class PythonDependencyTreeBuilder(DependencyTreeBuilder):
+ DEP_SOURCE_DIR = TRANSITIVE_ENV_NAME
def build_tree(self, manifest_path: Path) -> defaultdict[Any, list]:
venv_python = f'{manifest_path}/{TRANSITIVE_ENV_NAME}/bin/python'
@@ -1565,6 +1574,7 @@ def install_dependency(self, dependency, repo_path):
logger.warning('Failed to install dependency %s', dependency)
class JavaScriptDependencyTreeBuilder(DependencyTreeBuilder):
+ DEP_SOURCE_DIR = "node_modules"
def build_tree(self, manifest_path: Path) -> dict[str, list[str]]:
@@ -1639,6 +1649,25 @@ def get_dependency_tree_builder(programming_language: Ecosystem, query: str = ""
)
+# Maps each ecosystem to its builder class — used to build ECOSYSTEM_DEP_DIRS
+# from class-level DEP_SOURCE_DIR without instantiating builders.
+_ECOSYSTEM_BUILDER_MAP: dict[Ecosystem, type[DependencyTreeBuilder]] = {
+ Ecosystem.C_CPP: CCppDependencyTreeBuilder,
+ Ecosystem.GO: GoDependencyTreeBuilder,
+ Ecosystem.JAVA: JavaDependencyTreeBuilder,
+ Ecosystem.PYTHON: PythonDependencyTreeBuilder,
+ Ecosystem.JAVASCRIPT: JavaScriptDependencyTreeBuilder,
+}
+
+# Dynamic mapping of ecosystem → dependency source directory prefix.
+# Built from each builder's DEP_SOURCE_DIR class attribute.
+ECOSYSTEM_DEP_DIRS: dict[str, str] = {
+ eco.value: cls.DEP_SOURCE_DIR + "/"
+ for eco, cls in _ECOSYSTEM_BUILDER_MAP.items()
+ if cls.DEP_SOURCE_DIR
+}
+
+
class DependencyTree:
"""
A class that represents a dependency tree to access an appropriate
diff --git a/src/exploit_iq_commons/utils/document_embedding.py b/src/exploit_iq_commons/utils/document_embedding.py
index 035e8615..642757b2 100644
--- a/src/exploit_iq_commons/utils/document_embedding.py
+++ b/src/exploit_iq_commons/utils/document_embedding.py
@@ -217,6 +217,52 @@ def lazy_parse(self, blob: Blob) -> typing.Iterator[Document]:
)
+class _LoggingEmbeddingProxy:
+ """Wraps an Embeddings instance to log per-batch progress during VDB creation.
+
+ FAISS calls embed_documents once with all texts; the NIM SDK loops
+ internally in batches of max_batch_size calling _embed per batch.
+ This proxy intercepts embed_documents and does the batching itself
+ so it can log progress between batches.
+
+ NOTE: Tightly coupled with langchain_nvidia_ai_endpoints.NVIDIAEmbeddings.
+ Calls the private _embed(texts, model_type="passage") method directly.
+ Other Embeddings implementations (langchain ABC) don't expose _embed,
+ so this proxy will break if the embedding type changes or if NVIDIAEmbeddings
+ renames/removes _embed in a future version.
+ """
+
+ def __init__(self, embedding, total_chunks: int, start_time: float):
+ self._embedding = embedding
+ self._total_chunks = total_chunks
+ self._start_time = start_time
+ self._embedded = 0
+
+ def embed_documents(self, texts):
+ batch_size = getattr(self._embedding, "max_batch_size", 128)
+ all_embeddings = []
+ for i in range(0, len(texts), batch_size):
+ batch = texts[i:i + batch_size]
+ batch_start = time.time()
+ all_embeddings.extend(self._embedding._embed(batch, model_type="passage"))
+ self._embedded += len(batch)
+ elapsed = time.time() - self._start_time
+ rate = self._embedded / elapsed if elapsed > 0 else 0
+ remaining_min = ((self._total_chunks - self._embedded) / rate / 60) if rate > 0 else 0
+ logger.info("Embedding progress: %d / %d chunks (%.1f%%) - batch took %.2fs - ETA %.1f min",
+ self._embedded, self._total_chunks,
+ self._embedded / self._total_chunks * 100,
+ time.time() - batch_start,
+ remaining_min)
+ return all_embeddings
+
+ def embed_query(self, text):
+ return self._embedding.embed_query(text)
+
+ def __getattr__(self, name):
+ return getattr(self._embedding, name)
+
+
class DocumentEmbedding:
"""
A class to create a FAISS database from a list of source documents. The source documents are collected from git
@@ -374,8 +420,10 @@ def collect_documents(self, source_info: SourceDocumentsInfo) -> list[Document]:
"""
repo_path = self.get_repo_path(source_info)
+ cache_name = source_info.type if source_info.type != "code" else ""
documents, documents_were_in_cache = retrieve_from_cache(self._pickle_cache_directory,
- source_info.git_repo, source_info.ref)
+ source_info.git_repo, source_info.ref,
+ documents_name=cache_name)
if documents_were_in_cache or len(documents) > 0:
return documents
@@ -387,7 +435,8 @@ def collect_documents(self, source_info: SourceDocumentsInfo) -> list[Document]:
with repo_lock:
# Re-check cache — another thread may have populated it while we waited.
documents, documents_were_in_cache = retrieve_from_cache(self._pickle_cache_directory,
- source_info.git_repo, source_info.ref)
+ source_info.git_repo, source_info.ref,
+ documents_name=cache_name)
if documents_were_in_cache or len(documents) > 0:
return documents
@@ -403,7 +452,8 @@ def collect_documents(self, source_info: SourceDocumentsInfo) -> list[Document]:
documents = loader.load()
logger.info("Collected documents for '%s', Document count: %d", repo_path, len(documents))
- save_to_cache(self._pickle_cache_directory, source_info.git_repo, source_info.ref, documents)
+ save_to_cache(self._pickle_cache_directory, source_info.git_repo, source_info.ref, documents,
+ documents_name=cache_name)
return documents
def create_vdb(self, source_infos: list[SourceDocumentsInfo], output_path: PathLike):
@@ -465,8 +515,12 @@ def create_vdb(self, source_infos: list[SourceDocumentsInfo], output_path: PathL
embedding_start_time = time.time()
+ # Wrap embedding in a proxy that logs batch progress
+ total_chunks = len(chunked_documents)
+ logging_embedding = _LoggingEmbeddingProxy(self._embedding, total_chunks, embedding_start_time)
+
# Create the FAISS database
- db = FAISS.from_documents(chunked_documents, self._embedding)
+ db = FAISS.from_documents(chunked_documents, logging_embedding)
logger.info("Completed embedding in %.2f seconds for '%s'", time.time() - embedding_start_time, output_path)
@@ -513,7 +567,7 @@ def build_vdbs(self,
# Create embeddings for each source type
for source_type in ["code", "doc"]:
- if ignore_code_embedding:
+ if ignore_code_embedding and source_type == "code":
continue
# Filter the source documents
diff --git a/src/exploit_iq_commons/utils/functions_parsers/c_lang_function_parsers.py b/src/exploit_iq_commons/utils/functions_parsers/c_lang_function_parsers.py
index 37b73af1..7289e39c 100644
--- a/src/exploit_iq_commons/utils/functions_parsers/c_lang_function_parsers.py
+++ b/src/exploit_iq_commons/utils/functions_parsers/c_lang_function_parsers.py
@@ -763,3 +763,9 @@ def is_call_allowed(self, pkg_docs: list[Document], caller_function: Document, c
return False
return True
+
+ def get_import_search_patterns(self, package_name: str) -> list[re.Pattern]:
+ escaped = re.escape(package_name)
+ return [
+ re.compile(rf'#include\s*[<"]({escaped}[^>"]*)[>"]', re.IGNORECASE | re.MULTILINE),
+ ]
diff --git a/src/exploit_iq_commons/utils/functions_parsers/golang_functions_parsers.py b/src/exploit_iq_commons/utils/functions_parsers/golang_functions_parsers.py
index 78805062..ed7e9470 100644
--- a/src/exploit_iq_commons/utils/functions_parsers/golang_functions_parsers.py
+++ b/src/exploit_iq_commons/utils/functions_parsers/golang_functions_parsers.py
@@ -598,4 +598,11 @@ def is_package_imported(self, code_content: str, identifier: str, callee_package
package_name = import_package_line.split(r"\s")[1]
if package_name.strip().lower() == callee_package.strip().lower():
return True
- return False
\ No newline at end of file
+ return False
+
+ def get_import_search_patterns(self, package_name: str) -> list[re.Pattern]:
+ escaped = re.escape(package_name)
+ return [
+ re.compile(rf'import\s+"({escaped}[^"]*)"', re.IGNORECASE | re.MULTILINE),
+ re.compile(rf'import\s+\(\s*[^)]*"({escaped}[^"]*)"', re.IGNORECASE | re.MULTILINE),
+ ]
\ No newline at end of file
diff --git a/src/exploit_iq_commons/utils/functions_parsers/java_functions_parsers.py b/src/exploit_iq_commons/utils/functions_parsers/java_functions_parsers.py
index a5aff80f..231041d6 100644
--- a/src/exploit_iq_commons/utils/functions_parsers/java_functions_parsers.py
+++ b/src/exploit_iq_commons/utils/functions_parsers/java_functions_parsers.py
@@ -1570,17 +1570,20 @@ def _iter_fqcn_candidates(raw_type_token: str):
yield t
return
- # Simple token: try target allow-list by simple name
+ # Explicit import takes precedence: the caller's import statement
+ # is the definitive type binding for this simple name.
+ imp = _explicit_imports.get(t)
+ if imp:
+ yield imp
+ return
+
+ # Fallback: try target allow-list by simple name (only when
+ # the caller has no explicit import for this token).
cands = _target_by_simple.get(t)
if cands:
for fq in cands:
yield fq
- # Explicit import disambiguation
- imp = _explicit_imports.get(t)
- if imp:
- yield imp
-
# Wildcard imports: only usable if we can cheaply construct candidates
# (we only build candidates that are already in target_class_names to avoid work)
if _wild_import_pkgs and cands:
@@ -1748,7 +1751,9 @@ def _extract_ctor_type(expr: str) -> str:
caller_function_index=_caller_key(),
target_class_names=target_class_names,
function=function,
- code_documents=code_documents
+ code_documents=code_documents,
+ caller_explicit_imports=_explicit_imports,
+ caller_package=_caller_pkg,
)
if traced_cast:
return True
@@ -1807,7 +1812,9 @@ def _extract_ctor_type(expr: str) -> str:
caller_function_index=caller_key,
target_class_names=target_class_names,
function=function,
- code_documents=code_documents
+ code_documents=code_documents,
+ caller_explicit_imports=_explicit_imports,
+ caller_package=_caller_pkg,
)
if traced:
return True
@@ -1850,6 +1857,8 @@ def __trace_down_package(
target_class_names: frozenset[str],
function: Document,
code_documents: dict[str, Document],
+ caller_explicit_imports: dict[str, str] | None = None,
+ caller_package: str = "",
) -> bool:
variables_mappings = functions_local_variables_index.get(caller_function_index, {}) # CHANGED: safe fallback
parts = expression.split(".")
@@ -1880,6 +1889,22 @@ def _fqcn_candidates_from_token(type_token: str) -> list[str]:
# Derive simple name and match against allow-list by suffix
simple = t.rsplit(".", 1)[-1]
+
+ if caller_explicit_imports:
+ imported_fqcn = caller_explicit_imports.get(simple)
+ if imported_fqcn and imported_fqcn not in target_class_names:
+ return []
+
+ if (not caller_explicit_imports or simple not in caller_explicit_imports) and caller_package:
+ same_pkg_fqcn = f"{caller_package}.{simple}"
+ if same_pkg_fqcn not in target_class_names:
+ for td in type_documents:
+ td_src = td.metadata.get('source', '')
+ if simple in td_src and td_src.endswith(f"/{simple}.java"):
+ pkg_path = caller_package.replace('.', '/')
+ if pkg_path in td_src:
+ return []
+
out: list[str] = []
dot_suffix = "." + simple
dollar_suffix = "$" + simple
@@ -1916,7 +1941,9 @@ def _has_matching_type(type_token: str) -> bool:
struct_initializer_expression=struct_initializer_expression,
type_documents=type_documents,
value_list=value_list,
- target_class_names=target_class_names
+ target_class_names=target_class_names,
+ caller_explicit_imports=caller_explicit_imports,
+ caller_package=caller_package,
)
# Property/member is not in function, check if it's member/property of a type
@@ -3312,11 +3339,18 @@ def _simple_to_fqcns_index(target_class_names: frozenset[str]) -> dict[str, tupl
idx.setdefault(simple, []).append(fq)
return {k: tuple(v) for k, v in idx.items()}
- def _fqcn_candidates_from_token(self, type_token: str, target_class_names: frozenset[str]) -> tuple[str, ...]:
+ def _fqcn_candidates_from_token(self, type_token: str, target_class_names: frozenset[str],
+ caller_explicit_imports: dict[str, str] | None = None,
+ caller_package: str = "",
+ type_documents: list | None = None) -> tuple[str, ...]:
"""
Map a possibly-simple type token to FQCN candidates strictly within `target_class_names`.
- If token already equals an allowed FQCN => (token,)
- Else => all allowed FQCNs that share the same simple name
+ - If caller_explicit_imports maps the simple name to a FQCN NOT in target_class_names,
+ the caller's type is definitively different => ()
+ - If the caller's own package contains a class with this simple name and that FQCN
+ is NOT in target_class_names, the caller's type is the same-package one => ()
"""
t = self._normalize_type_token(type_token)
if not t:
@@ -3326,6 +3360,19 @@ def _fqcn_candidates_from_token(self, type_token: str, target_class_names: froze
simple = self._simple_name_from_type_token(t)
if not simple:
return ()
+ if caller_explicit_imports:
+ imported_fqcn = caller_explicit_imports.get(simple)
+ if imported_fqcn and imported_fqcn not in target_class_names:
+ return ()
+ if (not caller_explicit_imports or simple not in caller_explicit_imports) and caller_package and type_documents:
+ same_pkg_fqcn = f"{caller_package}.{simple}"
+ if same_pkg_fqcn not in target_class_names:
+ pkg_path = caller_package.replace('.', '/')
+ suffix = f"/{simple}.java"
+ for td in type_documents:
+ td_src = td.metadata.get('source', '')
+ if td_src.endswith(suffix) and pkg_path in td_src:
+ return ()
return self._simple_to_fqcns_index(target_class_names).get(simple, ())
def _has_matching_type_in_package(
@@ -3334,11 +3381,16 @@ def _has_matching_type_in_package(
type_token: str,
type_documents: list[Document],
target_class_names: frozenset[str],
+ caller_explicit_imports: dict[str, str] | None = None,
+ caller_package: str = "",
) -> bool:
"""
Calls __get_type_docs_matched_with_callee_type with FQCN candidates only.
"""
- for fq in self._fqcn_candidates_from_token(type_token, target_class_names):
+ for fq in self._fqcn_candidates_from_token(type_token, target_class_names,
+ caller_explicit_imports=caller_explicit_imports,
+ caller_package=caller_package,
+ type_documents=type_documents):
if self.__get_type_docs_matched_with_callee_type(callee_package, fq, type_documents, target_class_names):
return True
return False
@@ -3459,11 +3511,15 @@ def __lookup_package(
type_documents,
value_list,
target_class_names: frozenset[str],
+ caller_explicit_imports: dict[str, str] | None = None,
+ caller_package: str = "",
) -> bool:
if not struct_initializer_expression and resolved_type not in JAVA_METHOD_PRIM_TYPES:
if resolved_type and resolved_type not in JAVA_METHOD_PRIM_TYPES:
if self._has_matching_type_in_package(
- callee_package, resolved_type, type_documents, target_class_names
+ callee_package, resolved_type, type_documents, target_class_names,
+ caller_explicit_imports=caller_explicit_imports,
+ caller_package=caller_package,
):
return True
@@ -3473,7 +3529,9 @@ def __lookup_package(
elif struct_initializer_expression:
struct_type = struct_initializer_expression.group(0) # TODO list of expressions
if self._has_matching_type_in_package(
- callee_package, struct_type, type_documents, target_class_names
+ callee_package, struct_type, type_documents, target_class_names,
+ caller_explicit_imports=caller_explicit_imports,
+ caller_package=caller_package,
):
return True
@@ -3652,4 +3710,10 @@ def get_package_name(self, function: Document, package_name: str) -> str:
artifact_version = f"{parts[1]}:{parts[2]}"
else:
artifact_version = package_name
- return package_name if jar_name == artifact_version else ''
\ No newline at end of file
+ return package_name if jar_name == artifact_version else ''
+
+ def get_import_search_patterns(self, package_name: str) -> list[re.Pattern]:
+ escaped = re.escape(package_name)
+ return [
+ re.compile(rf"import\s+(?:static\s+)?({escaped}[\w.]*)\s*;", re.IGNORECASE | re.MULTILINE),
+ ]
\ No newline at end of file
diff --git a/src/exploit_iq_commons/utils/functions_parsers/javascript_functions_parser.py b/src/exploit_iq_commons/utils/functions_parsers/javascript_functions_parser.py
index 0275b64f..3c6017d7 100644
--- a/src/exploit_iq_commons/utils/functions_parsers/javascript_functions_parser.py
+++ b/src/exploit_iq_commons/utils/functions_parsers/javascript_functions_parser.py
@@ -1011,6 +1011,14 @@ def document_imports_package(self, documents: dict[str, Document], package_name:
importing_docs.append(doc)
return importing_docs
+ def get_import_search_patterns(self, package_name: str) -> list[re.Pattern]:
+ escaped = re.escape(package_name)
+ return [
+ re.compile(rf"""require\s*\(\s*['"]({escaped}[^'"]*)['"]\s*\)""", re.IGNORECASE | re.MULTILINE),
+ re.compile(rf"""import\s+.*?\s+from\s+['"]({escaped}[^'"]*)['"]\s*""", re.IGNORECASE | re.MULTILINE),
+ re.compile(rf"""import\s+['"]({escaped}[^'"]*)['"]\s*""", re.IGNORECASE | re.MULTILINE),
+ ]
+
def is_a_package(self, package_name: str, doc: Document) -> bool:
return package_name in self.get_package_names(doc)
diff --git a/src/exploit_iq_commons/utils/functions_parsers/lang_functions_parsers.py b/src/exploit_iq_commons/utils/functions_parsers/lang_functions_parsers.py
index f7160593..ab1da406 100644
--- a/src/exploit_iq_commons/utils/functions_parsers/lang_functions_parsers.py
+++ b/src/exploit_iq_commons/utils/functions_parsers/lang_functions_parsers.py
@@ -198,4 +198,7 @@ def create_dummy_for_standard_lib(self, package_name: str) -> bool:
# is this call allowed by the language?
def is_call_allowed(self,pkg_docs: list[Document], caller_function: Document, callee_function: Document) -> bool:
- return True
\ No newline at end of file
+ return True
+
+ def get_import_search_patterns(self, package_name: str) -> list[re.Pattern]:
+ return [re.compile(re.escape(package_name), re.IGNORECASE)]
\ No newline at end of file
diff --git a/src/exploit_iq_commons/utils/functions_parsers/python_functions_parser.py b/src/exploit_iq_commons/utils/functions_parsers/python_functions_parser.py
index f019cbbf..fe7eeec7 100644
--- a/src/exploit_iq_commons/utils/functions_parsers/python_functions_parser.py
+++ b/src/exploit_iq_commons/utils/functions_parsers/python_functions_parser.py
@@ -417,6 +417,13 @@ def is_package_imported(self, code_content: str, identifier: str, callee_package
return True
return False
+ def get_import_search_patterns(self, package_name: str) -> list[re.Pattern]:
+ escaped = re.escape(package_name)
+ return [
+ re.compile(rf"import\s+({escaped}[\w.]*)", re.IGNORECASE | re.MULTILINE),
+ re.compile(rf"from\s+({escaped}[\w.]*)\s+import\s+", re.IGNORECASE | re.MULTILINE),
+ ]
+
def is_a_package(self, package_name: str, doc: Document) -> bool:
return (not self.is_root_package(doc) and
self.get_package_name(function=doc, package_name=package_name))
\ No newline at end of file
diff --git a/src/vuln_analysis/configs/config-http-nim.yml b/src/vuln_analysis/configs/config-http-nim.yml
index d6321ae6..1d3ae4c2 100644
--- a/src/vuln_analysis/configs/config-http-nim.yml
+++ b/src/vuln_analysis/configs/config-http-nim.yml
@@ -78,6 +78,13 @@ functions:
max_retries: 5
Container Analysis Data:
_type: container_image_analysis_data
+ Configuration Scanner:
+ _type: configuration_scanner
+ max_results: 15
+ context_lines: 5
+ Import Usage Analyzer:
+ _type: import_usage_analyzer
+ max_files: 20
cve_agent_executor:
_type: cve_agent_executor
llm_name: cve_agent_executor_llm
@@ -90,6 +97,8 @@ functions:
- Function Caller Finder
- Function Locator
- Function Library Version Finder
+ - Configuration Scanner
+ - Import Usage Analyzer
max_concurrency: null
max_iterations: 10
prompt_examples: false
diff --git a/src/vuln_analysis/configs/config-http-openai.yml b/src/vuln_analysis/configs/config-http-openai.yml
index a67e18c1..80c16d7b 100644
--- a/src/vuln_analysis/configs/config-http-openai.yml
+++ b/src/vuln_analysis/configs/config-http-openai.yml
@@ -85,6 +85,13 @@ functions:
max_retries: 5
Container Analysis Data:
_type: container_image_analysis_data
+ Configuration Scanner:
+ _type: configuration_scanner
+ max_results: 15
+ context_lines: 5
+ Import Usage Analyzer:
+ _type: import_usage_analyzer
+ max_files: 20
cve_agent_executor:
_type: cve_agent_executor
llm_name: cve_agent_executor_llm
@@ -97,6 +104,8 @@ functions:
- Function Caller Finder
- Function Locator
- Function Library Version Finder
+ - Configuration Scanner
+ - Import Usage Analyzer
max_concurrency: null
max_iterations: 10
prompt_examples: false
diff --git a/src/vuln_analysis/configs/config-tracing.yml b/src/vuln_analysis/configs/config-tracing.yml
index 397258e6..a4d89ff9 100644
--- a/src/vuln_analysis/configs/config-tracing.yml
+++ b/src/vuln_analysis/configs/config-tracing.yml
@@ -89,6 +89,13 @@ functions:
max_retries: 5
Container Analysis Data:
_type: container_image_analysis_data
+ Configuration Scanner:
+ _type: configuration_scanner
+ max_results: 15
+ context_lines: 5
+ Import Usage Analyzer:
+ _type: import_usage_analyzer
+ max_files: 20
cve_agent_executor:
_type: cve_agent_executor
llm_name: cve_agent_executor_llm
@@ -101,6 +108,8 @@ functions:
- Function Caller Finder
- Function Locator
- Function Library Version Finder
+ - Configuration Scanner
+ - Import Usage Analyzer
max_concurrency: null
max_iterations: 10
prompt_examples: false
diff --git a/src/vuln_analysis/configs/config.yml b/src/vuln_analysis/configs/config.yml
index 779b8efe..9ba5f23b 100644
--- a/src/vuln_analysis/configs/config.yml
+++ b/src/vuln_analysis/configs/config.yml
@@ -62,6 +62,13 @@ functions:
_type: container_image_analysis_data
Function Library Version Finder:
_type: calling_function_library_version_finder
+ Configuration Scanner:
+ _type: configuration_scanner
+ max_results: 15
+ context_lines: 5
+ Import Usage Analyzer:
+ _type: import_usage_analyzer
+ max_files: 20
cve_agent_executor:
_type: cve_agent_executor
llm_name: cve_agent_executor_llm
@@ -71,6 +78,8 @@ functions:
# - Code Keyword Search # Uncomment to enable keyword search
- CVE Web Search
- Function Library Version Finder
+ - Configuration Scanner
+ - Import Usage Analyzer
max_concurrency: null
max_iterations: 10
prompt_examples: false
diff --git a/src/vuln_analysis/functions/agent_registry.py b/src/vuln_analysis/functions/agent_registry.py
new file mode 100644
index 00000000..4c12f81a
--- /dev/null
+++ b/src/vuln_analysis/functions/agent_registry.py
@@ -0,0 +1,41 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from vuln_analysis.functions.base_graph_agent import BaseGraphAgent
+
+_AGENT_REGISTRY: dict[str, type[BaseGraphAgent]] = {}
+
+
+def register_agent(agent_type: str):
+ """Class decorator that registers a BaseGraphAgent subclass under the given type name."""
+ def wrapper(cls):
+ _AGENT_REGISTRY[agent_type] = cls
+ return cls
+ return wrapper
+
+
+def get_agent_class(agent_type: str) -> type[BaseGraphAgent]:
+ if agent_type not in _AGENT_REGISTRY:
+ raise KeyError(f"Unknown agent type '{agent_type}'. Registered: {list(_AGENT_REGISTRY.keys())}")
+ return _AGENT_REGISTRY[agent_type]
+
+
+def get_all_agent_types() -> list[str]:
+ return list(_AGENT_REGISTRY.keys())
diff --git a/src/vuln_analysis/functions/base_graph_agent.py b/src/vuln_analysis/functions/base_graph_agent.py
new file mode 100644
index 00000000..063e3baf
--- /dev/null
+++ b/src/vuln_analysis/functions/base_graph_agent.py
@@ -0,0 +1,473 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+import uuid
+from abc import ABC, abstractmethod
+
+from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, RemoveMessage
+from langgraph.graph import StateGraph, END, START
+from langgraph.prebuilt import ToolNode
+
+from exploit_iq_commons.logging.loggers_factory import LoggingFactory
+from nat.builder.context import Context
+from vuln_analysis.tools.tool_names import ToolNames
+from vuln_analysis.functions.react_internals import (
+ AgentState,
+ BaseRulesTracker,
+ Thought,
+ Observation,
+ CodeFindings,
+ PackageSelection,
+ _build_tool_arguments,
+ FORCED_FINISH_PROMPT,
+ COMPREHENSION_PROMPT,
+ MEMORY_UPDATE_PROMPT,
+)
+from vuln_analysis.runtime_context import ctx_state
+from vuln_analysis.utils.token_utils import count_tokens, estimate_tokens, truncate_tool_output
+
+logger = LoggingFactory.get_agent_logger(__name__)
+AGENT_TRACER = Context.get()
+
+_TOOL_AVAILABILITY = {
+ ToolNames.CODE_SEMANTIC_SEARCH: lambda config, state: state.code_vdb_path is not None,
+ ToolNames.DOCS_SEMANTIC_SEARCH: lambda config, state: state.doc_vdb_path is not None,
+ ToolNames.CODE_KEYWORD_SEARCH: lambda config, state: state.code_index_path is not None,
+ ToolNames.IMPORT_USAGE_ANALYZER: lambda config, state: state.code_index_path is not None,
+ ToolNames.CVE_WEB_SEARCH: lambda config, state: config.cve_web_search_enabled,
+ ToolNames.CALL_CHAIN_ANALYZER: lambda config, state: config.transitive_search_tool_enabled and state.code_index_path is not None,
+ ToolNames.FUNCTION_CALLER_FINDER: lambda config, state: config.transitive_search_tool_enabled and state.code_index_path is not None,
+ ToolNames.FUNCTION_LOCATOR: lambda config, state: config.transitive_search_tool_enabled and state.code_index_path is not None,
+}
+
+
+def _is_tool_available(tool_name, config, state):
+ check = _TOOL_AVAILABILITY.get(tool_name)
+ return check(config, state) if check else True
+
+
+class BaseGraphAgent(ABC):
+ """Template for LangGraph-based CVE investigation agents.
+
+ Subclasses must implement:
+ - pre_process_node: initialize agent state (prompts, package selection, etc.)
+ - get_tools: load and select which tools this agent uses
+ - create_rules_tracker: return the appropriate RulesTracker instance
+
+ Shared graph nodes (thought, observation, forced_finish, should_continue)
+ and the graph wiring are provided by this base class.
+ """
+
+ def __init__(self, tools: list, llm, config):
+ self.tools = tools
+ self.config = config
+ self.thought_llm = llm.with_structured_output(Thought)
+ self.comprehension_llm = llm.with_structured_output(CodeFindings)
+ self.observation_llm = llm.with_structured_output(Observation)
+ self.package_filter_llm = llm.with_structured_output(PackageSelection)
+
+ @property
+ def agent_type(self) -> str:
+ """Short identifier for tracing spans (e.g. 'reachability', 'cu')."""
+ return "base"
+
+ @staticmethod
+ def _load_all_tools(builder, config) -> list:
+ from aiq.builder.framework_enum import LLMFrameworkEnum
+ return builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
+
+ async def _select_package(
+ self, ecosystem: str, candidate_packages: list[dict], critical_context: list[str],
+ workflow_state,
+ ) -> tuple[list[str], str | None]:
+ """Select target package from candidates using LLM filter.
+ Returns (filtered_critical_context, selected_package).
+ """
+ from langchain_core.messages import HumanMessage
+ from vuln_analysis.functions.react_internals import build_package_filter_prompt, _find_image_matching_candidate
+ from vuln_analysis.utils.intel_utils import filter_context_to_package
+
+ selected_package = None
+ if len(candidate_packages) > 1:
+ image_input = workflow_state.original_input.input.image
+ image_name = image_input.name
+ image_repo = image_input.source_info[0].git_repo if image_input.source_info else None
+
+ matched = _find_image_matching_candidate(candidate_packages, image_name, image_repo)
+ if matched:
+ selected_package = matched
+ logger.info("Package filter matched '%s' from image/repo (no LLM call needed, %d candidates)",
+ selected_package, len(candidate_packages))
+ else:
+ filter_prompt = build_package_filter_prompt(
+ ecosystem, candidate_packages,
+ image_name=image_name, image_repo=image_repo,
+ critical_context=critical_context,
+ )
+ selection = await self.package_filter_llm.ainvoke([HumanMessage(content=filter_prompt)])
+ selected_package = selection.selected_package
+ logger.info("Package filter selected '%s' from %d candidates (reason: %s)",
+ selected_package, len(candidate_packages), selection.reason)
+ critical_context = filter_context_to_package(critical_context, selected_package, candidate_packages)
+ elif len(candidate_packages) == 1:
+ selected_package = candidate_packages[0].get("name")
+ logger.info("Single candidate package: '%s'", selected_package)
+ critical_context = filter_context_to_package(critical_context, selected_package, candidate_packages)
+ return critical_context, selected_package
+
+ # --- Template methods (subclasses MUST override) ---
+
+ @abstractmethod
+ async def pre_process_node(self, state: AgentState) -> AgentState:
+ """Initialize agent state: select package, build prompts, set rules."""
+ ...
+
+ @staticmethod
+ @abstractmethod
+ def get_tools(builder, config, state) -> list:
+ """Load tools from the builder and select those this agent needs."""
+ ...
+
+ @staticmethod
+ @abstractmethod
+ def create_rules_tracker() -> BaseRulesTracker:
+ """Return a fresh RulesTracker instance for a single question."""
+ ...
+
+ # --- Optional hooks (subclasses MAY override) ---
+
+ def build_comprehension_context(self, state: AgentState) -> str:
+ """Return the vulnerability context string for the comprehension prompt.
+ Override in subclasses to reduce context and prevent hallucination."""
+ ctx_lines = state.get("critical_context", [])
+ return "\n".join(ctx_lines) if ctx_lines else "N/A"
+
+ def sanitize_findings(self, findings: list[str], state: AgentState) -> list[str]:
+ """Replace hallucinated CVE IDs in comprehension findings.
+ Only CVE IDs present in the current investigation are kept."""
+ workflow_state = ctx_state.get()
+ allowed_ids = {intel.vuln_id for intel in workflow_state.cve_intel}
+ sanitized = []
+ for finding in findings:
+ def _replace(m):
+ return m.group(0) if m.group(0) in allowed_ids else "the investigated vulnerability"
+ sanitized.append(re.sub(r"CVE-\d{4}-\d+", _replace, finding))
+ return sanitized
+
+ def post_observation(self, state: AgentState, tool_used: str,
+ tool_output: str, tool_input_detail: str) -> dict:
+ """Per-agent post-processing after observation_node core logic.
+ Returns dict to merge into observation_node output. Default: empty."""
+ return {}
+
+ def should_truncate_tool_output(self, state: AgentState, tool_used: str) -> bool:
+ """Whether to apply Java-specific truncation to tool output. Default: False."""
+ return False
+
+ def check_finish_allowed(self, state: AgentState) -> tuple[bool, str]:
+ """Hook to block premature finish. Returns (allowed, error_message).
+ Override in subclasses to enforce rules before allowing the agent to finish."""
+ return True, ""
+
+ # --- Shared concrete graph nodes ---
+
+ async def thought_node(self, state: AgentState) -> AgentState:
+ step_num = state.get("step", 0)
+ with AGENT_TRACER.push_active_function(f"{self.agent_type}_thought", input_data=f"step:{step_num}") as span:
+ try:
+ active_prompt = state.get("runtime_prompt")
+ messages = [SystemMessage(content=active_prompt)] + state["messages"]
+ obs = state.get("observation", None)
+ if obs is not None:
+ memory_list = obs.memory if obs.memory else ["No prior knowledge."]
+ recent_findings = obs.results if obs.results else ["No recent findings."]
+ memory_context = "\n".join(f"- {m}" for m in memory_list)
+ findings_context = "\n".join(f"- {f}" for f in recent_findings)
+ context_block = f"KNOWLEDGE:\n{memory_context}\nLATEST FINDINGS:\n{findings_context}"
+ messages.append(SystemMessage(content=context_block))
+
+ max_message_tokens = self.config.context_window_token_limit
+ raw_total = sum(
+ count_tokens(m.content) for m in messages
+ if hasattr(m, "content") and isinstance(m.content, str)
+ )
+ # cl100k_base undercounts by ~25% vs Llama 3.1's tokenizer
+ total = int(raw_total * 1.25)
+ if total > max_message_tokens and len(messages) > 2:
+ # Keep system prompt (messages[0]) and KNOWLEDGE block (messages[-1]).
+ # ToolMessages are prunable because observation_node already distilled
+ # their content into the KNOWLEDGE block via comprehension.
+ prunable = messages[1:-1]
+ for msg in prunable:
+ if total <= max_message_tokens:
+ break
+ raw_tok = count_tokens(msg.content) if hasattr(msg, "content") and isinstance(msg.content, str) else 0
+ messages.remove(msg)
+ total -= int(raw_tok * 1.25)
+ logger.info(
+ "thought_node pruning: estimated tokens now ~%d (limit %d) at step %d",
+ total, max_message_tokens, step_num,
+ )
+
+ response: Thought = await self.thought_llm.ainvoke(messages)
+
+ final_answer = "waiting for the agent to respond"
+ if response.mode == "finish":
+ finish_ok, finish_msg = self.check_finish_allowed(state)
+ if not finish_ok:
+ logger.info("%s finish blocked by rules at step %d: %s", self.agent_type, step_num, finish_msg)
+ span.set_output({"mode": "finish_blocked", "step": step_num + 1})
+ blocked_ai = AIMessage(content=response.final_answer or response.thought or "I want to finish.")
+ return {
+ "messages": [blocked_ai, HumanMessage(content=finish_msg)],
+ "thought": None,
+ "step": step_num + 1,
+ "max_steps": self.config.max_iterations,
+ "output": "waiting for the agent to respond",
+ }
+ ai_message = AIMessage(content=response.final_answer)
+ final_answer = response.final_answer
+ elif response.actions is None:
+ logger.warning("%s LLM returned mode='act' but actions is None, forcing finish", self.agent_type)
+ ai_message = AIMessage(content=response.thought or "No actions provided, finishing.")
+ response = Thought(
+ thought=response.thought or "No actions provided",
+ mode="finish",
+ actions=None,
+ final_answer=response.thought or "Insufficient evidence to provide a definitive answer."
+ )
+ final_answer = response.final_answer
+ else:
+ tool_name = response.actions.tool
+ try:
+ arguments = _build_tool_arguments(response.actions)
+ except ValueError as e:
+ logger.warning(
+ "%s bad tool arguments at step %d: %s", self.agent_type, step_num, e,
+ )
+ span.set_output({"error": str(e), "step": step_num + 1})
+ error_ai = AIMessage(content=response.thought or "I want to call a tool.")
+ return {
+ "messages": [error_ai, HumanMessage(
+ content=f"ERROR: {e}. Provide the required arguments and try again."
+ )],
+ "thought": None,
+ "step": step_num + 1,
+ "max_steps": self.config.max_iterations,
+ "output": "waiting for the agent to respond",
+ }
+ tool_call_id = str(uuid.uuid4())
+ ai_message = AIMessage(
+ content=response.thought,
+ tool_calls=[{
+ "name": tool_name,
+ "args": arguments,
+ "id": tool_call_id
+ }]
+ )
+
+ span.set_output({"mode": response.mode, "step": step_num + 1})
+ return {
+ "messages": [ai_message],
+ "thought": response,
+ "step": step_num + 1,
+ "max_steps": self.config.max_iterations,
+ "output": final_answer
+ }
+ except Exception as e:
+ logger.exception("%s thought_node failed at step %d", self.agent_type, step_num)
+ span.set_output({"error": str(e), "exception_type": type(e).__name__, "step": step_num})
+ raise
+
+ async def should_continue(self, state: AgentState) -> str:
+ if state.get("step", 0) >= state.get("max_steps", self.config.max_iterations):
+ return "forced_finish_node"
+ thought = state.get("thought", None)
+ if thought is None:
+ return "thought_node"
+ if thought.mode == "finish":
+ return END
+ return "tool_node"
+
+ @staticmethod
+ def _build_observation_context(obs) -> str | None:
+ """Format observation memory and recent findings into a context block."""
+ if obs is None:
+ return None
+ memory_list = obs.memory if obs.memory else []
+ recent_findings = obs.results if obs.results else []
+ if not memory_list and not recent_findings:
+ return None
+ parts = []
+ if memory_list:
+ parts.append("KNOWLEDGE:\n" + "\n".join(f"- {m}" for m in memory_list))
+ if recent_findings:
+ parts.append("LATEST FINDINGS:\n" + "\n".join(f"- {f}" for f in recent_findings))
+ return "\n".join(parts)
+
+ async def forced_finish_node(self, state: AgentState) -> AgentState:
+ step_num = state.get("step", 0)
+ with AGENT_TRACER.push_active_function(f"{self.agent_type}_forced_finish", input_data=f"step:{step_num}") as span:
+ try:
+ active_prompt = state.get("runtime_prompt")
+ messages = [SystemMessage(content=active_prompt)]
+ context_block = self._build_observation_context(state.get("observation", None))
+ if context_block:
+ messages.append(SystemMessage(content=context_block))
+ question = state.get("input", "")
+ finish_prompt = f"QUESTION: {question}\n\n{FORCED_FINISH_PROMPT}" if question else FORCED_FINISH_PROMPT
+ messages.append(HumanMessage(content=finish_prompt))
+ response: Thought = await self.thought_llm.ainvoke(messages)
+ if response.mode == "finish" and response.final_answer:
+ ai_message = AIMessage(content=response.final_answer)
+ final_answer = response.final_answer
+ else:
+ final_answer = "Failed to generate a final answer within the maximum allowed steps."
+ ai_message = AIMessage(content=final_answer)
+ response = Thought(
+ thought=response.thought or "Max steps exceeded",
+ mode="finish",
+ actions=None,
+ final_answer=final_answer
+ )
+ span.set_output({"final_answer_length": len(final_answer), "step": step_num})
+ return {
+ "messages": [ai_message],
+ "thought": response,
+ "step": step_num,
+ "max_steps": state.get("max_steps", self.config.max_iterations),
+ "observation": state.get("observation", None),
+ "output": final_answer
+ }
+ except Exception as e:
+ logger.exception("%s forced_finish_node failed at step %d", self.agent_type, step_num)
+ span.set_output({"error": str(e), "exception_type": type(e).__name__, "step": step_num})
+ raise
+
+ async def observation_node(self, state: AgentState) -> AgentState:
+ tool_message = state["messages"][-1]
+ last_thought_text = state["thought"].thought if state.get("thought") else "No previous thought."
+ tool_used = state["thought"].actions.tool if state.get("thought") and state["thought"].actions else "Unknown"
+ tool_input_detail = ""
+ if state.get("thought") and state["thought"].actions:
+ actions = state["thought"].actions
+ if actions.package_name and actions.function_name:
+ tool_input_detail = f"{actions.package_name},{actions.function_name}"
+ elif actions.query:
+ tool_input_detail = actions.query
+ elif actions.tool_input:
+ tool_input_detail = actions.tool_input
+ previous_memory = state.get("observation").memory if state.get("observation") else ["No data gathered yet."]
+ rules_tracker = state.get("rules_tracker")
+ with AGENT_TRACER.push_active_function(f"{self.agent_type}_observation", input_data=f"tool:{tool_used}") as span:
+ try:
+ tool_output_for_llm = tool_message.content
+ result, error_message = rules_tracker.check_thought_behavior(tool_used, tool_input_detail, tool_output_for_llm)
+ if result:
+ span.set_output({"rule_error": error_message})
+ return {"messages": [HumanMessage(content=error_message)]}
+
+ if self.should_truncate_tool_output(state, tool_used):
+ truncated_output = truncate_tool_output(tool_output_for_llm, tool_used)
+ else:
+ truncated_output = tool_output_for_llm
+
+ critical_context_text = self.build_comprehension_context(state)
+ comp_prompt = COMPREHENSION_PROMPT.format(
+ goal=state.get('input'),
+ selected_package=state.get('app_package') or "N/A",
+ critical_context=critical_context_text,
+ tool_used=tool_used,
+ tool_input_detail=tool_input_detail,
+ last_thought_text=last_thought_text,
+ tool_output=truncated_output,
+ )
+ code_findings: CodeFindings = await self.comprehension_llm.ainvoke([SystemMessage(content=comp_prompt)])
+
+ sanitized = self.sanitize_findings(code_findings.findings, state)
+ findings_text = "\n".join(f"- {f}" for f in sanitized)
+
+ mem_prompt = MEMORY_UPDATE_PROMPT.format(
+ goal=state.get('input'),
+ selected_package=state.get('app_package') or "N/A",
+ previous_memory=previous_memory,
+ findings=findings_text,
+ tool_outcome=code_findings.tool_outcome,
+ )
+ new_observation: Observation = await self.observation_llm.ainvoke([SystemMessage(content=mem_prompt)])
+
+ messages = state["messages"]
+ active_prompt = state.get("runtime_prompt")
+ estimated = estimate_tokens(active_prompt, messages, new_observation)
+ prune_messages = []
+ orig_estimated = estimated
+
+ span_trace_dict = {"comprehension_findings": sanitized, "tool_outcome": code_findings.tool_outcome}
+
+ if estimated > self.config.context_window_token_limit and len(messages) > 3:
+ prunable = messages[1:-2]
+ for msg in prunable:
+ prune_messages.append(RemoveMessage(id=msg.id))
+ estimated -= count_tokens(msg.content) if hasattr(msg, "content") and isinstance(msg.content, str) else 0
+ if estimated <= self.config.context_window_token_limit:
+ break
+ logger.info(
+ "Context pruning: removed %d messages, estimated tokens now ~%d (limit %d)",
+ len(prune_messages), estimated, self.config.context_window_token_limit,
+ )
+ span_trace_dict["orig_estimated"] = orig_estimated
+ span_trace_dict["estimated"] = estimated
+
+ # Agent-specific post-processing hook
+ extra = self.post_observation(state, tool_used, tool_output_for_llm, tool_input_detail)
+
+ span.set_output(span_trace_dict)
+ base_result = {
+ "messages": prune_messages,
+ "observation": new_observation,
+ "step": state.get("step", 0),
+ }
+ base_result.update(extra)
+ return base_result
+ except Exception as e:
+ logger.exception("%s observation_node failed", self.agent_type)
+ span.set_output({"error": str(e), "exception_type": type(e).__name__})
+ raise
+
+ # --- Graph wiring ---
+ # If overridden, should_continue must also be updated to match the new node names.
+
+ async def build_graph(self):
+ tool_node = ToolNode(self.tools, handle_tool_errors=True)
+
+ flow = StateGraph(AgentState)
+ flow.add_node("thought_node", self.thought_node)
+ flow.add_node("tool_node", tool_node)
+ flow.add_node("forced_finish_node", self.forced_finish_node)
+ flow.add_node("pre_process_node", self.pre_process_node)
+ flow.add_node("observation_node", self.observation_node)
+ flow.add_edge(START, "pre_process_node")
+ flow.add_edge("pre_process_node", "thought_node")
+ flow.add_conditional_edges(
+ "thought_node",
+ self.should_continue,
+ {END: END, "tool_node": "tool_node", "forced_finish_node": "forced_finish_node", "thought_node": "thought_node"}
+ )
+ flow.add_edge("tool_node", "observation_node")
+ flow.add_edge("observation_node", "thought_node")
+ flow.add_edge("forced_finish_node", END)
+
+ return flow.compile()
diff --git a/src/vuln_analysis/functions/code_understanding_agent.py b/src/vuln_analysis/functions/code_understanding_agent.py
new file mode 100644
index 00000000..7368dcd2
--- /dev/null
+++ b/src/vuln_analysis/functions/code_understanding_agent.py
@@ -0,0 +1,189 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from exploit_iq_commons.logging.loggers_factory import LoggingFactory
+from nat.builder.context import Context
+from vuln_analysis.functions.agent_registry import register_agent
+from vuln_analysis.functions.base_graph_agent import BaseGraphAgent, _is_tool_available
+
+from vuln_analysis.functions.react_internals import (
+ AgentState,
+ Observation,
+)
+from vuln_analysis.functions.code_understanding_internals import (
+ CU_AGENT_SYS_PROMPT,
+ CU_AGENT_THOUGHT_INSTRUCTIONS,
+ CodeUnderstandingRulesTracker,
+)
+from vuln_analysis.utils.code_understanding_prompt_factory import (
+ CU_TOOL_SELECTION_STRATEGY,
+ CU_TOOL_GENERAL_DESCRIPTIONS,
+)
+from vuln_analysis.utils.intel_utils import build_critical_context
+from vuln_analysis.runtime_context import ctx_state, cu_source_scope
+from vuln_analysis.tools.tool_names import ToolNames
+
+logger = LoggingFactory.get_agent_logger(__name__)
+AGENT_TRACER = Context.get()
+
+
+def _build_cu_system_prompt(descriptions: str, tool_guidance: str) -> str:
+ return f"""{CU_AGENT_SYS_PROMPT}
+
+
+{descriptions}
+
+
+
+{tool_guidance}
+
+
+{CU_AGENT_THOUGHT_INSTRUCTIONS}
+
+RESPONSE:
+{{"""
+
+
+def _build_cu_tool_guidance(ecosystem: str, available_tools: list) -> tuple[str, str]:
+ tool_names = [t.name for t in available_tools]
+ tool_desc_lines = [f"{t.name}: {t.description}" for t in available_tools]
+
+ lang = ecosystem.lower() if ecosystem else ""
+ if lang in CU_TOOL_SELECTION_STRATEGY:
+ logger.debug("Using %s-specific tool strategy for Code Understanding agent", lang)
+ guidance = CU_TOOL_SELECTION_STRATEGY[lang]
+ else:
+ logger.debug("No ecosystem-specific strategy for '%s', using generic Code Understanding tool guidance", lang)
+ guidance_list = []
+ for name in tool_names:
+ if name in CU_TOOL_GENERAL_DESCRIPTIONS:
+ guidance_list.append(f"{name}: {CU_TOOL_GENERAL_DESCRIPTIONS[name]}")
+ guidance = "\n".join(guidance_list) if guidance_list else "Use the available tools to investigate the question."
+
+ descriptions = "\n".join(tool_desc_lines)
+ return guidance, descriptions
+
+
+@register_agent("code_understanding")
+class CodeUnderstandingAgent(BaseGraphAgent):
+
+ @property
+ def agent_type(self) -> str:
+ return "cu"
+
+ _CU_TOOLS = frozenset({
+ ToolNames.DOCS_SEMANTIC_SEARCH,
+ ToolNames.CODE_KEYWORD_SEARCH,
+ ToolNames.CONFIGURATION_SCANNER,
+ ToolNames.IMPORT_USAGE_ANALYZER,
+ })
+
+ @staticmethod
+ def get_tools(builder, config, state) -> list:
+ all_tools = BaseGraphAgent._load_all_tools(builder, config)
+ return [
+ t for t in all_tools
+ if t.name in CodeUnderstandingAgent._CU_TOOLS
+ and _is_tool_available(t.name, config, state)
+ ]
+
+ @staticmethod
+ def create_rules_tracker() -> CodeUnderstandingRulesTracker:
+ return CodeUnderstandingRulesTracker()
+
+ def build_comprehension_context(self, state: AgentState) -> str:
+ """Return minimal grounding context to prevent CVE ID hallucination.
+
+ The CU agent's comprehension LLM only needs the CVE ID and package name
+ to ground its findings. Feeding the full GHSA/NVD context activates
+ parametric knowledge and causes the model to inject CVE IDs from training.
+ """
+ workflow_state = ctx_state.get()
+ vuln_id = workflow_state.cve_intel[0].vuln_id if workflow_state.cve_intel else "unknown"
+ package = state.get("app_package") or "unknown"
+ return (
+ f"Investigating {vuln_id} in package {package}.\n"
+ "Only extract facts explicitly stated in the tool output. "
+ "Do not add CVE IDs, vulnerability names, or advisory details "
+ "from your own knowledge."
+ )
+
+ async def pre_process_node(self, state: AgentState) -> AgentState:
+ workflow_state = ctx_state.get()
+ ecosystem = (
+ workflow_state.original_input.input.image.ecosystem.value
+ if workflow_state.original_input.input.image.ecosystem
+ else ""
+ )
+ with AGENT_TRACER.push_active_function("cu_pre_process", input_data=f"ecosystem:{ecosystem}") as span:
+ try:
+ precomputed = state.get("precomputed_intel")
+ if precomputed is not None:
+ critical_context = list(precomputed[0])
+ candidate_packages = [dict(p) for p in precomputed[1]]
+ else:
+ critical_context, candidate_packages, _ = build_critical_context(
+ workflow_state.cve_intel
+ )
+
+ critical_context, selected_package = await self._select_package(
+ ecosystem, candidate_packages, critical_context, workflow_state,
+ )
+
+ critical_context.append(
+ "TASK: Investigate usage, configuration, and presence of the vulnerable "
+ "component in the container. Focus on how the component is used, "
+ "not on call-chain reachability."
+ )
+
+ scope_parts = []
+ image_input = workflow_state.original_input.input.image
+ if image_input.source_info:
+ for si in image_input.source_info:
+ if si.git_repo:
+ repo_name = si.git_repo.rstrip("/").rsplit("/", 1)[-1]
+ if repo_name.endswith(".git"):
+ repo_name = repo_name[:-4]
+ scope_parts.append(repo_name)
+ if selected_package:
+ scope_parts.append(selected_package)
+ cu_source_scope.set(scope_parts if scope_parts else None)
+ logger.debug("Code Understanding source scope set to: %s", scope_parts)
+
+ tool_guidance, descriptions = _build_cu_tool_guidance(ecosystem, self.tools)
+ runtime_prompt = _build_cu_system_prompt(descriptions, tool_guidance)
+ active_tool_names = [t.name for t in self.tools]
+
+ rules_tracker = state.get("rules_tracker")
+ rules_tracker.set_target_package(selected_package)
+ rules_tracker.set_allowed_tools(active_tool_names)
+
+ span.set_output({
+ "selected_package": selected_package,
+ "agent_type": "code_understanding",
+ })
+
+ return {
+ "ecosystem": ecosystem,
+ "runtime_prompt": runtime_prompt,
+ "is_reachability": "no",
+ "observation": Observation(memory=critical_context, results=[]),
+ "critical_context": critical_context,
+ "app_package": selected_package,
+ }
+ except Exception as e:
+ logger.exception("cu_pre_process_node failed")
+ span.set_output({"error": str(e), "exception_type": type(e).__name__})
+ raise
diff --git a/src/vuln_analysis/functions/code_understanding_internals.py b/src/vuln_analysis/functions/code_understanding_internals.py
new file mode 100644
index 00000000..e6fc2bd2
--- /dev/null
+++ b/src/vuln_analysis/functions/code_understanding_internals.py
@@ -0,0 +1,91 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from exploit_iq_commons.logging.loggers_factory import LoggingFactory
+from vuln_analysis.functions.react_internals import BaseRulesTracker
+
+logger = LoggingFactory.get_agent_logger(__name__)
+
+
+class CodeUnderstandingRulesTracker(BaseRulesTracker):
+ """Behavioral rules for the Code Understanding sub-agent."""
+
+ def check_thought_behavior(self, action: str, action_input: str, output) -> tuple[bool, str]:
+ """Check all Code Understanding rules in priority order.
+
+ Returns (True, error_message) if a rule is violated, (False, '') otherwise.
+ Rule check order: duplicate call -> Rule 7 (dotted keywords) -> allowed tools.
+ """
+ if self._rule_duplicate_call(action, action_input):
+ logger.debug("CU duplicate call rule triggered: '%s' with same input", action)
+ return True, (
+ f"You already called {action} with this exact input. "
+ "You MUST use a DIFFERENT tool or a DIFFERENT input query. "
+ "Check KNOWLEDGE for what was already tried."
+ )
+ if self._rule_number_7(action, action_input, output):
+ logger.debug("Code Understanding Rule 7 triggered: dotted query with empty results for tool '%s'", action)
+ return True, (
+ "You are NOT following Rule 7. Your query contains dots and returned "
+ "no results. You MUST retry with just the final component. Follow the rules."
+ )
+ if self._rule_use_allowed_tools(action):
+ logger.debug("Code Understanding allowed-tools rule triggered: '%s' not in %s", action, self.allowed_tools)
+ return True, (
+ f"You are NOT following AVAILABLE_TOOLS. You MUST use the allowed tools "
+ f"{self.allowed_tools}. Follow the rules."
+ )
+ self.add_action(action, action_input, output)
+ return False, ""
+
+
+CU_AGENT_SYS_PROMPT = (
+ "You are a security analyst investigating a code understanding question about a CVE vulnerability.\n"
+ "Your goal is to collect evidence about how the codebase uses, configures, or depends on "
+ "the affected component. This is NOT a reachability question -- do NOT trace call chains.\n"
+ "GENERAL RULES:\n"
+ "- Base conclusions ONLY on tool results, not assumptions.\n"
+ "- If a search returns no results, that is evidence the component is absent or not configured.\n"
+ "- DISTINGUISH where findings come from: main application code vs. dependency libraries.\n"
+ "- A component being present in dependencies does NOT mean the application uses it.\n"
+ "- Configuration in framework dependencies (e.g., Spring defaults) IS relevant evidence.\n"
+ "- Import in an intermediate library (e.g., Spring imports XStream, app imports Spring) IS relevant.\n"
+ "ANSWER QUALITY:\n"
+ "- Answer the SPECIFIC question asked with evidence. Do NOT just report what tools found.\n"
+ "- Always state: WHAT you checked, WHAT you found, and WHY it leads to your conclusion.\n"
+ "- If tool results conflict, state the conflict explicitly.\n"
+ "- When citing evidence, explain HOW it relates to the question."
+)
+
+CU_AGENT_THOUGHT_INSTRUCTIONS = """
+1. Output valid JSON only. thought < 100 words. final_answer < 150 words.
+2. mode="act" REQUIRES actions. mode="finish" REQUIRES final_answer.
+3. Docs Semantic Search, Code Keyword Search: use query field.
+4. Configuration Scanner: use query field with keywords describing what to look for.
+5. Import Usage Analyzer: use query field with the package/module name.
+6. Do NOT call the same tool with the same input twice. Check KNOWLEDGE for "CALLED:" entries.
+7. If Code Keyword Search returns no results and the query contains dots, retry with just the final component.
+8. Before concluding, if the question involves a specific library or package, you MUST use Import Usage Analyzer to check how it is imported and used across sources.
+9. GUIDELINE: Synthesize findings across documentation, code, configuration, and imports before drawing conclusions.
+
+
+{{"thought": "Check configuration files for security settings related to the vulnerability", "mode": "act", "actions": {{"tool": "Configuration Scanner", "package_name": null, "function_name": null, "query": "deserialization allowlist denylist security", "tool_input": null, "reason": "Check if vulnerable feature is enabled or disabled in config"}}, "final_answer": null}}
+
+
+{{"thought": "Search source code for imports of the vulnerable library", "mode": "act", "actions": {{"tool": "Code Keyword Search", "package_name": null, "function_name": null, "query": "import vulnerable_library", "tool_input": null, "reason": "Find where the package is imported in code"}}, "final_answer": null}}
+
+
+{{"thought": "Check how the vulnerable library is imported and used across the codebase", "mode": "act", "actions": {{"tool": "Import Usage Analyzer", "package_name": null, "function_name": null, "query": "vulnerable_library", "tool_input": null, "reason": "Find all imports and usage sites of the vulnerable package"}}, "final_answer": null}}
+"""
\ No newline at end of file
diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py
index f140850d..d1fa7d4d 100644
--- a/src/vuln_analysis/functions/cve_agent.py
+++ b/src/vuln_analysis/functions/cve_agent.py
@@ -12,49 +12,36 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
+
import asyncio
-import json
-from pathlib import Path
-from vuln_analysis.runtime_context import ctx_state
-import typing
+
from aiq.builder.builder import Builder
from aiq.builder.framework_enum import LLMFrameworkEnum
from aiq.builder.function_info import FunctionInfo
from aiq.cli.register_workflow import register_function
from aiq.data_models.function import FunctionBaseConfig
-from langchain.agents import AgentExecutor
-from langchain.agents import create_react_agent
-from langchain.agents.agent import RunnableAgent
-from langchain.agents.mrkl.output_parser import MRKLOutputParser
from langchain_core.exceptions import OutputParserException
-from langchain_core.prompts import PromptTemplate
+from langchain_core.messages import HumanMessage
from pydantic import Field
+
from vuln_analysis.data_models.state import AgentMorpheusEngineState
-from vuln_analysis.tools.tool_names import ToolNames
-from vuln_analysis.tools.transitive_code_search import package_name_from_locator_query
from vuln_analysis.utils.error_handling_decorator import ToolRaisedException
-from vuln_analysis.utils.prompting import get_agent_prompt
+from vuln_analysis.utils.intel_utils import build_critical_context
+from vuln_analysis.runtime_context import ctx_state
from exploit_iq_commons.logging.loggers_factory import LoggingFactory, trace_id
-
-from vuln_analysis.functions.react_internals import build_system_prompt, build_classification_prompt, build_package_filter_prompt, AgentState, Thought, Observation, Classification, PackageSelection, CodeFindings, SystemRulesTracker, _build_tool_arguments, FORCED_FINISH_PROMPT, COMPREHENSION_PROMPT, MEMORY_UPDATE_PROMPT, AGENT_SYS_PROMPT_NON_REACHABILITY, AGENT_THOUGHT_INSTRUCTIONS_NON_REACHABILITY, AGENT_THOUGHT_INSTRUCTIONS_GO
-from vuln_analysis.utils.prompting import build_tool_descriptions
-from vuln_analysis.utils.prompt_factory import TOOL_SELECTION_STRATEGY, TOOL_SELECTION_STRATEGY_NON_REACHABILITY, TOOL_ECOSYSTEM_REGISTRY, FEW_SHOT_EXAMPLES
-from vuln_analysis.utils.intel_utils import build_critical_context, enrich_go_from_osv, filter_context_to_package
-from exploit_iq_commons.utils.git_utils import sanitize_git_url_for_path
-from exploit_iq_commons.utils.data_utils import DEFAULT_GIT_DIRECTORY
-from langgraph.graph import StateGraph, END, START
-from langgraph.prebuilt import ToolNode
-from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, RemoveMessage
-
-import uuid
-import tiktoken
from nat.builder.context import Context
+from vuln_analysis.functions.agent_registry import get_agent_class, get_all_agent_types
+from vuln_analysis.functions.dispatcher import QuestionRouting, dispatch_question
+# Import agent modules to trigger @register_agent decorators
+import vuln_analysis.functions.reachability_agent # noqa: F401
+import vuln_analysis.functions.code_understanding_agent # noqa: F401
+
logger = LoggingFactory.get_agent_logger(__name__)
AGENT_TRACER = Context.get()
+
class CVEAgentExecutorToolConfig(FunctionBaseConfig, name="cve_agent_executor"):
"""
Defines a function that iterates through checklist items using provided tools and gathered intel.
@@ -88,672 +75,74 @@ class CVEAgentExecutorToolConfig(FunctionBaseConfig, name="cve_agent_executor"):
description="Estimated token threshold for pruning old messages in observation node."
)
-async def common_build_tools(config: CVEAgentExecutorToolConfig, builder: Builder, state: AgentMorpheusEngineState) -> tuple[list[typing.Any], list[str], list[str]]:
-
- tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
- # Filter tools that are not available based on state
- tools = [
- tool for tool in tools
- if not ((tool.name == ToolNames.CODE_SEMANTIC_SEARCH and state.code_vdb_path is None) or
- (tool.name == ToolNames.DOCS_SEMANTIC_SEARCH and state.doc_vdb_path is None) or
- (tool.name == ToolNames.CODE_KEYWORD_SEARCH and state.code_index_path is None) or
- (tool.name == ToolNames.CVE_WEB_SEARCH and not config.cve_web_search_enabled) or
- (tool.name == ToolNames.CALL_CHAIN_ANALYZER and (not config.transitive_search_tool_enabled or
- state.code_index_path is None)) or
- (tool.name == ToolNames.FUNCTION_CALLER_FINDER and (not config.transitive_search_tool_enabled or
- state.code_index_path is None)) or
- (tool.name == ToolNames.FUNCTION_LOCATOR and (not config.transitive_search_tool_enabled or
- state.code_index_path is None))
- )
- ]
- # Get tool names after filtering for dynamic guidance
- enabled_tool_names = [tool.name for tool in tools]
- tool_descriptions_list = [t.name + ": " + t.description for t in tools]
- # Build tool selection guidance with strategic context
- tool_descriptions = build_tool_descriptions(enabled_tool_names)
- return tools, tool_descriptions, tool_descriptions_list
-
-def _validate_go_vendor_packages(
- source_info: list,
- candidate_packages: list[dict],
-) -> tuple[list[dict], list[str]]:
- """Check which Go candidate packages actually exist in the vendor directory."""
- code_si = next((si for si in source_info if si.type == "code"), None)
- if code_si is None:
- return candidate_packages, []
-
- repo_path = Path(DEFAULT_GIT_DIRECTORY) / sanitize_git_url_for_path(code_si.git_repo)
- vendor_path = repo_path / "vendor"
- if not vendor_path.is_dir():
- return candidate_packages, []
-
- validated = []
- removed = []
- for pkg in candidate_packages:
- pkg_name = pkg.get("name", "")
- if (vendor_path / pkg_name).is_dir():
- validated.append(pkg)
- else:
- removed.append(pkg_name)
-
- if validated:
- return validated, removed
- return candidate_packages, []
-
-
-async def _enrich_go_candidates(
- cve_intel: list,
- source_info: list,
- critical_context: list[str],
- candidate_packages: list[dict],
- vulnerable_functions_set: set[str],
-) -> tuple[list[dict], list[str]]:
- """Enrich Go candidates via OSV and validate against vendor directory."""
- ghsa_has_packages = any(c.get("source") == "ghsa" for c in candidate_packages)
- if not ghsa_has_packages or not vulnerable_functions_set:
- intel = cve_intel[0] if cve_intel else None
- if intel:
- await enrich_go_from_osv(intel, critical_context, candidate_packages, vulnerable_functions_set)
-
- if candidate_packages:
- candidate_packages, removed_pkgs = _validate_go_vendor_packages(
- source_info, candidate_packages
- )
- if removed_pkgs:
- logger.info("Go vendor validation removed %d packages not in vendor/: %s", len(removed_pkgs), removed_pkgs)
-
- return candidate_packages, sorted(vulnerable_functions_set)
-
-
-async def _create_graph_agent(config: CVEAgentExecutorToolConfig, builder: Builder, state: AgentMorpheusEngineState):
-
- tools, tool_guidance_list, tool_descriptions_list = await common_build_tools(config, builder, state)
- llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
- thought_llm = llm.with_structured_output(Thought)
- comprehension_llm = llm.with_structured_output(CodeFindings)
- observation_llm = llm.with_structured_output(Observation)
- reachability_llm = llm.with_structured_output(Classification)
- package_filter_llm = llm.with_structured_output(PackageSelection)
- tool_guidance = "\n".join(tool_guidance_list)
- descriptions = "\n".join(tool_descriptions_list)
- default_system_prompt = build_system_prompt(descriptions, tool_guidance)
- tool_node = ToolNode(tools, handle_tool_errors=True)
- TOOL_NODE = "tool_node"
- THOUGHT_NODE = "thought_node"
- FORCED_FINISH_NODE = "forced_finish_node"
- PRE_PROCESS_NODE = "pre_process_node"
- OBSERVATION_NODE = "observation_node"
-
- _tiktoken_enc = tiktoken.get_encoding("cl100k_base")
-
- def _count_tokens(text: str) -> int:
- """Count tokens using tiktoken cl100k_base encoding (~90-95% accurate for Llama 3.1)."""
- try:
- return len(_tiktoken_enc.encode(text))
- except Exception:
- return len(text) // 4
-
- def _truncate_tool_output(tool_output: str, tool_name: str, max_tokens: int = 400) -> str:
- """Truncate tool output to fit observation prompt within completion token budget.
-
- Code Keyword Search is the primary source of bloat (3700+ chars of redundant
- source snippets). CCA and FL outputs are typically small.
- Preserves the most informative parts based on tool type.
- """
- token_count = _count_tokens(tool_output)
- if token_count <= max_tokens:
- return tool_output
-
- lines = tool_output.split('\n')
-
- if tool_name == ToolNames.CALL_CHAIN_ANALYZER:
- head = '\n'.join(lines[:3])
- remaining = token_count - _count_tokens(head)
- return f"{head}\n[... truncated {remaining} tokens ...]"
-
- if tool_name == ToolNames.CODE_KEYWORD_SEARCH:
- kept_lines = []
- kept_tokens = 0
- for line in lines:
- is_header = line.startswith("---") or "Main application" in line or "library dependencies" in line or line.strip() == ""
- line_tokens = _count_tokens(line)
- if is_header or kept_tokens + line_tokens <= max_tokens:
- kept_lines.append(line)
- kept_tokens += line_tokens
- if kept_tokens >= max_tokens:
- break
- kept_lines.append(f"[... truncated {token_count - kept_tokens} tokens ...]")
- return '\n'.join(kept_lines)
-
- # Default: head (70%) + tail (30%)
- head_budget = int(max_tokens * 0.7)
- tail_budget = max_tokens - head_budget
- head_lines = []
- head_tokens = 0
- for line in lines:
- lt = _count_tokens(line)
- if head_tokens + lt > head_budget:
- break
- head_lines.append(line)
- head_tokens += lt
- tail_lines = []
- tail_tokens = 0
- for line in reversed(lines):
- lt = _count_tokens(line)
- if tail_tokens + lt > tail_budget:
- break
- tail_lines.insert(0, line)
- tail_tokens += lt
- truncated = token_count - head_tokens - tail_tokens
- return '\n'.join(head_lines) + f"\n[... truncated {truncated} tokens ...]\n" + '\n'.join(tail_lines)
-
- def _estimate_tokens(runtime_prompt: str, messages: list, observation: Observation | None) -> int:
- """Estimate the token count thought_node will send to the LLM."""
- parts = [runtime_prompt]
- for msg in messages:
- if hasattr(msg, "content") and isinstance(msg.content, str):
- parts.append(msg.content)
- if observation is not None:
- for item in (observation.memory or []):
- parts.append(item)
- for item in (observation.results or []):
- parts.append(item)
- return _count_tokens("\n".join(parts))
-
- def _build_tool_guidance_for_ecosystem(ecosystem: str, available_tools: list, is_reachability: str = "yes") -> tuple[str, str]:
- """Build tool guidance using language-specific strategies when available."""
- filtered_tools = [
- t for t in available_tools
- if (t.name != ToolNames.FUNCTION_CALLER_FINDER or ecosystem == "go") and
- (t.name != ToolNames.FUNCTION_LIBRARY_VERSION_FINDER or ecosystem == "java")
- ]
- list_of_tool_names = [t.name for t in filtered_tools]
- list_of_tool_descriptions = [t.name + ": " + t.description for t in filtered_tools]
-
- strategy = TOOL_SELECTION_STRATEGY if is_reachability == "yes" else TOOL_SELECTION_STRATEGY_NON_REACHABILITY
- lang = ecosystem.lower() if ecosystem else ""
- if lang in strategy:
- tool_guidance_local = strategy[lang]
- if lang == "java":
- tool_guidance_local += (
- " Use Function Library Version Finder to verify the installed version of a library "
- "before concluding exploitability (e.g., input 'commons-beanutils')."
- )
- if is_reachability == "yes":
- hint = FEW_SHOT_EXAMPLES.get(lang, "")
- if hint:
- tool_guidance_local += f"\nHint: {hint}"
- else:
- tool_guidance_list_local = build_tool_descriptions(list_of_tool_names)
- tool_guidance_local = "\n".join(tool_guidance_list_local)
-
- descriptions_local = "\n".join(list_of_tool_descriptions)
- return tool_guidance_local, descriptions_local
-
- async def pre_process_node(state: AgentState) -> AgentState:
- workflow_state = ctx_state.get()
- ecosystem = workflow_state.original_input.input.image.ecosystem.value if workflow_state.original_input.input.image.ecosystem else ""
- with AGENT_TRACER.push_active_function("pre_process node", input_data=f"ecosystem:{ecosystem}") as span:
- try:
- critical_context, candidate_packages, vulnerable_functions = build_critical_context(workflow_state.cve_intel)
- vulnerable_functions_set = set(vulnerable_functions)
-
- if ecosystem == "go":
- candidate_packages, vulnerable_functions = await _enrich_go_candidates(
- workflow_state.cve_intel,
- workflow_state.original_input.input.image.source_info,
- critical_context,
- candidate_packages,
- vulnerable_functions_set,
- )
-
- selected_package = None
- app_package = None
- if len(candidate_packages) > 1:
- image_input = workflow_state.original_input.input.image
- image_name = image_input.name
- source_repos = image_input.source_info
- image_repo = source_repos[0].git_repo if source_repos else None
- filter_prompt = build_package_filter_prompt(
- ecosystem, candidate_packages,
- image_name=image_name, image_repo=image_repo,
- critical_context=critical_context,
- )
- selection: PackageSelection = await package_filter_llm.ainvoke([HumanMessage(content=filter_prompt)])
- selected_package = selection.selected_package
- app_package = selected_package
- logger.info("Package filter selected '%s' from %d candidates (reason: %s)",
- selected_package, len(candidate_packages), selection.reason)
- critical_context = filter_context_to_package(critical_context, selected_package, candidate_packages)
- elif len(candidate_packages) == 1:
- selected_package = candidate_packages[0].get("name")
- app_package = selected_package
- logger.info("Single candidate package after validation: '%s'", selected_package)
- critical_context = filter_context_to_package(critical_context, selected_package, candidate_packages)
-
- critical_context.append(
- "TASK: Investigate usage and reachability of the vulnerable function/module in the container. "
- "Use the vulnerable module name from GHSA as primary investigation target."
- )
-
- question = state.get("input") or ""
- context_block = "\n".join(critical_context)
- classification_prompt = build_classification_prompt(context_block, question)
- classification_result: Classification = await reachability_llm.ainvoke([HumanMessage(content=classification_prompt)])
- span.set_output({
- "critical_context": critical_context,
- "candidate_packages": candidate_packages,
- "selected_package": selected_package,
- "app_package": app_package if selected_package else None,
- "reachability_question": classification_result.is_reachability,
- })
-
- is_reachability = classification_result.is_reachability
-
- if is_reachability == "yes":
- tool_guidance_local, descriptions_local = _build_tool_guidance_for_ecosystem(ecosystem, tools)
- go_instructions = {"instructions": AGENT_THOUGHT_INSTRUCTIONS_GO} if ecosystem == "go" else {}
- runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local, **go_instructions)
- active_tool_names = [t.name for t in tools]
- else:
- reachability_tool_names = { ToolNames.CALL_CHAIN_ANALYZER, ToolNames.FUNCTION_CALLER_FINDER}
- non_reach_tools = [t for t in tools if t.name not in reachability_tool_names]
- tool_guidance_local, descriptions_local = _build_tool_guidance_for_ecosystem(ecosystem, non_reach_tools, is_reachability="no")
- runtime_prompt = build_system_prompt(
- descriptions_local, tool_guidance_local,
- instructions=AGENT_THOUGHT_INSTRUCTIONS_NON_REACHABILITY,
- sys_prompt=AGENT_SYS_PROMPT_NON_REACHABILITY,
- )
- active_tool_names = [t.name for t in non_reach_tools]
- logger.info("Non-reachability question detected; removed reachability tools from prompt")
- rules_tracker = state.get("rules_tracker")
- app_package = app_package if selected_package else None
- rules_tracker.set_target_package(app_package)
- rules_tracker.set_allowed_tools(active_tool_names)
- if is_reachability == "yes":
- rules_tracker.set_target_functions(vulnerable_functions)
- return {
- "ecosystem": ecosystem,
- "runtime_prompt": runtime_prompt,
- "is_reachability": is_reachability,
- "observation": Observation(memory=critical_context, results=[]),
- "critical_context": critical_context,
- "app_package": app_package if selected_package else None,
- }
- except Exception as e:
- logger.exception("pre_process_node failed")
- span.set_output({"error": str(e), "exception_type": type(e).__name__})
- raise
-
-
- async def thought_node(state: AgentState) -> AgentState:
- step_num = state.get("step", 0)
- with AGENT_TRACER.push_active_function("thought node", input_data=f"step:{step_num}") as span:
- try:
- active_prompt = state.get("runtime_prompt") or default_system_prompt
- messages = [SystemMessage(content=active_prompt)] + state["messages"]
- obs = state.get("observation", None)
- if obs is not None:
- memory_list = obs.memory if obs.memory else ["No prior knowledge."]
- recent_findings = obs.results if obs.results else ["No recent findings."]
- memory_context = "\n".join(f"- {m}" for m in memory_list)
- findings_context = "\n".join(f"- {f}" for f in recent_findings)
- context_block = f"KNOWLEDGE:\n{memory_context}\nLATEST FINDINGS:\n{findings_context}"
- messages.append(SystemMessage(content=context_block))
- response: Thought = await thought_llm.ainvoke(messages)
-
- final_answer = "waiting for the agent to respond"
- if response.mode == "finish":
- ai_message = AIMessage(content=response.final_answer)
- final_answer = response.final_answer
- elif response.actions is None:
- logger.warning("LLM returned mode='act' but actions is None, forcing finish")
- ai_message = AIMessage(content=response.thought or "No actions provided, finishing.")
- response = Thought(
- thought=response.thought or "No actions provided",
- mode="finish",
- actions=None,
- final_answer=response.thought or "Insufficient evidence to provide a definitive answer."
- )
- final_answer = response.final_answer
- else:
- tool_name = response.actions.tool
- arguments = _build_tool_arguments(response.actions)
- tool_call_id = str(uuid.uuid4())
- ai_message = AIMessage(
- content=response.thought,
- tool_calls=[{
- "name": tool_name,
- "args": arguments,
- "id": tool_call_id
- }]
- )
-
- span.set_output({"mode": response.mode, "step": step_num + 1})
- return {
- "messages": [ai_message],
- "thought": response,
- "step": step_num + 1,
- "max_steps": config.max_iterations,
- "output": final_answer
- }
- except Exception as e:
- logger.exception("thought_node failed at step %d", step_num)
- span.set_output({"error": str(e), "exception_type": type(e).__name__, "step": step_num})
- raise
-
- async def should_continue(state: AgentState) -> str:
- thought = state.get("thought", None)
- if thought is not None and thought.mode == "finish":
- return END
- if state.get("step", 0) >= state.get("max_steps", config.max_iterations):
- return FORCED_FINISH_NODE
- return TOOL_NODE
-
- async def forced_finish_node(state: AgentState) -> AgentState:
- step_num = state.get("step", 0)
- with AGENT_TRACER.push_active_function("forced_finish node", input_data=f"step:{step_num}") as span:
- try:
- active_prompt = state.get("runtime_prompt") or default_system_prompt
- messages = [SystemMessage(content=active_prompt)] + state["messages"]
- messages.append(HumanMessage(content=FORCED_FINISH_PROMPT))
- obs = state.get("observation", None)
- if obs is not None and obs.memory:
- memory_context = "\n".join(f"- {m}" for m in obs.memory)
- messages.append(SystemMessage(content=f"KNOWLEDGE:\n{memory_context}"))
- response: Thought = await thought_llm.ainvoke(messages)
- if response.mode == "finish" and response.final_answer:
- ai_message = AIMessage(content=response.final_answer)
- final_answer = response.final_answer
- else:
- final_answer = "Failed to generate a final answer within the maximum allowed steps."
- ai_message = AIMessage(content=final_answer)
- response = Thought(
- thought=response.thought or "Max steps exceeded",
- mode="finish",
- actions=None,
- final_answer=final_answer
- )
- span.set_output({"final_answer_length": len(final_answer), "step": step_num})
- return {
- "messages": [ai_message],
- "thought": response,
- "step": step_num,
- "max_steps": state.get("max_steps", config.max_iterations),
- "observation": state.get("observation", None),
- "output": final_answer
- }
- except Exception as e:
- logger.exception("forced_finish_node failed at step %d", step_num)
- span.set_output({"error": str(e), "exception_type": type(e).__name__, "step": step_num})
- raise
-
- async def observation_node(state: AgentState) -> AgentState:
- tool_message = state["messages"][-1]
- last_thought_text = state["thought"].thought if state.get("thought") else "No previous thought."
- tool_used = state["thought"].actions.tool if state.get("thought") and state["thought"].actions else "Unknown"
- tool_input_detail = ""
- if state.get("thought") and state["thought"].actions:
- actions = state["thought"].actions
- if actions.package_name and actions.function_name:
- tool_input_detail = f"{actions.package_name},{actions.function_name}"
- elif actions.query:
- tool_input_detail = actions.query
- elif actions.tool_input:
- tool_input_detail = actions.tool_input
- previous_memory = state.get("observation").memory if state.get("observation") else ["No data gathered yet."]
- rules_tracker = state.get("rules_tracker")
- with AGENT_TRACER.push_active_function("observation node", input_data=f"tool used:{tool_used}") as span:
- try:
- tool_output_for_llm = tool_message.content
- result, error_message = rules_tracker.check_thought_behavior(tool_used, tool_input_detail, tool_output_for_llm)
- if result:
- span.set_output({"rule_error": error_message})
- return {"messages": [HumanMessage(content=error_message)]}
-
- if state.get("ecosystem", "").lower() == "java":
- truncated_output = _truncate_tool_output(tool_output_for_llm, tool_used)
- else:
- truncated_output = tool_output_for_llm
-
- # Step 1: Comprehension -- reads raw tool output, produces compact findings
- ctx_lines = state.get("critical_context", [])
- critical_context_text = "\n".join(ctx_lines) if ctx_lines else "N/A"
- comp_prompt = COMPREHENSION_PROMPT.format(
- goal=state.get('input'),
- selected_package=state.get('app_package') or "N/A",
- critical_context=critical_context_text,
- tool_used=tool_used,
- tool_input_detail=tool_input_detail,
- last_thought_text=last_thought_text,
- tool_output=truncated_output,
- )
- code_findings: CodeFindings = await comprehension_llm.ainvoke([SystemMessage(content=comp_prompt)])
-
- findings_text = "\n".join(f"- {f}" for f in code_findings.findings)
-
- # Step 2: Memory update -- merges compressed findings into cumulative memory
- mem_prompt = MEMORY_UPDATE_PROMPT.format(
- goal=state.get('input'),
- selected_package=state.get('app_package') or "N/A",
- previous_memory=previous_memory,
- findings=findings_text,
- tool_outcome=code_findings.tool_outcome,
- )
- new_observation: Observation = await observation_llm.ainvoke([SystemMessage(content=mem_prompt)])
-
- messages = state["messages"]
- active_prompt = state.get("runtime_prompt") or default_system_prompt
- estimated = _estimate_tokens(active_prompt, messages, new_observation)
- prune_messages = []
- orig_estimated = estimated
-
- span_trace_dict = {"comprehension_findings": code_findings.findings, "tool_outcome": code_findings.tool_outcome}
-
- if estimated > config.context_window_token_limit and len(messages) > 3:
- prunable = messages[1:-2]
- for msg in prunable:
- prune_messages.append(RemoveMessage(id=msg.id))
- estimated -= _count_tokens(msg.content) if hasattr(msg, "content") and isinstance(msg.content, str) else 0
- if estimated <= config.context_window_token_limit:
- break
- logger.info(
- "Context pruning: removed %d messages, estimated tokens now ~%d (limit %d)",
- len(prune_messages), estimated, config.context_window_token_limit,
- )
- span_trace_dict["orig_estimated"] = orig_estimated
- span_trace_dict["estimated"] = estimated
- span.set_output(span_trace_dict)
- cca_results = list(state.get("cca_results", []))
- if tool_used == ToolNames.CALL_CHAIN_ANALYZER:
- stripped = tool_output_for_llm.strip().lstrip("([")
- first_token = stripped.split(",", 1)[0].strip().lower()
- if first_token == "true":
- cca_results.append(True)
- elif first_token == "false":
- cca_results.append(False)
- package_validated = state.get("package_validated")
- if tool_used == ToolNames.FUNCTION_LOCATOR and state.get("is_reachability") == "yes":
- input_pkg = package_name_from_locator_query(tool_input_detail)
- target_pkg = (state.get("app_package") or "").strip().lower()
- if target_pkg and input_pkg == target_pkg:
- if "Package is valid" in tool_output_for_llm:
- package_validated = True
- elif "Package is not valid" in tool_output_for_llm and package_validated is None:
- package_validated = False
- return {
- "messages": prune_messages,
- "observation": new_observation,
- "step": state.get("step", 0),
- "cca_results": cca_results,
- "package_validated": package_validated,
- }
- except Exception as e:
- logger.exception("observation_node failed")
- span.set_output({"error": str(e), "exception_type": type(e).__name__})
- raise
-
- async def create_graph():
- flow = StateGraph(AgentState)
- flow.add_node(THOUGHT_NODE, thought_node)
- flow.add_node(TOOL_NODE, tool_node)
- flow.add_node(FORCED_FINISH_NODE, forced_finish_node)
- flow.add_node(PRE_PROCESS_NODE, pre_process_node)
- flow.add_node(OBSERVATION_NODE, observation_node)
- flow.add_edge(START, PRE_PROCESS_NODE)
- flow.add_edge(PRE_PROCESS_NODE, THOUGHT_NODE)
- flow.add_conditional_edges(
- THOUGHT_NODE,
- should_continue,
- {END: END, TOOL_NODE: TOOL_NODE, FORCED_FINISH_NODE: FORCED_FINISH_NODE}
- )
- flow.add_edge(TOOL_NODE, OBSERVATION_NODE)
- flow.add_edge(OBSERVATION_NODE, THOUGHT_NODE)
- flow.add_edge(FORCED_FINISH_NODE, END)
-
- app = flow.compile()
- if config.verbose:
- app.get_graph().draw_mermaid_png(output_file_path="flow.png")
- return app
- return await create_graph()
-async def _create_agent(config: CVEAgentExecutorToolConfig, builder: Builder,
- state: AgentMorpheusEngineState) -> AgentExecutor:
-
- tools, tool_descriptions,_ = await common_build_tools(config, builder, state)
-
- llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
- tool_guidance = "\n".join(tool_descriptions)
-
- # Get prompt template
- prompt_template_str = get_agent_prompt(config.prompt, config.prompt_examples)
-
- # Create prompt with tool_selection_strategy as partial variable
- prompt = PromptTemplate.from_template(
- prompt_template_str,
- partial_variables={
- 'tool_selection_strategy': tool_guidance if tool_guidance else "Use available tools as appropriate."
- }
- )
- # using langchain create_react_agent to create the agent
- agent = create_react_agent(llm=llm,
- tools=tools,
- prompt=prompt,
- output_parser=MRKLOutputParser(),
- stop_sequence=["\nObservation:", "\n\tObservation:"])
-
- agent_executor = AgentExecutor(
- agent=agent,
- tools=tools,
- early_stopping_method="force",
- handle_parsing_errors="Check your output and make sure it conforms, use the Action/Action Input syntax",
- max_iterations=config.max_iterations,
- return_intermediate_steps=config.return_intermediate_steps,
- verbose=config.verbose)
-
- # Disable streaming for accurate token counts
- if isinstance(agent_executor.agent, RunnableAgent):
- agent_executor.agent.stream_runnable = False
-
- return agent_executor
-
-async def _process_steps(agent, steps, semaphore, max_iterations: int = 10):
-
+async def _process_steps(agents: dict, routing_llm, steps, semaphore, max_iterations: int = 10):
+ workflow_state = ctx_state.get()
+ critical_context, candidate_packages, vulnerable_functions = build_critical_context(workflow_state.cve_intel)
+ context_block = "\n".join(critical_context)
+ precomputed_intel = (critical_context, candidate_packages, vulnerable_functions)
async def _process_step(step):
- async def call_agent(initial_state,config=None):
- if config:
- return await agent.ainvoke(initial_state,config=config)
- else:
- return await agent.ainvoke(initial_state)
+ routing = await dispatch_question(routing_llm, step, context_block)
+ agent_type = routing.agent_type
+
+ actual_type = agent_type if agent_type in agents else "reachability"
+ compiled_graph = agents[actual_type]
+ tracker = get_agent_class(actual_type).create_rules_tracker()
+
+ initial_state = {
+ "input": step,
+ "messages": [HumanMessage(content=step)],
+ "step": 0,
+ "max_steps": max_iterations,
+ "thought": None,
+ "observation": None,
+ "output": "waiting for the agent to respond",
+ "rules_tracker": tracker,
+ "precomputed_intel": precomputed_intel,
+ }
+ graph_config = {"recursion_limit": 50}
- initial_state = {"input": step}
- config = None
- if not isinstance(agent, AgentExecutor):
- initial_state = {
- "input": step,
- "messages": [HumanMessage(content=step)],
- "step": 0,
- "max_steps": max_iterations,
- "thought": None,
- "observation": None,
- "output": "waiting for the agent to respond",
- "rules_tracker": SystemRulesTracker(),
- }
- config = {
- "recursion_limit": 50
- }
- with AGENT_TRACER.push_active_function("checklist_question", input_data=step[:80]):
+ with AGENT_TRACER.push_active_function(
+ "checklist_question",
+ input_data=f"[{actual_type}] {step[:80]}",
+ ):
if semaphore:
async with semaphore:
- return await call_agent(initial_state, config)
+ return await compiled_graph.ainvoke(initial_state, config=graph_config)
else:
- return await call_agent(initial_state, config)
+ return await compiled_graph.ainvoke(initial_state, config=graph_config)
return await asyncio.gather(*(_process_step(step) for step in steps), return_exceptions=True)
-def _parse_intermediate_step(step: tuple[typing.Any, typing.Any]) -> dict[str, typing.Any]:
- """
- Parse an agent intermediate step into an AgentIntermediateStep object. Return the dictionary representation for
- compatibility with cudf.
- """
- if len(step) != 2:
- raise ValueError(f"Expected 2 values in each intermediate step but got {len(step)}.")
-
- action, output = step
-
- return {"tool_name": action.tool, "action_log": action.log, "tool_input": action.tool_input, "tool_output": output}
-
-
def _postprocess_results(results: list[list[dict]], replace_exceptions: bool, replace_exceptions_value: str | None,
- checklist_questions: list[list]) -> tuple[list[list[str]], list[list[list]]]:
- """
- Post-process results into lists of outputs and intermediate steps. Replace exceptions with placholder values if
- config.replace_exceptions = True.
- :param values:
+ checklist_questions: list[list]) -> list[list[dict]]:
+ """Post-process graph agent results into a uniform list of output dicts.
+
+ Replaces exceptions with placeholder values if replace_exceptions is True.
"""
outputs = [[] for _ in range(len(results))]
for i, answer_list in enumerate(results):
for j, answer in enumerate(answer_list):
- # Handle exceptions returned by the agent
- # OutputParserException is not a subclass of Exception, so we need to check for it separately
if isinstance(answer, (ToolRaisedException, OutputParserException, Exception)):
if replace_exceptions:
- # If the agent encounters a parsing error or a server error after retries, replace the error
- # with default values to prevent the pipeline from crashing
outputs[i].append({"input": checklist_questions[i][j], "output": replace_exceptions_value,
"intermediate_steps": None, "cca_results": [],
"package_validated": None})
if isinstance(answer, ToolRaisedException):
- tool_raised_exception: ToolRaisedException = answer
- logger.warning(f"An exception encountered during tool execution, in result [{i}][{j}]. for "
- f"question : {checklist_questions[i][j]}"
- f".tool raised exception details=> {tool_raised_exception}"
- f", replacing with default output -> {replace_exceptions_value}")
+ logger.warning(
+ "Tool execution exception in result[%d][%d], replacing with default output: %s",
+ i, j, type(answer).__name__)
else:
logger.warning(
- "General Exception encountered in result[%d][%d]: %s, for question -> %s, "
- "Replacing with default output: \"%s\" and intermediate_steps: None",
- i,
- j,
- str(answer),
- checklist_questions[i][j],
- replace_exceptions_value)
+ "General exception in result[%d][%d]: %s, replacing with default output",
+ i, j, type(answer).__name__)
- # For successful agent responses, extract the output, and intermediate steps if available
else:
- # intermediate_steps availability depends on config.return_intermediate_steps
- if "intermediate_steps" in answer:
- results[i][j]["intermediate_steps"] = [
- _parse_intermediate_step(step) for step in answer["intermediate_steps"]
- ]
- else:
- results[i][j]["intermediate_steps"] = None
-
outputs[i].append({"input": answer["input"], "output": answer["output"],
- "intermediate_steps": results[i][j]["intermediate_steps"],
+ "intermediate_steps": None,
"cca_results": answer.get("cca_results", []),
"package_validated": answer.get("package_validated")})
return outputs
@@ -768,13 +157,38 @@ async def _arun(state: AgentMorpheusEngineState) -> AgentMorpheusEngineState:
ctx_state.set(state)
checklist_plans = state.checklist_plans
-
- agent = await _create_graph_agent(config, builder, state)
- results = await asyncio.gather(*(_process_steps(agent, steps, semaphore, config.max_iterations)
- for steps in checklist_plans.values()), return_exceptions=True)
- results = _postprocess_results(results, config.replace_exceptions, config.replace_exceptions_value,
- list(checklist_plans.values()))
+ llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
+
+ agents = {}
+ for agent_type in get_all_agent_types():
+ agent_cls = get_agent_class(agent_type)
+ tools = agent_cls.get_tools(builder, config, state)
+ if tools:
+ agent_instance = agent_cls(tools=tools, llm=llm, config=config)
+ agents[agent_type] = await agent_instance.build_graph()
+
+ routing_llm = llm.with_structured_output(QuestionRouting)
+
+ results = await asyncio.gather(
+ *(
+ _process_steps(
+ agents,
+ routing_llm,
+ steps,
+ semaphore,
+ config.max_iterations,
+ )
+ for steps in checklist_plans.values()
+ ),
+ return_exceptions=True,
+ )
+ results = _postprocess_results(
+ results,
+ config.replace_exceptions,
+ config.replace_exceptions_value,
+ list(checklist_plans.values()),
+ )
state.checklist_results = dict(zip(checklist_plans.keys(), results))
with AGENT_TRACER.push_active_function("agent_finish", input_data={
diff --git a/src/vuln_analysis/functions/dispatcher.py b/src/vuln_analysis/functions/dispatcher.py
new file mode 100644
index 00000000..0ab1240b
--- /dev/null
+++ b/src/vuln_analysis/functions/dispatcher.py
@@ -0,0 +1,107 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Literal
+
+from langchain_core.messages import HumanMessage
+from pydantic import BaseModel, Field
+
+from exploit_iq_commons.logging.loggers_factory import LoggingFactory
+from nat.builder.context import Context
+
+logger = LoggingFactory.get_agent_logger(__name__)
+AGENT_TRACER = Context.get()
+
+
+class QuestionRouting(BaseModel):
+ """Structured output for LLM-based question routing classification.
+
+ The dispatcher classifies each checklist question into one of two agent types:
+ - reachability: call chain tracing (is vulnerable code reachable from app code?)
+ - code_understanding: configuration/version/presence/usage investigation
+ """
+ agent_type: Literal["reachability", "code_understanding"] = Field(
+ description="Route to 'reachability' if the question asks about call chains, "
+ "code paths, whether vulnerable code is used/called/reachable, "
+ "or whether untrusted data can reach a function. "
+ "Route to 'code_understanding' for configuration, version, "
+ "or application-level settings questions."
+ )
+ reason: str = Field(
+ description="One-sentence justification for the routing decision."
+ )
+
+
+ROUTING_PROMPT_TEMPLATE = """You are classifying a CVE investigation question to route it to the correct sub-agent.
+
+Two sub-agents are available:
+- **reachability**: Traces call chains and code paths. Use when the question asks whether vulnerable code is CALLED, USED, or REACHABLE from application code, or whether untrusted data can reach a specific function. IMPORTANT: If the question asks whether a class/function/module from the VULNERABLE PACKAGE (listed in the context below) is used, imported, or called — route to reachability. The reachability agent has Call Chain Analyzer and Function Locator to verify actual code-level usage and call paths, which is needed to distinguish "imported but never called on vulnerable path" from "actively used."
+- **code_understanding**: Investigates configuration, version, presence, and general application behavior. Use when the question asks about application-level configuration (e.g., XML parsing settings, TLS options), environment setup, or properties — things that do NOT require tracing whether a specific vulnerable function is called.
+
+Context (CVE / vulnerable packages):
+{context_block}
+
+Question: {question}
+
+Examples:
+- "Is the vulnerable function XStream.fromXML() called from application code?" → reachability
+- "Is the application configured to use the affected XML parser?" → code_understanding
+- "Can untrusted data reach BeanUtils.populate() through the call chain?" → reachability
+- "Is the vulnerable version of commons-beanutils installed?" → code_understanding
+- "Is the function parseXML() reachable from any HTTP handler?" → reachability
+- "Does the application enable external entity processing in its XML configuration?" → code_understanding
+- "Is SslHandler used in the application?" → reachability (SslHandler is from the vulnerable package)
+- "Is HttpPostStandardRequestDecoder.offer() called by application code?" → reachability
+- "Is the vulnerable newTransformer() method invoked anywhere in the application's XSLT processing?" → reachability
+- "Does application code call deserialize() or readObject() on the affected library?" → reachability
+- "Does the code sanitize input against path traversal characters and command injection metacharacters?" → code_understanding
+- "Are there input validation mechanisms to prevent malicious Markdown input from reaching the paragraph function?" → code_understanding
+- "Can a single request force the application server to load unbounded data into memory?" → code_understanding
+- "Can malformed input cause an unhandled exception that crashes or restarts the process?" → code_understanding
+
+Classify the question above."""
+
+
+def build_routing_prompt(context_block: str, question: str) -> str:
+ """Format the routing prompt template with CVE context and question text."""
+ return ROUTING_PROMPT_TEMPLATE.format(
+ context_block=context_block,
+ question=question,
+ )
+
+
+async def dispatch_question(
+ routing_llm,
+ question: str,
+ context_block: str,
+) -> QuestionRouting:
+ """Classify a checklist question and return routing decision.
+
+ Uses the routing LLM (with structured output) to decide whether the question
+ should go to the reachability agent (call chain tracing) or the code understanding
+ agent (configuration/presence/usage investigation).
+ """
+ prompt = build_routing_prompt(context_block, question)
+ with AGENT_TRACER.push_active_function("dispatch_question", input_data=f"question:{question[:80]}") as span:
+ logger.debug("Dispatching question to routing LLM: %.80s", question)
+ result: QuestionRouting = await routing_llm.ainvoke([HumanMessage(content=prompt)])
+ logger.info(
+ "Question routed to '%s' (reason: %s): %.80s",
+ result.agent_type,
+ result.reason,
+ question,
+ )
+ span.set_output({"agent_type": result.agent_type, "reason": result.reason})
+ return result
\ No newline at end of file
diff --git a/src/vuln_analysis/functions/reachability_agent.py b/src/vuln_analysis/functions/reachability_agent.py
new file mode 100644
index 00000000..6462939e
--- /dev/null
+++ b/src/vuln_analysis/functions/reachability_agent.py
@@ -0,0 +1,337 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
+
+from exploit_iq_commons.logging.loggers_factory import LoggingFactory
+from nat.builder.context import Context
+from vuln_analysis.functions.agent_registry import register_agent
+from vuln_analysis.functions.base_graph_agent import BaseGraphAgent, _is_tool_available
+from vuln_analysis.functions.react_internals import (
+ AgentState,
+ Classification,
+ Observation,
+ Thought,
+ ReachabilityRulesTracker,
+ build_reachability_system_prompt,
+ build_classification_prompt,
+ FORCED_FINISH_PROMPT,
+ REACHABILITY_AGENT_THOUGHT_INSTRUCTIONS_GO,
+ REACHABILITY_AGENT_NON_REACH_SYS_PROMPT,
+ REACHABILITY_AGENT_NON_REACH_THOUGHT_INSTRUCTIONS,
+)
+from vuln_analysis.runtime_context import ctx_state
+from vuln_analysis.tools.tool_names import ToolNames
+from vuln_analysis.tools.transitive_code_search import package_name_from_locator_query
+from vuln_analysis.utils.intel_utils import build_critical_context, enrich_go_from_osv
+from vuln_analysis.utils.prompting import build_tool_descriptions
+from vuln_analysis.utils.prompt_factory import (
+ TOOL_SELECTION_STRATEGY,
+ TOOL_SELECTION_STRATEGY_NON_REACHABILITY,
+ FEW_SHOT_EXAMPLES,
+)
+from pathlib import Path
+from exploit_iq_commons.utils.git_utils import sanitize_git_url_for_path
+from exploit_iq_commons.utils.data_utils import DEFAULT_GIT_DIRECTORY
+
+logger = LoggingFactory.get_agent_logger(__name__)
+AGENT_TRACER = Context.get()
+
+
+def _validate_go_vendor_packages(source_info, candidate_packages):
+ code_si = next((si for si in source_info if si.type == "code"), None)
+ if code_si is None:
+ return candidate_packages, []
+ repo_path = Path(DEFAULT_GIT_DIRECTORY) / sanitize_git_url_for_path(code_si.git_repo)
+ vendor_path = repo_path / "vendor"
+ if not vendor_path.is_dir():
+ return candidate_packages, []
+ validated = []
+ removed = []
+ for pkg in candidate_packages:
+ pkg_name = pkg.get("name", "")
+ if (vendor_path / pkg_name).is_dir():
+ validated.append(pkg)
+ else:
+ removed.append(pkg_name)
+ if validated:
+ return validated, removed
+ return candidate_packages, []
+
+
+async def _enrich_go_candidates(cve_intel, source_info, critical_context, candidate_packages, vulnerable_functions_set):
+ ghsa_has_packages = any(c.get("source") == "ghsa" for c in candidate_packages)
+ if not ghsa_has_packages or not vulnerable_functions_set:
+ intel = cve_intel[0] if cve_intel else None
+ if intel:
+ await enrich_go_from_osv(intel, critical_context, candidate_packages, vulnerable_functions_set)
+ if candidate_packages:
+ candidate_packages, removed_pkgs = _validate_go_vendor_packages(source_info, candidate_packages)
+ if removed_pkgs:
+ logger.info("Go vendor validation removed %d packages not in vendor/: %s", len(removed_pkgs), removed_pkgs)
+ return candidate_packages, sorted(vulnerable_functions_set)
+
+
+@register_agent("reachability")
+class ReachabilityAgent(BaseGraphAgent):
+
+ def __init__(self, tools, llm, config):
+ super().__init__(tools, llm, config)
+ self._classification_llm = llm.with_structured_output(Classification)
+
+ @property
+ def agent_type(self) -> str:
+ return "reachability"
+
+ _REACHABILITY_TOOLS = frozenset({
+ ToolNames.FUNCTION_LOCATOR,
+ ToolNames.CALL_CHAIN_ANALYZER,
+ ToolNames.FUNCTION_LIBRARY_VERSION_FINDER,
+ ToolNames.FUNCTION_CALLER_FINDER,
+ ToolNames.CODE_KEYWORD_SEARCH,
+ ToolNames.CVE_WEB_SEARCH,
+ ToolNames.CODE_SEMANTIC_SEARCH,
+ ToolNames.DOCS_SEMANTIC_SEARCH,
+ })
+
+ @staticmethod
+ def get_tools(builder, config, state) -> list:
+ all_tools = BaseGraphAgent._load_all_tools(builder, config)
+ return [
+ t for t in all_tools
+ if t.name in ReachabilityAgent._REACHABILITY_TOOLS
+ and _is_tool_available(t.name, config, state)
+ ]
+
+ @staticmethod
+ def create_rules_tracker() -> ReachabilityRulesTracker:
+ return ReachabilityRulesTracker()
+
+ def should_truncate_tool_output(self, state: AgentState, tool_used: str) -> bool:
+ return state.get("ecosystem", "").lower() == "java"
+
+ def _build_tool_guidance_for_ecosystem(self, ecosystem: str, available_tools: list,
+ is_reachability: str = "yes") -> tuple[str, str]:
+ filtered_tools = [
+ t for t in available_tools
+ if (t.name != ToolNames.FUNCTION_CALLER_FINDER or ecosystem == "go") and
+ (t.name != ToolNames.FUNCTION_LIBRARY_VERSION_FINDER or ecosystem == "java")
+ ]
+ list_of_tool_names = [t.name for t in filtered_tools]
+ list_of_tool_descriptions = [t.name + ": " + t.description for t in filtered_tools]
+
+ strategy = TOOL_SELECTION_STRATEGY if is_reachability == "yes" else TOOL_SELECTION_STRATEGY_NON_REACHABILITY
+ lang = ecosystem.lower() if ecosystem else ""
+ if lang in strategy:
+ tool_guidance_local = strategy[lang]
+ if lang == "java":
+ tool_guidance_local += (
+ " Use Function Library Version Finder to verify the installed version of a library "
+ "before concluding exploitability (e.g., input 'commons-beanutils')."
+ )
+ if is_reachability == "yes":
+ hint = FEW_SHOT_EXAMPLES.get(lang, "")
+ if hint:
+ tool_guidance_local += f"\nHint: {hint}"
+ else:
+ tool_guidance_list_local = build_tool_descriptions(list_of_tool_names)
+ tool_guidance_local = "\n".join(tool_guidance_list_local)
+
+ descriptions_local = "\n".join(list_of_tool_descriptions)
+ return tool_guidance_local, descriptions_local
+
+ async def pre_process_node(self, state: AgentState) -> AgentState:
+ workflow_state = ctx_state.get()
+ ecosystem = workflow_state.original_input.input.image.ecosystem.value if workflow_state.original_input.input.image.ecosystem else ""
+ with AGENT_TRACER.push_active_function("pre_process node", input_data=f"ecosystem:{ecosystem}") as span:
+ try:
+ precomputed = state.get("precomputed_intel")
+ if precomputed is not None:
+ critical_context = list(precomputed[0])
+ candidate_packages = [dict(p) for p in precomputed[1]]
+ vulnerable_functions = list(precomputed[2])
+ else:
+ critical_context, candidate_packages, vulnerable_functions = build_critical_context(workflow_state.cve_intel)
+ vulnerable_functions_set = set(vulnerable_functions)
+
+ if ecosystem == "go":
+ candidate_packages, vulnerable_functions = await _enrich_go_candidates(
+ workflow_state.cve_intel,
+ workflow_state.original_input.input.image.source_info,
+ critical_context,
+ candidate_packages,
+ vulnerable_functions_set,
+ )
+
+ critical_context, selected_package = await self._select_package(
+ ecosystem, candidate_packages, critical_context, workflow_state,
+ )
+ app_package = selected_package
+
+ critical_context.append(
+ "TASK: Investigate usage and reachability of the vulnerable function/module in the container. "
+ "Use the vulnerable module name from GHSA as primary investigation target."
+ )
+
+ question = state.get("input") or ""
+ context_block = "\n".join(critical_context)
+ classification_prompt = build_classification_prompt(context_block, question)
+ classification_result: Classification = await self._classification_llm.ainvoke([HumanMessage(content=classification_prompt)])
+ span.set_output({
+ "critical_context": critical_context,
+ "candidate_packages": candidate_packages,
+ "selected_package": selected_package,
+ "app_package": app_package if selected_package else None,
+ "reachability_question": classification_result.is_reachability,
+ })
+
+ is_reachability = classification_result.is_reachability
+
+ if is_reachability == "yes":
+ tool_guidance_local, descriptions_local = self._build_tool_guidance_for_ecosystem(ecosystem, self.tools)
+ go_instructions = {"instructions": REACHABILITY_AGENT_THOUGHT_INSTRUCTIONS_GO} if ecosystem == "go" else {}
+ runtime_prompt = build_reachability_system_prompt(descriptions_local, tool_guidance_local, **go_instructions)
+ active_tool_names = [t.name for t in self.tools]
+ else:
+ reachability_tool_names = {ToolNames.CALL_CHAIN_ANALYZER, ToolNames.FUNCTION_CALLER_FINDER}
+ non_reach_tools = [t for t in self.tools if t.name not in reachability_tool_names]
+ tool_guidance_local, descriptions_local = self._build_tool_guidance_for_ecosystem(ecosystem, non_reach_tools, is_reachability="no")
+ runtime_prompt = build_reachability_system_prompt(
+ descriptions_local, tool_guidance_local,
+ instructions=REACHABILITY_AGENT_NON_REACH_THOUGHT_INSTRUCTIONS,
+ sys_prompt=REACHABILITY_AGENT_NON_REACH_SYS_PROMPT,
+ )
+ active_tool_names = [t.name for t in non_reach_tools]
+ logger.info("Non-reachability question detected; removed reachability tools from prompt")
+
+ rules_tracker = state.get("rules_tracker")
+ app_package = app_package if selected_package else None
+ rules_tracker.set_ecosystem(ecosystem)
+ rules_tracker.set_target_package(app_package)
+ rules_tracker.set_allowed_tools(active_tool_names)
+ if is_reachability == "yes":
+ rules_tracker.set_target_functions(vulnerable_functions)
+ return {
+ "ecosystem": ecosystem,
+ "runtime_prompt": runtime_prompt,
+ "is_reachability": is_reachability,
+ "observation": Observation(memory=critical_context, results=[]),
+ "critical_context": critical_context,
+ "app_package": app_package,
+ }
+ except Exception as e:
+ logger.exception("pre_process_node failed")
+ span.set_output({"error": str(e), "exception_type": type(e).__name__})
+ raise
+
+ def check_finish_allowed(self, state: AgentState) -> tuple[bool, str]:
+ if state.get("is_reachability") != "yes":
+ return True, ""
+ rules_tracker = state.get("rules_tracker")
+ cca_results = state.get("cca_results", [])
+ return rules_tracker.check_finish_allowed(cca_results)
+
+ async def forced_finish_node(self, state: AgentState) -> AgentState:
+ """Override forced_finish to inject a no-CCA warning for reachability questions.
+
+ When the reachability agent exhausts all iterations without ever calling CCA,
+ the LLM has no reachability evidence. Without this override the LLM can
+ hallucinate exploitability based on library presence alone. The injected
+ prompt steers it toward 'insufficient evidence / not exploitable'.
+ """
+ cca_results = state.get("cca_results", [])
+ is_reachability = state.get("is_reachability")
+ if is_reachability == "yes" and not cca_results:
+ step_num = state.get("step", 0)
+ with AGENT_TRACER.push_active_function(
+ f"{self.agent_type}_forced_finish", input_data=f"step:{step_num}"
+ ) as span:
+ try:
+ active_prompt = state.get("runtime_prompt")
+ messages = [SystemMessage(content=active_prompt)]
+ context_block = self._build_observation_context(state.get("observation", None))
+ if context_block:
+ messages.append(SystemMessage(content=context_block))
+ question = state.get("input", "")
+ no_cca_prompt = (
+ (f"QUESTION: {question}\n\n" if question else "")
+ + FORCED_FINISH_PROMPT + "\n\n"
+ "CRITICAL: Call Chain Analyzer was NEVER called during this investigation. "
+ "Without Call Chain Analyzer verification, there is NO evidence that the "
+ "vulnerable function is reachable from application code. Library presence "
+ "in dependencies alone does NOT constitute exploitability. "
+ "You MUST conclude that there is insufficient evidence to confirm "
+ "exploitability — the function was NOT confirmed reachable."
+ )
+ messages.append(HumanMessage(content=no_cca_prompt))
+
+ response: Thought = await self.thought_llm.ainvoke(messages)
+ if response.mode == "finish" and response.final_answer:
+ final_answer = response.final_answer
+ else:
+ final_answer = (
+ "Insufficient evidence: Call Chain Analyzer was never invoked. "
+ "Cannot confirm the vulnerable function is reachable from application code."
+ )
+ response = Thought(
+ thought="Max steps exceeded without CCA verification",
+ mode="finish",
+ actions=None,
+ final_answer=final_answer,
+ )
+ ai_message = AIMessage(content=final_answer)
+ logger.info("Reachability forced_finish: no CCA calls, injected no-CCA warning")
+ span.set_output({"final_answer_length": len(final_answer), "step": step_num, "no_cca_warning": True})
+ return {
+ "messages": [ai_message],
+ "thought": response,
+ "step": step_num,
+ "max_steps": state.get("max_steps", self.config.max_iterations),
+ "observation": state.get("observation", None),
+ "output": final_answer,
+ }
+ except Exception as e:
+ logger.exception("%s forced_finish_node failed at step %d", self.agent_type, step_num)
+ span.set_output({"error": str(e), "exception_type": type(e).__name__, "step": step_num})
+ raise
+ return await super().forced_finish_node(state)
+
+ def post_observation(self, state: AgentState, tool_used: str,
+ tool_output: str, tool_input_detail: str) -> dict:
+ extra = {}
+ cca_results = list(state.get("cca_results", []))
+ if tool_used == ToolNames.CALL_CHAIN_ANALYZER:
+ stripped = tool_output.strip().lstrip("([")
+ first_token = stripped.split(",", 1)[0].strip().lower()
+ if first_token == "true":
+ cca_results.append(True)
+ elif first_token == "false":
+ cca_results.append(False)
+ extra["cca_results"] = cca_results
+
+ package_validated = state.get("package_validated")
+ rules_tracker = state.get("rules_tracker")
+ if tool_used == ToolNames.FUNCTION_LOCATOR and state.get("is_reachability") == "yes":
+ input_pkg = package_name_from_locator_query(tool_input_detail)
+ target_pkg = (state.get("app_package") or "").strip().lower()
+ if input_pkg and "Package is valid" in tool_output:
+ rules_tracker.add_validated_package(input_pkg)
+ if target_pkg and input_pkg == target_pkg:
+ package_validated = True
+ elif input_pkg and "Package is not valid" in tool_output:
+ if target_pkg and input_pkg == target_pkg and package_validated is None:
+ package_validated = False
+ extra["package_validated"] = package_validated
+ return extra
diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py
index 9561c349..80340067 100644
--- a/src/vuln_analysis/functions/react_internals.py
+++ b/src/vuln_analysis/functions/react_internals.py
@@ -17,11 +17,16 @@
from typing import Any
from typing import Literal
from langgraph.graph import MessagesState
-#---- REACT Schemas ----#
+
+from exploit_iq_commons.logging.loggers_factory import LoggingFactory
+from vuln_analysis.tools.tool_names import ToolNames
+
+logger = LoggingFactory.get_agent_logger(__name__)
+
+# ---- Pydantic Schemas ---- #
class ToolCall(BaseModel):
tool: str = Field(description="Exact tool name from AVAILABLE_TOOLS")
- #tool_input: str = Field(description="The input for the tool. Example: Code Keyword Search: PQescapeLiteral")
package_name: str | None = Field(
default=None,
description="Package/module name. REQUIRED when using Function Locator, Function Caller Finder, or Call Chain Analyzer. E.g. libpq, urllib, github.com/org/pkg"
@@ -30,12 +35,12 @@ class ToolCall(BaseModel):
default=None,
description="Function or method name with optional args. REQUIRED with package_name for code path tools. E.g. PQescapeLiteral(), parse(), errors.New(\"x\")"
)
- # For search tools (Code Keyword Search, Code Semantic Search, Docs Semantic Search, CVE Web Search)
+ # For search tools: Code Keyword Search, Code Semantic Search, Docs Semantic Search, CVE Web Search
query: str | None = Field(
default=None,
description="Search query. Use for search tools when package_name/function_name don't apply"
)
- # Fallback: if LLM uses tool_input for simple query-only tools
+ # Fallback for when LLM uses tool_input for simple query-only tools
tool_input: str | None = Field(
default=None,
description="Legacy/fallback input. Prefer package_name+function_name or query."
@@ -94,18 +99,18 @@ class PackageSelection(BaseModel):
description="One-sentence justification for why this package is the best investigation target."
)
-class SystemRulesTracker:
+class BaseRulesTracker:
+ """Shared behavioral rule infrastructure for all sub-agents."""
def __init__(self):
self.action_history = {}
self.target_package = None
self.allowed_tools = []
- self.target_functions: dict[str, bool] = {}
+
def set_allowed_tools(self, allowed_tools: list[str]):
self.allowed_tools = allowed_tools
+
def set_target_package(self, target_package: str):
self.target_package = target_package
- def set_target_functions(self, functions: list[str]):
- self.target_functions = {f: False for f in functions}
@staticmethod
def _is_empty_result(output) -> bool:
@@ -127,6 +132,11 @@ def add_action(self, action: str, action_input: str, output):
else:
self.action_history[action].append(entry)
+ def _rule_duplicate_call(self, action: str, action_input: str) -> bool:
+ if action not in self.action_history:
+ return False
+ return any(prev["input"] == action_input for prev in self.action_history[action])
+
def _rule_number_7(self, action: str, action_input: str, output) -> bool:
if action != "Code Keyword Search":
return False
@@ -141,10 +151,73 @@ def _rule_number_7(self, action: str, action_input: str, output) -> bool:
return True
return False
+ def _rule_use_allowed_tools(self, action: str) -> bool:
+ if action not in self.allowed_tools:
+ return True
+ return False
+
+ def check_thought_behavior(self, action: str, action_input: str, output) -> tuple[bool, str]:
+ if self._rule_duplicate_call(action, action_input):
+ return True, (
+ f"You already called {action} with this exact input. "
+ "You MUST use a DIFFERENT tool or a DIFFERENT input query. "
+ "Check KNOWLEDGE for what was already tried."
+ )
+ if self._rule_number_7(action, action_input, output):
+ return True, ("You are NOT following Rule 7. Your query contains dots and returned "
+ "no results. You MUST retry with just the final component. Follow the rules.")
+ if self._rule_use_allowed_tools(action):
+ return True, (f"You are NOT following AVAILABLE_TOOLS. You MUST use the allowed tools {self.allowed_tools}. Follow the rules.")
+ self.add_action(action, action_input, output)
+ return False, ""
+
+
+class ReachabilityRulesTracker(BaseRulesTracker):
+ """Behavioral rules for the reachability sub-agent."""
+ def __init__(self):
+ super().__init__()
+ self.target_functions: dict[str, bool] = {}
+ self.ecosystem: str = ""
+ self.validated_packages: set[str] = set()
+
+ def set_ecosystem(self, ecosystem: str):
+ self.ecosystem = ecosystem.lower() if ecosystem else ""
+
+ def check_finish_allowed(self, cca_results: list[bool]) -> tuple[bool, str]:
+ """Block finish when mandatory reachability tools were not used.
+
+ Two rules enforced:
+ 1. CCA must be called before finishing any reachability question.
+ 2. (Java only) FLVF must be called when CCA found reachability.
+ """
+ if not cca_results and ToolNames.CALL_CHAIN_ANALYZER not in self.action_history:
+ return False, (
+ "You MUST use Function Locator and Call Chain Analyzer before concluding. "
+ "Code Keyword Search alone is NOT sufficient to determine reachability. "
+ "Call Function Locator to validate the package, then Call Chain Analyzer "
+ "to check if the vulnerable function is reachable from application code."
+ )
+ if self.ecosystem == "java" and any(cca_results):
+ if ToolNames.FUNCTION_LIBRARY_VERSION_FINDER not in self.action_history:
+ return False, (
+ "MANDATORY VERSION CHECK: Call Chain Analyzer found the vulnerable function is reachable, "
+ "but you have NOT verified the installed library version. "
+ "You MUST call Function Library Version Finder (e.g., input 'commons-beanutils') "
+ "to check whether the installed version is within the vulnerable range before concluding."
+ )
+ return True, ""
+
+ def set_target_functions(self, functions: list[str]):
+ self.target_functions = {f: False for f in functions}
+
@staticmethod
def _normalize_package_name(name: str) -> str:
return name.strip().lower().replace("-", "_")
+ def add_validated_package(self, package_name: str):
+ """Register a package that FL confirmed as valid (e.g., uber-jar alternative)."""
+ self.validated_packages.add(self._normalize_package_name(package_name))
+
def _rule_number_8(self, action: str, action_input: str, output) -> bool:
if self.target_package is None:
return False
@@ -155,11 +228,12 @@ def _rule_number_8(self, action: str, action_input: str, output) -> bool:
target_pkg = self._normalize_package_name(self.target_package)
# Allow Java GAV with version suffix: input "group:artifact:version"
# should match target "group:artifact"
- if input_pkg != target_pkg and not input_pkg.startswith(target_pkg + ":"):
- return True
- return False
- def _rule_use_allowed_tools(self, action: str) -> bool:
- if action not in self.allowed_tools:
+ if input_pkg == target_pkg or input_pkg.startswith(target_pkg + ":"):
+ return False
+ # Allow packages that FL already validated (handles uber-jars like
+ # netty-all containing netty-codec-http classes)
+ if any(input_pkg == vp or input_pkg.startswith(vp + ":") for vp in self.validated_packages):
+ return False
return True
return False
@@ -186,24 +260,35 @@ def _rule_number_9(self, action: str, action_input: str) -> tuple[bool, str]:
return False, ""
def check_thought_behavior(self, action: str, action_input: str, output) -> tuple[bool, str]:
+ if self._rule_duplicate_call(action, action_input):
+ logger.debug("Duplicate call rule triggered: '%s' with same input", action)
+ return True, (
+ f"You already called {action} with this exact input. "
+ "You MUST use a DIFFERENT tool or a DIFFERENT input query. "
+ "Check KNOWLEDGE for what was already tried."
+ )
if self._rule_number_7(action, action_input, output):
+ logger.debug("Reachability Rule 7 triggered: dotted query with empty results for tool '%s'", action)
return True, ("You are NOT following Rule 7. Your query contains dots and returned "
"no results. You MUST retry with just the final component. Follow the rules.")
if self._rule_number_8(action, action_input, output):
+ logger.debug("Reachability Rule 8 triggered: wrong package for tool '%s', expected '%s'", action, self.target_package)
return True, (f"You are NOT following Rule 8. You are using the wrong package name. You MUST use the target package name {self.target_package} see KNOWLEDGE as the package_name before trying alternative packages. Follow the rules.")
if self._rule_use_allowed_tools(action):
+ logger.debug("Reachability allowed-tools rule triggered: '%s' not in %s", action, self.allowed_tools)
return True, (f"You are NOT following AVAILABLE_TOOLS. You MUST use the allowed tools {self.allowed_tools}. Follow the rules.")
rule9, msg9 = self._rule_number_9(action, action_input)
if rule9:
+ logger.debug("Reachability Rule 9 triggered: must investigate target functions first for tool '%s'", action)
return True, msg9
self.add_action(action, action_input, output)
return False, ""
-
+
+
class AgentState(MessagesState):
input: str = ""
step: int = 0
max_steps: int = 10
- #memory: str | None = None
thought: Thought | None = None
observation: Observation | None = None
output: str = ""
@@ -211,14 +296,19 @@ class AgentState(MessagesState):
runtime_prompt: str | None = None
app_package: str | None = None
is_reachability: str = "yes"
- rules_tracker: SystemRulesTracker = SystemRulesTracker()
+ rules_tracker: BaseRulesTracker = ReachabilityRulesTracker()
critical_context: list[str] = []
cca_results: list[bool] = []
package_validated: bool | None = None
-
-### --- End of REACT Schemas ----#
-#---- REACT Prompt Templates ----#
-AGENT_SYS_PROMPT = (
+ precomputed_intel: tuple | None = None
+
+
+# ---- Reachability Sub-Agent Prompts ---- #
+# The following prompts (REACHABILITY_AGENT_SYS_PROMPT, REACHABILITY_AGENT_THOUGHT_INSTRUCTIONS,
+# REACHABILITY_REACHABILITY_AGENT_THOUGHT_INSTRUCTIONS_GO) are used exclusively by the reachability sub-agent,
+# which traces call chains to determine if vulnerable code is reachable from application code.
+
+REACHABILITY_AGENT_SYS_PROMPT = (
"You are a security analyst investigating CVE exploitability in container images.\n"
"MANDATORY STEPS (follow in order, do NOT skip any):\n"
"1. IDENTIFY the vulnerable component/function from the CVE description.\n"
@@ -249,7 +339,6 @@ class AgentState(MessagesState):
"- When citing evidence, explain HOW it relates to the question -- do not just state that something was found."
)
-# Update LANGGRAPH_SYSTEM_PROMPT_TEMPLATE in react_internals.py
LANGGRAPH_SYSTEM_PROMPT_TEMPLATE = """{sys_prompt}
@@ -265,7 +354,7 @@ class AgentState(MessagesState):
RESPONSE:
{{"""
-AGENT_THOUGHT_INSTRUCTIONS = """
+REACHABILITY_AGENT_THOUGHT_INSTRUCTIONS = """
1. Output valid JSON only. thought < 100 words. final_answer < 150 words.
2. mode="act" REQUIRES actions. mode="finish" REQUIRES final_answer.
3. Function Locator, Function Caller Finder, Call Chain Analyzer: MUST set package_name AND function_name. Do NOT use query.
@@ -286,7 +375,7 @@ class AgentState(MessagesState):
{{"thought": "Function Locator confirmed the package. Now trace reachability with Call Chain Analyzer", "mode": "act", "actions": {{"tool": "Call Chain Analyzer", "package_name": "", "function_name": "", "query": null, "tool_input": null, "reason": "Check if function is reachable from application code"}}, "final_answer": null}}
"""
-AGENT_THOUGHT_INSTRUCTIONS_GO = """
+REACHABILITY_AGENT_THOUGHT_INSTRUCTIONS_GO = """
1. Output valid JSON only. thought < 100 words. final_answer < 150 words.
2. mode="act" REQUIRES actions. mode="finish" REQUIRES final_answer.
3. Function Locator, Function Caller Finder, Call Chain Analyzer: MUST set package_name AND function_name. Do NOT use query.
@@ -312,7 +401,7 @@ class AgentState(MessagesState):
{{"thought": "Function Caller Finder returned no callers, but this does not prove unreachable. MUST call Call Chain Analyzer to confirm reachability", "mode": "act", "actions": {{"tool": "Call Chain Analyzer", "package_name": "", "function_name": "", "query": null, "tool_input": null, "reason": "Check if function is reachable from application code"}}, "final_answer": null}}
"""
-AGENT_SYS_PROMPT_NON_REACHABILITY = (
+REACHABILITY_AGENT_NON_REACH_SYS_PROMPT = (
"You are a security analyst investigating CVE exploitability in container images.\n"
"This is NOT a reachability question -- do NOT trace call chains.\n"
"MANDATORY STEPS (follow in order, do NOT skip any):\n"
@@ -345,7 +434,7 @@ class AgentState(MessagesState):
"- When citing evidence, explain HOW it relates to the question -- do not just state that something was found."
)
-AGENT_THOUGHT_INSTRUCTIONS_NON_REACHABILITY = """
+REACHABILITY_AGENT_NON_REACH_THOUGHT_INSTRUCTIONS = """
1. Output valid JSON only. thought < 100 words. final_answer < 150 words.
2. mode="act" REQUIRES actions. mode="finish" REQUIRES final_answer.
3. Function Locator: MUST set package_name AND function_name. Do NOT use query.
@@ -404,6 +493,12 @@ class AgentState(MessagesState):
NEW OUTPUT:
{tool_output}
+CRITICAL — CALL CHAIN ANALYZER REACHABILITY:
+When TOOL USED is "Call Chain Analyzer":
+- If the result is POSITIVE (reachable): your first finding MUST be "REACHABLE via [package] - sufficient evidence." Do NOT hedge, qualify, or say "further investigation required."
+- If the result is NEGATIVE (not reachable): your first finding MUST be "NOT reachable via [package]."
+These override all other rules for Call Chain Analyzer results.
+
CODE COMPREHENSION RULES:
1. READ the actual code snippets in NEW OUTPUT. Do NOT just check whether something was "found" or "not found."
2. For each code snippet or function returned, determine:
@@ -452,17 +547,13 @@ class AgentState(MessagesState):
RESPONSE:
{{"""
-# Legacy prompt kept for backwards compatibility with older traces
-OBSERVATION_NODE_PROMPT = COMPREHENSION_PROMPT
-
-### --- End of REACT Prompt Templates ----#
-def build_system_prompt(
+def build_reachability_system_prompt(
tool_descriptions: str,
tool_guidance: str,
- instructions: str = AGENT_THOUGHT_INSTRUCTIONS,
+ instructions: str = REACHABILITY_AGENT_THOUGHT_INSTRUCTIONS,
sys_prompt: str | None = None,
) -> str:
- sys_prompt = sys_prompt or AGENT_SYS_PROMPT
+ sys_prompt = sys_prompt or REACHABILITY_AGENT_SYS_PROMPT
return LANGGRAPH_SYSTEM_PROMPT_TEMPLATE.format(
sys_prompt=sys_prompt,
tools=tool_descriptions,
@@ -545,6 +636,8 @@ def _build_tool_arguments(actions: ToolCall)->dict[str, Any]:
return {"query": actions.query}
if actions.tool_input:
return {"query": actions.tool_input} # fallback
+ logger.warning("Tool '%s' called without required arguments (package_name=%s, function_name=%s, query=%s)",
+ actions.tool, actions.package_name, actions.function_name, actions.query)
raise ValueError(f"Tool {actions.tool} requires package_name+function_name or query/tool_input")
diff --git a/src/vuln_analysis/register.py b/src/vuln_analysis/register.py
index dc847a87..fe3b4b75 100644
--- a/src/vuln_analysis/register.py
+++ b/src/vuln_analysis/register.py
@@ -47,6 +47,8 @@
from vuln_analysis.tools import container_image_analysis_data
from vuln_analysis.tools import local_vdb
from vuln_analysis.tools import serp
+from vuln_analysis.tools import configuration_scanner
+from vuln_analysis.tools import import_usage_analyzer
from vuln_analysis.utils.error_handling_decorator import catch_pipeline_errors_async
# pylint: enable=unused-import
from vuln_analysis.utils.llm_engine_utils import postprocess_engine_output, finalize_preprocess_engine_input
@@ -255,7 +257,7 @@ async def call_llm_engine_subgraph_node(message: AgentMorpheusEngineInput):
graph = graph_builder.compile()
def convert_str_to_agent_morpheus_input(input: str) -> AgentMorpheusInput:
- logger.debug("Converting input to AgentMorpheusInput: %s", input)
+ logger.debug("Converting JSON string input to AgentMorpheusInput (length: %d)", len(input))
try:
return AgentMorpheusInput.model_validate_json(input)
except Exception as e:
@@ -263,7 +265,7 @@ def convert_str_to_agent_morpheus_input(input: str) -> AgentMorpheusInput:
raise e
def convert_textio_to_agent_morpheus_input(input: TextIOWrapper) -> AgentMorpheusInput:
- logger.debug("Converting input to AgentMorpheusInput: %s", input)
+ logger.debug("Converting TextIOWrapper input to AgentMorpheusInput")
try:
data = input.read()
return AgentMorpheusInput.model_validate_json(data)
@@ -273,7 +275,7 @@ def convert_textio_to_agent_morpheus_input(input: TextIOWrapper) -> AgentMorpheu
raise e
def convert_agent_morpheus_output_to_str(output: AgentMorpheusOutput) -> str:
- logger.debug("Converting AgentMorpheusOutput to str: %s", output)
+ logger.debug("Converting AgentMorpheusOutput to JSON string")
try:
return output.model_dump_json()
except Exception as e:
diff --git a/src/vuln_analysis/runtime_context.py b/src/vuln_analysis/runtime_context.py
index fceb67aa..adca2360 100644
--- a/src/vuln_analysis/runtime_context.py
+++ b/src/vuln_analysis/runtime_context.py
@@ -19,4 +19,9 @@
import contextvars
# Holds the current AgentMorpheusEngineState for the active task
-ctx_state = contextvars.ContextVar("ctx_state", default="default_value")
\ No newline at end of file
+ctx_state = contextvars.ContextVar("ctx_state", default="default_value")
+
+# Source scope for CU agent tools (Docs Semantic Search, Code Keyword Search).
+# List of path substrings; tools keep results whose path matches any entry.
+# None means no scoping (search all sources).
+cu_source_scope = contextvars.ContextVar("cu_source_scope", default=None)
\ No newline at end of file
diff --git a/src/vuln_analysis/tools/configuration_scanner.py b/src/vuln_analysis/tools/configuration_scanner.py
new file mode 100644
index 00000000..0ed1ada3
--- /dev/null
+++ b/src/vuln_analysis/tools/configuration_scanner.py
@@ -0,0 +1,249 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import os
+import re
+from collections import OrderedDict
+from pathlib import Path
+
+from aiq.builder.builder import Builder
+from aiq.builder.framework_enum import LLMFrameworkEnum
+from aiq.builder.function_info import FunctionInfo
+from aiq.cli.register_workflow import register_function
+from aiq.data_models.function import FunctionBaseConfig
+from pydantic import Field
+
+from exploit_iq_commons.logging.loggers_factory import LoggingFactory
+from exploit_iq_commons.utils.git_utils import sanitize_git_url_for_path
+from exploit_iq_commons.utils.data_utils import DEFAULT_GIT_DIRECTORY
+from vuln_analysis.utils.error_handling_decorator import catch_tool_errors
+from vuln_analysis.utils.source_classification import is_dependency_path, filter_by_source_scope, format_app_dep_output
+
+CONFIGURATION_SCANNER = "configuration_scanner"
+
+logger = LoggingFactory.get_agent_logger(__name__)
+
+
+def format_context_snippet(lines: list[str], match_line: int, context_lines: int) -> str:
+ """Format source lines around a match, marking the match line with '>'."""
+ start = max(0, match_line - context_lines)
+ end = min(len(lines), match_line + context_lines + 1)
+ return "\n".join(
+ f"{'>' if i == match_line else ' '} {i+1}: {lines[i]}"
+ for i in range(start, end)
+ )
+
+# File patterns and directory names considered as configuration files
+CONFIG_FILE_PATTERNS = [
+ # Named config files — matched anywhere in the repo
+ "application.yml", "application.yaml", "application.properties",
+ "config.yaml", "config.yml", "config.xml",
+ "settings.toml", "settings.yaml", "settings.yml",
+ "web.xml", "beans.xml",
+ "Dockerfile", "Dockerfile.*", "docker-compose*.yml",
+ # Config-specific extensions — safe to match anywhere
+ "*.properties", "*.env", "*.conf", "*.ini",
+]
+
+# Directory names that typically contain configuration files
+CONFIG_DIR_PATTERNS = ["config", "conf", "conf.d", "etc", "resources"]
+
+# Avoids per-file regex compilation
+_CONFIG_EXTENSIONS = []
+_CONFIG_EXACT_NAMES = []
+_CONFIG_WILDCARD_PATTERNS = []
+
+for _pat in CONFIG_FILE_PATTERNS:
+ if _pat.startswith("*."):
+ _CONFIG_EXTENSIONS.append(_pat[1:])
+ elif "*" in _pat:
+ _CONFIG_WILDCARD_PATTERNS.append(
+ re.compile(_pat.replace(".", r"\.").replace("*", ".*"), re.IGNORECASE)
+ )
+ else:
+ _CONFIG_EXACT_NAMES.append(_pat.lower())
+
+
+class ConfigurationScannerToolConfig(FunctionBaseConfig, name=CONFIGURATION_SCANNER):
+ """Scans configuration files for vulnerability-relevant patterns."""
+ max_results: int = Field(default=15, description="Maximum config entries to return")
+ context_lines: int = Field(default=5, description="Lines of context around matches")
+
+
+def _is_config_file(file_path: str) -> bool:
+ """Check if a file path matches any known configuration file pattern."""
+ name = os.path.basename(file_path)
+ lower_name = name.lower()
+
+ # Check in order of speed: extension → exact name → wildcard regex
+ if any(lower_name.endswith(ext) for ext in _CONFIG_EXTENSIONS):
+ return True
+ if lower_name in _CONFIG_EXACT_NAMES:
+ return True
+ if any(p.match(lower_name) for p in _CONFIG_WILDCARD_PATTERNS):
+ return True
+
+ return False
+
+
+def _is_in_config_dir(file_path: str) -> bool:
+ """Check if a file resides under a known configuration directory."""
+ parts = Path(file_path).parts
+ return any(p.lower() in CONFIG_DIR_PATTERNS for p in parts)
+
+
+def _collect_config_files(repo_path: str) -> list[tuple[str, str]]:
+ """Walk repo and collect config file paths and contents.
+
+ Skips .git, __pycache__, node_modules, .tox directories.
+ Skips files larger than 500KB to avoid memory issues with large generated configs.
+ """
+ config_files = []
+ for root, dirs, files in os.walk(repo_path):
+ dirs[:] = [d for d in dirs if d not in (".git", "__pycache__", "node_modules", ".tox")]
+ for fname in files:
+ full_path = os.path.join(root, fname)
+ rel_path = os.path.relpath(full_path, repo_path)
+ if _is_config_file(rel_path) or _is_in_config_dir(rel_path):
+ try:
+ with open(full_path, "r", errors="ignore") as f:
+ content = f.read()
+ if len(content) < 500_000:
+ config_files.append((rel_path, content))
+ else:
+ logger.warning("Skipping config file %s: size %d exceeds 500KB limit", rel_path, len(content))
+ except Exception:
+ continue
+ logger.debug("Collected %d config files from %s", len(config_files), repo_path)
+ return config_files
+
+
+def _count_config_matches(config_files: list[tuple[str, str]], keywords: list[str]) -> int:
+ """Count total keyword-matching lines across config files (no formatting, no cap)."""
+ count = 0
+ for _, content in config_files:
+ for line in content.split("\n"):
+ if any(kw in line.lower() for kw in keywords):
+ count += 1
+ return count
+
+
+def search_config_content(
+ config_files: list[tuple[str, str]],
+ keywords: list[str],
+ max_results: int = 15,
+ context_lines: int = 5,
+ source_label: str = "unknown",
+) -> list[str]:
+ """Match keywords against config file contents, returning formatted snippets."""
+ matches = []
+ for rel_path, content in config_files:
+ lines = content.split("\n")
+ for line_num, line in enumerate(lines):
+ line_lower = line.lower()
+ if any(kw in line_lower for kw in keywords):
+ snippet = format_context_snippet(lines, line_num, context_lines)
+ matches.append(f"--- {rel_path} (source: {source_label}) line {line_num+1} ---\n{snippet}")
+ if len(matches) >= max_results:
+ break
+ if len(matches) >= max_results:
+ break
+ return matches
+
+
+@register_function(config_type=ConfigurationScannerToolConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
+async def configuration_scanner(config: ConfigurationScannerToolConfig, builder: Builder):
+ from vuln_analysis.runtime_context import ctx_state, cu_source_scope
+
+ _CONFIG_CACHE_MAX_SIZE = 20
+ _config_files_cache: OrderedDict[tuple, list[tuple[str, str]]] = OrderedDict()
+ _repo_locks: dict[tuple, asyncio.Lock] = {}
+ _repo_locks_guard = asyncio.Lock()
+
+ @catch_tool_errors(CONFIGURATION_SCANNER)
+ async def _arun(query: str) -> str:
+ workflow_state = ctx_state.get()
+ source_infos = workflow_state.original_input.input.image.source_info
+ source_scope = cu_source_scope.get()
+ logger.debug("Configuration scanner: searching for '%s' across %d source(s), scope=%s",
+ query, len(source_infos), source_scope)
+
+ keywords = [w.strip().lower() for w in re.split(r"[,\s]+", query) if len(w.strip()) >= 2]
+ all_app_configs = []
+ all_dep_configs = []
+
+ for si in source_infos:
+ if not hasattr(si, "git_repo"):
+ continue
+ repo_path = Path(DEFAULT_GIT_DIRECTORY) / sanitize_git_url_for_path(si.git_repo)
+ if not repo_path.is_dir():
+ continue
+
+ repo_key = (si.git_repo, si.ref)
+ if repo_key in _config_files_cache:
+ async with _repo_locks_guard:
+ _config_files_cache.move_to_end(repo_key)
+ else:
+ async with _repo_locks_guard:
+ if repo_key not in _repo_locks:
+ _repo_locks[repo_key] = asyncio.Lock()
+ repo_lock = _repo_locks[repo_key]
+ async with repo_lock:
+ if repo_key not in _config_files_cache:
+ _config_files_cache[repo_key] = _collect_config_files(str(repo_path))
+ if len(_config_files_cache) > _CONFIG_CACHE_MAX_SIZE:
+ _config_files_cache.popitem(last=False)
+
+ for cfg in _config_files_cache[repo_key]:
+ if is_dependency_path(cfg[0]):
+ all_dep_configs.append(cfg)
+ else:
+ all_app_configs.append(cfg)
+
+ all_dep_configs = filter_by_source_scope(all_dep_configs, source_scope, lambda x: x[0])
+
+ source_label = source_infos[0].git_repo if source_infos and hasattr(source_infos[0], "git_repo") else "unknown"
+ app_matches = search_config_content(
+ all_app_configs, keywords,
+ max_results=config.max_results,
+ context_lines=config.context_lines,
+ source_label=source_label,
+ )
+ remaining = config.max_results - len(app_matches)
+ dep_matches = search_config_content(
+ all_dep_configs, keywords,
+ max_results=max(remaining, 0),
+ context_lines=config.context_lines,
+ source_label=source_label,
+ ) if remaining > 0 else []
+
+ no_results_msg = f"No configuration entries found matching: {query}"
+ total_app = _count_config_matches(all_app_configs, keywords)
+ total_dep = _count_config_matches(all_dep_configs, keywords)
+ logger.debug("Configuration scanner: %d app + %d dep match(es) for '%s'",
+ total_app, total_dep, query)
+ return format_app_dep_output(app_matches, dep_matches, total_app, total_dep, no_results_msg)
+
+ yield FunctionInfo.from_fn(
+ _arun,
+ description=(
+ "Scans configuration files (YAML, XML, properties, build files, Dockerfiles) "
+ "for patterns related to the vulnerability. Input: keywords describing what "
+ "configuration to look for (e.g. 'XML external entity processing', "
+ "'deserialization enabled', 'xstream allowTypes'). "
+ "Searches all indexed sources including framework dependencies."
+ ),
+ )
\ No newline at end of file
diff --git a/src/vuln_analysis/tools/import_usage_analyzer.py b/src/vuln_analysis/tools/import_usage_analyzer.py
new file mode 100644
index 00000000..755c82f2
--- /dev/null
+++ b/src/vuln_analysis/tools/import_usage_analyzer.py
@@ -0,0 +1,158 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+
+from aiq.builder.builder import Builder
+from aiq.builder.framework_enum import LLMFrameworkEnum
+from aiq.builder.function_info import FunctionInfo
+from aiq.cli.register_workflow import register_function
+from aiq.data_models.function import FunctionBaseConfig
+from pydantic import Field
+
+from exploit_iq_commons.logging.loggers_factory import LoggingFactory
+from exploit_iq_commons.utils.functions_parsers.lang_functions_parsers_factory import get_language_function_parser
+from vuln_analysis.utils.error_handling_decorator import catch_tool_errors
+from vuln_analysis.utils.source_classification import is_dependency_path, filter_by_source_scope, format_app_dep_output
+
+IMPORT_USAGE_ANALYZER = "import_usage_analyzer"
+
+logger = LoggingFactory.get_agent_logger(__name__)
+
+
+def _find_usage_in_file(content: str, imported_names: list[str], max_usages: int = 5) -> list[str]:
+ """Find usage sites of imported names, excluding import lines themselves."""
+ usages = []
+ lines = content.split("\n")
+ for line_num, line in enumerate(lines):
+ for name in imported_names:
+ short_name = name.rsplit(".", 1)[-1] if "." in name else name
+ if re.search(rf'\b{re.escape(short_name)}\b', line) and not line.strip().startswith(("import ", "from ", "#include")):
+ usages.append(f" L{line_num+1}: {line.strip()}")
+ if len(usages) >= max_usages:
+ return usages
+ return usages
+
+
+class ImportUsageAnalyzerToolConfig(FunctionBaseConfig, name=IMPORT_USAGE_ANALYZER):
+ """Analyzes imports and usage patterns of a package across indexed sources."""
+ max_files: int = Field(default=20, description="Maximum files to report")
+
+
+def analyze_imports(searcher, import_patterns: list[re.Pattern], package_name: str,
+ max_files: int = 20, source_scope: list[str] | None = None,
+ ecosystem_label: str = "") -> str:
+ """Scan a Tantivy searcher for import patterns, with app/dep source awareness.
+
+ Results are separated into application code and dependency code sections,
+ with dependency results filtered by source_scope when provided.
+ """
+ num_docs = searcher.num_docs
+ app_results = []
+ dep_results = []
+
+ for doc_id in range(num_docs):
+ try:
+ raw = searcher.doc(doc_id)
+ file_path = raw["file_path"][0]
+ content = raw["content"][0]
+ except Exception:
+ continue
+
+ found_imports = []
+ for pattern in import_patterns:
+ for match in pattern.finditer(content):
+ found_imports.append(match.group(0))
+
+ if not found_imports:
+ continue
+
+ imported_names = []
+ for imp in found_imports:
+ parts = imp.split()
+ for p in parts:
+ cleaned = p.strip(";'\",<>(){}").strip()
+ if package_name.lower() in cleaned.lower() and len(cleaned) > 2:
+ imported_names.append(cleaned)
+
+ usages = _find_usage_in_file(content, imported_names)
+
+ entry = f"--- {file_path} ---\n"
+ entry += f" Imports: {'; '.join(found_imports[:3])}\n"
+ if usages:
+ entry += f" Usages ({len(usages)}):\n" + "\n".join(usages)
+ else:
+ entry += " No usage sites found beyond import."
+
+ if is_dependency_path(file_path):
+ dep_results.append((file_path, entry))
+ else:
+ app_results.append(entry)
+
+ dep_results = filter_by_source_scope(dep_results, source_scope, lambda x: x[0])
+ dep_entries = [entry for _, entry in dep_results]
+
+ no_results_msg = f"No imports of '{package_name}' found in indexed sources (ecosystem: {ecosystem_label})."
+ total_app = len(app_results)
+ total_dep = len(dep_entries)
+
+ trimmed_app = app_results[:max_files]
+ remaining = max_files - len(trimmed_app)
+ trimmed_dep = dep_entries[:max(remaining, 0)]
+
+ return format_app_dep_output(trimmed_app, trimmed_dep, total_app, total_dep, no_results_msg)
+
+
+@register_function(config_type=ImportUsageAnalyzerToolConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
+async def import_usage_analyzer(config: ImportUsageAnalyzerToolConfig, builder: Builder):
+ from vuln_analysis.runtime_context import ctx_state, cu_source_scope
+ from vuln_analysis.utils.full_text_search import FullTextSearch
+
+ @catch_tool_errors(IMPORT_USAGE_ANALYZER)
+ async def _arun(query: str) -> str:
+ workflow_state = ctx_state.get()
+ code_index_path = workflow_state.code_index_path
+ ecosystem = workflow_state.original_input.input.image.ecosystem
+ source_scope = cu_source_scope.get()
+
+ fts = FullTextSearch.get_instance(cache_path=code_index_path)
+ if fts.is_empty():
+ logger.debug("Import usage analyzer: no source code indexed at %s", code_index_path)
+ return "No source code indexed."
+
+ package_name = query.strip()
+ parser = get_language_function_parser(ecosystem, tree=None) if ecosystem else None
+ import_patterns = parser.get_import_search_patterns(package_name) if parser else [re.compile(re.escape(package_name), re.IGNORECASE)]
+ ecosystem_label = ecosystem.value if ecosystem else ""
+ logger.debug("Import usage analyzer: searching for '%s' (ecosystem: %s), scope=%s",
+ package_name, ecosystem_label, source_scope)
+
+ result = analyze_imports(
+ fts.index.searcher(), import_patterns, package_name,
+ max_files=config.max_files, source_scope=source_scope,
+ ecosystem_label=ecosystem_label,
+ )
+ logger.debug("Import usage analyzer: %s", result.split("\n", 1)[0])
+ return result
+
+ yield FunctionInfo.from_fn(
+ _arun,
+ description=(
+ "Finds all imports and usage patterns of a specific package/module across indexed sources. "
+ "Input: package or module name (e.g. 'encoding/xml', 'com.thoughtworks.xstream', "
+ "'urllib.parse'). Reports which files import it, how many, and how the package is used. "
+ "Searches all sources including framework dependencies."
+ ),
+ )
\ No newline at end of file
diff --git a/src/vuln_analysis/tools/lexical_full_search.py b/src/vuln_analysis/tools/lexical_full_search.py
index 0b24fcc1..8117e990 100644
--- a/src/vuln_analysis/tools/lexical_full_search.py
+++ b/src/vuln_analysis/tools/lexical_full_search.py
@@ -37,19 +37,25 @@ class LexicalSearchToolConfig(FunctionBaseConfig, name=LEXICAL_CODE_SEARCH):
@register_function(config_type=LexicalSearchToolConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN])
async def lexical_search(config: LexicalSearchToolConfig, builder: Builder): # pylint: disable=unused-argument
- from vuln_analysis.runtime_context import ctx_state
+ from vuln_analysis.runtime_context import ctx_state, cu_source_scope
from vuln_analysis.utils.full_text_search import FullTextSearch
@catch_tool_errors(LEXICAL_CODE_SEARCH)
async def _arun(query: str) -> str:
workflow_state = ctx_state.get()
code_index_path = workflow_state.code_index_path
- full_text_search = FullTextSearch(cache_path=code_index_path)
+ full_text_search = FullTextSearch.get_instance(cache_path=code_index_path)
if full_text_search.is_empty():
+ logger.debug("Lexical search: index is empty at %s", code_index_path)
raise ValueError(f"Invalid code index at: {code_index_path}, index is empty")
- result = full_text_search.search_index(query, config.top_k)
+ # Pass source_scope from ContextVar to restrict dependency results
+ # to the current question's target package (set by CU pre_process_node)
+ source_scope = cu_source_scope.get()
+ result = full_text_search.search_index(query, config.top_k, source_scope=source_scope)
+ logger.debug("Lexical search: query='%.80s', top_k=%d, source_scope=%s",
+ query, config.top_k, source_scope)
return result
diff --git a/src/vuln_analysis/tools/local_vdb.py b/src/vuln_analysis/tools/local_vdb.py
index 7549032f..98b701d6 100644
--- a/src/vuln_analysis/tools/local_vdb.py
+++ b/src/vuln_analysis/tools/local_vdb.py
@@ -47,15 +47,22 @@ async def load_vectordb_asretriever(config: LocalVDBRetrieverToolConfig, builder
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import PromptTemplate
- from vuln_analysis.runtime_context import ctx_state
+ from vuln_analysis.runtime_context import ctx_state, cu_source_scope
embedder = await builder.get_embedder(embedder_name=config.embedder_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
+ qa_prompt = PromptTemplate(template=("Use the following pieces of context to answer the question at the end. "
+ "If you don't know the answer, just say that you don't know, "
+ "don't try to make up an answer.\n\n{context}\n\n"
+ "Question: {question}\nHelpful Answer:"),
+ input_variables=['context', 'question'])
+ # FAISS deserialization is expensive — cache per db_source path
+ _retrieval_cache: dict[str, RetrievalQA] = {}
+
@catch_tool_errors(LOCAL_VDB_RETRIEVER)
async def _arun(query: str) -> str | dict:
- # workaround since the agent executor only accepts strings.
workflow_state = ctx_state.get()
if config.vdb_type == VdbType.CODE:
@@ -65,21 +72,32 @@ async def _arun(query: str) -> str | dict:
else:
raise ValueError(f"Invalid VDB type: {config.vdb_type}. Must be one of {VdbType.CODE} or {VdbType.DOC}.")
- qa_prompt = PromptTemplate(template=("Use the following pieces of context to answer the question at the end. "
- "If you don't know the answer, just say that you don't know, "
- "don't try to make up an answer.\n\n{context}\n\n"
- "Question: {question}\nHelpful Answer:"),
- input_variables=['context', 'question'])
-
- vector_db = FAISS.load_local(db_source, embedder, allow_dangerous_deserialization=True)
- retrieval_qa_tool = RetrievalQA.from_chain_type(llm=llm,
- chain_type="stuff",
- chain_type_kwargs={"prompt": qa_prompt},
- retriever=vector_db.as_retriever(),
- return_source_documents=config.return_source_documents)
+ if db_source not in _retrieval_cache:
+ logger.info("Loading FAISS index from %s (type: %s)", db_source, config.vdb_type)
+ vector_db = FAISS.load_local(db_source, embedder, allow_dangerous_deserialization=True)
+ _retrieval_cache[db_source] = RetrievalQA.from_chain_type(
+ llm=llm,
+ chain_type="stuff",
+ chain_type_kwargs={"prompt": qa_prompt},
+ retriever=vector_db.as_retriever(),
+ return_source_documents=config.return_source_documents,
+ )
+ retrieval_qa_tool = _retrieval_cache[db_source]
output_dict = await retrieval_qa_tool.ainvoke(query)
+ if config.return_source_documents and "source_documents" in output_dict:
+ source_scope = cu_source_scope.get()
+ if source_scope:
+ pre_filter = len(output_dict["source_documents"])
+ output_dict["source_documents"] = [
+ doc for doc in output_dict["source_documents"]
+ if any(scope in doc.metadata.get("source", "") for scope in source_scope)
+ ]
+ if pre_filter != len(output_dict["source_documents"]):
+ logger.debug("Source scope %s filtered doc results from %d to %d",
+ source_scope, pre_filter, len(output_dict["source_documents"]))
+
# If returning source documents, include the result and source_documents keys in the output
if config.return_source_documents:
return {k: v for k, v in output_dict.items() if k in ["result", "source_documents"]}
@@ -96,7 +114,9 @@ async def _arun(query: str) -> str | dict:
elif config.vdb_type == VdbType.DOC:
description = (
"Searches container documentation using semantic search. "
- "Answers questions about application purpose, architecture, and features."
+ "Answers questions about application purpose, architecture, and features. "
+ "Queries must be specific — include a concrete term (library name, config property, protocol). "
+ "Vague queries like 'configuration' or 'security settings' return unhelpful results."
)
else:
raise ValueError(f"Invalid VDB type: {config.vdb_type}. Must be one of {VdbType.CODE} or {VdbType.DOC}.")
diff --git a/src/vuln_analysis/tools/tests/test_transitive_code_search.py b/src/vuln_analysis/tools/tests/test_transitive_code_search.py
index b432fc25..65037fe5 100644
--- a/src/vuln_analysis/tools/tests/test_transitive_code_search.py
+++ b/src/vuln_analysis/tools/tests/test_transitive_code_search.py
@@ -1075,3 +1075,27 @@ def test_query_cleaning_strips_unicode_quotes(raw_query, expected_cleaned):
"""Test that query cleaning handles both ASCII and Unicode smart quotes."""
cleaned = raw_query.strip().split("\n")[0].strip().strip("'\"\u2018\u2019\u201c\u201d").strip()
assert cleaned == expected_cleaned, f"For input {repr(raw_query)}: got {repr(cleaned)}, expected {repr(expected_cleaned)}"
+
+
+@pytest.mark.parametrize("raw_query, expected_cleaned", [
+ # Standard ASCII quotes
+ ("'commons-beanutils:commons-beanutils:1.9.4'", "commons-beanutils:commons-beanutils:1.9.4"),
+ ('"commons-beanutils:commons-beanutils:1.9.4"', "commons-beanutils:commons-beanutils:1.9.4"),
+ # Unicode smart quotes (left/right single)
+ ("\u2018commons-beanutils:commons-beanutils:1.9.4\u2019", "commons-beanutils:commons-beanutils:1.9.4"),
+ # Unicode smart quotes (left/right double)
+ ("\u201ccommons-beanutils:commons-beanutils:1.9.4\u201d", "commons-beanutils:commons-beanutils:1.9.4"),
+ # Mixed: ASCII left, unicode right
+ ("'commons-beanutils:commons-beanutils:1.9.4\u2019", "commons-beanutils:commons-beanutils:1.9.4"),
+ ("\"commons-beanutils:commons-beanutils:1.9.4\u201d", "commons-beanutils:commons-beanutils:1.9.4"),
+ # No quotes
+ ("commons-beanutils:commons-beanutils:1.9.4", "commons-beanutils:commons-beanutils:1.9.4"),
+ # Whitespace + quotes
+ (" 'commons-beanutils:commons-beanutils:1.9.4' ", "commons-beanutils:commons-beanutils:1.9.4"),
+ # Trailing newline junk from LLM
+ ("'commons-beanutils:commons-beanutils:1.9.4'\nPlease wait...", "commons-beanutils:commons-beanutils:1.9.4"),
+])
+def test_query_cleaning_strips_unicode_quotes(raw_query, expected_cleaned):
+ """Test that query cleaning handles both ASCII and Unicode smart quotes."""
+ cleaned = raw_query.strip().split("\n")[0].strip().strip("'\"\u2018\u2019\u201c\u201d").strip()
+ assert cleaned == expected_cleaned, f"For input {repr(raw_query)}: got {repr(cleaned)}, expected {repr(expected_cleaned)}"
diff --git a/src/vuln_analysis/tools/tool_names.py b/src/vuln_analysis/tools/tool_names.py
index fa1ddf72..f07d40d5 100644
--- a/src/vuln_analysis/tools/tool_names.py
+++ b/src/vuln_analysis/tools/tool_names.py
@@ -22,45 +22,51 @@
class ToolNames:
- """
- Constants for agent tool names matching YAML configuration keys.
-
+ """Constants for agent tool names matching YAML configuration keys.
+
These names correspond to the function keys in the YAML config files
- (e.g., config.yml, config-http-openai-local.yml).
+ (e.g., config.yml, config-http-nim.yml).
"""
-
+
# Code analysis tools
CODE_SEMANTIC_SEARCH = "Code Semantic Search"
- """Searches container source code using semantic vector embeddings"""
-
+ """Searches container source code using semantic vector embeddings."""
+
DOCS_SEMANTIC_SEARCH = "Docs Semantic Search"
- """Searches container documentation using semantic vector embeddings"""
-
+ """Searches container documentation using semantic vector embeddings."""
+
CODE_KEYWORD_SEARCH = "Code Keyword Search"
- """Searches container source code for exact keyword matches"""
-
- # Code path analysis
+ """Searches container source code for exact keyword matches."""
+
+ # Code path analysis tools
FUNCTION_LOCATOR = "Function Locator"
- """Mandatory first step for code path analysis. Validates package names, locates functions using fuzzy matching, provides ecosystem type."""
+ """Mandatory first step for code path analysis. Validates package names and locates functions using fuzzy matching."""
CALL_CHAIN_ANALYZER = "Call Chain Analyzer"
- """Checks if a function is reachable from application code"""
-
+ """Checks if a function is reachable from application code."""
+
FUNCTION_CALLER_FINDER = "Function Caller Finder"
- """Golang only. Finds functions calling specific standard library methods."""
-
+ """Go only. Finds functions calling specific standard library methods."""
+
# External and cached data sources
CVE_WEB_SEARCH = "CVE Web Search"
- """Searches the web for CVE and vulnerability information"""
-
+ """Searches the web for CVE and vulnerability information."""
+
CONTAINER_ANALYSIS_DATA = "Container Analysis Data"
- """Retrieves pre-analyzed data from earlier container scan steps"""
+ """Retrieves pre-analyzed data from earlier container scan steps."""
FUNCTION_LIBRARY_VERSION_FINDER = "Function Library Version Finder"
- """Checks in which library version the function is used"""
+ """Java only. Checks in which library version the function is used."""
+
+ # Code Understanding agent tools
+ CONFIGURATION_SCANNER = "Configuration Scanner"
+ """Scans configuration files (YAML, XML, properties, build files) for vulnerability-relevant patterns."""
+
+ IMPORT_USAGE_ANALYZER = "Import Usage Analyzer"
+ """Finds all imports and usage patterns of a specific package across indexed sources."""
-# Export as module-level constants
+# Module-level constants for convenience imports
CODE_SEMANTIC_SEARCH = ToolNames.CODE_SEMANTIC_SEARCH
DOCS_SEMANTIC_SEARCH = ToolNames.DOCS_SEMANTIC_SEARCH
CODE_KEYWORD_SEARCH = ToolNames.CODE_KEYWORD_SEARCH
@@ -70,6 +76,8 @@ class ToolNames:
CVE_WEB_SEARCH = ToolNames.CVE_WEB_SEARCH
CONTAINER_ANALYSIS_DATA = ToolNames.CONTAINER_ANALYSIS_DATA
FUNCTION_LIBRARY_VERSION_FINDER = ToolNames.FUNCTION_LIBRARY_VERSION_FINDER
+CONFIGURATION_SCANNER = ToolNames.CONFIGURATION_SCANNER
+IMPORT_USAGE_ANALYZER = ToolNames.IMPORT_USAGE_ANALYZER
@@ -82,6 +90,8 @@ class ToolNames:
'FUNCTION_CALLER_FINDER',
'CVE_WEB_SEARCH',
'CONTAINER_ANALYSIS_DATA',
- 'FUNCTION_LOCATOR',
+ 'PACKAGE_FUNCTION_LOCATOR',
'FUNCTION_LIBRARY_VERSION_FINDER',
-]
\ No newline at end of file
+ 'CONFIGURATION_SCANNER',
+ 'IMPORT_USAGE_ANALYZER',
+]
diff --git a/src/vuln_analysis/utils/code_understanding_prompt_factory.py b/src/vuln_analysis/utils/code_understanding_prompt_factory.py
new file mode 100644
index 00000000..8063b183
--- /dev/null
+++ b/src/vuln_analysis/utils/code_understanding_prompt_factory.py
@@ -0,0 +1,85 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Per-ecosystem tool strategies and generic tool descriptions for the Code Understanding agent."""
+
+from vuln_analysis.tools.tool_names import ToolNames
+
+_DOCS_SEARCH_STEP = (
+ "4. Use Docs Semantic Search for broader architectural questions about the application "
+ "— queries must include a specific term (library name, config property, protocol)."
+)
+
+# Generic tool descriptions used when no ecosystem-specific strategy is available
+CU_TOOL_GENERAL_DESCRIPTIONS: dict[str, str] = {
+ ToolNames.DOCS_SEMANTIC_SEARCH: (
+ "Searches container documentation using semantic search. "
+ "Answers questions about application purpose, architecture, features, and dependencies. "
+ "IMPORTANT: Queries must be specific — include at least one concrete term "
+ "(class name, config property, library name, or protocol). "
+ "GOOD queries: 'deserialization whitelist configuration', "
+ "'SSL/TLS certificate validation settings', 'input sanitization before parsing'. "
+ "BAD queries: 'security settings', 'configuration', 'error handling'. "
+ "Vague queries return unhelpful results."
+ ),
+ ToolNames.CODE_KEYWORD_SEARCH: (
+ "Performs keyword search on container source code for exact text matches. "
+ "Input should be a function name, class name, import statement, or code pattern."
+ ),
+ ToolNames.CONFIGURATION_SCANNER: (
+ "Scans configuration files (YAML, XML, properties, build files, Dockerfiles) "
+ "for feature flags, enabled components, and security settings. "
+ "Searches all sources including framework dependencies."
+ ),
+ ToolNames.IMPORT_USAGE_ANALYZER: (
+ "Finds all imports and usage patterns of a package/module across sources. "
+ "Reports which files import it, usage count, and how it's used. "
+ "Searches all sources including framework dependencies."
+ ),
+}
+
+# Per-ecosystem tool selection strategies guiding the agent's investigation order
+CU_TOOL_SELECTION_STRATEGY: dict[str, str] = {
+ "python": (
+ "1. Start with Configuration Scanner to check if affected features are enabled or configured. "
+ "2. Use Code Keyword Search for exact import/function lookups (e.g. 'import xml.etree'). "
+ "3. Use Import Usage Analyzer to understand how the affected package is imported and used. "
+ + _DOCS_SEARCH_STEP
+ ),
+ "go": (
+ "1. Start with Configuration Scanner to check configs for feature flags and security settings. "
+ "2. Use Code Keyword Search for import paths (e.g. 'encoding/xml'). "
+ "3. Use Import Usage Analyzer to trace imports of the affected package. "
+ + _DOCS_SEARCH_STEP
+ ),
+ "java": (
+ "1. Start with Configuration Scanner to check configs, pom.xml, and build.gradle for feature flags and security settings. "
+ "2. Use Code Keyword Search for import statements (e.g. 'import javax.xml'). "
+ "3. Use Import Usage Analyzer to trace how the affected library is imported across sources. "
+ + _DOCS_SEARCH_STEP
+ ),
+ "javascript": (
+ "1. Start with Configuration Scanner to check package.json, .env, and config files. "
+ "2. Use Code Keyword Search for require/import patterns (e.g. 'require(\"xml2js\")'). "
+ "3. Use Import Usage Analyzer to find all imports of the affected package. "
+ + _DOCS_SEARCH_STEP
+ ),
+ "c": (
+ "1. Start with Configuration Scanner to check Makefiles, CMakeLists.txt, and config headers. "
+ "2. Use Code Keyword Search for function names and #include patterns. "
+ "3. Use Import Usage Analyzer to trace #include usage of the affected library. "
+ + _DOCS_SEARCH_STEP
+ ),
+}
diff --git a/src/vuln_analysis/utils/full_text_search.py b/src/vuln_analysis/utils/full_text_search.py
index 0a02ab99..6fa0e8e5 100644
--- a/src/vuln_analysis/utils/full_text_search.py
+++ b/src/vuln_analysis/utils/full_text_search.py
@@ -16,6 +16,8 @@
import json
import os
import re
+import threading
+from collections import OrderedDict
from pathlib import Path
from typing import Iterable
@@ -35,9 +37,7 @@
def tokenize_code(code: str) -> str:
- """
- Tokenize code into simple readable document format
- """
+ """Normalize source code identifiers for search: split camelCase/snake_case, lowercase, filter noise."""
matches = re.finditer(r"\b\w{2,}\b", code)
tokens = []
for m in matches:
@@ -47,6 +47,7 @@ def tokenize_code(code: str) -> str:
for part in variable_pattern.findall(section):
if len(part) < 2:
continue
+ # Exclude hex-like strings and repetitive patterns
if (sum(1 for c in part if "a" <= c <= "z" or "A" <= c <= "Z" or "0" <= c <= "9") > len(part) // 2
and len(part) / len(set(part)) < 4):
tokens.append(part.lower())
@@ -55,8 +56,9 @@ def tokenize_code(code: str) -> str:
def clean_query(input_query) -> str:
- """
- Parse a query string with OR/AND operations and fix quotes issues using regex.
+ """Parse a query string with OR/AND operations and fix quote issues.
+
+ Sanitizes LLM-generated queries for Tantivy — unmatched quotes crash its parser.
"""
# Remove anything after newline characters
if '\n' in input_query:
@@ -85,16 +87,21 @@ def replace_quoted(match):
return input_query
-ECOSYSTEM_DEP_DIRS = {
- "c": "rpm_libs/",
- "go": "vendor/",
- "javascript": "node_modules/",
- "python": "transitive_env/",
- "java": "dependencies-sources/",
-}
+from vuln_analysis.utils.source_classification import is_dependency_path, filter_by_source_scope
class FullTextSearch:
+ """Tantivy-based full text search index for source code and documentation.
+
+ Wraps a Tantivy index with document management, query sanitization, and
+ result separation into app code vs. dependency code. Instances are cached
+ via get_instance() with LRU eviction to avoid holding too many open indexes.
+ """
INDEX_TYPE = "tantivy"
+ # LRU cache of FullTextSearch instances keyed by cache_path.
+ # Prevents unbounded growth of open Tantivy index handles.
+ _instances: "OrderedDict[str, FullTextSearch]" = None
+ _INSTANCE_CACHE_MAX = 10
+ _instances_lock = threading.Lock()
def __init__(self, cache_path: str = None, tokenizer=False):
if cache_path:
@@ -104,13 +111,37 @@ def __init__(self, cache_path: str = None, tokenizer=False):
self.index.reload()
self.tokenizer = tokenizer
+ @classmethod
+ def get_instance(cls, cache_path: str, tokenizer=False) -> "FullTextSearch":
+ """Return a cached instance for the given path, with LRU eviction.
+
+ On cache hit, moves the entry to the end (most recently used).
+ On cache miss, creates a new instance and evicts the least recently used
+ if the cache exceeds _INSTANCE_CACHE_MAX.
+ """
+ with cls._instances_lock:
+ if cls._instances is None:
+ cls._instances = OrderedDict()
+ if cache_path in cls._instances:
+ cls._instances.move_to_end(cache_path)
+ return cls._instances[cache_path]
+ instance = cls(cache_path=cache_path, tokenizer=tokenizer)
+ with cls._instances_lock:
+ if cache_path in cls._instances:
+ return cls._instances[cache_path]
+ cls._instances[cache_path] = instance
+ if len(cls._instances) > cls._INSTANCE_CACHE_MAX:
+ evicted = cls._instances.popitem(last=False)
+ logger.debug("Evicted LRU FullTextSearch instance: %s", evicted[0])
+ return instance
+
@classmethod
def get_index_directory(cls, base_path: str, hash_value: str) -> Path:
"""Returns the path where the index should be stored"""
return Path(base_path) / cls.INDEX_TYPE / hash_value
def _build_schema(self):
- """Build schema for the code index"""
+ """Build Tantivy schema: file_path + content (both stored for retrieval) + doc_id."""
schema_builder = SchemaBuilder()
schema_builder.add_text_field("file_path", stored=True)
schema_builder.add_text_field("content", stored=True)
@@ -119,17 +150,28 @@ def _build_schema(self):
return schema
def add_documents(self, documents: Iterable):
-
+ """Index an iterable of (file_path, content) tuples into Tantivy."""
writer = self.index.writer()
for doc_id, (title, text) in enumerate(documents):
writer.add_document(TantivyDocument(file_path=title, content=text, doc_id=doc_id))
writer.commit()
def is_empty(self):
+ """Check if the index contains any documents."""
return self.index.searcher().num_docs == 0
- def search_index(self, query: str, top_k: int = 10) -> str:
+ def search_index(self, query: str, top_k: int = 10, source_scope: list[str] | None = None) -> str:
+ """Search the index using Tantivy's ranked query engine.
+
+ Returns results separated into application code and dependency code sections.
+ When source_scope is provided, dependency results are filtered to only
+ include files matching the scope (e.g. specific package names).
+ Args:
+ query: Search query string (sanitized internally)
+ top_k: Maximum total results to return (app results prioritized)
+ source_scope: Optional list of path substrings to filter dependency results
+ """
self.index.reload()
try:
if self.tokenizer:
@@ -141,16 +183,20 @@ def search_index(self, query: str, top_k: int = 10) -> str:
app_docs = []
dep_docs = []
-
- vendors_list = list(ECOSYSTEM_DEP_DIRS.values())
for _, doc_id in results:
raw = searcher.doc(doc_id)
doc = {"source": raw["file_path"][0], "content": raw["content"][0]}
- if any(doc["source"].startswith(vendor) for vendor in vendors_list):
+ if is_dependency_path(doc["source"]):
dep_docs.append(doc)
else:
app_docs.append(doc)
+ pre_filter_count = len(dep_docs)
+ dep_docs = filter_by_source_scope(dep_docs, source_scope, lambda d: d["source"])
+ if pre_filter_count != len(dep_docs):
+ logger.debug("Source scope %s filtered dep results from %d to %d",
+ source_scope, pre_filter_count, len(dep_docs))
+
total_app = len(app_docs)
total_dep = len(dep_docs)
@@ -174,11 +220,11 @@ def search_index(self, query: str, top_k: int = 10) -> str:
return "\n".join(parts)
except Exception as e:
- logger.exception(e)
+ logger.exception("Search failed for query")
return "There was an error searching for the query."
def add_documents_from_langchain_chunks(self, documents: list[Document]):
- """Create an index from langchain chunked documents"""
+ """Index pre-chunked LangChain documents (used by document_embedding.py for VDB build)."""
try:
documents = [(doc.metadata['source'], doc.page_content) for doc in documents]
@@ -188,14 +234,18 @@ def add_documents_from_langchain_chunks(self, documents: list[Document]):
self.add_documents(tqdm(documents, total=len(documents), desc="Indexing"))
except Exception as e:
- logger.warning(f"Unable to add documents to the index {e}")
+ logger.warning("Unable to add documents to the index: %s", e)
def add_documents_from_code_path(self,
code_path: str,
include_extensions: list[str],
use_langparser=True,
splitter=True):
- """Create an index from raw files."""
+ """Index raw source files from a directory, optionally using LangChain's LanguageParser.
+
+ When use_langparser=True, files are parsed with language-aware chunking.
+ When False, files are read as raw text (optionally tokenized for search).
+ """
doc_content = []
@@ -227,6 +277,6 @@ def add_documents_from_code_path(self,
content = tokenize_code(content)
doc_content.append((file_path, content))
except Exception as e:
- logger.warning(f"Error reading {file_path}: {e}")
+ logger.warning("Error reading %s: %s", file_path, e)
self.add_documents(doc_content)
diff --git a/src/vuln_analysis/utils/intel_utils.py b/src/vuln_analysis/utils/intel_utils.py
index 32e6942f..d8dee99c 100644
--- a/src/vuln_analysis/utils/intel_utils.py
+++ b/src/vuln_analysis/utils/intel_utils.py
@@ -26,6 +26,8 @@
logger = LoggingFactory.get_agent_logger(__name__)
+_MAX_RHSA_CANDIDATES = 20
+
def update_version(incoming_version, current_version, compare):
"""
@@ -200,12 +202,19 @@ def build_critical_context(cve_intel_list) -> tuple[list[str], list[dict], list[
critical_context.append(f"Search keywords: {', '.join(short_names)}")
for f in vf:
vulnerable_functions.add(f.rsplit('.', 1)[-1])
+ ver_range = vuln.get('vulnerable_version_range', '') if isinstance(vuln, dict) else getattr(v, 'vulnerable_version_range', '')
+ patched = vuln.get('first_patched_version', '') if isinstance(vuln, dict) else getattr(v, 'first_patched_version', '')
if pkg:
if isinstance(pkg, dict):
pkg_name = pkg.get("name", "")
pkg_eco = pkg.get("ecosystem", "")
if pkg_name:
critical_context.append(f"Vulnerable module ({pkg_eco}): {pkg_name}")
+ if ver_range:
+ ver_ctx = f"Vulnerable version range ({pkg_name}): {ver_range}"
+ if patched:
+ ver_ctx += f" | First patched version: {patched}"
+ critical_context.append(ver_ctx)
if pkg_name not in seen_packages:
seen_packages.add(pkg_name)
candidate_packages.append({"name": pkg_name, "source": "ghsa", "ecosystem": pkg_eco})
@@ -227,10 +236,14 @@ def build_critical_context(cve_intel_list) -> tuple[list[str], list[dict], list[
critical_context.append(f"KNOWN MITIGATIONS: {mit_text[:500]}")
if cve_intel.rhsa.package_state:
pkgs = list(set(p.package_name for p in cve_intel.rhsa.package_state if p.package_name))
+ rhsa_added = 0
for p in pkgs:
if p not in seen_packages:
seen_packages.add(p)
candidate_packages.append({"name": p, "source": "rhsa"})
+ rhsa_added += 1
+ if rhsa_added >= _MAX_RHSA_CANDIDATES:
+ break
if len(pkgs) > 5:
critical_context.append(
f"Affected across {len(pkgs)} Red Hat products (sample: {', '.join(pkgs[:5])}). "
diff --git a/src/vuln_analysis/utils/source_classification.py b/src/vuln_analysis/utils/source_classification.py
new file mode 100644
index 00000000..6a74fd12
--- /dev/null
+++ b/src/vuln_analysis/utils/source_classification.py
@@ -0,0 +1,67 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Callable, TypeVar
+
+from exploit_iq_commons.utils.dep_tree import ECOSYSTEM_DEP_DIRS
+
+_VENDORS = list(ECOSYSTEM_DEP_DIRS.values())
+
+T = TypeVar("T")
+
+
+def is_dependency_path(file_path: str) -> bool:
+ """Return True if file_path starts with any ecosystem dependency directory prefix."""
+ return any(file_path.startswith(vendor) for vendor in _VENDORS)
+
+
+def filter_by_source_scope(dep_items: list[T], source_scope: list[str] | None,
+ path_fn: Callable[[T], str]) -> list[T]:
+ """Filter dependency items to those whose path contains any source_scope substring.
+
+ Args:
+ dep_items: Dependency result items to filter.
+ source_scope: Path substrings to match. None means no filtering.
+ path_fn: Extracts the file path string from an item.
+ """
+ if not source_scope:
+ return dep_items
+ return [item for item in dep_items if any(scope in path_fn(item) for scope in source_scope)]
+
+
+def format_app_dep_output(app_items: list[str], dep_items: list[str],
+ total_app: int, total_dep: int,
+ no_results_msg: str) -> str:
+ """Format results with app/dep section headers matching Code Keyword Search output.
+
+ Output format:
+ Main application (N of M results)
+
+ Application library dependencies (N of M results)
+
+ """
+ if total_app == 0 and total_dep == 0:
+ return no_results_msg
+
+ app_header = f"Main application ({len(app_items)} of {total_app} results)"
+ dep_header = f"Application library dependencies ({len(dep_items)} of {total_dep} results)"
+
+ parts = [app_header]
+ if app_items:
+ parts.append("\n\n".join(app_items))
+ parts.append(dep_header)
+ if dep_items:
+ parts.append("\n\n".join(dep_items))
+ return "\n".join(parts)
diff --git a/src/vuln_analysis/utils/token_utils.py b/src/vuln_analysis/utils/token_utils.py
new file mode 100644
index 00000000..c3e43589
--- /dev/null
+++ b/src/vuln_analysis/utils/token_utils.py
@@ -0,0 +1,99 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Token counting and truncation utilities for agent context management."""
+
+import tiktoken
+
+from vuln_analysis.tools.tool_names import ToolNames
+
+_tiktoken_enc = tiktoken.get_encoding("cl100k_base")
+
+
+def count_tokens(text: str) -> int:
+ """Count the number of tokens in a text string using the cl100k_base encoding.
+
+ Falls back to a rough character-based estimate if encoding fails.
+ """
+ try:
+ return len(_tiktoken_enc.encode(text))
+ except Exception:
+ return len(text) // 4
+
+
+def estimate_tokens(runtime_prompt: str, messages: list, observation) -> int:
+ """Estimate total token count of the current agent context (prompt + messages + observation)."""
+ parts = [runtime_prompt]
+ for msg in messages:
+ if hasattr(msg, "content") and isinstance(msg.content, str):
+ parts.append(msg.content)
+ if observation is not None:
+ for item in (observation.memory or []):
+ parts.append(item)
+ for item in (observation.results or []):
+ parts.append(item)
+ return count_tokens("\n".join(parts))
+
+
+def truncate_tool_output(tool_output: str, tool_name: str, max_tokens: int = 400) -> str:
+ """Truncate tool output to fit within a token budget.
+
+ Uses tool-specific strategies: Call Chain Analyzer keeps the header,
+ Code Keyword Search preserves section headers, and the default strategy
+ keeps head (70%) and tail (30%) of the output.
+ """
+ token_count = count_tokens(tool_output)
+ if token_count <= max_tokens:
+ return tool_output
+
+ lines = tool_output.split('\n')
+
+ if tool_name == ToolNames.CALL_CHAIN_ANALYZER:
+ head = '\n'.join(lines[:3])
+ remaining = token_count - count_tokens(head)
+ return f"{head}\n[... truncated {remaining} tokens ...]"
+
+ if tool_name == ToolNames.CODE_KEYWORD_SEARCH:
+ kept_lines = []
+ kept_tokens = 0
+ for line in lines:
+ is_header = line.startswith("---") or "Main application" in line or "library dependencies" in line or line.strip() == ""
+ line_tokens = count_tokens(line)
+ if is_header or kept_tokens + line_tokens <= max_tokens:
+ kept_lines.append(line)
+ kept_tokens += line_tokens
+ if kept_tokens >= max_tokens:
+ break
+ kept_lines.append(f"[... truncated {token_count - kept_tokens} tokens ...]")
+ return '\n'.join(kept_lines)
+
+ head_budget = int(max_tokens * 0.7)
+ tail_budget = max_tokens - head_budget
+ head_lines, head_tokens = [], 0
+ for line in lines:
+ lt = count_tokens(line)
+ if head_tokens + lt > head_budget:
+ break
+ head_lines.append(line)
+ head_tokens += lt
+ tail_lines, tail_tokens = [], 0
+ for line in reversed(lines):
+ lt = count_tokens(line)
+ if tail_tokens + lt > tail_budget:
+ break
+ tail_lines.insert(0, line)
+ tail_tokens += lt
+ truncated = token_count - head_tokens - tail_tokens
+ return '\n'.join(head_lines) + f"\n[... truncated {truncated} tokens ...]\n" + '\n'.join(tail_lines)
diff --git a/tests/agent_test_helpers.py b/tests/agent_test_helpers.py
new file mode 100644
index 00000000..2d60aed2
--- /dev/null
+++ b/tests/agent_test_helpers.py
@@ -0,0 +1,61 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from unittest.mock import MagicMock
+
+from vuln_analysis.tools.tool_names import ToolNames
+
+
+class MockTool:
+ def __init__(self, name: str):
+ self.name = name
+
+
+ALL_TOOLS = [
+ MockTool(ToolNames.CODE_SEMANTIC_SEARCH),
+ MockTool(ToolNames.DOCS_SEMANTIC_SEARCH),
+ MockTool(ToolNames.CODE_KEYWORD_SEARCH),
+ MockTool(ToolNames.FUNCTION_LOCATOR),
+ MockTool(ToolNames.CALL_CHAIN_ANALYZER),
+ MockTool(ToolNames.FUNCTION_CALLER_FINDER),
+ MockTool(ToolNames.CVE_WEB_SEARCH),
+ MockTool(ToolNames.CONTAINER_ANALYSIS_DATA),
+ MockTool(ToolNames.FUNCTION_LIBRARY_VERSION_FINDER),
+ MockTool(ToolNames.CONFIGURATION_SCANNER),
+ MockTool(ToolNames.IMPORT_USAGE_ANALYZER),
+]
+
+
+def make_builder(tools=None):
+ builder = MagicMock()
+ builder.get_tools = MagicMock(return_value=list(tools if tools is not None else ALL_TOOLS))
+ return builder
+
+
+def make_config(**overrides):
+ config = MagicMock()
+ config.tool_names = overrides.get("tool_names", [])
+ config.transitive_search_tool_enabled = overrides.get("transitive_search_tool_enabled", True)
+ config.cve_web_search_enabled = overrides.get("cve_web_search_enabled", True)
+ config.max_iterations = 10
+ return config
+
+
+def make_state(code_vdb_path="/path", doc_vdb_path="/path", code_index_path="/path"):
+ state = MagicMock()
+ state.code_vdb_path = code_vdb_path
+ state.doc_vdb_path = doc_vdb_path
+ state.code_index_path = code_index_path
+ return state
diff --git a/tests/test_agent_registry.py b/tests/test_agent_registry.py
new file mode 100644
index 00000000..6260495d
--- /dev/null
+++ b/tests/test_agent_registry.py
@@ -0,0 +1,102 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Unit tests for agent_registry: register_agent decorator, get_agent_class, get_all_agent_types.
+"""
+
+import pytest
+
+from vuln_analysis.functions.agent_registry import (
+ _AGENT_REGISTRY,
+ register_agent,
+ get_agent_class,
+ get_all_agent_types,
+)
+
+
+class TestRealAgentRegistration:
+ """Verify that importing the agent modules registers both agents."""
+
+ def test_reachability_registered(self):
+ import vuln_analysis.functions.reachability_agent # noqa: F401
+ assert "reachability" in _AGENT_REGISTRY
+
+ def test_code_understanding_registered(self):
+ import vuln_analysis.functions.code_understanding_agent # noqa: F401
+ assert "code_understanding" in _AGENT_REGISTRY
+
+ def test_get_all_agent_types_contains_both(self):
+ import vuln_analysis.functions.reachability_agent # noqa: F401
+ import vuln_analysis.functions.code_understanding_agent # noqa: F401
+ types = get_all_agent_types()
+ assert "reachability" in types
+ assert "code_understanding" in types
+
+ def test_get_agent_class_reachability(self):
+ from vuln_analysis.functions.reachability_agent import ReachabilityAgent
+ assert get_agent_class("reachability") is ReachabilityAgent
+
+ def test_get_agent_class_code_understanding(self):
+ from vuln_analysis.functions.code_understanding_agent import CodeUnderstandingAgent
+ assert get_agent_class("code_understanding") is CodeUnderstandingAgent
+
+
+class TestGetAgentClassErrors:
+ """Test error cases for get_agent_class."""
+
+ def test_unknown_type_raises_key_error(self):
+ with pytest.raises(KeyError, match="Unknown agent type 'nonexistent'"):
+ get_agent_class("nonexistent")
+
+ def test_error_message_lists_registered_types(self):
+ import vuln_analysis.functions.reachability_agent # noqa: F401
+ import vuln_analysis.functions.code_understanding_agent # noqa: F401
+ with pytest.raises(KeyError) as exc_info:
+ get_agent_class("bogus")
+ msg = str(exc_info.value)
+ assert "reachability" in msg
+ assert "code_understanding" in msg
+
+
+class TestRegisterAgentDecorator:
+ """Test the register_agent decorator mechanics using a dummy class."""
+
+ def setup_method(self):
+ self._saved = dict(_AGENT_REGISTRY)
+
+ def teardown_method(self):
+ _AGENT_REGISTRY.clear()
+ _AGENT_REGISTRY.update(self._saved)
+
+ def test_decorator_registers_class(self):
+ @register_agent("test_dummy")
+ class DummyAgent:
+ pass
+
+ assert get_agent_class("test_dummy") is DummyAgent
+
+ def test_decorator_returns_original_class(self):
+ @register_agent("test_dummy2")
+ class DummyAgent:
+ pass
+
+ assert DummyAgent.__name__ == "DummyAgent"
+
+ def test_re_registration_replaces_class(self):
+ @register_agent("test_replace")
+ class First:
+ pass
+
+ @register_agent("test_replace")
+ class Second:
+ pass
+
+ assert get_agent_class("test_replace") is Second
+
+ def test_registered_type_appears_in_all_types(self):
+ @register_agent("test_listed")
+ class Listed:
+ pass
+
+ assert "test_listed" in get_all_agent_types()
diff --git a/tests/test_base_graph_agent.py b/tests/test_base_graph_agent.py
new file mode 100644
index 00000000..2ea1bcf4
--- /dev/null
+++ b/tests/test_base_graph_agent.py
@@ -0,0 +1,805 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Unit tests for BaseGraphAgent: should_continue routing, default hooks, agent_type property,
+thought_node context pruning.
+"""
+
+import pytest
+from unittest.mock import MagicMock, AsyncMock, patch
+from langchain_core.messages import SystemMessage, AIMessage, HumanMessage, ToolMessage
+
+from vuln_analysis.functions.base_graph_agent import BaseGraphAgent
+from vuln_analysis.functions.react_internals import Thought, ToolCall, Observation
+
+
+class _ConcreteAgent(BaseGraphAgent):
+ """Minimal concrete subclass for testing base class behavior."""
+
+ async def pre_process_node(self, state):
+ return state
+
+ @staticmethod
+ def get_tools(builder, config, state):
+ return []
+
+ @staticmethod
+ def create_rules_tracker():
+ return MagicMock()
+
+
+def _make_agent():
+ mock_llm = MagicMock()
+ mock_llm.with_structured_output = MagicMock(return_value=MagicMock())
+ config = MagicMock()
+ config.max_iterations = 10
+ return _ConcreteAgent(tools=[], llm=mock_llm, config=config)
+
+
+class TestShouldContinue:
+ """Test should_continue routing logic."""
+
+ @pytest.mark.asyncio
+ async def test_returns_end_on_finish_mode(self):
+ agent = _make_agent()
+ thought = Thought(thought="done", mode="finish", actions=None, final_answer="answer")
+ state = {"thought": thought, "step": 3, "max_steps": 10}
+ result = await agent.should_continue(state)
+ assert result == "__end__"
+
+ @pytest.mark.asyncio
+ async def test_returns_forced_finish_at_max_steps(self):
+ agent = _make_agent()
+ thought = Thought(
+ thought="still working",
+ mode="act",
+ actions=ToolCall(tool="some_tool", query="q", reason="testing"),
+ final_answer=None,
+ )
+ state = {"thought": thought, "step": 10, "max_steps": 10}
+ result = await agent.should_continue(state)
+ assert result == "forced_finish_node"
+
+ @pytest.mark.asyncio
+ async def test_returns_forced_finish_beyond_max_steps(self):
+ agent = _make_agent()
+ thought = Thought(
+ thought="still working",
+ mode="act",
+ actions=ToolCall(tool="some_tool", query="q", reason="testing"),
+ final_answer=None,
+ )
+ state = {"thought": thought, "step": 15, "max_steps": 10}
+ result = await agent.should_continue(state)
+ assert result == "forced_finish_node"
+
+ @pytest.mark.asyncio
+ async def test_returns_tool_node_when_continuing(self):
+ agent = _make_agent()
+ thought = Thought(
+ thought="need more info",
+ mode="act",
+ actions=ToolCall(tool="some_tool", query="q", reason="testing"),
+ final_answer=None,
+ )
+ state = {"thought": thought, "step": 3, "max_steps": 10}
+ result = await agent.should_continue(state)
+ assert result == "tool_node"
+
+ @pytest.mark.asyncio
+ async def test_returns_thought_node_when_thought_is_none(self):
+ agent = _make_agent()
+ state = {"thought": None, "step": 0, "max_steps": 10}
+ result = await agent.should_continue(state)
+ assert result == "thought_node"
+
+ @pytest.mark.asyncio
+ async def test_forced_finish_when_thought_none_at_max_steps(self):
+ """Step limit must be enforced even when thought is None (e.g. after
+ check_finish_allowed repeatedly blocks). Without this, the agent
+ self-loops thought_node→thought_node until GraphRecursionError."""
+ agent = _make_agent()
+ state = {"thought": None, "step": 10, "max_steps": 10}
+ result = await agent.should_continue(state)
+ assert result == "forced_finish_node"
+
+ @pytest.mark.asyncio
+ async def test_forced_finish_when_thought_none_beyond_max_steps(self):
+ agent = _make_agent()
+ state = {"thought": None, "step": 15, "max_steps": 10}
+ result = await agent.should_continue(state)
+ assert result == "forced_finish_node"
+
+ @pytest.mark.asyncio
+ async def test_uses_config_max_iterations_as_fallback(self):
+ agent = _make_agent()
+ thought = Thought(
+ thought="working",
+ mode="act",
+ actions=ToolCall(tool="some_tool", query="q", reason="testing"),
+ final_answer=None,
+ )
+ state = {"thought": thought, "step": 10}
+ result = await agent.should_continue(state)
+ assert result == "forced_finish_node"
+
+
+class TestDefaultHooks:
+ """Test default hook implementations on BaseGraphAgent."""
+
+ def test_post_observation_returns_empty_dict(self):
+ agent = _make_agent()
+ result = agent.post_observation(state={}, tool_used="X", tool_output="Y", tool_input_detail="Z")
+ assert result == {}
+
+ def test_should_truncate_returns_false(self):
+ agent = _make_agent()
+ result = agent.should_truncate_tool_output(state={}, tool_used="X")
+ assert result is False
+
+ def test_agent_type_property(self):
+ agent = _make_agent()
+ assert agent.agent_type == "base"
+
+ def test_build_comprehension_context_returns_full_context(self):
+ agent = _make_agent()
+ state = {"critical_context": ["CVE Description: test vuln", "Vulnerable module: xstream"]}
+ result = agent.build_comprehension_context(state)
+ assert "CVE Description: test vuln" in result
+ assert "Vulnerable module: xstream" in result
+
+ def test_build_comprehension_context_empty_state(self):
+ agent = _make_agent()
+ assert agent.build_comprehension_context({}) == "N/A"
+
+ def test_build_comprehension_context_empty_list(self):
+ agent = _make_agent()
+ assert agent.build_comprehension_context({"critical_context": []}) == "N/A"
+
+ @patch("vuln_analysis.functions.base_graph_agent.ctx_state")
+ def test_sanitize_findings_replaces_wrong_cve(self, mock_ctx):
+ intel = MagicMock()
+ intel.vuln_id = "CVE-2021-43859"
+ ws = MagicMock()
+ ws.cve_intel = [intel]
+ mock_ctx.get.return_value = ws
+
+ agent = _make_agent()
+ findings = ["Found CVE-2020-26217 in code", "Package present"]
+ result = agent.sanitize_findings(findings, {})
+ assert "CVE-2020-26217" not in result[0]
+ assert "the investigated vulnerability" in result[0]
+ assert result[1] == "Package present"
+
+ @patch("vuln_analysis.functions.base_graph_agent.ctx_state")
+ def test_sanitize_findings_keeps_correct_cve(self, mock_ctx):
+ intel = MagicMock()
+ intel.vuln_id = "CVE-2021-43859"
+ ws = MagicMock()
+ ws.cve_intel = [intel]
+ mock_ctx.get.return_value = ws
+
+ agent = _make_agent()
+ findings = ["Affects CVE-2021-43859"]
+ result = agent.sanitize_findings(findings, {})
+ assert result == ["Affects CVE-2021-43859"]
+
+ @patch("vuln_analysis.functions.base_graph_agent.ctx_state")
+ def test_sanitize_findings_empty_list(self, mock_ctx):
+ ws = MagicMock()
+ ws.cve_intel = []
+ mock_ctx.get.return_value = ws
+
+ agent = _make_agent()
+ assert agent.sanitize_findings([], {}) == []
+
+
+class TestInit:
+ """Test BaseGraphAgent constructor wires up LLM wrappers."""
+
+ def test_creates_four_structured_output_llms(self):
+ mock_llm = MagicMock()
+ config = MagicMock()
+ config.max_iterations = 10
+ agent = _ConcreteAgent(tools=["t1", "t2"], llm=mock_llm, config=config)
+
+ assert mock_llm.with_structured_output.call_count == 4
+ assert agent.tools == ["t1", "t2"]
+ assert agent.config is config
+
+
+def _make_thought_response(mode="finish", final_answer="done"):
+ return Thought(thought="thinking", mode=mode, actions=None, final_answer=final_answer)
+
+
+def _long_content(n_words=500):
+ return " ".join(["word"] * n_words)
+
+
+class TestThoughtNodePruning:
+ """Test that thought_node prunes messages when tokens exceed the limit."""
+
+ @pytest.mark.asyncio
+ @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER")
+ async def test_prunes_middle_messages_when_over_limit(self, mock_tracer):
+ agent = _make_agent()
+ agent.config.context_window_token_limit = 100
+ agent.thought_llm.ainvoke = AsyncMock(return_value=_make_thought_response())
+
+ long = _long_content(200)
+ state = {
+ "runtime_prompt": "system prompt",
+ "messages": [
+ HumanMessage(content=long),
+ AIMessage(content=long),
+ ToolMessage(content=long, tool_call_id="tc1"),
+ AIMessage(content=long),
+ ToolMessage(content="recent tool output", tool_call_id="tc2"),
+ HumanMessage(content="recent question"),
+ ],
+ "observation": None,
+ "step": 2,
+ }
+
+ await agent.thought_node(state)
+
+ invoked_messages = agent.thought_llm.ainvoke.call_args[0][0]
+ num_original = 1 + 6 # system prompt + 6 state messages
+ assert len(invoked_messages) < num_original
+
+ @pytest.mark.asyncio
+ @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER")
+ async def test_no_pruning_when_under_limit(self, mock_tracer):
+ agent = _make_agent()
+ agent.config.context_window_token_limit = 50000
+ agent.thought_llm.ainvoke = AsyncMock(return_value=_make_thought_response())
+
+ state = {
+ "runtime_prompt": "short prompt",
+ "messages": [
+ HumanMessage(content="hello"),
+ AIMessage(content="response"),
+ ],
+ "observation": None,
+ "step": 1,
+ }
+
+ await agent.thought_node(state)
+
+ invoked_messages = agent.thought_llm.ainvoke.call_args[0][0]
+ contents = [m.content for m in invoked_messages if hasattr(m, "content")]
+ assert "hello" in contents
+ assert "response" in contents
+
+ @pytest.mark.asyncio
+ @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER")
+ async def test_preserves_system_prompt_and_last_message(self, mock_tracer):
+ agent = _make_agent()
+ agent.config.context_window_token_limit = 50
+ agent.thought_llm.ainvoke = AsyncMock(return_value=_make_thought_response())
+
+ long = _long_content(200)
+ state = {
+ "runtime_prompt": "system prompt",
+ "messages": [
+ HumanMessage(content=long),
+ AIMessage(content=long),
+ ToolMessage(content=long, tool_call_id="tc1"),
+ HumanMessage(content="latest question"),
+ ],
+ "observation": None,
+ "step": 3,
+ }
+
+ await agent.thought_node(state)
+
+ invoked_messages = agent.thought_llm.ainvoke.call_args[0][0]
+ contents = [m.content for m in invoked_messages if hasattr(m, "content")]
+ assert "system prompt" in contents
+ assert "latest question" in contents
+ assert long not in contents
+
+ @pytest.mark.asyncio
+ @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER")
+ async def test_pruning_includes_observation_context_in_count(self, mock_tracer):
+ agent = _make_agent()
+ agent.config.context_window_token_limit = 200
+ agent.thought_llm.ainvoke = AsyncMock(return_value=_make_thought_response())
+
+ long = _long_content(100)
+ obs = Observation(
+ memory=[_long_content(50)],
+ results=[_long_content(50)],
+ )
+ state = {
+ "runtime_prompt": "system prompt",
+ "messages": [
+ HumanMessage(content=long),
+ AIMessage(content=long),
+ ToolMessage(content="tool out", tool_call_id="tc1"),
+ HumanMessage(content="question"),
+ ],
+ "observation": obs,
+ "step": 2,
+ }
+
+ await agent.thought_node(state)
+
+ invoked_messages = agent.thought_llm.ainvoke.call_args[0][0]
+ assert any("KNOWLEDGE" in m.content for m in invoked_messages if hasattr(m, "content") and isinstance(m.content, str))
+ contents = [m.content for m in invoked_messages if hasattr(m, "content")]
+ assert long not in contents
+
+
+class TestThoughtNodeBadToolArguments:
+ """Test that thought_node recovers from bad tool arguments instead of crashing.
+
+ Mirrors the old AgentExecutor's handle_parsing_errors behavior: when the LLM
+ produces a ToolCall with missing required fields, the agent should get an error
+ message and retry rather than killing the entire graph.
+ """
+
+ @pytest.mark.asyncio
+ @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER")
+ async def test_recovers_from_missing_arguments(self, mock_tracer):
+ """When all ToolCall fields are None, thought_node returns an error
+ HumanMessage with thought=None so should_continue loops back."""
+ agent = _make_agent()
+ agent.config.context_window_token_limit = 50000
+
+ bad_actions = ToolCall(
+ tool="Function Library Version Finder",
+ package_name=None,
+ function_name=None,
+ query=None,
+ tool_input=None,
+ reason="check version",
+ )
+ bad_response = Thought(
+ thought="check the version",
+ mode="act",
+ actions=bad_actions,
+ final_answer=None,
+ )
+ agent.thought_llm.ainvoke = AsyncMock(return_value=bad_response)
+
+ state = {
+ "runtime_prompt": "system prompt",
+ "messages": [HumanMessage(content="Is SslHandler used?")],
+ "observation": None,
+ "step": 2,
+ }
+
+ result = await agent.thought_node(state)
+
+ assert result["thought"] is None
+ assert result["step"] == 3
+ assert result["output"] == "waiting for the agent to respond"
+ assert len(result["messages"]) == 2
+ ai_msg = result["messages"][0]
+ assert isinstance(ai_msg, AIMessage)
+ assert "check the version" in ai_msg.content
+ error_msg = result["messages"][1]
+ assert isinstance(error_msg, HumanMessage)
+ assert "ERROR" in error_msg.content
+ assert "Function Library Version Finder" in error_msg.content
+
+ @pytest.mark.asyncio
+ @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER")
+ async def test_recovery_routes_back_to_thought_node_bad_args(self, mock_tracer):
+ """After a bad-arguments recovery, should_continue returns 'thought_node'
+ because thought is None — the agent gets another chance."""
+ agent = _make_agent()
+
+ state = {"thought": None, "step": 3, "max_steps": 10}
+ route = await agent.should_continue(state)
+ assert route == "thought_node"
+
+ @pytest.mark.asyncio
+ @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER")
+ async def test_recovery_still_counts_toward_max_steps(self, mock_tracer):
+ """A bad-arguments iteration increments step, so the agent hits
+ forced_finish_node when step reaches max_steps."""
+ agent = _make_agent()
+ agent.config.context_window_token_limit = 50000
+
+ bad_actions = ToolCall(
+ tool="Some Tool",
+ reason="testing",
+ )
+ bad_response = Thought(
+ thought="trying something",
+ mode="act",
+ actions=bad_actions,
+ final_answer=None,
+ )
+ agent.thought_llm.ainvoke = AsyncMock(return_value=bad_response)
+
+ state = {
+ "runtime_prompt": "prompt",
+ "messages": [HumanMessage(content="question")],
+ "observation": None,
+ "step": 9,
+ "max_steps": 10,
+ }
+
+ result = await agent.thought_node(state)
+
+ assert result["step"] == 10
+ assert result["thought"] is None
+
+ route = await agent.should_continue(result)
+ assert route == "forced_finish_node"
+
+ @pytest.mark.asyncio
+ @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER")
+ async def test_valid_tool_call_still_works(self, mock_tracer):
+ """Verify that valid tool calls are not affected by the ValueError handling."""
+ agent = _make_agent()
+ agent.config.context_window_token_limit = 50000
+
+ good_actions = ToolCall(
+ tool="Configuration Scanner",
+ query="netty SSL settings",
+ reason="check config",
+ )
+ good_response = Thought(
+ thought="scan for config",
+ mode="act",
+ actions=good_actions,
+ final_answer=None,
+ )
+ agent.thought_llm.ainvoke = AsyncMock(return_value=good_response)
+
+ state = {
+ "runtime_prompt": "system prompt",
+ "messages": [HumanMessage(content="question")],
+ "observation": None,
+ "step": 1,
+ }
+
+ result = await agent.thought_node(state)
+
+ assert result["thought"] is good_response
+ assert result["step"] == 2
+ ai_msg = result["messages"][0]
+ assert isinstance(ai_msg, AIMessage)
+ assert ai_msg.tool_calls[0]["name"] == "Configuration Scanner"
+ assert ai_msg.tool_calls[0]["args"] == {"query": "netty SSL settings"}
+
+
+class TestCheckFinishAllowedBlocking:
+ """Test that blocked finishes include AIMessage and respect step limits."""
+
+ @pytest.mark.asyncio
+ @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER")
+ async def test_blocked_finish_includes_ai_message(self, mock_tracer):
+ """When check_finish_allowed blocks, the LLM's finish attempt must be
+ recorded as an AIMessage so the chat model sees its own response and
+ the rejection in proper Human/AI alternation."""
+ agent = _make_agent()
+ agent.config.context_window_token_limit = 50000
+ agent.check_finish_allowed = MagicMock(
+ return_value=(False, "You MUST use Function Locator first.")
+ )
+
+ finish_response = Thought(
+ thought="I have enough info",
+ mode="finish",
+ actions=None,
+ final_answer="The function is not reachable.",
+ )
+ agent.thought_llm.ainvoke = AsyncMock(return_value=finish_response)
+
+ state = {
+ "runtime_prompt": "system prompt",
+ "messages": [HumanMessage(content="Is the function reachable?")],
+ "observation": None,
+ "step": 2,
+ }
+
+ result = await agent.thought_node(state)
+
+ assert result["thought"] is None
+ assert result["step"] == 3
+ assert len(result["messages"]) == 2
+ ai_msg = result["messages"][0]
+ assert isinstance(ai_msg, AIMessage)
+ assert "The function is not reachable." in ai_msg.content
+ human_msg = result["messages"][1]
+ assert isinstance(human_msg, HumanMessage)
+ assert "Function Locator" in human_msg.content
+
+ @pytest.mark.asyncio
+ @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER")
+ async def test_blocked_finish_ai_message_falls_back_to_thought(self, mock_tracer):
+ """When final_answer is None, the AIMessage should use the thought text."""
+ agent = _make_agent()
+ agent.config.context_window_token_limit = 50000
+ agent.check_finish_allowed = MagicMock(
+ return_value=(False, "Call CCA first.")
+ )
+
+ finish_response = Thought(
+ thought="seems done",
+ mode="finish",
+ actions=None,
+ final_answer=None,
+ )
+ agent.thought_llm.ainvoke = AsyncMock(return_value=finish_response)
+
+ state = {
+ "runtime_prompt": "prompt",
+ "messages": [HumanMessage(content="question")],
+ "observation": None,
+ "step": 0,
+ }
+
+ result = await agent.thought_node(state)
+
+ ai_msg = result["messages"][0]
+ assert isinstance(ai_msg, AIMessage)
+ assert "seems done" in ai_msg.content
+
+ @pytest.mark.asyncio
+ @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER")
+ async def test_blocked_finish_at_max_steps_routes_to_forced_finish(self, mock_tracer):
+ """If check_finish_allowed blocks at step 9 (incrementing to 10),
+ should_continue must route to forced_finish_node, not self-loop."""
+ agent = _make_agent()
+ agent.config.context_window_token_limit = 50000
+ agent.check_finish_allowed = MagicMock(
+ return_value=(False, "Call FL and CCA first.")
+ )
+
+ finish_response = Thought(
+ thought="done",
+ mode="finish",
+ actions=None,
+ final_answer="answer",
+ )
+ agent.thought_llm.ainvoke = AsyncMock(return_value=finish_response)
+
+ state = {
+ "runtime_prompt": "prompt",
+ "messages": [HumanMessage(content="question")],
+ "observation": None,
+ "step": 9,
+ "max_steps": 10,
+ }
+
+ result = await agent.thought_node(state)
+ assert result["step"] == 10
+ assert result["thought"] is None
+
+ route = await agent.should_continue(result)
+ assert route == "forced_finish_node"
+
+ @pytest.mark.asyncio
+ @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER")
+ async def test_bad_args_includes_ai_message(self, mock_tracer):
+ """Bad tool arguments error must include an AIMessage with the LLM's
+ original thought for proper chat alternation."""
+ agent = _make_agent()
+ agent.config.context_window_token_limit = 50000
+
+ bad_actions = ToolCall(
+ tool="Function Locator",
+ reason="locate function",
+ )
+ bad_response = Thought(
+ thought="Let me find the function",
+ mode="act",
+ actions=bad_actions,
+ final_answer=None,
+ )
+ agent.thought_llm.ainvoke = AsyncMock(return_value=bad_response)
+
+ state = {
+ "runtime_prompt": "prompt",
+ "messages": [HumanMessage(content="question")],
+ "observation": None,
+ "step": 3,
+ }
+
+ result = await agent.thought_node(state)
+
+ assert len(result["messages"]) == 2
+ ai_msg = result["messages"][0]
+ assert isinstance(ai_msg, AIMessage)
+ assert "Let me find the function" in ai_msg.content
+ error_msg = result["messages"][1]
+ assert isinstance(error_msg, HumanMessage)
+ assert "ERROR" in error_msg.content
+
+
+class TestSelectPackage:
+ """Tests for _select_package image-match fast path and LLM fallback."""
+
+ def _make_workflow_state(self, image_name="registry.redhat.io/openshift4/ose-docker-builder",
+ git_repo="https://github.com/openshift/builder"):
+ si = MagicMock()
+ si.git_repo = git_repo
+ image = MagicMock()
+ image.name = image_name
+ image.source_info = [si]
+ ws = MagicMock()
+ ws.original_input.input.image = image
+ return ws
+
+ @pytest.mark.asyncio
+ async def test_image_match_skips_llm(self):
+ """When a candidate name matches the image/repo, LLM is not called."""
+ agent = _make_agent()
+ candidates = [
+ {"name": "builder", "source": "rhsa"},
+ {"name": "kernel", "source": "rhsa"},
+ {"name": "glibc", "source": "rhsa"},
+ ]
+ ws = self._make_workflow_state()
+
+ with patch("vuln_analysis.utils.intel_utils.filter_context_to_package",
+ side_effect=lambda ctx, pkg, cands: ctx):
+ ctx, selected = await agent._select_package(
+ "go", candidates, ["CVE desc"], ws,
+ )
+
+ assert selected == "builder"
+ agent.package_filter_llm.ainvoke.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_no_match_calls_llm(self):
+ """When no candidate matches the image, LLM is called."""
+ agent = _make_agent()
+ mock_selection = MagicMock()
+ mock_selection.selected_package = "xstream"
+ mock_selection.reason = "ecosystem match"
+ agent.package_filter_llm.ainvoke = AsyncMock(return_value=mock_selection)
+
+ candidates = [
+ {"name": "xstream", "source": "ghsa", "ecosystem": "Maven"},
+ {"name": "kernel", "source": "rhsa"},
+ ]
+ ws = self._make_workflow_state(image_name="registry.redhat.io/infinispan/server",
+ git_repo="https://github.com/infinispan/infinispan")
+
+ with patch("vuln_analysis.utils.intel_utils.filter_context_to_package",
+ side_effect=lambda ctx, pkg, cands: ctx):
+ ctx, selected = await agent._select_package(
+ "java", candidates, ["CVE desc"], ws,
+ )
+
+ assert selected == "xstream"
+ agent.package_filter_llm.ainvoke.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_image_match_with_many_candidates(self):
+ """1000+ candidates with image match -> LLM skipped, no overflow."""
+ agent = _make_agent()
+ candidates = [{"name": f"rhsa-product-{i}", "source": "rhsa"} for i in range(1200)]
+ candidates.append({"name": "builder", "source": "rhsa"})
+ ws = self._make_workflow_state()
+
+ with patch("vuln_analysis.utils.intel_utils.filter_context_to_package",
+ side_effect=lambda ctx, pkg, cands: ctx):
+ ctx, selected = await agent._select_package(
+ "go", candidates, ["CVE desc"], ws,
+ )
+
+ assert selected == "builder"
+ agent.package_filter_llm.ainvoke.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_single_candidate_no_llm(self):
+ """Single candidate is used directly without LLM."""
+ agent = _make_agent()
+ candidates = [{"name": "jinja2", "source": "ghsa"}]
+ ws = self._make_workflow_state()
+
+ with patch("vuln_analysis.utils.intel_utils.filter_context_to_package",
+ side_effect=lambda ctx, pkg, cands: ctx):
+ ctx, selected = await agent._select_package(
+ "python", candidates, ["CVE desc"], ws,
+ )
+
+ assert selected == "jinja2"
+ agent.package_filter_llm.ainvoke.assert_not_called()
+
+
+class TestForcedFinishNode:
+ """Tests for forced_finish_node using observation memory instead of full history."""
+
+ @pytest.mark.asyncio
+ async def test_uses_observation_memory_not_full_history(self):
+ """forced_finish_node should build prompt from observation memory,
+ not the full conversation history, to avoid token overflow."""
+ agent = _make_agent()
+ mock_response = Thought(
+ thought="summarizing", mode="finish", actions=None,
+ final_answer="Based on evidence, the function is not reachable.",
+ )
+ agent.thought_llm.ainvoke = AsyncMock(return_value=mock_response)
+
+ obs = Observation(
+ results=["CCA returned (False, [])"],
+ memory=["Package validated: commons-beanutils:1.9.4",
+ "FL found PropertyUtilsBean.getProperty",
+ "CCA: function not reachable from app code"],
+ )
+ big_messages = [
+ HumanMessage(content=f"tool output {i}" * 200) for i in range(10)
+ ]
+ state = {
+ "step": 10, "max_steps": 10,
+ "runtime_prompt": "You are a security analyst.",
+ "messages": big_messages,
+ "observation": obs,
+ "thought": None,
+ }
+
+ with patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER"):
+ result = await agent.forced_finish_node(state)
+
+ call_args = agent.thought_llm.ainvoke.call_args[0][0]
+ assert not any(
+ msg.content.startswith("tool output") for msg in call_args
+ if hasattr(msg, "content")
+ ), "Full conversation history should not be in the prompt"
+ knowledge_msg = [
+ msg for msg in call_args
+ if hasattr(msg, "content") and "KNOWLEDGE" in msg.content
+ ]
+ assert len(knowledge_msg) == 1, "Observation memory should be in the prompt as KNOWLEDGE block"
+ assert "LATEST FINDINGS" in knowledge_msg[0].content, "Recent findings should also be included"
+ assert "CCA returned (False, [])" in knowledge_msg[0].content
+ assert result["output"] == "Based on evidence, the function is not reachable."
+
+ @pytest.mark.asyncio
+ async def test_works_without_observation(self):
+ """forced_finish_node should work even when no observations exist."""
+ agent = _make_agent()
+ mock_response = Thought(
+ thought="no evidence", mode="finish", actions=None,
+ final_answer="No evidence found.",
+ )
+ agent.thought_llm.ainvoke = AsyncMock(return_value=mock_response)
+
+ state = {
+ "step": 10, "max_steps": 10,
+ "runtime_prompt": "You are a security analyst.",
+ "messages": [HumanMessage(content="some message")],
+ "observation": None,
+ "thought": None,
+ }
+
+ with patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER"):
+ result = await agent.forced_finish_node(state)
+
+ call_args = agent.thought_llm.ainvoke.call_args[0][0]
+ assert len(call_args) == 2 # system prompt + forced finish prompt
+ assert result["output"] == "No evidence found."
+
+ @pytest.mark.asyncio
+ async def test_fallback_on_non_finish_response(self):
+ """forced_finish_node returns default message when LLM doesn't finish."""
+ agent = _make_agent()
+ mock_response = Thought(
+ thought="I want to call another tool", mode="act",
+ actions=ToolCall(tool="Function Locator", package_name="pkg", function_name="fn", reason="test"),
+ final_answer=None,
+ )
+ agent.thought_llm.ainvoke = AsyncMock(return_value=mock_response)
+
+ state = {
+ "step": 10, "max_steps": 10,
+ "runtime_prompt": "You are a security analyst.",
+ "messages": [],
+ "observation": None,
+ "thought": None,
+ }
+
+ with patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER"):
+ result = await agent.forced_finish_node(state)
+
+ assert "Failed to generate a final answer" in result["output"]
diff --git a/tests/test_base_tool_descriptions.py b/tests/test_base_tool_descriptions.py
index 37929060..f457f59f 100644
--- a/tests/test_base_tool_descriptions.py
+++ b/tests/test_base_tool_descriptions.py
@@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""
-Tests for the base build_tool_descriptions() function.
+"""Tests for the base build_tool_descriptions() function.
-Tests that the consolidated base function provides simple tool descriptions
+Verifies that the consolidated base function provides simple tool descriptions
that can be formatted by specialized functions for different contexts.
"""
@@ -76,7 +75,7 @@ def test_base_all_tools():
result = build_tool_descriptions(tool_names)
- # Should have 7 descriptions
+ # Should have 7 descriptions (CONTAINER_ANALYSIS_DATA has no description in the base function)
assert len(result) == 7
# Verify all tools are present
@@ -155,36 +154,6 @@ def test_mod_few_shot_structure():
print("✓ MOD_FEW_SHOT structure validated")
-def test_cve_web_search_description_warns_about_versions():
- """Test that CVE Web Search description includes version warning."""
- # Without Version Finder: generic version warning
- tool_names = [ToolNames.CVE_WEB_SEARCH]
- result = build_tool_descriptions(tool_names)
- assert len(result) == 1
- desc = result[0]
- assert "MULTIPLE versions" in desc
- assert ToolNames.FUNCTION_LIBRARY_VERSION_FINDER not in desc
-
- # With Version Finder (Java): references the tool
- tool_names = [ToolNames.CVE_WEB_SEARCH, ToolNames.FUNCTION_LIBRARY_VERSION_FINDER]
- result = build_tool_descriptions(tool_names)
- descs = " ".join(result)
- assert "MULTIPLE versions" in descs
- assert ToolNames.FUNCTION_LIBRARY_VERSION_FINDER in descs
-
- print("✓ CVE Web Search description includes version warning")
-
-
-def test_agent_prompt_contains_version_awareness_instructions():
- """Test that agent prompt template contains version awareness instructions."""
- from vuln_analysis.utils.prompting import get_agent_prompt
-
- prompt = get_agent_prompt()
- assert "VERSION" in prompt
-
- print("✓ Agent prompt contains version awareness instructions")
-
-
if __name__ == "__main__":
print("Running Base Tool Descriptions tests...\n")
@@ -194,7 +163,5 @@ def test_agent_prompt_contains_version_awareness_instructions():
test_base_empty_list()
test_checklist_formats_descriptions()
test_mod_few_shot_structure()
- test_cve_web_search_description_warns_about_versions()
- test_agent_prompt_contains_version_awareness_instructions()
print("\n✅ All base tool descriptions tests passed!")
diff --git a/tests/test_build_code_understanding_tools.py b/tests/test_build_code_understanding_tools.py
new file mode 100644
index 00000000..1eb88038
--- /dev/null
+++ b/tests/test_build_code_understanding_tools.py
@@ -0,0 +1,276 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Unit tests for CodeUnderstandingAgent: tool selection, availability, and comprehension hooks."""
+
+from unittest.mock import MagicMock, patch
+
+from agent_test_helpers import MockTool, make_builder, make_config, make_state
+from vuln_analysis.functions.code_understanding_agent import CodeUnderstandingAgent
+from vuln_analysis.functions.code_understanding_internals import CodeUnderstandingRulesTracker
+from vuln_analysis.tools.tool_names import ToolNames
+
+
+def _get_tools(builder=None, config=None, state=None):
+ return CodeUnderstandingAgent.get_tools(
+ builder or make_builder(),
+ config or make_config(),
+ state or make_state(),
+ )
+
+
+class TestGetTools:
+ """Test CodeUnderstandingAgent.get_tools selection and availability logic."""
+
+ def test_filters_to_exactly_4_cu_tools(self):
+ result = _get_tools()
+ assert len(result) == 4
+
+ def test_output_tool_names(self):
+ result = _get_tools()
+ result_names = {t.name for t in result}
+ expected_names = {
+ ToolNames.DOCS_SEMANTIC_SEARCH,
+ ToolNames.CODE_KEYWORD_SEARCH,
+ ToolNames.CONFIGURATION_SCANNER,
+ ToolNames.IMPORT_USAGE_ANALYZER,
+ }
+ assert result_names == expected_names
+
+ def test_excludes_reachability_tools(self):
+ tools = [
+ MockTool(ToolNames.FUNCTION_LOCATOR),
+ MockTool(ToolNames.CALL_CHAIN_ANALYZER),
+ MockTool(ToolNames.FUNCTION_CALLER_FINDER),
+ MockTool(ToolNames.DOCS_SEMANTIC_SEARCH),
+ MockTool(ToolNames.CODE_KEYWORD_SEARCH),
+ ]
+ result = _get_tools(builder=make_builder(tools))
+ result_names = {t.name for t in result}
+ assert ToolNames.FUNCTION_LOCATOR not in result_names
+ assert ToolNames.CALL_CHAIN_ANALYZER not in result_names
+ assert ToolNames.FUNCTION_CALLER_FINDER not in result_names
+ assert len(result) == 2
+
+ def test_excludes_web_and_container_tools(self):
+ tools = [
+ MockTool(ToolNames.CVE_WEB_SEARCH),
+ MockTool(ToolNames.CONTAINER_ANALYSIS_DATA),
+ MockTool(ToolNames.DOCS_SEMANTIC_SEARCH),
+ MockTool(ToolNames.CODE_KEYWORD_SEARCH),
+ ]
+ result = _get_tools(builder=make_builder(tools))
+ result_names = {t.name for t in result}
+ assert ToolNames.CVE_WEB_SEARCH not in result_names
+ assert ToolNames.CONTAINER_ANALYSIS_DATA not in result_names
+ assert len(result) == 2
+
+ def test_excludes_version_finder(self):
+ tools = [
+ MockTool(ToolNames.FUNCTION_LIBRARY_VERSION_FINDER),
+ MockTool(ToolNames.DOCS_SEMANTIC_SEARCH),
+ MockTool(ToolNames.CODE_KEYWORD_SEARCH),
+ ]
+ result = _get_tools(builder=make_builder(tools))
+ result_names = {t.name for t in result}
+ assert ToolNames.FUNCTION_LIBRARY_VERSION_FINDER not in result_names
+ assert len(result) == 2
+
+ def test_empty_builder_returns_empty(self):
+ result = _get_tools(builder=make_builder(tools=[]))
+ assert result == []
+
+ def test_no_matching_tools_returns_empty(self):
+ tools = [
+ MockTool(ToolNames.FUNCTION_LOCATOR),
+ MockTool(ToolNames.CALL_CHAIN_ANALYZER),
+ MockTool(ToolNames.FUNCTION_CALLER_FINDER),
+ MockTool(ToolNames.CVE_WEB_SEARCH),
+ ]
+ result = _get_tools(builder=make_builder(tools))
+ assert result == []
+
+ def test_preserves_tool_object_identity(self):
+ docs_tool = MockTool(ToolNames.DOCS_SEMANTIC_SEARCH)
+ keyword_tool = MockTool(ToolNames.CODE_KEYWORD_SEARCH)
+ locator_tool = MockTool(ToolNames.FUNCTION_LOCATOR)
+ builder = make_builder(tools=[docs_tool, keyword_tool, locator_tool])
+ result = _get_tools(builder=builder)
+ assert len(result) == 2
+ assert docs_tool in result
+ assert keyword_tool in result
+ assert locator_tool not in result
+
+ def test_partial_overlap(self):
+ tools = [
+ MockTool(ToolNames.DOCS_SEMANTIC_SEARCH),
+ MockTool(ToolNames.CODE_KEYWORD_SEARCH),
+ MockTool(ToolNames.FUNCTION_LOCATOR),
+ MockTool(ToolNames.CVE_WEB_SEARCH),
+ ]
+ result = _get_tools(builder=make_builder(tools))
+ result_names = {t.name for t in result}
+ expected_names = {
+ ToolNames.DOCS_SEMANTIC_SEARCH,
+ ToolNames.CODE_KEYWORD_SEARCH,
+ }
+ assert len(result) == 2
+ assert result_names == expected_names
+
+
+class TestGetToolsAvailability:
+ """get_tools filters out tools whose infrastructure prerequisites are not met."""
+
+ def test_filters_docs_semantic_search_when_no_vdb(self):
+ state = make_state(doc_vdb_path=None)
+ result = _get_tools(state=state)
+ assert ToolNames.DOCS_SEMANTIC_SEARCH not in {t.name for t in result}
+
+ def test_filters_code_keyword_search_when_no_index(self):
+ state = make_state(code_index_path=None)
+ result = _get_tools(state=state)
+ assert ToolNames.CODE_KEYWORD_SEARCH not in {t.name for t in result}
+
+ def test_cu_only_tools_always_kept(self):
+ state = make_state(code_vdb_path=None, doc_vdb_path=None, code_index_path=None)
+ result = _get_tools(state=state)
+ result_names = {t.name for t in result}
+ assert ToolNames.CONFIGURATION_SCANNER in result_names
+
+ def test_filters_import_usage_analyzer_when_no_index(self):
+ state = make_state(code_index_path=None)
+ result = _get_tools(state=state)
+ assert ToolNames.IMPORT_USAGE_ANALYZER not in {t.name for t in result}
+
+ def test_import_usage_analyzer_available_with_index(self):
+ state = make_state(code_index_path="/some/path")
+ result = _get_tools(state=state)
+ assert ToolNames.IMPORT_USAGE_ANALYZER in {t.name for t in result}
+
+
+class TestCodeUnderstandingAgentMeta:
+ """Test create_rules_tracker and agent_type for CodeUnderstandingAgent."""
+
+ def test_create_rules_tracker_returns_cu_tracker(self):
+ tracker = CodeUnderstandingAgent.create_rules_tracker()
+ assert isinstance(tracker, CodeUnderstandingRulesTracker)
+
+ def test_create_rules_tracker_returns_fresh_instance(self):
+ t1 = CodeUnderstandingAgent.create_rules_tracker()
+ t2 = CodeUnderstandingAgent.create_rules_tracker()
+ assert t1 is not t2
+
+ def test_agent_type_is_cu(self):
+ mock_llm = MagicMock()
+ mock_llm.with_structured_output = MagicMock(return_value=MagicMock())
+ config = MagicMock()
+ config.max_iterations = 10
+ agent = CodeUnderstandingAgent(tools=[], llm=mock_llm, config=config)
+ assert agent.agent_type == "cu"
+
+
+def _make_cu_agent():
+ mock_llm = MagicMock()
+ mock_llm.with_structured_output = MagicMock(return_value=MagicMock())
+ config = MagicMock()
+ config.max_iterations = 10
+ return CodeUnderstandingAgent(tools=[], llm=mock_llm, config=config)
+
+
+def _mock_ctx_state(*vuln_ids):
+ """Return a mock workflow state with cve_intel entries for the given vuln IDs."""
+ intel_list = []
+ for vid in vuln_ids:
+ intel = MagicMock()
+ intel.vuln_id = vid
+ intel_list.append(intel)
+ ws = MagicMock()
+ ws.cve_intel = intel_list
+ return ws
+
+
+class TestCUComprehensionHooks:
+ """Test CU agent comprehension context reduction and CVE sanitization."""
+
+ @patch("vuln_analysis.functions.code_understanding_agent.ctx_state")
+ def test_build_comprehension_context_contains_vuln_id_and_package(self, mock_ctx):
+ mock_ctx.get.return_value = _mock_ctx_state("CVE-2021-43859")
+ agent = _make_cu_agent()
+ state = {"app_package": "com.thoughtworks.xstream:xstream"}
+ result = agent.build_comprehension_context(state)
+ assert "CVE-2021-43859" in result
+ assert "com.thoughtworks.xstream:xstream" in result
+
+ @patch("vuln_analysis.functions.code_understanding_agent.ctx_state")
+ def test_build_comprehension_context_includes_grounding_instruction(self, mock_ctx):
+ mock_ctx.get.return_value = _mock_ctx_state("CVE-2021-43859")
+ agent = _make_cu_agent()
+ result = agent.build_comprehension_context({"app_package": "pkg"})
+ assert "Only extract facts explicitly stated in the tool output" in result
+ assert "Do not add CVE IDs" in result
+
+ @patch("vuln_analysis.functions.code_understanding_agent.ctx_state")
+ def test_build_comprehension_context_excludes_critical_context(self, mock_ctx):
+ mock_ctx.get.return_value = _mock_ctx_state("CVE-2021-43859")
+ agent = _make_cu_agent()
+ state = {
+ "app_package": "xstream",
+ "critical_context": ["GHSA description: XStream can cause DoS", "NVD: high severity"],
+ }
+ result = agent.build_comprehension_context(state)
+ assert "GHSA description" not in result
+ assert "NVD" not in result
+
+ @patch("vuln_analysis.functions.code_understanding_agent.ctx_state")
+ def test_build_comprehension_context_unknown_fallbacks(self, mock_ctx):
+ ws = MagicMock()
+ ws.cve_intel = []
+ mock_ctx.get.return_value = ws
+ agent = _make_cu_agent()
+ result = agent.build_comprehension_context({})
+ assert "unknown" in result
+
+ @patch("vuln_analysis.functions.base_graph_agent.ctx_state")
+ def test_sanitize_findings_replaces_wrong_cve(self, mock_ctx):
+ mock_ctx.get.return_value = _mock_ctx_state("CVE-2021-43859")
+ agent = _make_cu_agent()
+ findings = ["XStream 1.4.18 is vulnerable to CVE-2020-26217"]
+ result = agent.sanitize_findings(findings, {})
+ assert result == ["XStream 1.4.18 is vulnerable to the investigated vulnerability"]
+
+ @patch("vuln_analysis.functions.base_graph_agent.ctx_state")
+ def test_sanitize_findings_keeps_correct_cve(self, mock_ctx):
+ mock_ctx.get.return_value = _mock_ctx_state("CVE-2021-43859")
+ agent = _make_cu_agent()
+ findings = ["Affects CVE-2021-43859"]
+ result = agent.sanitize_findings(findings, {})
+ assert result == ["Affects CVE-2021-43859"]
+
+ @patch("vuln_analysis.functions.base_graph_agent.ctx_state")
+ def test_sanitize_findings_replaces_multiple_wrong_cves(self, mock_ctx):
+ mock_ctx.get.return_value = _mock_ctx_state("CVE-2021-43859")
+ agent = _make_cu_agent()
+ findings = ["CVE-2020-26217 and CVE-2019-10086 affect this, also CVE-2021-43859"]
+ result = agent.sanitize_findings(findings, {})
+ assert "CVE-2020-26217" not in result[0]
+ assert "CVE-2019-10086" not in result[0]
+ assert "CVE-2021-43859" in result[0]
+ assert result[0].count("the investigated vulnerability") == 2
+
+ @patch("vuln_analysis.functions.base_graph_agent.ctx_state")
+ def test_sanitize_findings_no_cve_ids_unchanged(self, mock_ctx):
+ mock_ctx.get.return_value = _mock_ctx_state("CVE-2021-43859")
+ agent = _make_cu_agent()
+ findings = ["XStream 1.4.18 found in dependencies", "Package is present"]
+ result = agent.sanitize_findings(findings, {})
+ assert result == findings
+
+ @patch("vuln_analysis.functions.base_graph_agent.ctx_state")
+ def test_sanitize_findings_multi_cve_intel(self, mock_ctx):
+ mock_ctx.get.return_value = _mock_ctx_state("CVE-2021-43859", "CVE-2021-39144")
+ agent = _make_cu_agent()
+ findings = ["CVE-2021-43859 and CVE-2021-39144 and CVE-2020-26217"]
+ result = agent.sanitize_findings(findings, {})
+ assert "CVE-2021-43859" in result[0]
+ assert "CVE-2021-39144" in result[0]
+ assert "CVE-2020-26217" not in result[0]
diff --git a/tests/test_code_understanding_internals.py b/tests/test_code_understanding_internals.py
new file mode 100644
index 00000000..ad52a8c5
--- /dev/null
+++ b/tests/test_code_understanding_internals.py
@@ -0,0 +1,176 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from vuln_analysis.functions.code_understanding_internals import (
+ CodeUnderstandingRulesTracker,
+ CU_AGENT_SYS_PROMPT,
+ CU_AGENT_THOUGHT_INSTRUCTIONS,
+)
+
+
+class TestCodeUnderstandingRulesTracker:
+ def test_iua_allowed_without_survey(self):
+ """IUA can be called at any time — no gating."""
+ tracker = CodeUnderstandingRulesTracker()
+ tracker.set_allowed_tools(["Import Usage Analyzer"])
+
+ violated, msg = tracker.check_thought_behavior(
+ "Import Usage Analyzer",
+ "com.example.package",
+ ["imports found"]
+ )
+
+ assert violated is False
+ assert msg == ""
+
+ def test_allowed_tool_passes(self):
+ tracker = CodeUnderstandingRulesTracker()
+ tracker.set_allowed_tools(["Code Keyword Search"])
+
+ violated, msg = tracker.check_thought_behavior(
+ "Code Keyword Search",
+ "some query",
+ ["results"]
+ )
+
+ assert violated is False
+
+ def test_docs_semantic_search_passes(self):
+ tracker = CodeUnderstandingRulesTracker()
+ tracker.set_allowed_tools(["Docs Semantic Search"])
+
+ violated, msg = tracker.check_thought_behavior(
+ "Docs Semantic Search",
+ "query",
+ ["docs"]
+ )
+
+ assert violated is False
+
+ def test_check_rule7_fires(self):
+ tracker = CodeUnderstandingRulesTracker()
+ tracker.set_allowed_tools(["Code Keyword Search"])
+
+ tracker.check_thought_behavior(
+ "Code Keyword Search",
+ "com.example.Class",
+ []
+ )
+
+ violated, msg = tracker.check_thought_behavior(
+ "Code Keyword Search",
+ "com.example.Another",
+ []
+ )
+
+ assert violated is True
+ assert "Rule 7" in msg
+ assert "dots" in msg
+
+ def test_check_allowed_tools_rejects_unknown(self):
+ tracker = CodeUnderstandingRulesTracker()
+ tracker.set_allowed_tools(["Code Keyword Search"])
+
+ violated, msg = tracker.check_thought_behavior(
+ "Unknown Tool",
+ "query",
+ ["results"]
+ )
+
+ assert violated is True
+ assert "AVAILABLE_TOOLS" in msg
+ assert "Code Keyword Search" in msg
+
+ def test_check_passes_and_adds_action(self):
+ tracker = CodeUnderstandingRulesTracker()
+ tracker.set_allowed_tools(["Code Keyword Search"])
+
+ assert "Code Keyword Search" not in tracker.action_history
+
+ violated, msg = tracker.check_thought_behavior(
+ "Code Keyword Search",
+ "import xstream",
+ ["result1", "result2"]
+ )
+
+ assert violated is False
+ assert msg == ""
+ assert "Code Keyword Search" in tracker.action_history
+ assert len(tracker.action_history["Code Keyword Search"]) == 1
+ assert tracker.action_history["Code Keyword Search"][0]["input"] == "import xstream"
+
+ def test_duplicate_config_scanner_blocked(self):
+ tracker = CodeUnderstandingRulesTracker()
+ tracker.set_allowed_tools(["Configuration Scanner"])
+ tracker.check_thought_behavior(
+ "Configuration Scanner", "deserialization beanutils", ["config found"]
+ )
+ violated, msg = tracker.check_thought_behavior(
+ "Configuration Scanner", "deserialization beanutils", ["config found"]
+ )
+ assert violated is True
+ assert "already called" in msg
+
+ def test_config_scanner_different_query_allowed(self):
+ tracker = CodeUnderstandingRulesTracker()
+ tracker.set_allowed_tools(["Configuration Scanner"])
+ tracker.check_thought_behavior(
+ "Configuration Scanner", "deserialization beanutils", ["config found"]
+ )
+ violated, msg = tracker.check_thought_behavior(
+ "Configuration Scanner", "security allowlist", ["other config"]
+ )
+ assert violated is False
+ assert msg == ""
+
+ def test_check_failing_does_not_add_action(self):
+ tracker = CodeUnderstandingRulesTracker()
+ tracker.set_allowed_tools(["Configuration Scanner"])
+
+ violated, msg = tracker.check_thought_behavior(
+ "Unknown Tool",
+ "query",
+ ["results"]
+ )
+
+ assert violated is True
+ assert "Unknown Tool" not in tracker.action_history
+
+
+class TestCUConstants:
+ def test_sys_prompt_is_nonempty_string(self):
+ assert isinstance(CU_AGENT_SYS_PROMPT, str)
+ assert len(CU_AGENT_SYS_PROMPT) > 0
+ assert len(CU_AGENT_SYS_PROMPT.strip()) > 0
+
+ def test_sys_prompt_mentions_code_understanding(self):
+ assert "code understanding" in CU_AGENT_SYS_PROMPT.lower()
+
+ def test_sys_prompt_does_not_mention_call_chain_analyzer(self):
+ assert "Call Chain Analyzer" not in CU_AGENT_SYS_PROMPT
+ assert "call chain analyzer" not in CU_AGENT_SYS_PROMPT.lower()
+
+ def test_thought_instructions_have_rules_tags(self):
+ assert "" in CU_AGENT_THOUGHT_INSTRUCTIONS
+ assert "" in CU_AGENT_THOUGHT_INSTRUCTIONS
+
+ def test_thought_instructions_have_three_examples(self):
+ assert "" in CU_AGENT_THOUGHT_INSTRUCTIONS
+ assert "" in CU_AGENT_THOUGHT_INSTRUCTIONS
+ assert "" in CU_AGENT_THOUGHT_INSTRUCTIONS
+
+ def test_thought_instructions_rules_numbered_1_to_9(self):
+ for i in range(1, 10):
+ assert f"{i}." in CU_AGENT_THOUGHT_INSTRUCTIONS
diff --git a/tests/test_code_understanding_prompt_factory.py b/tests/test_code_understanding_prompt_factory.py
new file mode 100644
index 00000000..e05484b3
--- /dev/null
+++ b/tests/test_code_understanding_prompt_factory.py
@@ -0,0 +1,62 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from vuln_analysis.tools.tool_names import ToolNames
+from vuln_analysis.utils.code_understanding_prompt_factory import (
+ CU_TOOL_GENERAL_DESCRIPTIONS,
+ CU_TOOL_SELECTION_STRATEGY,
+)
+
+
+class TestCUToolGeneralDescriptions:
+ def test_has_4_entries(self):
+ assert len(CU_TOOL_GENERAL_DESCRIPTIONS) == 4
+
+ def test_keys_match_cu_tool_names(self):
+ expected_keys = {
+ ToolNames.DOCS_SEMANTIC_SEARCH,
+ ToolNames.CODE_KEYWORD_SEARCH,
+ ToolNames.CONFIGURATION_SCANNER,
+ ToolNames.IMPORT_USAGE_ANALYZER,
+ }
+ assert set(CU_TOOL_GENERAL_DESCRIPTIONS.keys()) == expected_keys
+
+ def test_values_non_empty(self):
+ for key, value in CU_TOOL_GENERAL_DESCRIPTIONS.items():
+ assert isinstance(value, str), f"Value for '{key}' is not a string"
+ assert len(value.strip()) > 0, f"Value for '{key}' is empty"
+
+
+class TestCUToolSelectionStrategy:
+ def test_has_5_ecosystems(self):
+ expected_ecosystems = {"python", "go", "java", "javascript", "c"}
+ assert set(CU_TOOL_SELECTION_STRATEGY.keys()) == expected_ecosystems
+
+ def test_strategies_non_empty(self):
+ for ecosystem, strategy in CU_TOOL_SELECTION_STRATEGY.items():
+ assert isinstance(strategy, str), f"Strategy for '{ecosystem}' is not a string"
+ assert len(strategy.strip()) > 0, f"Strategy for '{ecosystem}' is empty"
+
+ def test_each_mentions_at_least_3_tool_names(self):
+ cu_tool_names = [
+ ToolNames.DOCS_SEMANTIC_SEARCH,
+ ToolNames.CODE_KEYWORD_SEARCH,
+ ToolNames.CONFIGURATION_SCANNER,
+ ToolNames.IMPORT_USAGE_ANALYZER,
+ ]
+
+ for ecosystem, strategy in CU_TOOL_SELECTION_STRATEGY.items():
+ mentioned_tools = sum(1 for tool_name in cu_tool_names if tool_name in strategy)
+ assert mentioned_tools >= 3, f"Strategy for '{ecosystem}' mentions only {mentioned_tools} tool(s)"
diff --git a/tests/test_configuration_scanner.py b/tests/test_configuration_scanner.py
new file mode 100644
index 00000000..ca9a462c
--- /dev/null
+++ b/tests/test_configuration_scanner.py
@@ -0,0 +1,420 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from vuln_analysis.tools.configuration_scanner import (
+ _is_config_file,
+ _is_in_config_dir,
+ _collect_config_files,
+ search_config_content,
+)
+
+
+class TestIsConfigFile:
+ @pytest.mark.parametrize(
+ "file_path,expected",
+ [
+ # Named config files — matched anywhere
+ ("config.yaml", True),
+ ("config.yml", True),
+ ("config.xml", True),
+ ("application.properties", True),
+ ("application.yml", True),
+ ("application.yaml", True),
+ ("settings.toml", True),
+ ("settings.yaml", True),
+ ("settings.yml", True),
+ ("beans.xml", True),
+ ("web.xml", True),
+ ("Dockerfile", True),
+ ("Dockerfile.prod", True),
+ ("Dockerfile.dev", True),
+ ("docker-compose.yml", True),
+ ("docker-compose.prod.yml", True),
+ # Config-specific extensions — matched anywhere
+ ("app.env", True),
+ (".env", True),
+ ("nginx.conf", True),
+ ("config.ini", True),
+ ("app.properties", True),
+ # Removed: build/dependency files (handled by other tools)
+ ("pom.xml", False),
+ ("build.gradle", False),
+ ("build.gradle.kts", False),
+ ("go.mod", False),
+ ("go.sum", False),
+ ("package.json", False),
+ ("requirements.txt", False),
+ ("setup.py", False),
+ ("setup.cfg", False),
+ ("pyproject.toml", False),
+ # Removed: catch-all extensions (only matched inside config dirs)
+ ("app.yml", False),
+ ("app.cfg", False),
+ ("data.yaml", False),
+ ("docker-compose-dev.yaml", False),
+ # Non-config files
+ ("main.py", False),
+ ("utils.go", False),
+ ("README.md", False),
+ ("app.js", False),
+ ("test.txt", False),
+ ("data.json", False),
+ ],
+ )
+ def test_config_file_patterns(self, file_path, expected):
+ assert _is_config_file(file_path) == expected
+
+ @pytest.mark.parametrize(
+ "file_path,expected",
+ [
+ ("DOCKERFILE", True),
+ ("CONFIG.YAML", True),
+ ("APPLICATION.PROPERTIES", True),
+ ("DOCKER-COMPOSE.YML", True),
+ ("BEANS.XML", True),
+ ("SETTINGS.TOML", True),
+ ],
+ )
+ def test_case_insensitive(self, file_path, expected):
+ assert _is_config_file(file_path) == expected
+
+ def test_file_in_subdirectory(self):
+ assert _is_config_file("path/to/config.yaml") is True
+ assert _is_config_file("deep/nested/path/application.properties") is True
+
+
+class TestIsInConfigDir:
+ @pytest.mark.parametrize(
+ "file_path,expected",
+ [
+ ("config/app.txt", True),
+ ("conf/settings.txt", True),
+ ("conf.d/default.conf", True),
+ ("etc/nginx/nginx.conf", True),
+ ("src/main/resources/app.yml", True),
+ ("app/config/database.yml", True),
+ ("project/conf/server.xml", True),
+ ("src/main/java/App.java", False),
+ ("lib/utils.py", False),
+ ("tests/test_app.py", False),
+ ("data/sample.csv", False),
+ ],
+ )
+ def test_config_dir_patterns(self, file_path, expected):
+ assert _is_in_config_dir(file_path) == expected
+
+ @pytest.mark.parametrize(
+ "file_path,expected",
+ [
+ ("Config/app.txt", True),
+ ("RESOURCES/app.yml", True),
+ ("ETC/nginx.conf", True),
+ ("CONF.D/default.conf", True),
+ ],
+ )
+ def test_case_insensitive_dir(self, file_path, expected):
+ assert _is_in_config_dir(file_path) == expected
+
+
+class TestCollectConfigFiles:
+ def test_finds_config_files(self, tmp_path):
+ (tmp_path / "config.yaml").write_text("key: value")
+ (tmp_path / "application.properties").write_text("app.name=test")
+ (tmp_path / "nginx.conf").write_text("server {}")
+
+ result = _collect_config_files(str(tmp_path))
+
+ assert len(result) == 3
+ paths = {path for path, _ in result}
+ assert "config.yaml" in paths
+ assert "application.properties" in paths
+ assert "nginx.conf" in paths
+
+ for path, content in result:
+ assert len(content) > 0
+
+ def test_finds_files_in_config_dir(self, tmp_path):
+ config_dir = tmp_path / "config"
+ config_dir.mkdir()
+ (config_dir / "database.txt").write_text("db_config")
+ (config_dir / "server.txt").write_text("server_config")
+
+ result = _collect_config_files(str(tmp_path))
+
+ assert len(result) == 2
+ paths = {path for path, _ in result}
+ assert "config/database.txt" in paths
+ assert "config/server.txt" in paths
+
+ def test_excludes_git_dir(self, tmp_path):
+ git_dir = tmp_path / ".git"
+ git_dir.mkdir()
+ (git_dir / "config").write_text("git config")
+ (tmp_path / "application.yml").write_text("app: config")
+
+ result = _collect_config_files(str(tmp_path))
+
+ assert len(result) == 1
+ paths = {path for path, _ in result}
+ assert "application.yml" in paths
+ assert ".git/config" not in paths
+
+ def test_excludes_pycache(self, tmp_path):
+ pycache_dir = tmp_path / "__pycache__"
+ pycache_dir.mkdir()
+ (pycache_dir / "config.yaml").write_text("cached")
+ (tmp_path / "application.yml").write_text("app: config")
+
+ result = _collect_config_files(str(tmp_path))
+
+ assert len(result) == 1
+ paths = {path for path, _ in result}
+ assert "application.yml" in paths
+ assert "__pycache__/config.yaml" not in paths
+
+ def test_excludes_node_modules(self, tmp_path):
+ node_modules_dir = tmp_path / "node_modules"
+ node_modules_dir.mkdir()
+ (node_modules_dir / "package.json").write_text("{}")
+ (tmp_path / "application.yml").write_text("app: config")
+
+ result = _collect_config_files(str(tmp_path))
+
+ assert len(result) == 1
+ paths = {path for path, _ in result}
+ assert "application.yml" in paths
+ assert "node_modules/package.json" not in paths
+
+ def test_excludes_tox(self, tmp_path):
+ tox_dir = tmp_path / ".tox"
+ tox_dir.mkdir()
+ (tox_dir / "config.ini").write_text("[tox]")
+ (tmp_path / "application.yml").write_text("app: config")
+
+ result = _collect_config_files(str(tmp_path))
+
+ assert len(result) == 1
+ paths = {path for path, _ in result}
+ assert "application.yml" in paths
+ assert ".tox/config.ini" not in paths
+
+ def test_skips_large_files(self, tmp_path):
+ large_content = "x" * 500_001
+ (tmp_path / "large.properties").write_text(large_content)
+ (tmp_path / "small.properties").write_text("key=value")
+
+ result = _collect_config_files(str(tmp_path))
+
+ assert len(result) == 1
+ paths = {path for path, _ in result}
+ assert "small.properties" in paths
+ assert "large.properties" not in paths
+
+ def test_empty_repo_returns_empty(self, tmp_path):
+ result = _collect_config_files(str(tmp_path))
+ assert result == []
+
+ def test_nested_directory_structure(self, tmp_path):
+ src_dir = tmp_path / "src" / "main" / "resources"
+ src_dir.mkdir(parents=True)
+ (src_dir / "application.yml").write_text("app: test")
+
+ conf_dir = tmp_path / "conf"
+ conf_dir.mkdir()
+ (conf_dir / "server.txt").write_text("server config")
+
+ (tmp_path / "config.yaml").write_text("key: value")
+
+ result = _collect_config_files(str(tmp_path))
+
+ assert len(result) == 3
+ paths = {path for path, _ in result}
+ assert "src/main/resources/application.yml" in paths
+ assert "conf/server.txt" in paths
+ assert "config.yaml" in paths
+
+ def test_handles_read_errors_gracefully(self, tmp_path):
+ (tmp_path / "good.properties").write_text("key=value")
+ bad_file = tmp_path / "bad.properties"
+ bad_file.write_bytes(b"\xff\xfe invalid utf-8")
+ bad_file.chmod(0o000)
+
+ result = _collect_config_files(str(tmp_path))
+
+ paths = {path for path, _ in result}
+ assert "good.properties" in paths
+
+ bad_file.chmod(0o644)
+
+ def test_returns_relative_paths(self, tmp_path):
+ subdir = tmp_path / "config"
+ subdir.mkdir()
+ (subdir / "app.yml").write_text("app: config")
+
+ result = _collect_config_files(str(tmp_path))
+
+ assert len(result) == 1
+ path, _ = result[0]
+ assert path == "config/app.yml"
+ assert not path.startswith("/")
+ assert str(tmp_path) not in path
+
+ def test_returns_file_contents(self, tmp_path):
+ expected_content = "database:\n host: localhost\n port: 5432"
+ (tmp_path / "config.yaml").write_text(expected_content)
+
+ result = _collect_config_files(str(tmp_path))
+
+ assert len(result) == 1
+ path, content = result[0]
+ assert path == "config.yaml"
+ assert content == expected_content
+
+
+class TestSearchConfigContent:
+ def test_empty_config_files(self):
+ result = search_config_content([], ["keyword"])
+ assert result == []
+
+ def test_no_matches(self):
+ config_files = [("app.yml", "database:\n host: localhost")]
+ result = search_config_content(config_files, ["xstream"])
+ assert result == []
+
+ def test_single_match(self):
+ config_files = [("app.yml", "line1\nxstream_enabled: true\nline3")]
+ result = search_config_content(config_files, ["xstream"], context_lines=0)
+ assert len(result) == 1
+ assert "app.yml" in result[0]
+ assert "xstream_enabled: true" in result[0]
+
+ def test_source_label(self):
+ config_files = [("app.yml", "match_keyword")]
+ result = search_config_content(config_files, ["match"], source_label="git://repo", context_lines=0)
+ assert "source: git://repo" in result[0]
+
+ def test_multiple_files(self):
+ config_files = [
+ ("a.yml", "xstream: yes"),
+ ("b.yml", "no match here"),
+ ("c.xml", "xstream config"),
+ ]
+ result = search_config_content(config_files, ["xstream"], context_lines=0)
+ assert len(result) == 2
+ assert "a.yml" in result[0]
+ assert "c.xml" in result[1]
+
+ def test_max_results_cap(self):
+ config_files = [("big.yml", "\n".join(f"keyword_{i}" for i in range(50)))]
+ result = search_config_content(config_files, ["keyword"], max_results=5, context_lines=0)
+ assert len(result) == 5
+
+ def test_context_lines(self):
+ content = "line1\nline2\nMATCH_LINE\nline4\nline5"
+ config_files = [("app.yml", content)]
+ result = search_config_content(config_files, ["match_line"], context_lines=1)
+ assert len(result) == 1
+ assert "line2" in result[0]
+ assert "> 3: MATCH_LINE" in result[0]
+ assert "line4" in result[0]
+ assert "line1" not in result[0]
+ assert "line5" not in result[0]
+
+ def test_case_insensitive_keywords(self):
+ config_files = [("app.yml", "XStream_Config: enabled")]
+ result = search_config_content(config_files, ["xstream"], context_lines=0)
+ assert len(result) == 1
+
+ def test_multiple_keywords_match_any(self):
+ config_files = [
+ ("a.yml", "xstream: true"),
+ ("b.yml", "deserialization: false"),
+ ("c.yml", "no relevant content"),
+ ]
+ result = search_config_content(config_files, ["xstream", "deserialization"], context_lines=0)
+ assert len(result) == 2
+
+ def test_line_number_in_output(self):
+ config_files = [("app.yml", "line1\nline2\nxstream: true")]
+ result = search_config_content(config_files, ["xstream"], context_lines=0)
+ assert "line 3" in result[0]
+
+ def test_default_source_label(self):
+ config_files = [("app.yml", "match_keyword")]
+ result = search_config_content(config_files, ["match"], context_lines=0)
+ assert "source: unknown" in result[0]
+
+
+class TestConfigScannerSourceAwareness:
+ """Tests for app/dep classification and source scoping in Configuration Scanner."""
+
+ def _make_configs(self):
+ return [
+ ("config.yaml", "xstream: app-level"),
+ ("src/main/resources/application.yml", "xstream: app-config"),
+ ("dependencies-sources/xstream-1.4/config.properties", "xstream: dep-xstream"),
+ ("dependencies-sources/commons-io-2.6/config.properties", "xstream: dep-commons"),
+ ]
+
+ def test_app_config_prioritized_over_dependency(self):
+ from vuln_analysis.utils.source_classification import (
+ is_dependency_path, format_app_dep_output,
+ )
+ configs = self._make_configs()
+ app_configs = [c for c in configs if not is_dependency_path(c[0])]
+ dep_configs = [c for c in configs if is_dependency_path(c[0])]
+
+ app_matches = search_config_content(app_configs, ["xstream"], max_results=10, context_lines=0)
+ dep_matches = search_config_content(dep_configs, ["xstream"], max_results=10, context_lines=0)
+
+ result = format_app_dep_output(app_matches, dep_matches, len(app_matches), len(dep_matches),
+ "No matches")
+ assert "Main application (2 of 2 results)" in result
+ assert "Application library dependencies (2 of 2 results)" in result
+ app_pos = result.index("config.yaml")
+ dep_pos = result.index("Application library dependencies")
+ assert app_pos < dep_pos
+
+ def test_source_scope_filters_dependency_configs(self):
+ from vuln_analysis.utils.source_classification import (
+ is_dependency_path, filter_by_source_scope,
+ )
+ configs = self._make_configs()
+ dep_configs = [c for c in configs if is_dependency_path(c[0])]
+
+ filtered = filter_by_source_scope(dep_configs, ["xstream"], lambda x: x[0])
+ assert len(filtered) == 1
+ assert "xstream-1.4" in filtered[0][0]
+
+ def test_app_configs_unaffected_by_source_scope(self):
+ from vuln_analysis.utils.source_classification import is_dependency_path
+ configs = self._make_configs()
+ app_configs = [c for c in configs if not is_dependency_path(c[0])]
+
+ app_matches = search_config_content(app_configs, ["xstream"], max_results=10, context_lines=0)
+ assert len(app_matches) == 2
+
+ def test_app_dep_headers_in_output(self):
+ from vuln_analysis.utils.source_classification import format_app_dep_output
+ result = format_app_dep_output(["match1"], ["match2"], 1, 1, "No matches")
+ assert "Main application" in result
+ assert "Application library dependencies" in result
+
+ def test_no_matches_preserves_message(self):
+ from vuln_analysis.utils.source_classification import format_app_dep_output
+ result = format_app_dep_output([], [], 0, 0, "No configuration entries found matching: xstream")
+ assert result == "No configuration entries found matching: xstream"
diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py
new file mode 100644
index 00000000..bdbeef34
--- /dev/null
+++ b/tests/test_dispatcher.py
@@ -0,0 +1,160 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+from unittest.mock import AsyncMock
+from pydantic import ValidationError
+from langchain_core.messages import HumanMessage
+
+from vuln_analysis.functions.dispatcher import (
+ QuestionRouting,
+ build_routing_prompt,
+ dispatch_question,
+)
+
+
+class TestQuestionRouting:
+ def test_valid_reachability_type(self):
+ routing = QuestionRouting(agent_type="reachability", reason="test reason")
+ assert routing.agent_type == "reachability"
+ assert routing.reason == "test reason"
+
+ def test_valid_code_understanding_type(self):
+ routing = QuestionRouting(agent_type="code_understanding", reason="test reason")
+ assert routing.agent_type == "code_understanding"
+ assert routing.reason == "test reason"
+
+ def test_invalid_type_rejected(self):
+ with pytest.raises(ValidationError) as exc_info:
+ QuestionRouting(agent_type="invalid", reason="test")
+ assert "agent_type" in str(exc_info.value)
+
+ def test_missing_agent_type_rejected(self):
+ with pytest.raises(ValidationError) as exc_info:
+ QuestionRouting(reason="test")
+ assert "agent_type" in str(exc_info.value)
+
+ def test_missing_reason_rejected(self):
+ with pytest.raises(ValidationError) as exc_info:
+ QuestionRouting(agent_type="reachability")
+ assert "reason" in str(exc_info.value)
+
+
+class TestBuildRoutingPrompt:
+ def test_context_and_question_inserted(self):
+ context = "CVE-2024-1234, package: foo:bar"
+ question = "Is the function vulnerable()?"
+ result = build_routing_prompt(context, question)
+
+ assert context in result
+ assert question in result
+
+ def test_all_few_shot_examples_present(self):
+ result = build_routing_prompt("context", "question")
+
+ expected_examples = [
+ "XStream.fromXML()",
+ "XML parser",
+ "BeanUtils.populate()",
+ "commons-beanutils",
+ "parseXML()",
+ "external entity processing",
+ "newTransformer()",
+ "deserialize()",
+ ]
+
+ for example in expected_examples:
+ assert example in result, f"Missing example: {example}"
+
+ def test_both_agent_types_mentioned(self):
+ result = build_routing_prompt("context", "question")
+
+ assert "reachability" in result
+ assert "code_understanding" in result
+
+ def test_empty_inputs(self):
+ result = build_routing_prompt("", "")
+ assert isinstance(result, str)
+ assert len(result) > 0
+
+ def test_special_chars_in_inputs(self):
+ context = "CVE {test} with % and {{nested}}"
+ question = "Is {function}() reachable? % complete"
+
+ result = build_routing_prompt(context, question)
+
+ assert context in result
+ assert question in result
+
+
+class TestDispatchQuestion:
+ @pytest.mark.asyncio
+ async def test_dispatch_returns_routing_result(self):
+ expected_result = QuestionRouting(
+ agent_type="reachability",
+ reason="Test reason"
+ )
+
+ mock_llm = AsyncMock()
+ mock_llm.ainvoke.return_value = expected_result
+
+ result = await dispatch_question(
+ routing_llm=mock_llm,
+ question="Test question?",
+ context_block="Test context",
+ )
+
+ assert result == expected_result
+ assert result.agent_type == "reachability"
+ assert result.reason == "Test reason"
+
+ @pytest.mark.asyncio
+ async def test_dispatch_passes_human_message(self):
+ mock_result = QuestionRouting(
+ agent_type="code_understanding",
+ reason="Config check"
+ )
+
+ mock_llm = AsyncMock()
+ mock_llm.ainvoke.return_value = mock_result
+
+ question = "Is the XML parser configured?"
+ context = "CVE-2024-5678"
+
+ await dispatch_question(
+ routing_llm=mock_llm,
+ question=question,
+ context_block=context,
+ )
+
+ mock_llm.ainvoke.assert_called_once()
+
+ call_args = mock_llm.ainvoke.call_args[0][0]
+ assert len(call_args) == 1
+ assert isinstance(call_args[0], HumanMessage)
+ assert question in call_args[0].content
+ assert context in call_args[0].content
+
+ @pytest.mark.asyncio
+ async def test_dispatch_propagates_exception(self):
+ mock_llm = AsyncMock()
+ mock_llm.ainvoke.side_effect = ValueError("LLM error")
+
+ with pytest.raises(ValueError, match="LLM error"):
+ await dispatch_question(
+ routing_llm=mock_llm,
+ question="Test question",
+ context_block="Test context",
+ )
diff --git a/tests/test_import_usage_analyzer.py b/tests/test_import_usage_analyzer.py
new file mode 100644
index 00000000..98d77001
--- /dev/null
+++ b/tests/test_import_usage_analyzer.py
@@ -0,0 +1,278 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+
+from exploit_iq_commons.utils.dep_tree import Ecosystem
+from exploit_iq_commons.utils.functions_parsers.lang_functions_parsers_factory import get_language_function_parser
+from vuln_analysis.tools.import_usage_analyzer import (
+ _find_usage_in_file,
+ analyze_imports,
+)
+
+
+def _get_patterns(ecosystem: Ecosystem, package_name: str) -> list[re.Pattern]:
+ parser = get_language_function_parser(ecosystem, tree=None)
+ return parser.get_import_search_patterns(package_name)
+
+
+class TestGetImportSearchPatterns:
+ def test_python_import_pattern(self):
+ patterns = _get_patterns(Ecosystem.PYTHON, "xml.etree")
+
+ assert any(p.search("import xml.etree.ElementTree") for p in patterns)
+ assert any(p.search("from xml.etree import ElementTree") for p in patterns)
+
+ def test_java_import_pattern(self):
+ patterns = _get_patterns(Ecosystem.JAVA, "com.thoughtworks.xstream")
+
+ assert any(p.search("import com.thoughtworks.xstream.XStream;") for p in patterns)
+ assert any(p.search("import static com.thoughtworks.xstream.XStream;") for p in patterns)
+
+ def test_go_import_pattern(self):
+ patterns = _get_patterns(Ecosystem.GO, "encoding/xml")
+
+ assert any(p.search('import "encoding/xml"') for p in patterns)
+ assert any(p.search('import (\n "encoding/xml"\n)') for p in patterns)
+
+ def test_javascript_require_pattern(self):
+ patterns = _get_patterns(Ecosystem.JAVASCRIPT, "xml2js")
+
+ assert any(p.search("const xml = require('xml2js')") for p in patterns)
+ assert any(p.search('const xml = require("xml2js")') for p in patterns)
+
+ def test_javascript_import_from_pattern(self):
+ patterns = _get_patterns(Ecosystem.JAVASCRIPT, "xml2js")
+
+ assert any(p.search("import xml from 'xml2js'") for p in patterns)
+ assert any(p.search('import { parseString } from "xml2js"') for p in patterns)
+
+ def test_c_include_pattern(self):
+ patterns = _get_patterns(Ecosystem.C_CPP, "libxml")
+
+ assert any(p.search('#include ') for p in patterns)
+ assert any(p.search('#include "libxml/tree.h"') for p in patterns)
+
+ def test_special_chars_escaped(self):
+ patterns = _get_patterns(Ecosystem.PYTHON, "foo.bar[baz]")
+
+ assert len(patterns) > 0
+ for p in patterns:
+ assert isinstance(p, re.Pattern)
+
+ def test_case_insensitive(self):
+ patterns = _get_patterns(Ecosystem.PYTHON, "XmlParser")
+
+ assert any(p.search("import xmlparser") for p in patterns)
+ assert any(p.search("import XMLPARSER") for p in patterns)
+ assert any(p.search("import XmlParser") for p in patterns)
+
+
+class TestFindUsageInFile:
+ def test_finds_usage_sites(self):
+ content = """import xml.etree.ElementTree
+tree = ElementTree.parse('data.xml')
+root = tree.getroot()
+for child in root:
+ print(child.tag)
+"""
+ usages = _find_usage_in_file(content, ["xml.etree.ElementTree"])
+
+ assert len(usages) > 0
+ assert any("ElementTree.parse" in u for u in usages)
+
+ def test_skips_import_lines(self):
+ content = """import xml.etree.ElementTree
+from xml.etree import ElementTree
+tree = ElementTree.parse('data.xml')
+"""
+ usages = _find_usage_in_file(content, ["ElementTree"])
+
+ assert len(usages) == 1
+ assert "L3:" in usages[0]
+ assert "parse" in usages[0]
+
+ def test_skips_from_lines(self):
+ content = """from xml.etree import ElementTree
+tree = ElementTree.parse('data.xml')
+"""
+ usages = _find_usage_in_file(content, ["ElementTree"])
+
+ assert len(usages) == 1
+ assert "L2:" in usages[0]
+
+ def test_skips_include_lines(self):
+ content = """#include
+xmlDocPtr doc = xmlParseFile("data.xml");
+"""
+ usages = _find_usage_in_file(content, ["xmlParseFile"])
+
+ assert len(usages) == 1
+ assert "L2:" in usages[0]
+
+ def test_max_usages_cap(self):
+ content = "\n".join([f"var x{i} = XStream();" for i in range(20)])
+
+ usages = _find_usage_in_file(content, ["XStream"], max_usages=5)
+
+ assert len(usages) == 5
+
+ def test_word_boundary_matching(self):
+ content = """class MyXStreamHelper:
+ pass
+xs = XStream()
+"""
+ usages = _find_usage_in_file(content, ["XStream"])
+
+ assert len(usages) == 1
+ assert any("L3:" in u and "XStream()" in u for u in usages)
+
+ def test_dotted_name_uses_short_component(self):
+ content = """import com.thoughtworks.xstream.XStream;
+XStream xs = new XStream();
+"""
+ usages = _find_usage_in_file(content, ["com.thoughtworks.xstream.XStream"])
+
+ assert len(usages) == 1
+ assert "L2:" in usages[0]
+ assert "XStream xs" in usages[0]
+
+ def test_no_usages_returns_empty(self):
+ content = """import xml.etree.ElementTree
+from xml.etree import ElementTree
+"""
+ usages = _find_usage_in_file(content, ["ElementTree"])
+
+ assert usages == []
+
+ def test_empty_content(self):
+ usages = _find_usage_in_file("", ["XStream"])
+
+ assert usages == []
+
+ def test_empty_names(self):
+ content = "xs = XStream();"
+ usages = _find_usage_in_file(content, [])
+
+ assert usages == []
+
+ def test_multiple_names_same_line(self):
+ content = """import foo, bar
+result = foo.process(bar.data)
+"""
+ usages = _find_usage_in_file(content, ["foo", "bar"])
+
+ assert len(usages) == 2
+ assert all("L2:" in u for u in usages)
+ assert any("foo.process" in u for u in usages)
+ assert any("bar.data" in u for u in usages)
+
+ def test_line_number_format(self):
+ content = """import XStream
+xs = XStream()
+"""
+ usages = _find_usage_in_file(content, ["XStream"])
+
+ assert usages[0].startswith(" L2:")
+
+ def test_strips_line_content(self):
+ content = """import XStream
+ xs = XStream()
+"""
+ usages = _find_usage_in_file(content, ["XStream"])
+
+ assert "xs = XStream()" in usages[0]
+ assert usages[0].count(" " * 8) == 0
+
+
+class _MockSearcher:
+ def __init__(self, docs):
+ self.docs = docs
+ self.num_docs = len(docs)
+
+ def doc(self, doc_id):
+ if doc_id >= len(self.docs):
+ raise IndexError(f"doc_id {doc_id} out of range")
+ return self.docs[doc_id]
+
+
+def _import_searcher(*docs):
+ return _MockSearcher(list(docs))
+
+
+class TestAnalyzeImportsSourceAwareness:
+
+ def _java_patterns(self, package_name):
+ return _get_patterns(Ecosystem.JAVA, package_name)
+
+ def test_app_imports_prioritized(self):
+ docs = [
+ {"file_path": ["dependencies-sources/xstream-1.4/XStreamConverter.java"],
+ "content": ["import com.thoughtworks.xstream.XStream;\nXStream xs = new XStream();"]},
+ {"file_path": ["src/main/java/App.java"],
+ "content": ["import com.thoughtworks.xstream.XStream;\nXStream converter = new XStream();"]},
+ ]
+ patterns = self._java_patterns("com.thoughtworks.xstream")
+ result = analyze_imports(_import_searcher(*docs), patterns, "com.thoughtworks.xstream")
+ assert "Main application" in result
+ assert "Application library dependencies" in result
+ app_pos = result.index("Main application")
+ dep_pos = result.index("Application library dependencies")
+ assert app_pos < dep_pos
+ assert "src/main/java/App.java" in result
+
+ def test_source_scope_filters_dep_imports(self):
+ docs = [
+ {"file_path": ["dependencies-sources/xstream-1.4/Converter.java"],
+ "content": ["import com.thoughtworks.xstream.XStream;"]},
+ {"file_path": ["dependencies-sources/commons-io/IOUtils.java"],
+ "content": ["import com.thoughtworks.xstream.XStream;"]},
+ ]
+ patterns = self._java_patterns("com.thoughtworks.xstream")
+ result = analyze_imports(_import_searcher(*docs), patterns, "com.thoughtworks.xstream",
+ source_scope=["xstream"])
+ assert "xstream-1.4" in result
+ assert "commons-io" not in result
+
+ def test_no_imports_unchanged(self):
+ docs = [
+ {"file_path": ["src/App.java"], "content": ["public class App {}"]},
+ ]
+ patterns = self._java_patterns("com.thoughtworks.xstream")
+ result = analyze_imports(_import_searcher(*docs), patterns, "com.thoughtworks.xstream",
+ ecosystem_label="java")
+ assert "No imports of 'com.thoughtworks.xstream' found" in result
+
+ def test_app_dep_headers_in_output(self):
+ docs = [
+ {"file_path": ["src/App.java"],
+ "content": ["import com.thoughtworks.xstream.XStream;"]},
+ {"file_path": ["vendor/lib/Helper.java"],
+ "content": ["import com.thoughtworks.xstream.XStream;"]},
+ ]
+ patterns = self._java_patterns("com.thoughtworks.xstream")
+ result = analyze_imports(_import_searcher(*docs), patterns, "com.thoughtworks.xstream")
+ assert "Main application" in result
+ assert "Application library dependencies" in result
+
+ def test_only_dep_imports(self):
+ docs = [
+ {"file_path": ["dependencies-sources/lib/Foo.java"],
+ "content": ["import com.thoughtworks.xstream.XStream;"]},
+ ]
+ patterns = self._java_patterns("com.thoughtworks.xstream")
+ result = analyze_imports(_import_searcher(*docs), patterns, "com.thoughtworks.xstream")
+ assert "Main application (0 of 0 results)" in result
+ assert "Application library dependencies (1 of 1 results)" in result
diff --git a/tests/test_intel_utils.py b/tests/test_intel_utils.py
new file mode 100644
index 00000000..831cea72
--- /dev/null
+++ b/tests/test_intel_utils.py
@@ -0,0 +1,69 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Tests for intel_utils: build_critical_context RHSA candidate capping."""
+
+from exploit_iq_commons.data_models.cve_intel import CveIntel, CveIntelRhsa
+from vuln_analysis.utils.intel_utils import build_critical_context, _MAX_RHSA_CANDIDATES
+
+
+class TestBuildCriticalContextRhsaCap:
+ """Tests for RHSA candidate package capping in build_critical_context."""
+
+ def _make_cve_intel_with_rhsa_packages(self, count):
+ """Create a CveIntel with `count` RHSA package_state entries."""
+ package_states = [
+ CveIntelRhsa.PackageState(package_name=f"rhsa-product-{i}")
+ for i in range(count)
+ ]
+ rhsa = CveIntelRhsa(package_state=package_states)
+ return CveIntel(vuln_id="CVE-2026-99999", rhsa=rhsa)
+
+ def test_rhsa_cap_limits_candidates(self):
+ """RHSA with 100+ packages should only add _MAX_RHSA_CANDIDATES to candidates."""
+ cve_intel = self._make_cve_intel_with_rhsa_packages(100)
+ _, candidates, _ = build_critical_context([cve_intel])
+
+ rhsa_candidates = [c for c in candidates if c["source"] == "rhsa"]
+ assert len(rhsa_candidates) == _MAX_RHSA_CANDIDATES
+
+ def test_rhsa_below_cap_all_included(self):
+ """RHSA with fewer than _MAX_RHSA_CANDIDATES packages includes all."""
+ cve_intel = self._make_cve_intel_with_rhsa_packages(5)
+ _, candidates, _ = build_critical_context([cve_intel])
+
+ rhsa_candidates = [c for c in candidates if c["source"] == "rhsa"]
+ assert len(rhsa_candidates) == 5
+
+ def test_rhsa_cap_with_1000_packages(self):
+ """Reproduces production scenario: 1000+ RHSA packages capped at limit."""
+ cve_intel = self._make_cve_intel_with_rhsa_packages(1234)
+ _, candidates, _ = build_critical_context([cve_intel])
+
+ rhsa_candidates = [c for c in candidates if c["source"] == "rhsa"]
+ assert len(rhsa_candidates) == _MAX_RHSA_CANDIDATES
+
+ def test_rhsa_cap_does_not_affect_ghsa(self):
+ """GHSA candidates are not affected by the RHSA cap."""
+ cve_intel = self._make_cve_intel_with_rhsa_packages(100)
+ # Manually add GHSA data
+ from exploit_iq_commons.data_models.cve_intel import CveIntelGhsa
+ cve_intel.ghsa = CveIntelGhsa(
+ ghsa_id="GHSA-test-0001",
+ vulnerabilities=[{"package": {"name": "xstream", "ecosystem": "Maven"}}],
+ )
+ _, candidates, _ = build_critical_context([cve_intel])
+
+ ghsa_candidates = [c for c in candidates if c["source"] == "ghsa"]
+ rhsa_candidates = [c for c in candidates if c["source"] == "rhsa"]
+ assert len(ghsa_candidates) == 1
+ assert len(rhsa_candidates) == _MAX_RHSA_CANDIDATES
+
+ def test_context_note_still_shows_total_count(self):
+ """The critical_context note should still mention the full package count."""
+ cve_intel = self._make_cve_intel_with_rhsa_packages(50)
+ context, _, _ = build_critical_context([cve_intel])
+
+ affected_notes = [c for c in context if "Affected across" in c]
+ assert len(affected_notes) == 1
+ assert "50" in affected_notes[0]
\ No newline at end of file
diff --git a/tests/test_process_steps.py b/tests/test_process_steps.py
new file mode 100644
index 00000000..381d4492
--- /dev/null
+++ b/tests/test_process_steps.py
@@ -0,0 +1,294 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""
+Unit tests for _process_steps in cve_agent.py.
+
+Focus: tracker/graph alignment on fallback, initial_state construction,
+semaphore usage, concurrent step processing.
+"""
+
+import asyncio
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+from langchain_core.messages import HumanMessage
+
+from vuln_analysis.functions.cve_agent import _process_steps
+from vuln_analysis.functions.dispatcher import QuestionRouting
+from vuln_analysis.functions.react_internals import ReachabilityRulesTracker
+from vuln_analysis.functions.code_understanding_internals import CodeUnderstandingRulesTracker
+
+
+@pytest.fixture
+def mock_workflow_state():
+ state = MagicMock()
+ state.cve_intel = []
+ return state
+
+
+@pytest.fixture
+def mock_graph():
+ graph = AsyncMock()
+ graph.ainvoke = AsyncMock(return_value={"input": "q", "output": "answer"})
+ return graph
+
+
+@pytest.fixture
+def patch_externals(mock_workflow_state):
+ """Patch ctx_state, build_critical_context, dispatch_question, and AGENT_TRACER."""
+ with (
+ patch("vuln_analysis.functions.cve_agent.ctx_state") as mock_ctx,
+ patch("vuln_analysis.functions.cve_agent.build_critical_context") as mock_bcc,
+ patch("vuln_analysis.functions.cve_agent.dispatch_question") as mock_dispatch,
+ patch("vuln_analysis.functions.cve_agent.AGENT_TRACER") as mock_tracer,
+ ):
+ mock_ctx.get.return_value = mock_workflow_state
+ mock_bcc.return_value = (["ctx_line"], [{"name": "pkg"}], ["vuln_func"])
+ mock_tracer.push_active_function.return_value.__enter__ = MagicMock()
+ mock_tracer.push_active_function.return_value.__exit__ = MagicMock(return_value=False)
+ yield {
+ "ctx_state": mock_ctx,
+ "build_critical_context": mock_bcc,
+ "dispatch_question": mock_dispatch,
+ "tracer": mock_tracer,
+ }
+
+
+class TestFallbackAlignment:
+ """When routed agent_type is not in agents dict, both graph AND tracker
+ must fall back to reachability."""
+
+ @pytest.mark.asyncio
+ async def test_cu_fallback_uses_reachability_graph(self, patch_externals, mock_graph):
+ patch_externals["dispatch_question"].return_value = QuestionRouting(
+ agent_type="code_understanding", reason="config question"
+ )
+ agents = {"reachability": mock_graph}
+
+ await _process_steps(agents, MagicMock(), ["Is it configured?"], None)
+
+ mock_graph.ainvoke.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_cu_fallback_uses_reachability_tracker(self, patch_externals, mock_graph):
+ patch_externals["dispatch_question"].return_value = QuestionRouting(
+ agent_type="code_understanding", reason="config question"
+ )
+ agents = {"reachability": mock_graph}
+
+ await _process_steps(agents, MagicMock(), ["Is it configured?"], None)
+
+ call_args = mock_graph.ainvoke.call_args[0][0]
+ tracker = call_args["rules_tracker"]
+ assert isinstance(tracker, ReachabilityRulesTracker)
+ assert not isinstance(tracker, CodeUnderstandingRulesTracker)
+
+ @pytest.mark.asyncio
+ async def test_direct_route_uses_correct_tracker(self, patch_externals):
+ patch_externals["dispatch_question"].return_value = QuestionRouting(
+ agent_type="code_understanding", reason="config question"
+ )
+ cu_graph = AsyncMock()
+ cu_graph.ainvoke = AsyncMock(return_value={"input": "q", "output": "a"})
+ reach_graph = AsyncMock()
+ agents = {"reachability": reach_graph, "code_understanding": cu_graph}
+
+ await _process_steps(agents, MagicMock(), ["Is it configured?"], None)
+
+ call_args = cu_graph.ainvoke.call_args[0][0]
+ tracker = call_args["rules_tracker"]
+ assert isinstance(tracker, CodeUnderstandingRulesTracker)
+ reach_graph.ainvoke.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_reachability_route_uses_reachability_tracker(self, patch_externals, mock_graph):
+ patch_externals["dispatch_question"].return_value = QuestionRouting(
+ agent_type="reachability", reason="function call check"
+ )
+ agents = {"reachability": mock_graph}
+
+ await _process_steps(agents, MagicMock(), ["Is func reachable?"], None)
+
+ call_args = mock_graph.ainvoke.call_args[0][0]
+ tracker = call_args["rules_tracker"]
+ assert isinstance(tracker, ReachabilityRulesTracker)
+
+
+class TestInitialState:
+ """Verify initial_state dict has correct keys and values."""
+
+ @pytest.mark.asyncio
+ async def test_initial_state_keys(self, patch_externals, mock_graph):
+ patch_externals["dispatch_question"].return_value = QuestionRouting(
+ agent_type="reachability", reason="test"
+ )
+ agents = {"reachability": mock_graph}
+
+ await _process_steps(agents, MagicMock(), ["test question"], None)
+
+ call_args = mock_graph.ainvoke.call_args[0][0]
+ assert call_args["input"] == "test question"
+ assert call_args["step"] == 0
+ assert call_args["max_steps"] == 10
+ assert call_args["thought"] is None
+ assert call_args["observation"] is None
+ assert call_args["output"] == "waiting for the agent to respond"
+ assert isinstance(call_args["messages"][0], HumanMessage)
+ assert call_args["messages"][0].content == "test question"
+
+ @pytest.mark.asyncio
+ async def test_precomputed_intel_passed(self, patch_externals, mock_graph):
+ patch_externals["dispatch_question"].return_value = QuestionRouting(
+ agent_type="reachability", reason="test"
+ )
+ agents = {"reachability": mock_graph}
+
+ await _process_steps(agents, MagicMock(), ["q"], None)
+
+ call_args = mock_graph.ainvoke.call_args[0][0]
+ intel = call_args["precomputed_intel"]
+ assert intel == (["ctx_line"], [{"name": "pkg"}], ["vuln_func"])
+
+ @pytest.mark.asyncio
+ async def test_custom_max_iterations(self, patch_externals, mock_graph):
+ patch_externals["dispatch_question"].return_value = QuestionRouting(
+ agent_type="reachability", reason="test"
+ )
+ agents = {"reachability": mock_graph}
+
+ await _process_steps(agents, MagicMock(), ["q"], None, max_iterations=25)
+
+ call_args = mock_graph.ainvoke.call_args[0][0]
+ assert call_args["max_steps"] == 25
+
+ @pytest.mark.asyncio
+ async def test_graph_config_recursion_limit(self, patch_externals, mock_graph):
+ patch_externals["dispatch_question"].return_value = QuestionRouting(
+ agent_type="reachability", reason="test"
+ )
+ agents = {"reachability": mock_graph}
+
+ await _process_steps(agents, MagicMock(), ["q"], None)
+
+ config_arg = mock_graph.ainvoke.call_args[1]["config"]
+ assert config_arg == {"recursion_limit": 50}
+
+
+class TestConcurrency:
+ """Test concurrent step processing and semaphore behavior."""
+
+ @pytest.mark.asyncio
+ async def test_multiple_steps_all_processed(self, patch_externals, mock_graph):
+ patch_externals["dispatch_question"].return_value = QuestionRouting(
+ agent_type="reachability", reason="test"
+ )
+ agents = {"reachability": mock_graph}
+ steps = ["q1", "q2", "q3"]
+
+ results = await _process_steps(agents, MagicMock(), steps, None)
+
+ assert len(results) == 3
+ assert mock_graph.ainvoke.call_count == 3
+
+ @pytest.mark.asyncio
+ async def test_semaphore_limits_concurrency(self, patch_externals):
+ max_concurrent = 0
+ current_concurrent = 0
+
+ async def slow_invoke(state, config=None):
+ nonlocal max_concurrent, current_concurrent
+ current_concurrent += 1
+ max_concurrent = max(max_concurrent, current_concurrent)
+ await asyncio.sleep(0.01)
+ current_concurrent -= 1
+ return {"input": state["input"], "output": "done"}
+
+ mock_graph = AsyncMock()
+ mock_graph.ainvoke = slow_invoke
+
+ patch_externals["dispatch_question"].return_value = QuestionRouting(
+ agent_type="reachability", reason="test"
+ )
+ agents = {"reachability": mock_graph}
+ semaphore = asyncio.Semaphore(2)
+
+ await _process_steps(agents, MagicMock(), ["q1", "q2", "q3", "q4"], semaphore)
+
+ assert max_concurrent <= 2
+
+ @pytest.mark.asyncio
+ async def test_no_semaphore_allows_full_concurrency(self, patch_externals):
+ max_concurrent = 0
+ current_concurrent = 0
+
+ async def slow_invoke(state, config=None):
+ nonlocal max_concurrent, current_concurrent
+ current_concurrent += 1
+ max_concurrent = max(max_concurrent, current_concurrent)
+ await asyncio.sleep(0.01)
+ current_concurrent -= 1
+ return {"input": state["input"], "output": "done"}
+
+ mock_graph = AsyncMock()
+ mock_graph.ainvoke = slow_invoke
+
+ patch_externals["dispatch_question"].return_value = QuestionRouting(
+ agent_type="reachability", reason="test"
+ )
+ agents = {"reachability": mock_graph}
+
+ await _process_steps(agents, MagicMock(), ["q1", "q2", "q3", "q4"], None)
+
+ assert max_concurrent == 4
+
+ @pytest.mark.asyncio
+ async def test_exception_in_step_returned_not_raised(self, patch_externals):
+ mock_graph = AsyncMock()
+ mock_graph.ainvoke = AsyncMock(side_effect=RuntimeError("graph failed"))
+
+ patch_externals["dispatch_question"].return_value = QuestionRouting(
+ agent_type="reachability", reason="test"
+ )
+ agents = {"reachability": mock_graph}
+
+ results = await _process_steps(agents, MagicMock(), ["q1"], None)
+
+ assert len(results) == 1
+ assert isinstance(results[0], RuntimeError)
+
+ @pytest.mark.asyncio
+ async def test_empty_steps_returns_empty(self, patch_externals, mock_graph):
+ agents = {"reachability": mock_graph}
+ results = await _process_steps(agents, MagicMock(), [], None)
+ assert results == []
+
+
+class TestTracingSpan:
+ """Test that the tracing span receives actual_type, not routed type."""
+
+ @pytest.mark.asyncio
+ async def test_span_logs_actual_type_on_fallback(self, patch_externals, mock_graph):
+ patch_externals["dispatch_question"].return_value = QuestionRouting(
+ agent_type="code_understanding", reason="config"
+ )
+ agents = {"reachability": mock_graph}
+
+ await _process_steps(agents, MagicMock(), ["q"], None)
+
+ call_args = patch_externals["tracer"].push_active_function.call_args
+ assert "[reachability]" in call_args[1]["input_data"]
+
+ @pytest.mark.asyncio
+ async def test_span_logs_routed_type_when_available(self, patch_externals):
+ patch_externals["dispatch_question"].return_value = QuestionRouting(
+ agent_type="code_understanding", reason="config"
+ )
+ cu_graph = AsyncMock()
+ cu_graph.ainvoke = AsyncMock(return_value={"input": "q", "output": "a"})
+ agents = {"reachability": MagicMock(), "code_understanding": cu_graph}
+
+ await _process_steps(agents, MagicMock(), ["q"], None)
+
+ call_args = patch_externals["tracer"].push_active_function.call_args
+ assert "[code_understanding]" in call_args[1]["input_data"]
diff --git a/tests/test_python_segmenter.py b/tests/test_python_segmenter.py
index c909d0c0..d2158caf 100644
--- a/tests/test_python_segmenter.py
+++ b/tests/test_python_segmenter.py
@@ -89,7 +89,7 @@
("import os", False, "py3 import statement"),
],
)
-def test_is_python2_code(self, code: str, expected: bool, description: str):
+def test_is_python2_code(code: str, expected: bool, description: str):
"""Test that Python 2/3 patterns are correctly detected."""
result = is_python2_code(code)
assert result is expected, f"Expected {expected} for {description}, got {result}"
\ No newline at end of file
diff --git a/tests/test_reachability_agent.py b/tests/test_reachability_agent.py
new file mode 100644
index 00000000..590a8042
--- /dev/null
+++ b/tests/test_reachability_agent.py
@@ -0,0 +1,221 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Unit tests for ReachabilityAgent: get_tools, create_rules_tracker,
+agent_type, should_truncate_tool_output."""
+
+import pytest
+from unittest.mock import MagicMock
+
+from agent_test_helpers import MockTool, ALL_TOOLS, make_builder, make_config, make_state
+from vuln_analysis.functions.reachability_agent import ReachabilityAgent
+from vuln_analysis.functions.react_internals import ReachabilityRulesTracker
+from vuln_analysis.tools.tool_names import ToolNames
+
+
+def _make_reachability_agent(tools=None):
+ mock_llm = MagicMock()
+ mock_llm.with_structured_output = MagicMock(return_value=MagicMock())
+ config = MagicMock()
+ config.max_iterations = 10
+ return ReachabilityAgent(tools=tools or [], llm=mock_llm, config=config)
+
+
+class TestGetTools:
+ """ReachabilityAgent.get_tools selects reachability tools and filters by availability."""
+
+ def test_keeps_reachability_tools(self):
+ builder = make_builder()
+ config = make_config()
+ state = make_state()
+ result = ReachabilityAgent.get_tools(builder, config, state)
+ result_names = {t.name for t in result}
+ assert ToolNames.FUNCTION_LOCATOR in result_names
+ assert ToolNames.CALL_CHAIN_ANALYZER in result_names
+ assert ToolNames.CODE_KEYWORD_SEARCH in result_names
+ assert ToolNames.CVE_WEB_SEARCH in result_names
+
+ def test_excludes_cu_only_tools(self):
+ builder = make_builder()
+ config = make_config()
+ state = make_state()
+ result = ReachabilityAgent.get_tools(builder, config, state)
+ result_names = {t.name for t in result}
+ assert ToolNames.CONFIGURATION_SCANNER not in result_names
+ assert ToolNames.IMPORT_USAGE_ANALYZER not in result_names
+
+ def test_excludes_container_analysis_data(self):
+ builder = make_builder()
+ config = make_config()
+ state = make_state()
+ result = ReachabilityAgent.get_tools(builder, config, state)
+ result_names = {t.name for t in result}
+ assert ToolNames.CONTAINER_ANALYSIS_DATA not in result_names
+
+ def test_keeps_all_8_reachability_tools(self):
+ builder = make_builder()
+ config = make_config()
+ state = make_state()
+ result = ReachabilityAgent.get_tools(builder, config, state)
+ assert len(result) == 8
+
+ def test_empty_builder_returns_empty(self):
+ builder = make_builder(tools=[])
+ config = make_config()
+ state = make_state()
+ result = ReachabilityAgent.get_tools(builder, config, state)
+ assert result == []
+
+ def test_preserves_tool_order(self):
+ ordered_tools = [
+ MockTool(ToolNames.CVE_WEB_SEARCH),
+ MockTool(ToolNames.FUNCTION_LOCATOR),
+ MockTool(ToolNames.CALL_CHAIN_ANALYZER),
+ ]
+ builder = make_builder(tools=ordered_tools)
+ config = make_config()
+ state = make_state()
+ result = ReachabilityAgent.get_tools(builder, config, state)
+ assert [t.name for t in result] == [
+ ToolNames.CVE_WEB_SEARCH,
+ ToolNames.FUNCTION_LOCATOR,
+ ToolNames.CALL_CHAIN_ANALYZER,
+ ]
+
+ def test_unknown_tools_excluded(self):
+ tools = [MockTool(ToolNames.FUNCTION_LOCATOR), MockTool("Some Future Tool")]
+ builder = make_builder(tools=tools)
+ config = make_config()
+ state = make_state()
+ result = ReachabilityAgent.get_tools(builder, config, state)
+ assert len(result) == 1
+ assert result[0].name == ToolNames.FUNCTION_LOCATOR
+
+
+class TestGetToolsAvailability:
+ """get_tools filters out tools whose infrastructure prerequisites are not met."""
+
+ def test_filters_code_semantic_search_when_no_vdb(self):
+ builder = make_builder()
+ config = make_config()
+ state = make_state(code_vdb_path=None)
+ result = ReachabilityAgent.get_tools(builder, config, state)
+ assert ToolNames.CODE_SEMANTIC_SEARCH not in {t.name for t in result}
+
+ def test_filters_docs_semantic_search_when_no_vdb(self):
+ builder = make_builder()
+ config = make_config()
+ state = make_state(doc_vdb_path=None)
+ result = ReachabilityAgent.get_tools(builder, config, state)
+ assert ToolNames.DOCS_SEMANTIC_SEARCH not in {t.name for t in result}
+
+ def test_filters_code_keyword_search_when_no_index(self):
+ builder = make_builder()
+ config = make_config()
+ state = make_state(code_index_path=None)
+ result = ReachabilityAgent.get_tools(builder, config, state)
+ assert ToolNames.CODE_KEYWORD_SEARCH not in {t.name for t in result}
+
+ def test_filters_transitive_tools_when_no_index(self):
+ builder = make_builder()
+ config = make_config()
+ state = make_state(code_index_path=None)
+ result = ReachabilityAgent.get_tools(builder, config, state)
+ result_names = {t.name for t in result}
+ assert ToolNames.CALL_CHAIN_ANALYZER not in result_names
+ assert ToolNames.FUNCTION_CALLER_FINDER not in result_names
+ assert ToolNames.FUNCTION_LOCATOR not in result_names
+
+ def test_filters_cve_web_search_when_disabled(self):
+ builder = make_builder()
+ config = make_config(cve_web_search_enabled=False)
+ state = make_state()
+ result = ReachabilityAgent.get_tools(builder, config, state)
+ assert ToolNames.CVE_WEB_SEARCH not in {t.name for t in result}
+
+ def test_filters_transitive_tools_when_disabled(self):
+ builder = make_builder()
+ config = make_config(transitive_search_tool_enabled=False)
+ state = make_state()
+ result = ReachabilityAgent.get_tools(builder, config, state)
+ result_names = {t.name for t in result}
+ assert ToolNames.CALL_CHAIN_ANALYZER not in result_names
+ assert ToolNames.FUNCTION_CALLER_FINDER not in result_names
+ assert ToolNames.FUNCTION_LOCATOR not in result_names
+
+ def test_version_finder_always_kept(self):
+ builder = make_builder()
+ config = make_config()
+ state = make_state(code_vdb_path=None, doc_vdb_path=None, code_index_path=None)
+ result = ReachabilityAgent.get_tools(builder, config, state)
+ assert ToolNames.FUNCTION_LIBRARY_VERSION_FINDER in {t.name for t in result}
+
+
+class TestReachabilityDuplicateCall:
+
+ def test_duplicate_call_blocked(self):
+ tracker = ReachabilityRulesTracker()
+ tracker.set_allowed_tools(["Function Locator"])
+ tracker.set_target_package("commons-beanutils")
+ tracker.check_thought_behavior("Function Locator", "commons-beanutils,getProperty", ["result"])
+ violated, msg = tracker.check_thought_behavior("Function Locator", "commons-beanutils,getProperty", ["result"])
+ assert violated is True
+ assert "already called" in msg
+
+
+class TestCreateRulesTracker:
+
+ def test_returns_reachability_rules_tracker(self):
+ tracker = ReachabilityAgent.create_rules_tracker()
+ assert isinstance(tracker, ReachabilityRulesTracker)
+
+ def test_returns_fresh_instance_each_call(self):
+ t1 = ReachabilityAgent.create_rules_tracker()
+ t2 = ReachabilityAgent.create_rules_tracker()
+ assert t1 is not t2
+
+
+class TestAgentType:
+
+ def test_agent_type_is_reachability(self):
+ agent = _make_reachability_agent()
+ assert agent.agent_type == "reachability"
+
+
+class TestShouldTruncateToolOutput:
+
+ def test_true_for_java(self):
+ agent = _make_reachability_agent()
+ assert agent.should_truncate_tool_output({"ecosystem": "java"}, "any_tool") is True
+
+ def test_true_for_java_uppercase(self):
+ agent = _make_reachability_agent()
+ assert agent.should_truncate_tool_output({"ecosystem": "Java"}, "any_tool") is True
+
+ def test_false_for_go(self):
+ agent = _make_reachability_agent()
+ assert agent.should_truncate_tool_output({"ecosystem": "go"}, "any_tool") is False
+
+ def test_false_for_python(self):
+ agent = _make_reachability_agent()
+ assert agent.should_truncate_tool_output({"ecosystem": "python"}, "any_tool") is False
+
+ def test_false_for_empty_ecosystem(self):
+ agent = _make_reachability_agent()
+ assert agent.should_truncate_tool_output({"ecosystem": ""}, "any_tool") is False
+
+ def test_false_when_ecosystem_missing(self):
+ agent = _make_reachability_agent()
+ assert agent.should_truncate_tool_output({}, "any_tool") is False
+
+
+class TestInit:
+
+ def test_creates_fifth_classification_llm(self):
+ mock_llm = MagicMock()
+ mock_llm.with_structured_output = MagicMock(return_value=MagicMock())
+ config = MagicMock()
+ config.max_iterations = 10
+ agent = ReachabilityAgent(tools=[], llm=mock_llm, config=config)
+ assert mock_llm.with_structured_output.call_count == 5
+ assert hasattr(agent, "_classification_llm")
diff --git a/tests/test_react_internals_rules.py b/tests/test_react_internals_rules.py
new file mode 100644
index 00000000..2440aecd
--- /dev/null
+++ b/tests/test_react_internals_rules.py
@@ -0,0 +1,418 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from vuln_analysis.functions.react_internals import (
+ BaseRulesTracker,
+ ReachabilityRulesTracker,
+ AgentState,
+ _find_image_matching_candidate,
+)
+
+
+class TestBaseRulesTracker:
+ def test_init_defaults(self):
+ tracker = BaseRulesTracker()
+ assert tracker.action_history == {}
+ assert tracker.target_package is None
+ assert tracker.allowed_tools == []
+
+ def test_set_allowed_tools(self):
+ tracker = BaseRulesTracker()
+ tools = ["Function Locator", "Code Keyword Search"]
+ tracker.set_allowed_tools(tools)
+ assert tracker.allowed_tools == tools
+
+ def test_set_target_package(self):
+ tracker = BaseRulesTracker()
+ tracker.set_target_package("commons-beanutils")
+ assert tracker.target_package == "commons-beanutils"
+
+ def test_is_empty_result_empty_list(self):
+ assert BaseRulesTracker._is_empty_result([]) is True
+
+ def test_is_empty_result_bracket_string(self):
+ assert BaseRulesTracker._is_empty_result("[]") is True
+
+ def test_is_empty_result_empty_string(self):
+ assert BaseRulesTracker._is_empty_result("") is True
+
+ def test_is_empty_result_whitespace(self):
+ assert BaseRulesTracker._is_empty_result(" ") is True
+
+ def test_is_empty_result_none(self):
+ assert BaseRulesTracker._is_empty_result(None) is False
+
+ def test_is_empty_result_int(self):
+ assert BaseRulesTracker._is_empty_result(0) is False
+
+ def test_is_empty_result_nonempty_list(self):
+ assert BaseRulesTracker._is_empty_result(["item"]) is False
+
+ def test_add_action_accumulates(self):
+ tracker = BaseRulesTracker()
+ tracker.add_action("Function Locator", "pkg,fn1", ["result1"])
+ tracker.add_action("Function Locator", "pkg,fn2", ["result2"])
+ assert len(tracker.action_history["Function Locator"]) == 2
+ assert tracker.action_history["Function Locator"][0]["input"] == "pkg,fn1"
+ assert tracker.action_history["Function Locator"][1]["input"] == "pkg,fn2"
+
+
+class TestDuplicateCallRule:
+ def test_duplicate_call_blocked(self):
+ tracker = BaseRulesTracker()
+ tracker.set_allowed_tools(["Function Locator"])
+ tracker.check_thought_behavior("Function Locator", "pkg,fn", ["result"])
+ violated, msg = tracker.check_thought_behavior("Function Locator", "pkg,fn", ["result"])
+ assert violated is True
+ assert "already called" in msg
+
+ def test_different_input_allowed(self):
+ tracker = BaseRulesTracker()
+ tracker.set_allowed_tools(["Function Locator"])
+ tracker.check_thought_behavior("Function Locator", "pkg,fn1", ["result"])
+ violated, msg = tracker.check_thought_behavior("Function Locator", "pkg,fn2", ["result"])
+ assert violated is False
+ assert msg == ""
+
+ def test_different_tool_same_input_allowed(self):
+ tracker = BaseRulesTracker()
+ tracker.set_allowed_tools(["Function Locator", "Code Keyword Search"])
+ tracker.check_thought_behavior("Function Locator", "pkg,fn", ["result"])
+ violated, msg = tracker.check_thought_behavior("Code Keyword Search", "pkg,fn", ["result"])
+ assert violated is False
+ assert msg == ""
+
+ def test_first_call_always_passes(self):
+ tracker = BaseRulesTracker()
+ tracker.set_allowed_tools(["Function Locator"])
+ violated, msg = tracker.check_thought_behavior("Function Locator", "pkg,fn", ["result"])
+ assert violated is False
+ assert msg == ""
+
+
+class TestBaseRule7:
+ def test_rule7_non_cks_tool_ignored(self):
+ tracker = BaseRulesTracker()
+ tracker.add_action("Code Keyword Search", "org.Class", [])
+ result = tracker._rule_number_7("Function Locator", "org.Class", [])
+ assert result is False
+
+ def test_rule7_no_dot_in_query_ignored(self):
+ tracker = BaseRulesTracker()
+ tracker.add_action("Code Keyword Search", "ClassName", [])
+ result = tracker._rule_number_7("Code Keyword Search", "ClassName", [])
+ assert result is False
+
+ def test_rule7_nonempty_output_ignored(self):
+ tracker = BaseRulesTracker()
+ tracker.add_action("Code Keyword Search", "org.Class", [])
+ result = tracker._rule_number_7("Code Keyword Search", "org.Class", ["match"])
+ assert result is False
+
+ def test_rule7_no_prior_history_ignored(self):
+ tracker = BaseRulesTracker()
+ result = tracker._rule_number_7("Code Keyword Search", "org.Class", [])
+ assert result is False
+
+ def test_rule7_prior_non_dotted_ignored(self):
+ tracker = BaseRulesTracker()
+ tracker.add_action("Code Keyword Search", "ClassName", [])
+ result = tracker._rule_number_7("Code Keyword Search", "org.Class", [])
+ assert result is False
+
+ def test_rule7_prior_had_results_ignored(self):
+ tracker = BaseRulesTracker()
+ tracker.add_action("Code Keyword Search", "org.Class", ["match"])
+ result = tracker._rule_number_7("Code Keyword Search", "org.Class", [])
+ assert result is False
+
+ def test_rule7_consecutive_dotted_empty_fires(self):
+ tracker = BaseRulesTracker()
+ tracker.add_action("Code Keyword Search", "org.Class", [])
+ result = tracker._rule_number_7("Code Keyword Search", "org.Another", [])
+ assert result is True
+
+
+class TestBaseRuleAllowedTools:
+ def test_allowed_tools_in_list_passes(self):
+ tracker = BaseRulesTracker()
+ tracker.set_allowed_tools(["Function Locator", "Code Keyword Search"])
+ result = tracker._rule_use_allowed_tools("Function Locator")
+ assert result is False
+
+ def test_allowed_tools_not_in_list_fails(self):
+ tracker = BaseRulesTracker()
+ tracker.set_allowed_tools(["Function Locator"])
+ result = tracker._rule_use_allowed_tools("CVE Web Search")
+ assert result is True
+
+ def test_allowed_tools_empty_list_rejects_all(self):
+ tracker = BaseRulesTracker()
+ tracker.set_allowed_tools([])
+ result = tracker._rule_use_allowed_tools("Function Locator")
+ assert result is True
+
+
+class TestBaseCheckThoughtBehavior:
+ def test_happy_path_returns_false_and_adds_to_history(self):
+ tracker = BaseRulesTracker()
+ tracker.set_allowed_tools(["Function Locator"])
+ violated, msg = tracker.check_thought_behavior("Function Locator", "pkg,fn", ["result"])
+ assert violated is False
+ assert msg == ""
+ assert "Function Locator" in tracker.action_history
+ assert tracker.action_history["Function Locator"][0]["input"] == "pkg,fn"
+
+ def test_duplicate_priority_over_rule7(self):
+ tracker = BaseRulesTracker()
+ tracker.set_allowed_tools(["Code Keyword Search"])
+ tracker.add_action("Code Keyword Search", "org.Class", [])
+ violated, msg = tracker.check_thought_behavior("Code Keyword Search", "org.Class", [])
+ assert violated is True
+ assert "already called" in msg
+ assert "Rule 7" not in msg
+
+ def test_rule7_priority_over_allowed_tools(self):
+ tracker = BaseRulesTracker()
+ tracker.set_allowed_tools(["Other Tool"])
+ tracker.add_action("Code Keyword Search", "org.Class", [])
+ violated, msg = tracker.check_thought_behavior("Code Keyword Search", "org.Another", [])
+ assert violated is True
+ assert "Rule 7" in msg
+ assert "AVAILABLE_TOOLS" not in msg
+
+ def test_allowed_tools_error_message_format(self):
+ tracker = BaseRulesTracker()
+ tracker.set_allowed_tools(["Function Locator", "Code Keyword Search"])
+ violated, msg = tracker.check_thought_behavior("CVE Web Search", "query", [])
+ assert violated is True
+ assert "AVAILABLE_TOOLS" in msg
+ assert "['Function Locator', 'Code Keyword Search']" in msg
+
+
+class TestReachabilityRulesTracker:
+ def test_extends_base(self):
+ tracker = ReachabilityRulesTracker()
+ assert isinstance(tracker, BaseRulesTracker)
+
+ def test_set_target_functions_creates_dict(self):
+ tracker = ReachabilityRulesTracker()
+ tracker.set_target_functions(["fn1", "fn2"])
+ assert tracker.target_functions == {"fn1": False, "fn2": False}
+
+ def test_rule8_target_package_enforcement(self):
+ tracker = ReachabilityRulesTracker()
+ tracker.set_target_package("commons-beanutils")
+ result = tracker._rule_number_8("Function Locator", "wrong-package,fn", [])
+ assert result is True
+
+ def test_rule8_prefix_matching_for_java_gav(self):
+ tracker = ReachabilityRulesTracker()
+ tracker.set_target_package("commons-beanutils:commons-beanutils")
+ result = tracker._rule_number_8("Function Locator", "commons-beanutils:commons-beanutils:1.9.4,fn", [])
+ assert result is False
+
+ def test_rule8_skipped_after_first_call(self):
+ tracker = ReachabilityRulesTracker()
+ tracker.set_target_package("commons-beanutils")
+ tracker.set_allowed_tools(["Function Locator"])
+ tracker.add_action("Function Locator", "commons-beanutils,fn", ["result"])
+ result = tracker._rule_number_8("Function Locator", "wrong-package,fn", [])
+ assert result is False
+
+ def test_rule8_no_target_package_skipped(self):
+ tracker = ReachabilityRulesTracker()
+ result = tracker._rule_number_8("Function Locator", "any-package,fn", [])
+ assert result is False
+
+ def test_rule9_vulnerable_functions_first(self):
+ tracker = ReachabilityRulesTracker()
+ tracker.set_target_functions(["getProperty", "setProperty"])
+ violated, msg = tracker._rule_number_9("Call Chain Analyzer", "pkg,someOtherFunction")
+ assert violated is True
+ assert "Rule 9" in msg
+ assert "getProperty" in msg or "setProperty" in msg
+
+ def test_rule9_passes_after_checking_vulnerable(self):
+ tracker = ReachabilityRulesTracker()
+ tracker.set_target_functions(["getProperty"])
+ violated, msg = tracker._rule_number_9("Call Chain Analyzer", "pkg,PropertyUtilsBean.getProperty")
+ assert violated is False
+ assert msg == ""
+ violated, msg = tracker._rule_number_9("Call Chain Analyzer", "pkg,someOtherFunction")
+ assert violated is False
+ assert msg == ""
+
+ def test_reachability_check_order(self):
+ tracker = ReachabilityRulesTracker()
+ tracker.set_allowed_tools(["Other Tool"])
+ tracker.set_target_package("commons-beanutils")
+ tracker.set_target_functions(["getProperty"])
+
+ tracker.add_action("Code Keyword Search", "org.Class", [])
+ violated, msg = tracker.check_thought_behavior("Code Keyword Search", "org.Another", [])
+ assert violated is True
+ assert "Rule 7" in msg
+
+ tracker2 = ReachabilityRulesTracker()
+ tracker2.set_allowed_tools(["Call Chain Analyzer"])
+ tracker2.set_target_package("commons-beanutils")
+ violated, msg = tracker2.check_thought_behavior("Call Chain Analyzer", "wrong-package,fn", [])
+ assert violated is True
+ assert "Rule 8" in msg
+
+ tracker3 = ReachabilityRulesTracker()
+ tracker3.set_allowed_tools(["Other Tool"])
+ violated, msg = tracker3.check_thought_behavior("Call Chain Analyzer", "pkg,fn", [])
+ assert violated is True
+ assert "AVAILABLE_TOOLS" in msg
+
+ tracker4 = ReachabilityRulesTracker()
+ tracker4.set_allowed_tools(["Call Chain Analyzer"])
+ tracker4.set_target_package("pkg")
+ tracker4.set_target_functions(["getProperty"])
+ violated, msg = tracker4.check_thought_behavior("Call Chain Analyzer", "pkg,otherFunction", [])
+ assert violated is True
+ assert "Rule 9" in msg
+
+
+class TestCheckFinishAllowed:
+ """Tests for ReachabilityRulesTracker.check_finish_allowed."""
+
+ def test_blocks_finish_when_cca_never_called(self):
+ """Reachability agent must call CCA before finishing."""
+ tracker = ReachabilityRulesTracker()
+ allowed, msg = tracker.check_finish_allowed(cca_results=[])
+ assert allowed is False
+ assert "Function Locator" in msg
+ assert "Call Chain Analyzer" in msg
+
+ def test_allows_finish_when_cca_returned_false(self):
+ """CCA was called and returned not-reachable — finish is allowed."""
+ tracker = ReachabilityRulesTracker()
+ tracker.add_action("Call Chain Analyzer", "pkg,fn", "(False, [])")
+ allowed, msg = tracker.check_finish_allowed(cca_results=[False])
+ assert allowed is True
+
+ def test_allows_finish_when_cca_returned_true_non_java(self):
+ """Non-Java: CCA found reachable — finish allowed (no FLVF requirement)."""
+ tracker = ReachabilityRulesTracker()
+ tracker.set_ecosystem("go")
+ tracker.add_action("Call Chain Analyzer", "pkg,fn", "(True, [path])")
+ allowed, msg = tracker.check_finish_allowed(cca_results=[True])
+ assert allowed is True
+
+ def test_blocks_finish_java_cca_true_no_flvf(self):
+ """Java: CCA found reachable but FLVF not called — block finish."""
+ tracker = ReachabilityRulesTracker()
+ tracker.set_ecosystem("java")
+ tracker.add_action("Call Chain Analyzer", "pkg,fn", "(True, [path])")
+ allowed, msg = tracker.check_finish_allowed(cca_results=[True])
+ assert allowed is False
+ assert "VERSION CHECK" in msg
+ assert "Function Library Version Finder" in msg
+
+ def test_allows_finish_java_cca_true_with_flvf(self):
+ """Java: CCA found reachable and FLVF was called — finish allowed."""
+ tracker = ReachabilityRulesTracker()
+ tracker.set_ecosystem("java")
+ tracker.add_action("Call Chain Analyzer", "pkg,fn", "(True, [path])")
+ tracker.add_action("Function Library Version Finder", "pkg", "1.9.4")
+ allowed, msg = tracker.check_finish_allowed(cca_results=[True])
+ assert allowed is True
+
+ def test_allows_finish_java_cca_false(self):
+ """Java: CCA returned not-reachable — no FLVF requirement."""
+ tracker = ReachabilityRulesTracker()
+ tracker.set_ecosystem("java")
+ tracker.add_action("Call Chain Analyzer", "pkg,fn", "(False, [])")
+ allowed, msg = tracker.check_finish_allowed(cca_results=[False])
+ assert allowed is True
+
+ def test_allows_finish_when_cca_in_history_but_results_empty(self):
+ """Edge case: CCA was called (in action_history) but cca_results is
+ empty — e.g., output parsing failed. Don't block since CCA was attempted."""
+ tracker = ReachabilityRulesTracker()
+ tracker.add_action("Call Chain Analyzer", "pkg,fn", "some unparseable output")
+ allowed, msg = tracker.check_finish_allowed(cca_results=[])
+ assert allowed is True
+
+
+class TestAgentState:
+ def test_default_tracker_type_annotation(self):
+ assert "rules_tracker" in AgentState.__annotations__
+
+ def test_default_is_reachability_field(self):
+ assert "is_reachability" in AgentState.__annotations__
+
+
+class TestFindImageMatchingCandidate:
+ """Tests for _find_image_matching_candidate used by package filter fast path."""
+
+ def test_match_in_image_name(self):
+ candidates = [
+ {"name": "builder", "source": "rhsa"},
+ {"name": "kernel", "source": "rhsa"},
+ ]
+ result = _find_image_matching_candidate(
+ candidates, "registry.redhat.io/openshift4/ose-docker-builder", None,
+ )
+ assert result == "builder"
+
+ def test_match_in_repo(self):
+ candidates = [
+ {"name": "kernel", "source": "rhsa"},
+ {"name": "infinispan", "source": "rhsa"},
+ ]
+ result = _find_image_matching_candidate(
+ candidates, "registry.redhat.io/some-image", "https://github.com/infinispan/infinispan",
+ )
+ assert result == "infinispan"
+
+ def test_no_match(self):
+ candidates = [
+ {"name": "kernel", "source": "rhsa"},
+ {"name": "glibc", "source": "rhsa"},
+ ]
+ result = _find_image_matching_candidate(
+ candidates, "registry.redhat.io/openshift4/ose-docker-builder",
+ "https://github.com/openshift/builder",
+ )
+ assert result is None
+
+ def test_short_name_skipped(self):
+ """Candidate names shorter than 3 chars are not matched."""
+ candidates = [{"name": "go", "source": "rhsa"}]
+ result = _find_image_matching_candidate(
+ candidates, "registry.redhat.io/golang-builder", None,
+ )
+ assert result is None
+
+ def test_no_image_no_repo(self):
+ candidates = [{"name": "builder", "source": "rhsa"}]
+ result = _find_image_matching_candidate(candidates, None, None)
+ assert result is None
+
+ def test_first_match_wins(self):
+ """When multiple candidates could match, the first one wins."""
+ candidates = [
+ {"name": "openshift", "source": "rhsa"},
+ {"name": "builder", "source": "rhsa"},
+ ]
+ result = _find_image_matching_candidate(
+ candidates, "registry.redhat.io/openshift4/ose-docker-builder",
+ "https://github.com/openshift/builder",
+ )
+ assert result == "openshift"
diff --git a/tests/test_source_classification.py b/tests/test_source_classification.py
new file mode 100644
index 00000000..53f07212
--- /dev/null
+++ b/tests/test_source_classification.py
@@ -0,0 +1,141 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from vuln_analysis.utils.source_classification import (
+ is_dependency_path,
+ filter_by_source_scope,
+ format_app_dep_output,
+)
+
+
+class TestIsDependencyPath:
+ @pytest.mark.parametrize("path", [
+ "dependencies-sources/commons-beanutils/PropertyUtilsBean.java",
+ "vendor/github.com/foo/bar/bar.go",
+ "transitive_env/lib/requests/models.py",
+ "node_modules/express/index.js",
+ "rpm_libs/libxml2/parser.c",
+ ])
+ def test_dependency_paths(self, path):
+ assert is_dependency_path(path) is True
+
+ @pytest.mark.parametrize("path", [
+ "src/main/java/com/example/App.java",
+ "cmd/main.go",
+ "app.py",
+ "pom.xml",
+ "src/main/resources/application.yml",
+ ])
+ def test_app_paths(self, path):
+ assert is_dependency_path(path) is False
+
+ def test_empty_path(self):
+ assert is_dependency_path("") is False
+
+
+class TestFilterBySourceScope:
+ def test_filters_by_scope(self):
+ items = [
+ ("dependencies-sources/xstream/XStream.java", "entry1"),
+ ("dependencies-sources/commons-io/IOUtils.java", "entry2"),
+ ("dependencies-sources/xstream/Converter.java", "entry3"),
+ ]
+ result = filter_by_source_scope(items, ["xstream"], lambda x: x[0])
+ assert len(result) == 2
+ assert result[0][1] == "entry1"
+ assert result[1][1] == "entry3"
+
+ def test_none_scope_returns_all(self):
+ items = [("a", "x"), ("b", "y")]
+ result = filter_by_source_scope(items, None, lambda x: x[0])
+ assert result == items
+
+ def test_empty_scope_returns_all(self):
+ items = [("a", "x"), ("b", "y")]
+ result = filter_by_source_scope(items, [], lambda x: x[0])
+ assert result == items
+
+ def test_multiple_scope_terms(self):
+ items = [
+ ("vendor/foo/main.go", "e1"),
+ ("vendor/bar/main.go", "e2"),
+ ("vendor/baz/main.go", "e3"),
+ ]
+ result = filter_by_source_scope(items, ["foo", "baz"], lambda x: x[0])
+ assert len(result) == 2
+ assert result[0][1] == "e1"
+ assert result[1][1] == "e3"
+
+ def test_no_matches_returns_empty(self):
+ items = [("vendor/foo/main.go", "e1")]
+ result = filter_by_source_scope(items, ["nonexistent"], lambda x: x[0])
+ assert result == []
+
+
+class TestFormatAppDepOutput:
+ def test_both_sections(self):
+ result = format_app_dep_output(
+ ["app_match_1", "app_match_2"],
+ ["dep_match_1"],
+ total_app=2, total_dep=1,
+ no_results_msg="No results",
+ )
+ assert "Main application (2 of 2 results)" in result
+ assert "Application library dependencies (1 of 1 results)" in result
+ assert "app_match_1" in result
+ assert "dep_match_1" in result
+
+ def test_empty_returns_no_results_msg(self):
+ result = format_app_dep_output([], [], 0, 0, "Nothing found")
+ assert result == "Nothing found"
+
+ def test_app_only(self):
+ result = format_app_dep_output(
+ ["app1"], [], total_app=1, total_dep=0,
+ no_results_msg="No results",
+ )
+ assert "Main application (1 of 1 results)" in result
+ assert "Application library dependencies (0 of 0 results)" in result
+ assert "app1" in result
+
+ def test_dep_only(self):
+ result = format_app_dep_output(
+ [], ["dep1"], total_app=0, total_dep=1,
+ no_results_msg="No results",
+ )
+ assert "Main application (0 of 0 results)" in result
+ assert "Application library dependencies (1 of 1 results)" in result
+ assert "dep1" in result
+
+ def test_trimmed_counts(self):
+ result = format_app_dep_output(
+ ["a1", "a2"], ["d1"],
+ total_app=5, total_dep=3,
+ no_results_msg="No results",
+ )
+ assert "Main application (2 of 5 results)" in result
+ assert "Application library dependencies (1 of 3 results)" in result
+
+ def test_app_before_dep(self):
+ result = format_app_dep_output(
+ ["APP_SECTION"], ["DEP_SECTION"],
+ total_app=1, total_dep=1,
+ no_results_msg="No results",
+ )
+ app_pos = result.index("APP_SECTION")
+ dep_pos = result.index("DEP_SECTION")
+ assert app_pos < dep_pos
diff --git a/tests/test_transitive_detection.py b/tests/test_transitive_detection.py
index 5a5fa5a0..cd65acb9 100644
--- a/tests/test_transitive_detection.py
+++ b/tests/test_transitive_detection.py
@@ -1,28 +1,114 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+"""Tests for detect_ecosystem in dep_tree.py."""
+
import pytest
-from exploit_iq_commons.utils.transitive_code_searcher_tool import TransitiveCodeSearcher
-
-C_CPP_DETECTION_SCENARIOS = [
- # Positive cases
- (["src/main.c"], True),
- (["src/utils.cpp"], True),
- (["include/lib.h"], True),
- (["src/deep/nested/logic.cxx"], True),
- # Negative cases
- (["Makefile", "app.py"], False),
- (["README.md", "LICENSE"], False),
- (["main.cs"], False),
- (["style.css"], False),
- ([], False),
- # Ignored directories
- ([".git/objects/obj.c"], False),
- ([".venv/lib/site/test.c"], False),
- ([".config/settings.h"], False),
+from exploit_iq_commons.utils.dep_tree import Ecosystem, detect_ecosystem
+
+
+# --- C/C++ detection: requires manifest + source files ---
+
+C_CPP_POSITIVE_SCENARIOS = [
+ # manifest + source file combinations
+ ("CMakeLists.txt", ["src/main.c"]),
+ ("CMakeLists.txt", ["src/utils.cpp"]),
+ ("Makefile", ["include/lib.h"]),
+ ("meson.build", ["src/deep/nested/logic.cxx"]),
+ ("CMakeLists.txt", ["lib.cc"]),
+ ("CMakeLists.txt", ["api.hpp"]),
]
-@pytest.mark.parametrize("file_paths, expected", C_CPP_DETECTION_SCENARIOS)
-def test_has_c_cpp_sources_scenarios(tmp_path, file_paths, expected):
+
+C_CPP_NEGATIVE_SCENARIOS = [
+ # manifest present but no C/C++ source files
+ ("Makefile", ["app.py"]),
+ ("CMakeLists.txt", ["README.md", "LICENSE"]),
+ ("Makefile", ["main.cs"]),
+ ("CMakeLists.txt", ["style.css"]),
+ ("Makefile", []),
+ # C/C++ sources but NO manifest — should not detect
+ (None, ["src/main.c"]),
+ (None, ["src/utils.cpp", "include/lib.h"]),
+ # manifest + source files in ignored dotfile directories
+ ("CMakeLists.txt", [".git/objects/obj.c"]),
+ ("CMakeLists.txt", [".venv/lib/site/test.c"]),
+ ("CMakeLists.txt", [".config/settings.h"]),
+]
+
+
+@pytest.mark.parametrize("manifest, file_paths", C_CPP_POSITIVE_SCENARIOS)
+def test_c_cpp_detected(tmp_path, manifest, file_paths):
+ (tmp_path / manifest).touch()
for path in file_paths:
full_path = tmp_path / path
full_path.parent.mkdir(parents=True, exist_ok=True)
full_path.touch()
- result = TransitiveCodeSearcher._has_c_cpp_sources(tmp_path)
- assert result is expected, f"Failed for scenario: {file_paths}"
\ No newline at end of file
+ assert detect_ecosystem(tmp_path) == Ecosystem.C_CPP
+
+
+@pytest.mark.parametrize("manifest, file_paths", C_CPP_NEGATIVE_SCENARIOS)
+def test_c_cpp_not_detected(tmp_path, manifest, file_paths):
+ if manifest:
+ (tmp_path / manifest).touch()
+ for path in file_paths:
+ full_path = tmp_path / path
+ full_path.parent.mkdir(parents=True, exist_ok=True)
+ full_path.touch()
+ assert detect_ecosystem(tmp_path) != Ecosystem.C_CPP
+
+
+# --- Other ecosystem detection ---
+
+OTHER_ECOSYSTEM_SCENARIOS = [
+ ("go.mod", Ecosystem.GO),
+ ("requirements.txt", Ecosystem.PYTHON),
+ ("pyproject.toml", Ecosystem.PYTHON),
+ ("setup.py", Ecosystem.PYTHON),
+ ("package.json", Ecosystem.JAVASCRIPT),
+ ("pom.xml", Ecosystem.JAVA),
+]
+
+
+@pytest.mark.parametrize("manifest, expected", OTHER_ECOSYSTEM_SCENARIOS)
+def test_ecosystem_detected_by_manifest(tmp_path, manifest, expected):
+ (tmp_path / manifest).touch()
+ assert detect_ecosystem(tmp_path) == expected
+
+
+def test_no_manifest_returns_none(tmp_path):
+ (tmp_path / "README.md").touch()
+ assert detect_ecosystem(tmp_path) is None
+
+
+def test_empty_dir_returns_none(tmp_path):
+ assert detect_ecosystem(tmp_path) is None
+
+
+# --- Priority order: Go > Python > JS > Java > C/C++ ---
+
+def test_go_takes_priority_over_python(tmp_path):
+ (tmp_path / "go.mod").touch()
+ (tmp_path / "requirements.txt").touch()
+ assert detect_ecosystem(tmp_path) == Ecosystem.GO
+
+
+def test_python_takes_priority_over_java(tmp_path):
+ (tmp_path / "requirements.txt").touch()
+ (tmp_path / "pom.xml").touch()
+ assert detect_ecosystem(tmp_path) == Ecosystem.PYTHON
+
+
+def test_java_takes_priority_over_c_cpp(tmp_path):
+ (tmp_path / "pom.xml").touch()
+ (tmp_path / "CMakeLists.txt").touch()
+ (tmp_path / "main.c").touch()
+ assert detect_ecosystem(tmp_path) == Ecosystem.JAVA
+
+
+def test_configure_in_auto_subdir(tmp_path):
+ """detect_ecosystem checks for auto/configure as a C/C++ manifest."""
+ (tmp_path / "auto").mkdir()
+ (tmp_path / "auto" / "configure").touch()
+ (tmp_path / "src").mkdir()
+ (tmp_path / "src" / "main.c").touch()
+ assert detect_ecosystem(tmp_path) == Ecosystem.C_CPP