diff --git a/.tekton/on-cm-runner.yaml b/.tekton/on-cm-runner.yaml index 27d0f667..6630f8f2 100644 --- a/.tekton/on-cm-runner.yaml +++ b/.tekton/on-cm-runner.yaml @@ -269,7 +269,7 @@ spec: resources: requests: memory: "1Gi" - cpu: "500m" + cpu: "300m" readinessProbe: tcpSocket: port: 9092 diff --git a/Makefile b/Makefile index 74b2a19c..cf733572 100644 --- a/Makefile +++ b/Makefile @@ -71,8 +71,8 @@ test: test-unit test-llm-metrics ## Run all tests. @echo "All tests have been run." test-unit: ## Run unit tests. - @echo "Running unit tests in $(SRC_DIR)..." - @python -m pytest $(SRC_DIR) $(PYTEST_OPTS) + @echo "Running unit tests in $(SRC_DIR) and tests/..." + @python -m pytest $(SRC_DIR) tests/ $(PYTEST_OPTS) test-llm-metrics: ## Run LLM metrics tests. @echo "Running LLM metrics tests..." diff --git a/src/exploit_iq_commons/utils/dep_tree.py b/src/exploit_iq_commons/utils/dep_tree.py index 68a42657..e3c9c531 100644 --- a/src/exploit_iq_commons/utils/dep_tree.py +++ b/src/exploit_iq_commons/utils/dep_tree.py @@ -142,6 +142,10 @@ class DependencyTreeBuilder(ABC): supported ecosystem. """ + # Directory name where dependency source files are stored (e.g. "vendor", "node_modules"). + # Each subclass sets this to its ecosystem's convention. + DEP_SOURCE_DIR: str = "" + @abstractmethod # Build a sort of "upside down" tree - a dict containing mapping of each # package to a list of all consuming packages @@ -157,6 +161,8 @@ def install_dependencies(self, manifest_path: Path): class CCppDependencyTreeBuilder(DependencyTreeBuilder): + DEP_SOURCE_DIR = C_DEP_LIBS_NAME + # Pre-compiled regex patterns (optimization: compile once, use many times) INCLUDE_COMBINED_RE = re.compile( r'#include\s*([<"])([^>"]+)[>"]' @@ -181,7 +187,7 @@ def __init__(self): "bench", "benchmark", "demo", "sample" ] self.C_STANDARD_LIB = "glibc" - self.RPM_LIBS_DIR = C_DEP_LIBS_NAME + self.RPM_LIBS_DIR = self.DEP_SOURCE_DIR self.output_json_path = None self.ccp_dep_tree = None @@ -788,6 +794,7 @@ def find_project_name(self, root_dir="."): class GoDependencyTreeBuilder(DependencyTreeBuilder): + DEP_SOURCE_DIR = "vendor" def install_dependencies(self, manifest_path: Path): self.download_go_mod_vendor(manifest_path) @@ -897,6 +904,7 @@ def extract_package_name(self, package_name: str) -> str: return package_name class JavaDependencyTreeBuilder(DependencyTreeBuilder): + DEP_SOURCE_DIR = "dependencies-sources" def __init__(self, query: str): self._query = query @@ -911,7 +919,7 @@ def __check_file_exists(self, dir_path: str | Path, filename: str) -> bool: def install_dependencies(self, manifest_path: Path): mvn_command = "mvn" settings_path = os.getenv('JAVA_MAVEN_DEFAULT_SETTINGS_FILE_PATH','../../../../kustomize/base/settings.xml') - source_path = "dependencies-sources" + source_path = self.DEP_SOURCE_DIR if self.__check_file_exists(manifest_path, "mvnw"): mvn_command = "./mvnw" @@ -1168,6 +1176,7 @@ def looks_like_version(v: str) -> bool: return depth, coord class PythonDependencyTreeBuilder(DependencyTreeBuilder): + DEP_SOURCE_DIR = TRANSITIVE_ENV_NAME def build_tree(self, manifest_path: Path) -> defaultdict[Any, list]: venv_python = f'{manifest_path}/{TRANSITIVE_ENV_NAME}/bin/python' @@ -1565,6 +1574,7 @@ def install_dependency(self, dependency, repo_path): logger.warning('Failed to install dependency %s', dependency) class JavaScriptDependencyTreeBuilder(DependencyTreeBuilder): + DEP_SOURCE_DIR = "node_modules" def build_tree(self, manifest_path: Path) -> dict[str, list[str]]: @@ -1639,6 +1649,25 @@ def get_dependency_tree_builder(programming_language: Ecosystem, query: str = "" ) +# Maps each ecosystem to its builder class — used to build ECOSYSTEM_DEP_DIRS +# from class-level DEP_SOURCE_DIR without instantiating builders. +_ECOSYSTEM_BUILDER_MAP: dict[Ecosystem, type[DependencyTreeBuilder]] = { + Ecosystem.C_CPP: CCppDependencyTreeBuilder, + Ecosystem.GO: GoDependencyTreeBuilder, + Ecosystem.JAVA: JavaDependencyTreeBuilder, + Ecosystem.PYTHON: PythonDependencyTreeBuilder, + Ecosystem.JAVASCRIPT: JavaScriptDependencyTreeBuilder, +} + +# Dynamic mapping of ecosystem → dependency source directory prefix. +# Built from each builder's DEP_SOURCE_DIR class attribute. +ECOSYSTEM_DEP_DIRS: dict[str, str] = { + eco.value: cls.DEP_SOURCE_DIR + "/" + for eco, cls in _ECOSYSTEM_BUILDER_MAP.items() + if cls.DEP_SOURCE_DIR +} + + class DependencyTree: """ A class that represents a dependency tree to access an appropriate diff --git a/src/exploit_iq_commons/utils/document_embedding.py b/src/exploit_iq_commons/utils/document_embedding.py index 035e8615..642757b2 100644 --- a/src/exploit_iq_commons/utils/document_embedding.py +++ b/src/exploit_iq_commons/utils/document_embedding.py @@ -217,6 +217,52 @@ def lazy_parse(self, blob: Blob) -> typing.Iterator[Document]: ) +class _LoggingEmbeddingProxy: + """Wraps an Embeddings instance to log per-batch progress during VDB creation. + + FAISS calls embed_documents once with all texts; the NIM SDK loops + internally in batches of max_batch_size calling _embed per batch. + This proxy intercepts embed_documents and does the batching itself + so it can log progress between batches. + + NOTE: Tightly coupled with langchain_nvidia_ai_endpoints.NVIDIAEmbeddings. + Calls the private _embed(texts, model_type="passage") method directly. + Other Embeddings implementations (langchain ABC) don't expose _embed, + so this proxy will break if the embedding type changes or if NVIDIAEmbeddings + renames/removes _embed in a future version. + """ + + def __init__(self, embedding, total_chunks: int, start_time: float): + self._embedding = embedding + self._total_chunks = total_chunks + self._start_time = start_time + self._embedded = 0 + + def embed_documents(self, texts): + batch_size = getattr(self._embedding, "max_batch_size", 128) + all_embeddings = [] + for i in range(0, len(texts), batch_size): + batch = texts[i:i + batch_size] + batch_start = time.time() + all_embeddings.extend(self._embedding._embed(batch, model_type="passage")) + self._embedded += len(batch) + elapsed = time.time() - self._start_time + rate = self._embedded / elapsed if elapsed > 0 else 0 + remaining_min = ((self._total_chunks - self._embedded) / rate / 60) if rate > 0 else 0 + logger.info("Embedding progress: %d / %d chunks (%.1f%%) - batch took %.2fs - ETA %.1f min", + self._embedded, self._total_chunks, + self._embedded / self._total_chunks * 100, + time.time() - batch_start, + remaining_min) + return all_embeddings + + def embed_query(self, text): + return self._embedding.embed_query(text) + + def __getattr__(self, name): + return getattr(self._embedding, name) + + class DocumentEmbedding: """ A class to create a FAISS database from a list of source documents. The source documents are collected from git @@ -374,8 +420,10 @@ def collect_documents(self, source_info: SourceDocumentsInfo) -> list[Document]: """ repo_path = self.get_repo_path(source_info) + cache_name = source_info.type if source_info.type != "code" else "" documents, documents_were_in_cache = retrieve_from_cache(self._pickle_cache_directory, - source_info.git_repo, source_info.ref) + source_info.git_repo, source_info.ref, + documents_name=cache_name) if documents_were_in_cache or len(documents) > 0: return documents @@ -387,7 +435,8 @@ def collect_documents(self, source_info: SourceDocumentsInfo) -> list[Document]: with repo_lock: # Re-check cache — another thread may have populated it while we waited. documents, documents_were_in_cache = retrieve_from_cache(self._pickle_cache_directory, - source_info.git_repo, source_info.ref) + source_info.git_repo, source_info.ref, + documents_name=cache_name) if documents_were_in_cache or len(documents) > 0: return documents @@ -403,7 +452,8 @@ def collect_documents(self, source_info: SourceDocumentsInfo) -> list[Document]: documents = loader.load() logger.info("Collected documents for '%s', Document count: %d", repo_path, len(documents)) - save_to_cache(self._pickle_cache_directory, source_info.git_repo, source_info.ref, documents) + save_to_cache(self._pickle_cache_directory, source_info.git_repo, source_info.ref, documents, + documents_name=cache_name) return documents def create_vdb(self, source_infos: list[SourceDocumentsInfo], output_path: PathLike): @@ -465,8 +515,12 @@ def create_vdb(self, source_infos: list[SourceDocumentsInfo], output_path: PathL embedding_start_time = time.time() + # Wrap embedding in a proxy that logs batch progress + total_chunks = len(chunked_documents) + logging_embedding = _LoggingEmbeddingProxy(self._embedding, total_chunks, embedding_start_time) + # Create the FAISS database - db = FAISS.from_documents(chunked_documents, self._embedding) + db = FAISS.from_documents(chunked_documents, logging_embedding) logger.info("Completed embedding in %.2f seconds for '%s'", time.time() - embedding_start_time, output_path) @@ -513,7 +567,7 @@ def build_vdbs(self, # Create embeddings for each source type for source_type in ["code", "doc"]: - if ignore_code_embedding: + if ignore_code_embedding and source_type == "code": continue # Filter the source documents diff --git a/src/exploit_iq_commons/utils/functions_parsers/c_lang_function_parsers.py b/src/exploit_iq_commons/utils/functions_parsers/c_lang_function_parsers.py index 37b73af1..7289e39c 100644 --- a/src/exploit_iq_commons/utils/functions_parsers/c_lang_function_parsers.py +++ b/src/exploit_iq_commons/utils/functions_parsers/c_lang_function_parsers.py @@ -763,3 +763,9 @@ def is_call_allowed(self, pkg_docs: list[Document], caller_function: Document, c return False return True + + def get_import_search_patterns(self, package_name: str) -> list[re.Pattern]: + escaped = re.escape(package_name) + return [ + re.compile(rf'#include\s*[<"]({escaped}[^>"]*)[>"]', re.IGNORECASE | re.MULTILINE), + ] diff --git a/src/exploit_iq_commons/utils/functions_parsers/golang_functions_parsers.py b/src/exploit_iq_commons/utils/functions_parsers/golang_functions_parsers.py index 78805062..ed7e9470 100644 --- a/src/exploit_iq_commons/utils/functions_parsers/golang_functions_parsers.py +++ b/src/exploit_iq_commons/utils/functions_parsers/golang_functions_parsers.py @@ -598,4 +598,11 @@ def is_package_imported(self, code_content: str, identifier: str, callee_package package_name = import_package_line.split(r"\s")[1] if package_name.strip().lower() == callee_package.strip().lower(): return True - return False \ No newline at end of file + return False + + def get_import_search_patterns(self, package_name: str) -> list[re.Pattern]: + escaped = re.escape(package_name) + return [ + re.compile(rf'import\s+"({escaped}[^"]*)"', re.IGNORECASE | re.MULTILINE), + re.compile(rf'import\s+\(\s*[^)]*"({escaped}[^"]*)"', re.IGNORECASE | re.MULTILINE), + ] \ No newline at end of file diff --git a/src/exploit_iq_commons/utils/functions_parsers/java_functions_parsers.py b/src/exploit_iq_commons/utils/functions_parsers/java_functions_parsers.py index a5aff80f..231041d6 100644 --- a/src/exploit_iq_commons/utils/functions_parsers/java_functions_parsers.py +++ b/src/exploit_iq_commons/utils/functions_parsers/java_functions_parsers.py @@ -1570,17 +1570,20 @@ def _iter_fqcn_candidates(raw_type_token: str): yield t return - # Simple token: try target allow-list by simple name + # Explicit import takes precedence: the caller's import statement + # is the definitive type binding for this simple name. + imp = _explicit_imports.get(t) + if imp: + yield imp + return + + # Fallback: try target allow-list by simple name (only when + # the caller has no explicit import for this token). cands = _target_by_simple.get(t) if cands: for fq in cands: yield fq - # Explicit import disambiguation - imp = _explicit_imports.get(t) - if imp: - yield imp - # Wildcard imports: only usable if we can cheaply construct candidates # (we only build candidates that are already in target_class_names to avoid work) if _wild_import_pkgs and cands: @@ -1748,7 +1751,9 @@ def _extract_ctor_type(expr: str) -> str: caller_function_index=_caller_key(), target_class_names=target_class_names, function=function, - code_documents=code_documents + code_documents=code_documents, + caller_explicit_imports=_explicit_imports, + caller_package=_caller_pkg, ) if traced_cast: return True @@ -1807,7 +1812,9 @@ def _extract_ctor_type(expr: str) -> str: caller_function_index=caller_key, target_class_names=target_class_names, function=function, - code_documents=code_documents + code_documents=code_documents, + caller_explicit_imports=_explicit_imports, + caller_package=_caller_pkg, ) if traced: return True @@ -1850,6 +1857,8 @@ def __trace_down_package( target_class_names: frozenset[str], function: Document, code_documents: dict[str, Document], + caller_explicit_imports: dict[str, str] | None = None, + caller_package: str = "", ) -> bool: variables_mappings = functions_local_variables_index.get(caller_function_index, {}) # CHANGED: safe fallback parts = expression.split(".") @@ -1880,6 +1889,22 @@ def _fqcn_candidates_from_token(type_token: str) -> list[str]: # Derive simple name and match against allow-list by suffix simple = t.rsplit(".", 1)[-1] + + if caller_explicit_imports: + imported_fqcn = caller_explicit_imports.get(simple) + if imported_fqcn and imported_fqcn not in target_class_names: + return [] + + if (not caller_explicit_imports or simple not in caller_explicit_imports) and caller_package: + same_pkg_fqcn = f"{caller_package}.{simple}" + if same_pkg_fqcn not in target_class_names: + for td in type_documents: + td_src = td.metadata.get('source', '') + if simple in td_src and td_src.endswith(f"/{simple}.java"): + pkg_path = caller_package.replace('.', '/') + if pkg_path in td_src: + return [] + out: list[str] = [] dot_suffix = "." + simple dollar_suffix = "$" + simple @@ -1916,7 +1941,9 @@ def _has_matching_type(type_token: str) -> bool: struct_initializer_expression=struct_initializer_expression, type_documents=type_documents, value_list=value_list, - target_class_names=target_class_names + target_class_names=target_class_names, + caller_explicit_imports=caller_explicit_imports, + caller_package=caller_package, ) # Property/member is not in function, check if it's member/property of a type @@ -3312,11 +3339,18 @@ def _simple_to_fqcns_index(target_class_names: frozenset[str]) -> dict[str, tupl idx.setdefault(simple, []).append(fq) return {k: tuple(v) for k, v in idx.items()} - def _fqcn_candidates_from_token(self, type_token: str, target_class_names: frozenset[str]) -> tuple[str, ...]: + def _fqcn_candidates_from_token(self, type_token: str, target_class_names: frozenset[str], + caller_explicit_imports: dict[str, str] | None = None, + caller_package: str = "", + type_documents: list | None = None) -> tuple[str, ...]: """ Map a possibly-simple type token to FQCN candidates strictly within `target_class_names`. - If token already equals an allowed FQCN => (token,) - Else => all allowed FQCNs that share the same simple name + - If caller_explicit_imports maps the simple name to a FQCN NOT in target_class_names, + the caller's type is definitively different => () + - If the caller's own package contains a class with this simple name and that FQCN + is NOT in target_class_names, the caller's type is the same-package one => () """ t = self._normalize_type_token(type_token) if not t: @@ -3326,6 +3360,19 @@ def _fqcn_candidates_from_token(self, type_token: str, target_class_names: froze simple = self._simple_name_from_type_token(t) if not simple: return () + if caller_explicit_imports: + imported_fqcn = caller_explicit_imports.get(simple) + if imported_fqcn and imported_fqcn not in target_class_names: + return () + if (not caller_explicit_imports or simple not in caller_explicit_imports) and caller_package and type_documents: + same_pkg_fqcn = f"{caller_package}.{simple}" + if same_pkg_fqcn not in target_class_names: + pkg_path = caller_package.replace('.', '/') + suffix = f"/{simple}.java" + for td in type_documents: + td_src = td.metadata.get('source', '') + if td_src.endswith(suffix) and pkg_path in td_src: + return () return self._simple_to_fqcns_index(target_class_names).get(simple, ()) def _has_matching_type_in_package( @@ -3334,11 +3381,16 @@ def _has_matching_type_in_package( type_token: str, type_documents: list[Document], target_class_names: frozenset[str], + caller_explicit_imports: dict[str, str] | None = None, + caller_package: str = "", ) -> bool: """ Calls __get_type_docs_matched_with_callee_type with FQCN candidates only. """ - for fq in self._fqcn_candidates_from_token(type_token, target_class_names): + for fq in self._fqcn_candidates_from_token(type_token, target_class_names, + caller_explicit_imports=caller_explicit_imports, + caller_package=caller_package, + type_documents=type_documents): if self.__get_type_docs_matched_with_callee_type(callee_package, fq, type_documents, target_class_names): return True return False @@ -3459,11 +3511,15 @@ def __lookup_package( type_documents, value_list, target_class_names: frozenset[str], + caller_explicit_imports: dict[str, str] | None = None, + caller_package: str = "", ) -> bool: if not struct_initializer_expression and resolved_type not in JAVA_METHOD_PRIM_TYPES: if resolved_type and resolved_type not in JAVA_METHOD_PRIM_TYPES: if self._has_matching_type_in_package( - callee_package, resolved_type, type_documents, target_class_names + callee_package, resolved_type, type_documents, target_class_names, + caller_explicit_imports=caller_explicit_imports, + caller_package=caller_package, ): return True @@ -3473,7 +3529,9 @@ def __lookup_package( elif struct_initializer_expression: struct_type = struct_initializer_expression.group(0) # TODO list of expressions if self._has_matching_type_in_package( - callee_package, struct_type, type_documents, target_class_names + callee_package, struct_type, type_documents, target_class_names, + caller_explicit_imports=caller_explicit_imports, + caller_package=caller_package, ): return True @@ -3652,4 +3710,10 @@ def get_package_name(self, function: Document, package_name: str) -> str: artifact_version = f"{parts[1]}:{parts[2]}" else: artifact_version = package_name - return package_name if jar_name == artifact_version else '' \ No newline at end of file + return package_name if jar_name == artifact_version else '' + + def get_import_search_patterns(self, package_name: str) -> list[re.Pattern]: + escaped = re.escape(package_name) + return [ + re.compile(rf"import\s+(?:static\s+)?({escaped}[\w.]*)\s*;", re.IGNORECASE | re.MULTILINE), + ] \ No newline at end of file diff --git a/src/exploit_iq_commons/utils/functions_parsers/javascript_functions_parser.py b/src/exploit_iq_commons/utils/functions_parsers/javascript_functions_parser.py index 0275b64f..3c6017d7 100644 --- a/src/exploit_iq_commons/utils/functions_parsers/javascript_functions_parser.py +++ b/src/exploit_iq_commons/utils/functions_parsers/javascript_functions_parser.py @@ -1011,6 +1011,14 @@ def document_imports_package(self, documents: dict[str, Document], package_name: importing_docs.append(doc) return importing_docs + def get_import_search_patterns(self, package_name: str) -> list[re.Pattern]: + escaped = re.escape(package_name) + return [ + re.compile(rf"""require\s*\(\s*['"]({escaped}[^'"]*)['"]\s*\)""", re.IGNORECASE | re.MULTILINE), + re.compile(rf"""import\s+.*?\s+from\s+['"]({escaped}[^'"]*)['"]\s*""", re.IGNORECASE | re.MULTILINE), + re.compile(rf"""import\s+['"]({escaped}[^'"]*)['"]\s*""", re.IGNORECASE | re.MULTILINE), + ] + def is_a_package(self, package_name: str, doc: Document) -> bool: return package_name in self.get_package_names(doc) diff --git a/src/exploit_iq_commons/utils/functions_parsers/lang_functions_parsers.py b/src/exploit_iq_commons/utils/functions_parsers/lang_functions_parsers.py index f7160593..ab1da406 100644 --- a/src/exploit_iq_commons/utils/functions_parsers/lang_functions_parsers.py +++ b/src/exploit_iq_commons/utils/functions_parsers/lang_functions_parsers.py @@ -198,4 +198,7 @@ def create_dummy_for_standard_lib(self, package_name: str) -> bool: # is this call allowed by the language? def is_call_allowed(self,pkg_docs: list[Document], caller_function: Document, callee_function: Document) -> bool: - return True \ No newline at end of file + return True + + def get_import_search_patterns(self, package_name: str) -> list[re.Pattern]: + return [re.compile(re.escape(package_name), re.IGNORECASE)] \ No newline at end of file diff --git a/src/exploit_iq_commons/utils/functions_parsers/python_functions_parser.py b/src/exploit_iq_commons/utils/functions_parsers/python_functions_parser.py index f019cbbf..fe7eeec7 100644 --- a/src/exploit_iq_commons/utils/functions_parsers/python_functions_parser.py +++ b/src/exploit_iq_commons/utils/functions_parsers/python_functions_parser.py @@ -417,6 +417,13 @@ def is_package_imported(self, code_content: str, identifier: str, callee_package return True return False + def get_import_search_patterns(self, package_name: str) -> list[re.Pattern]: + escaped = re.escape(package_name) + return [ + re.compile(rf"import\s+({escaped}[\w.]*)", re.IGNORECASE | re.MULTILINE), + re.compile(rf"from\s+({escaped}[\w.]*)\s+import\s+", re.IGNORECASE | re.MULTILINE), + ] + def is_a_package(self, package_name: str, doc: Document) -> bool: return (not self.is_root_package(doc) and self.get_package_name(function=doc, package_name=package_name)) \ No newline at end of file diff --git a/src/vuln_analysis/configs/config-http-nim.yml b/src/vuln_analysis/configs/config-http-nim.yml index d6321ae6..1d3ae4c2 100644 --- a/src/vuln_analysis/configs/config-http-nim.yml +++ b/src/vuln_analysis/configs/config-http-nim.yml @@ -78,6 +78,13 @@ functions: max_retries: 5 Container Analysis Data: _type: container_image_analysis_data + Configuration Scanner: + _type: configuration_scanner + max_results: 15 + context_lines: 5 + Import Usage Analyzer: + _type: import_usage_analyzer + max_files: 20 cve_agent_executor: _type: cve_agent_executor llm_name: cve_agent_executor_llm @@ -90,6 +97,8 @@ functions: - Function Caller Finder - Function Locator - Function Library Version Finder + - Configuration Scanner + - Import Usage Analyzer max_concurrency: null max_iterations: 10 prompt_examples: false diff --git a/src/vuln_analysis/configs/config-http-openai.yml b/src/vuln_analysis/configs/config-http-openai.yml index a67e18c1..80c16d7b 100644 --- a/src/vuln_analysis/configs/config-http-openai.yml +++ b/src/vuln_analysis/configs/config-http-openai.yml @@ -85,6 +85,13 @@ functions: max_retries: 5 Container Analysis Data: _type: container_image_analysis_data + Configuration Scanner: + _type: configuration_scanner + max_results: 15 + context_lines: 5 + Import Usage Analyzer: + _type: import_usage_analyzer + max_files: 20 cve_agent_executor: _type: cve_agent_executor llm_name: cve_agent_executor_llm @@ -97,6 +104,8 @@ functions: - Function Caller Finder - Function Locator - Function Library Version Finder + - Configuration Scanner + - Import Usage Analyzer max_concurrency: null max_iterations: 10 prompt_examples: false diff --git a/src/vuln_analysis/configs/config-tracing.yml b/src/vuln_analysis/configs/config-tracing.yml index 397258e6..a4d89ff9 100644 --- a/src/vuln_analysis/configs/config-tracing.yml +++ b/src/vuln_analysis/configs/config-tracing.yml @@ -89,6 +89,13 @@ functions: max_retries: 5 Container Analysis Data: _type: container_image_analysis_data + Configuration Scanner: + _type: configuration_scanner + max_results: 15 + context_lines: 5 + Import Usage Analyzer: + _type: import_usage_analyzer + max_files: 20 cve_agent_executor: _type: cve_agent_executor llm_name: cve_agent_executor_llm @@ -101,6 +108,8 @@ functions: - Function Caller Finder - Function Locator - Function Library Version Finder + - Configuration Scanner + - Import Usage Analyzer max_concurrency: null max_iterations: 10 prompt_examples: false diff --git a/src/vuln_analysis/configs/config.yml b/src/vuln_analysis/configs/config.yml index 779b8efe..9ba5f23b 100644 --- a/src/vuln_analysis/configs/config.yml +++ b/src/vuln_analysis/configs/config.yml @@ -62,6 +62,13 @@ functions: _type: container_image_analysis_data Function Library Version Finder: _type: calling_function_library_version_finder + Configuration Scanner: + _type: configuration_scanner + max_results: 15 + context_lines: 5 + Import Usage Analyzer: + _type: import_usage_analyzer + max_files: 20 cve_agent_executor: _type: cve_agent_executor llm_name: cve_agent_executor_llm @@ -71,6 +78,8 @@ functions: # - Code Keyword Search # Uncomment to enable keyword search - CVE Web Search - Function Library Version Finder + - Configuration Scanner + - Import Usage Analyzer max_concurrency: null max_iterations: 10 prompt_examples: false diff --git a/src/vuln_analysis/functions/agent_registry.py b/src/vuln_analysis/functions/agent_registry.py new file mode 100644 index 00000000..4c12f81a --- /dev/null +++ b/src/vuln_analysis/functions/agent_registry.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from vuln_analysis.functions.base_graph_agent import BaseGraphAgent + +_AGENT_REGISTRY: dict[str, type[BaseGraphAgent]] = {} + + +def register_agent(agent_type: str): + """Class decorator that registers a BaseGraphAgent subclass under the given type name.""" + def wrapper(cls): + _AGENT_REGISTRY[agent_type] = cls + return cls + return wrapper + + +def get_agent_class(agent_type: str) -> type[BaseGraphAgent]: + if agent_type not in _AGENT_REGISTRY: + raise KeyError(f"Unknown agent type '{agent_type}'. Registered: {list(_AGENT_REGISTRY.keys())}") + return _AGENT_REGISTRY[agent_type] + + +def get_all_agent_types() -> list[str]: + return list(_AGENT_REGISTRY.keys()) diff --git a/src/vuln_analysis/functions/base_graph_agent.py b/src/vuln_analysis/functions/base_graph_agent.py new file mode 100644 index 00000000..063e3baf --- /dev/null +++ b/src/vuln_analysis/functions/base_graph_agent.py @@ -0,0 +1,473 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import uuid +from abc import ABC, abstractmethod + +from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, RemoveMessage +from langgraph.graph import StateGraph, END, START +from langgraph.prebuilt import ToolNode + +from exploit_iq_commons.logging.loggers_factory import LoggingFactory +from nat.builder.context import Context +from vuln_analysis.tools.tool_names import ToolNames +from vuln_analysis.functions.react_internals import ( + AgentState, + BaseRulesTracker, + Thought, + Observation, + CodeFindings, + PackageSelection, + _build_tool_arguments, + FORCED_FINISH_PROMPT, + COMPREHENSION_PROMPT, + MEMORY_UPDATE_PROMPT, +) +from vuln_analysis.runtime_context import ctx_state +from vuln_analysis.utils.token_utils import count_tokens, estimate_tokens, truncate_tool_output + +logger = LoggingFactory.get_agent_logger(__name__) +AGENT_TRACER = Context.get() + +_TOOL_AVAILABILITY = { + ToolNames.CODE_SEMANTIC_SEARCH: lambda config, state: state.code_vdb_path is not None, + ToolNames.DOCS_SEMANTIC_SEARCH: lambda config, state: state.doc_vdb_path is not None, + ToolNames.CODE_KEYWORD_SEARCH: lambda config, state: state.code_index_path is not None, + ToolNames.IMPORT_USAGE_ANALYZER: lambda config, state: state.code_index_path is not None, + ToolNames.CVE_WEB_SEARCH: lambda config, state: config.cve_web_search_enabled, + ToolNames.CALL_CHAIN_ANALYZER: lambda config, state: config.transitive_search_tool_enabled and state.code_index_path is not None, + ToolNames.FUNCTION_CALLER_FINDER: lambda config, state: config.transitive_search_tool_enabled and state.code_index_path is not None, + ToolNames.FUNCTION_LOCATOR: lambda config, state: config.transitive_search_tool_enabled and state.code_index_path is not None, +} + + +def _is_tool_available(tool_name, config, state): + check = _TOOL_AVAILABILITY.get(tool_name) + return check(config, state) if check else True + + +class BaseGraphAgent(ABC): + """Template for LangGraph-based CVE investigation agents. + + Subclasses must implement: + - pre_process_node: initialize agent state (prompts, package selection, etc.) + - get_tools: load and select which tools this agent uses + - create_rules_tracker: return the appropriate RulesTracker instance + + Shared graph nodes (thought, observation, forced_finish, should_continue) + and the graph wiring are provided by this base class. + """ + + def __init__(self, tools: list, llm, config): + self.tools = tools + self.config = config + self.thought_llm = llm.with_structured_output(Thought) + self.comprehension_llm = llm.with_structured_output(CodeFindings) + self.observation_llm = llm.with_structured_output(Observation) + self.package_filter_llm = llm.with_structured_output(PackageSelection) + + @property + def agent_type(self) -> str: + """Short identifier for tracing spans (e.g. 'reachability', 'cu').""" + return "base" + + @staticmethod + def _load_all_tools(builder, config) -> list: + from aiq.builder.framework_enum import LLMFrameworkEnum + return builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN) + + async def _select_package( + self, ecosystem: str, candidate_packages: list[dict], critical_context: list[str], + workflow_state, + ) -> tuple[list[str], str | None]: + """Select target package from candidates using LLM filter. + Returns (filtered_critical_context, selected_package). + """ + from langchain_core.messages import HumanMessage + from vuln_analysis.functions.react_internals import build_package_filter_prompt, _find_image_matching_candidate + from vuln_analysis.utils.intel_utils import filter_context_to_package + + selected_package = None + if len(candidate_packages) > 1: + image_input = workflow_state.original_input.input.image + image_name = image_input.name + image_repo = image_input.source_info[0].git_repo if image_input.source_info else None + + matched = _find_image_matching_candidate(candidate_packages, image_name, image_repo) + if matched: + selected_package = matched + logger.info("Package filter matched '%s' from image/repo (no LLM call needed, %d candidates)", + selected_package, len(candidate_packages)) + else: + filter_prompt = build_package_filter_prompt( + ecosystem, candidate_packages, + image_name=image_name, image_repo=image_repo, + critical_context=critical_context, + ) + selection = await self.package_filter_llm.ainvoke([HumanMessage(content=filter_prompt)]) + selected_package = selection.selected_package + logger.info("Package filter selected '%s' from %d candidates (reason: %s)", + selected_package, len(candidate_packages), selection.reason) + critical_context = filter_context_to_package(critical_context, selected_package, candidate_packages) + elif len(candidate_packages) == 1: + selected_package = candidate_packages[0].get("name") + logger.info("Single candidate package: '%s'", selected_package) + critical_context = filter_context_to_package(critical_context, selected_package, candidate_packages) + return critical_context, selected_package + + # --- Template methods (subclasses MUST override) --- + + @abstractmethod + async def pre_process_node(self, state: AgentState) -> AgentState: + """Initialize agent state: select package, build prompts, set rules.""" + ... + + @staticmethod + @abstractmethod + def get_tools(builder, config, state) -> list: + """Load tools from the builder and select those this agent needs.""" + ... + + @staticmethod + @abstractmethod + def create_rules_tracker() -> BaseRulesTracker: + """Return a fresh RulesTracker instance for a single question.""" + ... + + # --- Optional hooks (subclasses MAY override) --- + + def build_comprehension_context(self, state: AgentState) -> str: + """Return the vulnerability context string for the comprehension prompt. + Override in subclasses to reduce context and prevent hallucination.""" + ctx_lines = state.get("critical_context", []) + return "\n".join(ctx_lines) if ctx_lines else "N/A" + + def sanitize_findings(self, findings: list[str], state: AgentState) -> list[str]: + """Replace hallucinated CVE IDs in comprehension findings. + Only CVE IDs present in the current investigation are kept.""" + workflow_state = ctx_state.get() + allowed_ids = {intel.vuln_id for intel in workflow_state.cve_intel} + sanitized = [] + for finding in findings: + def _replace(m): + return m.group(0) if m.group(0) in allowed_ids else "the investigated vulnerability" + sanitized.append(re.sub(r"CVE-\d{4}-\d+", _replace, finding)) + return sanitized + + def post_observation(self, state: AgentState, tool_used: str, + tool_output: str, tool_input_detail: str) -> dict: + """Per-agent post-processing after observation_node core logic. + Returns dict to merge into observation_node output. Default: empty.""" + return {} + + def should_truncate_tool_output(self, state: AgentState, tool_used: str) -> bool: + """Whether to apply Java-specific truncation to tool output. Default: False.""" + return False + + def check_finish_allowed(self, state: AgentState) -> tuple[bool, str]: + """Hook to block premature finish. Returns (allowed, error_message). + Override in subclasses to enforce rules before allowing the agent to finish.""" + return True, "" + + # --- Shared concrete graph nodes --- + + async def thought_node(self, state: AgentState) -> AgentState: + step_num = state.get("step", 0) + with AGENT_TRACER.push_active_function(f"{self.agent_type}_thought", input_data=f"step:{step_num}") as span: + try: + active_prompt = state.get("runtime_prompt") + messages = [SystemMessage(content=active_prompt)] + state["messages"] + obs = state.get("observation", None) + if obs is not None: + memory_list = obs.memory if obs.memory else ["No prior knowledge."] + recent_findings = obs.results if obs.results else ["No recent findings."] + memory_context = "\n".join(f"- {m}" for m in memory_list) + findings_context = "\n".join(f"- {f}" for f in recent_findings) + context_block = f"KNOWLEDGE:\n{memory_context}\nLATEST FINDINGS:\n{findings_context}" + messages.append(SystemMessage(content=context_block)) + + max_message_tokens = self.config.context_window_token_limit + raw_total = sum( + count_tokens(m.content) for m in messages + if hasattr(m, "content") and isinstance(m.content, str) + ) + # cl100k_base undercounts by ~25% vs Llama 3.1's tokenizer + total = int(raw_total * 1.25) + if total > max_message_tokens and len(messages) > 2: + # Keep system prompt (messages[0]) and KNOWLEDGE block (messages[-1]). + # ToolMessages are prunable because observation_node already distilled + # their content into the KNOWLEDGE block via comprehension. + prunable = messages[1:-1] + for msg in prunable: + if total <= max_message_tokens: + break + raw_tok = count_tokens(msg.content) if hasattr(msg, "content") and isinstance(msg.content, str) else 0 + messages.remove(msg) + total -= int(raw_tok * 1.25) + logger.info( + "thought_node pruning: estimated tokens now ~%d (limit %d) at step %d", + total, max_message_tokens, step_num, + ) + + response: Thought = await self.thought_llm.ainvoke(messages) + + final_answer = "waiting for the agent to respond" + if response.mode == "finish": + finish_ok, finish_msg = self.check_finish_allowed(state) + if not finish_ok: + logger.info("%s finish blocked by rules at step %d: %s", self.agent_type, step_num, finish_msg) + span.set_output({"mode": "finish_blocked", "step": step_num + 1}) + blocked_ai = AIMessage(content=response.final_answer or response.thought or "I want to finish.") + return { + "messages": [blocked_ai, HumanMessage(content=finish_msg)], + "thought": None, + "step": step_num + 1, + "max_steps": self.config.max_iterations, + "output": "waiting for the agent to respond", + } + ai_message = AIMessage(content=response.final_answer) + final_answer = response.final_answer + elif response.actions is None: + logger.warning("%s LLM returned mode='act' but actions is None, forcing finish", self.agent_type) + ai_message = AIMessage(content=response.thought or "No actions provided, finishing.") + response = Thought( + thought=response.thought or "No actions provided", + mode="finish", + actions=None, + final_answer=response.thought or "Insufficient evidence to provide a definitive answer." + ) + final_answer = response.final_answer + else: + tool_name = response.actions.tool + try: + arguments = _build_tool_arguments(response.actions) + except ValueError as e: + logger.warning( + "%s bad tool arguments at step %d: %s", self.agent_type, step_num, e, + ) + span.set_output({"error": str(e), "step": step_num + 1}) + error_ai = AIMessage(content=response.thought or "I want to call a tool.") + return { + "messages": [error_ai, HumanMessage( + content=f"ERROR: {e}. Provide the required arguments and try again." + )], + "thought": None, + "step": step_num + 1, + "max_steps": self.config.max_iterations, + "output": "waiting for the agent to respond", + } + tool_call_id = str(uuid.uuid4()) + ai_message = AIMessage( + content=response.thought, + tool_calls=[{ + "name": tool_name, + "args": arguments, + "id": tool_call_id + }] + ) + + span.set_output({"mode": response.mode, "step": step_num + 1}) + return { + "messages": [ai_message], + "thought": response, + "step": step_num + 1, + "max_steps": self.config.max_iterations, + "output": final_answer + } + except Exception as e: + logger.exception("%s thought_node failed at step %d", self.agent_type, step_num) + span.set_output({"error": str(e), "exception_type": type(e).__name__, "step": step_num}) + raise + + async def should_continue(self, state: AgentState) -> str: + if state.get("step", 0) >= state.get("max_steps", self.config.max_iterations): + return "forced_finish_node" + thought = state.get("thought", None) + if thought is None: + return "thought_node" + if thought.mode == "finish": + return END + return "tool_node" + + @staticmethod + def _build_observation_context(obs) -> str | None: + """Format observation memory and recent findings into a context block.""" + if obs is None: + return None + memory_list = obs.memory if obs.memory else [] + recent_findings = obs.results if obs.results else [] + if not memory_list and not recent_findings: + return None + parts = [] + if memory_list: + parts.append("KNOWLEDGE:\n" + "\n".join(f"- {m}" for m in memory_list)) + if recent_findings: + parts.append("LATEST FINDINGS:\n" + "\n".join(f"- {f}" for f in recent_findings)) + return "\n".join(parts) + + async def forced_finish_node(self, state: AgentState) -> AgentState: + step_num = state.get("step", 0) + with AGENT_TRACER.push_active_function(f"{self.agent_type}_forced_finish", input_data=f"step:{step_num}") as span: + try: + active_prompt = state.get("runtime_prompt") + messages = [SystemMessage(content=active_prompt)] + context_block = self._build_observation_context(state.get("observation", None)) + if context_block: + messages.append(SystemMessage(content=context_block)) + question = state.get("input", "") + finish_prompt = f"QUESTION: {question}\n\n{FORCED_FINISH_PROMPT}" if question else FORCED_FINISH_PROMPT + messages.append(HumanMessage(content=finish_prompt)) + response: Thought = await self.thought_llm.ainvoke(messages) + if response.mode == "finish" and response.final_answer: + ai_message = AIMessage(content=response.final_answer) + final_answer = response.final_answer + else: + final_answer = "Failed to generate a final answer within the maximum allowed steps." + ai_message = AIMessage(content=final_answer) + response = Thought( + thought=response.thought or "Max steps exceeded", + mode="finish", + actions=None, + final_answer=final_answer + ) + span.set_output({"final_answer_length": len(final_answer), "step": step_num}) + return { + "messages": [ai_message], + "thought": response, + "step": step_num, + "max_steps": state.get("max_steps", self.config.max_iterations), + "observation": state.get("observation", None), + "output": final_answer + } + except Exception as e: + logger.exception("%s forced_finish_node failed at step %d", self.agent_type, step_num) + span.set_output({"error": str(e), "exception_type": type(e).__name__, "step": step_num}) + raise + + async def observation_node(self, state: AgentState) -> AgentState: + tool_message = state["messages"][-1] + last_thought_text = state["thought"].thought if state.get("thought") else "No previous thought." + tool_used = state["thought"].actions.tool if state.get("thought") and state["thought"].actions else "Unknown" + tool_input_detail = "" + if state.get("thought") and state["thought"].actions: + actions = state["thought"].actions + if actions.package_name and actions.function_name: + tool_input_detail = f"{actions.package_name},{actions.function_name}" + elif actions.query: + tool_input_detail = actions.query + elif actions.tool_input: + tool_input_detail = actions.tool_input + previous_memory = state.get("observation").memory if state.get("observation") else ["No data gathered yet."] + rules_tracker = state.get("rules_tracker") + with AGENT_TRACER.push_active_function(f"{self.agent_type}_observation", input_data=f"tool:{tool_used}") as span: + try: + tool_output_for_llm = tool_message.content + result, error_message = rules_tracker.check_thought_behavior(tool_used, tool_input_detail, tool_output_for_llm) + if result: + span.set_output({"rule_error": error_message}) + return {"messages": [HumanMessage(content=error_message)]} + + if self.should_truncate_tool_output(state, tool_used): + truncated_output = truncate_tool_output(tool_output_for_llm, tool_used) + else: + truncated_output = tool_output_for_llm + + critical_context_text = self.build_comprehension_context(state) + comp_prompt = COMPREHENSION_PROMPT.format( + goal=state.get('input'), + selected_package=state.get('app_package') or "N/A", + critical_context=critical_context_text, + tool_used=tool_used, + tool_input_detail=tool_input_detail, + last_thought_text=last_thought_text, + tool_output=truncated_output, + ) + code_findings: CodeFindings = await self.comprehension_llm.ainvoke([SystemMessage(content=comp_prompt)]) + + sanitized = self.sanitize_findings(code_findings.findings, state) + findings_text = "\n".join(f"- {f}" for f in sanitized) + + mem_prompt = MEMORY_UPDATE_PROMPT.format( + goal=state.get('input'), + selected_package=state.get('app_package') or "N/A", + previous_memory=previous_memory, + findings=findings_text, + tool_outcome=code_findings.tool_outcome, + ) + new_observation: Observation = await self.observation_llm.ainvoke([SystemMessage(content=mem_prompt)]) + + messages = state["messages"] + active_prompt = state.get("runtime_prompt") + estimated = estimate_tokens(active_prompt, messages, new_observation) + prune_messages = [] + orig_estimated = estimated + + span_trace_dict = {"comprehension_findings": sanitized, "tool_outcome": code_findings.tool_outcome} + + if estimated > self.config.context_window_token_limit and len(messages) > 3: + prunable = messages[1:-2] + for msg in prunable: + prune_messages.append(RemoveMessage(id=msg.id)) + estimated -= count_tokens(msg.content) if hasattr(msg, "content") and isinstance(msg.content, str) else 0 + if estimated <= self.config.context_window_token_limit: + break + logger.info( + "Context pruning: removed %d messages, estimated tokens now ~%d (limit %d)", + len(prune_messages), estimated, self.config.context_window_token_limit, + ) + span_trace_dict["orig_estimated"] = orig_estimated + span_trace_dict["estimated"] = estimated + + # Agent-specific post-processing hook + extra = self.post_observation(state, tool_used, tool_output_for_llm, tool_input_detail) + + span.set_output(span_trace_dict) + base_result = { + "messages": prune_messages, + "observation": new_observation, + "step": state.get("step", 0), + } + base_result.update(extra) + return base_result + except Exception as e: + logger.exception("%s observation_node failed", self.agent_type) + span.set_output({"error": str(e), "exception_type": type(e).__name__}) + raise + + # --- Graph wiring --- + # If overridden, should_continue must also be updated to match the new node names. + + async def build_graph(self): + tool_node = ToolNode(self.tools, handle_tool_errors=True) + + flow = StateGraph(AgentState) + flow.add_node("thought_node", self.thought_node) + flow.add_node("tool_node", tool_node) + flow.add_node("forced_finish_node", self.forced_finish_node) + flow.add_node("pre_process_node", self.pre_process_node) + flow.add_node("observation_node", self.observation_node) + flow.add_edge(START, "pre_process_node") + flow.add_edge("pre_process_node", "thought_node") + flow.add_conditional_edges( + "thought_node", + self.should_continue, + {END: END, "tool_node": "tool_node", "forced_finish_node": "forced_finish_node", "thought_node": "thought_node"} + ) + flow.add_edge("tool_node", "observation_node") + flow.add_edge("observation_node", "thought_node") + flow.add_edge("forced_finish_node", END) + + return flow.compile() diff --git a/src/vuln_analysis/functions/code_understanding_agent.py b/src/vuln_analysis/functions/code_understanding_agent.py new file mode 100644 index 00000000..7368dcd2 --- /dev/null +++ b/src/vuln_analysis/functions/code_understanding_agent.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from exploit_iq_commons.logging.loggers_factory import LoggingFactory +from nat.builder.context import Context +from vuln_analysis.functions.agent_registry import register_agent +from vuln_analysis.functions.base_graph_agent import BaseGraphAgent, _is_tool_available + +from vuln_analysis.functions.react_internals import ( + AgentState, + Observation, +) +from vuln_analysis.functions.code_understanding_internals import ( + CU_AGENT_SYS_PROMPT, + CU_AGENT_THOUGHT_INSTRUCTIONS, + CodeUnderstandingRulesTracker, +) +from vuln_analysis.utils.code_understanding_prompt_factory import ( + CU_TOOL_SELECTION_STRATEGY, + CU_TOOL_GENERAL_DESCRIPTIONS, +) +from vuln_analysis.utils.intel_utils import build_critical_context +from vuln_analysis.runtime_context import ctx_state, cu_source_scope +from vuln_analysis.tools.tool_names import ToolNames + +logger = LoggingFactory.get_agent_logger(__name__) +AGENT_TRACER = Context.get() + + +def _build_cu_system_prompt(descriptions: str, tool_guidance: str) -> str: + return f"""{CU_AGENT_SYS_PROMPT} + + +{descriptions} + + + +{tool_guidance} + + +{CU_AGENT_THOUGHT_INSTRUCTIONS} + +RESPONSE: +{{""" + + +def _build_cu_tool_guidance(ecosystem: str, available_tools: list) -> tuple[str, str]: + tool_names = [t.name for t in available_tools] + tool_desc_lines = [f"{t.name}: {t.description}" for t in available_tools] + + lang = ecosystem.lower() if ecosystem else "" + if lang in CU_TOOL_SELECTION_STRATEGY: + logger.debug("Using %s-specific tool strategy for Code Understanding agent", lang) + guidance = CU_TOOL_SELECTION_STRATEGY[lang] + else: + logger.debug("No ecosystem-specific strategy for '%s', using generic Code Understanding tool guidance", lang) + guidance_list = [] + for name in tool_names: + if name in CU_TOOL_GENERAL_DESCRIPTIONS: + guidance_list.append(f"{name}: {CU_TOOL_GENERAL_DESCRIPTIONS[name]}") + guidance = "\n".join(guidance_list) if guidance_list else "Use the available tools to investigate the question." + + descriptions = "\n".join(tool_desc_lines) + return guidance, descriptions + + +@register_agent("code_understanding") +class CodeUnderstandingAgent(BaseGraphAgent): + + @property + def agent_type(self) -> str: + return "cu" + + _CU_TOOLS = frozenset({ + ToolNames.DOCS_SEMANTIC_SEARCH, + ToolNames.CODE_KEYWORD_SEARCH, + ToolNames.CONFIGURATION_SCANNER, + ToolNames.IMPORT_USAGE_ANALYZER, + }) + + @staticmethod + def get_tools(builder, config, state) -> list: + all_tools = BaseGraphAgent._load_all_tools(builder, config) + return [ + t for t in all_tools + if t.name in CodeUnderstandingAgent._CU_TOOLS + and _is_tool_available(t.name, config, state) + ] + + @staticmethod + def create_rules_tracker() -> CodeUnderstandingRulesTracker: + return CodeUnderstandingRulesTracker() + + def build_comprehension_context(self, state: AgentState) -> str: + """Return minimal grounding context to prevent CVE ID hallucination. + + The CU agent's comprehension LLM only needs the CVE ID and package name + to ground its findings. Feeding the full GHSA/NVD context activates + parametric knowledge and causes the model to inject CVE IDs from training. + """ + workflow_state = ctx_state.get() + vuln_id = workflow_state.cve_intel[0].vuln_id if workflow_state.cve_intel else "unknown" + package = state.get("app_package") or "unknown" + return ( + f"Investigating {vuln_id} in package {package}.\n" + "Only extract facts explicitly stated in the tool output. " + "Do not add CVE IDs, vulnerability names, or advisory details " + "from your own knowledge." + ) + + async def pre_process_node(self, state: AgentState) -> AgentState: + workflow_state = ctx_state.get() + ecosystem = ( + workflow_state.original_input.input.image.ecosystem.value + if workflow_state.original_input.input.image.ecosystem + else "" + ) + with AGENT_TRACER.push_active_function("cu_pre_process", input_data=f"ecosystem:{ecosystem}") as span: + try: + precomputed = state.get("precomputed_intel") + if precomputed is not None: + critical_context = list(precomputed[0]) + candidate_packages = [dict(p) for p in precomputed[1]] + else: + critical_context, candidate_packages, _ = build_critical_context( + workflow_state.cve_intel + ) + + critical_context, selected_package = await self._select_package( + ecosystem, candidate_packages, critical_context, workflow_state, + ) + + critical_context.append( + "TASK: Investigate usage, configuration, and presence of the vulnerable " + "component in the container. Focus on how the component is used, " + "not on call-chain reachability." + ) + + scope_parts = [] + image_input = workflow_state.original_input.input.image + if image_input.source_info: + for si in image_input.source_info: + if si.git_repo: + repo_name = si.git_repo.rstrip("/").rsplit("/", 1)[-1] + if repo_name.endswith(".git"): + repo_name = repo_name[:-4] + scope_parts.append(repo_name) + if selected_package: + scope_parts.append(selected_package) + cu_source_scope.set(scope_parts if scope_parts else None) + logger.debug("Code Understanding source scope set to: %s", scope_parts) + + tool_guidance, descriptions = _build_cu_tool_guidance(ecosystem, self.tools) + runtime_prompt = _build_cu_system_prompt(descriptions, tool_guidance) + active_tool_names = [t.name for t in self.tools] + + rules_tracker = state.get("rules_tracker") + rules_tracker.set_target_package(selected_package) + rules_tracker.set_allowed_tools(active_tool_names) + + span.set_output({ + "selected_package": selected_package, + "agent_type": "code_understanding", + }) + + return { + "ecosystem": ecosystem, + "runtime_prompt": runtime_prompt, + "is_reachability": "no", + "observation": Observation(memory=critical_context, results=[]), + "critical_context": critical_context, + "app_package": selected_package, + } + except Exception as e: + logger.exception("cu_pre_process_node failed") + span.set_output({"error": str(e), "exception_type": type(e).__name__}) + raise diff --git a/src/vuln_analysis/functions/code_understanding_internals.py b/src/vuln_analysis/functions/code_understanding_internals.py new file mode 100644 index 00000000..e6fc2bd2 --- /dev/null +++ b/src/vuln_analysis/functions/code_understanding_internals.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from exploit_iq_commons.logging.loggers_factory import LoggingFactory +from vuln_analysis.functions.react_internals import BaseRulesTracker + +logger = LoggingFactory.get_agent_logger(__name__) + + +class CodeUnderstandingRulesTracker(BaseRulesTracker): + """Behavioral rules for the Code Understanding sub-agent.""" + + def check_thought_behavior(self, action: str, action_input: str, output) -> tuple[bool, str]: + """Check all Code Understanding rules in priority order. + + Returns (True, error_message) if a rule is violated, (False, '') otherwise. + Rule check order: duplicate call -> Rule 7 (dotted keywords) -> allowed tools. + """ + if self._rule_duplicate_call(action, action_input): + logger.debug("CU duplicate call rule triggered: '%s' with same input", action) + return True, ( + f"You already called {action} with this exact input. " + "You MUST use a DIFFERENT tool or a DIFFERENT input query. " + "Check KNOWLEDGE for what was already tried." + ) + if self._rule_number_7(action, action_input, output): + logger.debug("Code Understanding Rule 7 triggered: dotted query with empty results for tool '%s'", action) + return True, ( + "You are NOT following Rule 7. Your query contains dots and returned " + "no results. You MUST retry with just the final component. Follow the rules." + ) + if self._rule_use_allowed_tools(action): + logger.debug("Code Understanding allowed-tools rule triggered: '%s' not in %s", action, self.allowed_tools) + return True, ( + f"You are NOT following AVAILABLE_TOOLS. You MUST use the allowed tools " + f"{self.allowed_tools}. Follow the rules." + ) + self.add_action(action, action_input, output) + return False, "" + + +CU_AGENT_SYS_PROMPT = ( + "You are a security analyst investigating a code understanding question about a CVE vulnerability.\n" + "Your goal is to collect evidence about how the codebase uses, configures, or depends on " + "the affected component. This is NOT a reachability question -- do NOT trace call chains.\n" + "GENERAL RULES:\n" + "- Base conclusions ONLY on tool results, not assumptions.\n" + "- If a search returns no results, that is evidence the component is absent or not configured.\n" + "- DISTINGUISH where findings come from: main application code vs. dependency libraries.\n" + "- A component being present in dependencies does NOT mean the application uses it.\n" + "- Configuration in framework dependencies (e.g., Spring defaults) IS relevant evidence.\n" + "- Import in an intermediate library (e.g., Spring imports XStream, app imports Spring) IS relevant.\n" + "ANSWER QUALITY:\n" + "- Answer the SPECIFIC question asked with evidence. Do NOT just report what tools found.\n" + "- Always state: WHAT you checked, WHAT you found, and WHY it leads to your conclusion.\n" + "- If tool results conflict, state the conflict explicitly.\n" + "- When citing evidence, explain HOW it relates to the question." +) + +CU_AGENT_THOUGHT_INSTRUCTIONS = """ +1. Output valid JSON only. thought < 100 words. final_answer < 150 words. +2. mode="act" REQUIRES actions. mode="finish" REQUIRES final_answer. +3. Docs Semantic Search, Code Keyword Search: use query field. +4. Configuration Scanner: use query field with keywords describing what to look for. +5. Import Usage Analyzer: use query field with the package/module name. +6. Do NOT call the same tool with the same input twice. Check KNOWLEDGE for "CALLED:" entries. +7. If Code Keyword Search returns no results and the query contains dots, retry with just the final component. +8. Before concluding, if the question involves a specific library or package, you MUST use Import Usage Analyzer to check how it is imported and used across sources. +9. GUIDELINE: Synthesize findings across documentation, code, configuration, and imports before drawing conclusions. + + +{{"thought": "Check configuration files for security settings related to the vulnerability", "mode": "act", "actions": {{"tool": "Configuration Scanner", "package_name": null, "function_name": null, "query": "deserialization allowlist denylist security", "tool_input": null, "reason": "Check if vulnerable feature is enabled or disabled in config"}}, "final_answer": null}} + + +{{"thought": "Search source code for imports of the vulnerable library", "mode": "act", "actions": {{"tool": "Code Keyword Search", "package_name": null, "function_name": null, "query": "import vulnerable_library", "tool_input": null, "reason": "Find where the package is imported in code"}}, "final_answer": null}} + + +{{"thought": "Check how the vulnerable library is imported and used across the codebase", "mode": "act", "actions": {{"tool": "Import Usage Analyzer", "package_name": null, "function_name": null, "query": "vulnerable_library", "tool_input": null, "reason": "Find all imports and usage sites of the vulnerable package"}}, "final_answer": null}} +""" \ No newline at end of file diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index f140850d..d1fa7d4d 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -12,49 +12,36 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os + import asyncio -import json -from pathlib import Path -from vuln_analysis.runtime_context import ctx_state -import typing + from aiq.builder.builder import Builder from aiq.builder.framework_enum import LLMFrameworkEnum from aiq.builder.function_info import FunctionInfo from aiq.cli.register_workflow import register_function from aiq.data_models.function import FunctionBaseConfig -from langchain.agents import AgentExecutor -from langchain.agents import create_react_agent -from langchain.agents.agent import RunnableAgent -from langchain.agents.mrkl.output_parser import MRKLOutputParser from langchain_core.exceptions import OutputParserException -from langchain_core.prompts import PromptTemplate +from langchain_core.messages import HumanMessage from pydantic import Field + from vuln_analysis.data_models.state import AgentMorpheusEngineState -from vuln_analysis.tools.tool_names import ToolNames -from vuln_analysis.tools.transitive_code_search import package_name_from_locator_query from vuln_analysis.utils.error_handling_decorator import ToolRaisedException -from vuln_analysis.utils.prompting import get_agent_prompt +from vuln_analysis.utils.intel_utils import build_critical_context +from vuln_analysis.runtime_context import ctx_state from exploit_iq_commons.logging.loggers_factory import LoggingFactory, trace_id - -from vuln_analysis.functions.react_internals import build_system_prompt, build_classification_prompt, build_package_filter_prompt, AgentState, Thought, Observation, Classification, PackageSelection, CodeFindings, SystemRulesTracker, _build_tool_arguments, FORCED_FINISH_PROMPT, COMPREHENSION_PROMPT, MEMORY_UPDATE_PROMPT, AGENT_SYS_PROMPT_NON_REACHABILITY, AGENT_THOUGHT_INSTRUCTIONS_NON_REACHABILITY, AGENT_THOUGHT_INSTRUCTIONS_GO -from vuln_analysis.utils.prompting import build_tool_descriptions -from vuln_analysis.utils.prompt_factory import TOOL_SELECTION_STRATEGY, TOOL_SELECTION_STRATEGY_NON_REACHABILITY, TOOL_ECOSYSTEM_REGISTRY, FEW_SHOT_EXAMPLES -from vuln_analysis.utils.intel_utils import build_critical_context, enrich_go_from_osv, filter_context_to_package -from exploit_iq_commons.utils.git_utils import sanitize_git_url_for_path -from exploit_iq_commons.utils.data_utils import DEFAULT_GIT_DIRECTORY -from langgraph.graph import StateGraph, END, START -from langgraph.prebuilt import ToolNode -from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, RemoveMessage - -import uuid -import tiktoken from nat.builder.context import Context +from vuln_analysis.functions.agent_registry import get_agent_class, get_all_agent_types +from vuln_analysis.functions.dispatcher import QuestionRouting, dispatch_question +# Import agent modules to trigger @register_agent decorators +import vuln_analysis.functions.reachability_agent # noqa: F401 +import vuln_analysis.functions.code_understanding_agent # noqa: F401 + logger = LoggingFactory.get_agent_logger(__name__) AGENT_TRACER = Context.get() + class CVEAgentExecutorToolConfig(FunctionBaseConfig, name="cve_agent_executor"): """ Defines a function that iterates through checklist items using provided tools and gathered intel. @@ -88,672 +75,74 @@ class CVEAgentExecutorToolConfig(FunctionBaseConfig, name="cve_agent_executor"): description="Estimated token threshold for pruning old messages in observation node." ) -async def common_build_tools(config: CVEAgentExecutorToolConfig, builder: Builder, state: AgentMorpheusEngineState) -> tuple[list[typing.Any], list[str], list[str]]: - - tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN) - # Filter tools that are not available based on state - tools = [ - tool for tool in tools - if not ((tool.name == ToolNames.CODE_SEMANTIC_SEARCH and state.code_vdb_path is None) or - (tool.name == ToolNames.DOCS_SEMANTIC_SEARCH and state.doc_vdb_path is None) or - (tool.name == ToolNames.CODE_KEYWORD_SEARCH and state.code_index_path is None) or - (tool.name == ToolNames.CVE_WEB_SEARCH and not config.cve_web_search_enabled) or - (tool.name == ToolNames.CALL_CHAIN_ANALYZER and (not config.transitive_search_tool_enabled or - state.code_index_path is None)) or - (tool.name == ToolNames.FUNCTION_CALLER_FINDER and (not config.transitive_search_tool_enabled or - state.code_index_path is None)) or - (tool.name == ToolNames.FUNCTION_LOCATOR and (not config.transitive_search_tool_enabled or - state.code_index_path is None)) - ) - ] - # Get tool names after filtering for dynamic guidance - enabled_tool_names = [tool.name for tool in tools] - tool_descriptions_list = [t.name + ": " + t.description for t in tools] - # Build tool selection guidance with strategic context - tool_descriptions = build_tool_descriptions(enabled_tool_names) - return tools, tool_descriptions, tool_descriptions_list - -def _validate_go_vendor_packages( - source_info: list, - candidate_packages: list[dict], -) -> tuple[list[dict], list[str]]: - """Check which Go candidate packages actually exist in the vendor directory.""" - code_si = next((si for si in source_info if si.type == "code"), None) - if code_si is None: - return candidate_packages, [] - - repo_path = Path(DEFAULT_GIT_DIRECTORY) / sanitize_git_url_for_path(code_si.git_repo) - vendor_path = repo_path / "vendor" - if not vendor_path.is_dir(): - return candidate_packages, [] - - validated = [] - removed = [] - for pkg in candidate_packages: - pkg_name = pkg.get("name", "") - if (vendor_path / pkg_name).is_dir(): - validated.append(pkg) - else: - removed.append(pkg_name) - - if validated: - return validated, removed - return candidate_packages, [] - - -async def _enrich_go_candidates( - cve_intel: list, - source_info: list, - critical_context: list[str], - candidate_packages: list[dict], - vulnerable_functions_set: set[str], -) -> tuple[list[dict], list[str]]: - """Enrich Go candidates via OSV and validate against vendor directory.""" - ghsa_has_packages = any(c.get("source") == "ghsa" for c in candidate_packages) - if not ghsa_has_packages or not vulnerable_functions_set: - intel = cve_intel[0] if cve_intel else None - if intel: - await enrich_go_from_osv(intel, critical_context, candidate_packages, vulnerable_functions_set) - - if candidate_packages: - candidate_packages, removed_pkgs = _validate_go_vendor_packages( - source_info, candidate_packages - ) - if removed_pkgs: - logger.info("Go vendor validation removed %d packages not in vendor/: %s", len(removed_pkgs), removed_pkgs) - - return candidate_packages, sorted(vulnerable_functions_set) - - -async def _create_graph_agent(config: CVEAgentExecutorToolConfig, builder: Builder, state: AgentMorpheusEngineState): - - tools, tool_guidance_list, tool_descriptions_list = await common_build_tools(config, builder, state) - llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) - thought_llm = llm.with_structured_output(Thought) - comprehension_llm = llm.with_structured_output(CodeFindings) - observation_llm = llm.with_structured_output(Observation) - reachability_llm = llm.with_structured_output(Classification) - package_filter_llm = llm.with_structured_output(PackageSelection) - tool_guidance = "\n".join(tool_guidance_list) - descriptions = "\n".join(tool_descriptions_list) - default_system_prompt = build_system_prompt(descriptions, tool_guidance) - tool_node = ToolNode(tools, handle_tool_errors=True) - TOOL_NODE = "tool_node" - THOUGHT_NODE = "thought_node" - FORCED_FINISH_NODE = "forced_finish_node" - PRE_PROCESS_NODE = "pre_process_node" - OBSERVATION_NODE = "observation_node" - - _tiktoken_enc = tiktoken.get_encoding("cl100k_base") - - def _count_tokens(text: str) -> int: - """Count tokens using tiktoken cl100k_base encoding (~90-95% accurate for Llama 3.1).""" - try: - return len(_tiktoken_enc.encode(text)) - except Exception: - return len(text) // 4 - - def _truncate_tool_output(tool_output: str, tool_name: str, max_tokens: int = 400) -> str: - """Truncate tool output to fit observation prompt within completion token budget. - - Code Keyword Search is the primary source of bloat (3700+ chars of redundant - source snippets). CCA and FL outputs are typically small. - Preserves the most informative parts based on tool type. - """ - token_count = _count_tokens(tool_output) - if token_count <= max_tokens: - return tool_output - - lines = tool_output.split('\n') - - if tool_name == ToolNames.CALL_CHAIN_ANALYZER: - head = '\n'.join(lines[:3]) - remaining = token_count - _count_tokens(head) - return f"{head}\n[... truncated {remaining} tokens ...]" - - if tool_name == ToolNames.CODE_KEYWORD_SEARCH: - kept_lines = [] - kept_tokens = 0 - for line in lines: - is_header = line.startswith("---") or "Main application" in line or "library dependencies" in line or line.strip() == "" - line_tokens = _count_tokens(line) - if is_header or kept_tokens + line_tokens <= max_tokens: - kept_lines.append(line) - kept_tokens += line_tokens - if kept_tokens >= max_tokens: - break - kept_lines.append(f"[... truncated {token_count - kept_tokens} tokens ...]") - return '\n'.join(kept_lines) - - # Default: head (70%) + tail (30%) - head_budget = int(max_tokens * 0.7) - tail_budget = max_tokens - head_budget - head_lines = [] - head_tokens = 0 - for line in lines: - lt = _count_tokens(line) - if head_tokens + lt > head_budget: - break - head_lines.append(line) - head_tokens += lt - tail_lines = [] - tail_tokens = 0 - for line in reversed(lines): - lt = _count_tokens(line) - if tail_tokens + lt > tail_budget: - break - tail_lines.insert(0, line) - tail_tokens += lt - truncated = token_count - head_tokens - tail_tokens - return '\n'.join(head_lines) + f"\n[... truncated {truncated} tokens ...]\n" + '\n'.join(tail_lines) - - def _estimate_tokens(runtime_prompt: str, messages: list, observation: Observation | None) -> int: - """Estimate the token count thought_node will send to the LLM.""" - parts = [runtime_prompt] - for msg in messages: - if hasattr(msg, "content") and isinstance(msg.content, str): - parts.append(msg.content) - if observation is not None: - for item in (observation.memory or []): - parts.append(item) - for item in (observation.results or []): - parts.append(item) - return _count_tokens("\n".join(parts)) - - def _build_tool_guidance_for_ecosystem(ecosystem: str, available_tools: list, is_reachability: str = "yes") -> tuple[str, str]: - """Build tool guidance using language-specific strategies when available.""" - filtered_tools = [ - t for t in available_tools - if (t.name != ToolNames.FUNCTION_CALLER_FINDER or ecosystem == "go") and - (t.name != ToolNames.FUNCTION_LIBRARY_VERSION_FINDER or ecosystem == "java") - ] - list_of_tool_names = [t.name for t in filtered_tools] - list_of_tool_descriptions = [t.name + ": " + t.description for t in filtered_tools] - - strategy = TOOL_SELECTION_STRATEGY if is_reachability == "yes" else TOOL_SELECTION_STRATEGY_NON_REACHABILITY - lang = ecosystem.lower() if ecosystem else "" - if lang in strategy: - tool_guidance_local = strategy[lang] - if lang == "java": - tool_guidance_local += ( - " Use Function Library Version Finder to verify the installed version of a library " - "before concluding exploitability (e.g., input 'commons-beanutils')." - ) - if is_reachability == "yes": - hint = FEW_SHOT_EXAMPLES.get(lang, "") - if hint: - tool_guidance_local += f"\nHint: {hint}" - else: - tool_guidance_list_local = build_tool_descriptions(list_of_tool_names) - tool_guidance_local = "\n".join(tool_guidance_list_local) - - descriptions_local = "\n".join(list_of_tool_descriptions) - return tool_guidance_local, descriptions_local - - async def pre_process_node(state: AgentState) -> AgentState: - workflow_state = ctx_state.get() - ecosystem = workflow_state.original_input.input.image.ecosystem.value if workflow_state.original_input.input.image.ecosystem else "" - with AGENT_TRACER.push_active_function("pre_process node", input_data=f"ecosystem:{ecosystem}") as span: - try: - critical_context, candidate_packages, vulnerable_functions = build_critical_context(workflow_state.cve_intel) - vulnerable_functions_set = set(vulnerable_functions) - - if ecosystem == "go": - candidate_packages, vulnerable_functions = await _enrich_go_candidates( - workflow_state.cve_intel, - workflow_state.original_input.input.image.source_info, - critical_context, - candidate_packages, - vulnerable_functions_set, - ) - - selected_package = None - app_package = None - if len(candidate_packages) > 1: - image_input = workflow_state.original_input.input.image - image_name = image_input.name - source_repos = image_input.source_info - image_repo = source_repos[0].git_repo if source_repos else None - filter_prompt = build_package_filter_prompt( - ecosystem, candidate_packages, - image_name=image_name, image_repo=image_repo, - critical_context=critical_context, - ) - selection: PackageSelection = await package_filter_llm.ainvoke([HumanMessage(content=filter_prompt)]) - selected_package = selection.selected_package - app_package = selected_package - logger.info("Package filter selected '%s' from %d candidates (reason: %s)", - selected_package, len(candidate_packages), selection.reason) - critical_context = filter_context_to_package(critical_context, selected_package, candidate_packages) - elif len(candidate_packages) == 1: - selected_package = candidate_packages[0].get("name") - app_package = selected_package - logger.info("Single candidate package after validation: '%s'", selected_package) - critical_context = filter_context_to_package(critical_context, selected_package, candidate_packages) - - critical_context.append( - "TASK: Investigate usage and reachability of the vulnerable function/module in the container. " - "Use the vulnerable module name from GHSA as primary investigation target." - ) - - question = state.get("input") or "" - context_block = "\n".join(critical_context) - classification_prompt = build_classification_prompt(context_block, question) - classification_result: Classification = await reachability_llm.ainvoke([HumanMessage(content=classification_prompt)]) - span.set_output({ - "critical_context": critical_context, - "candidate_packages": candidate_packages, - "selected_package": selected_package, - "app_package": app_package if selected_package else None, - "reachability_question": classification_result.is_reachability, - }) - - is_reachability = classification_result.is_reachability - - if is_reachability == "yes": - tool_guidance_local, descriptions_local = _build_tool_guidance_for_ecosystem(ecosystem, tools) - go_instructions = {"instructions": AGENT_THOUGHT_INSTRUCTIONS_GO} if ecosystem == "go" else {} - runtime_prompt = build_system_prompt(descriptions_local, tool_guidance_local, **go_instructions) - active_tool_names = [t.name for t in tools] - else: - reachability_tool_names = { ToolNames.CALL_CHAIN_ANALYZER, ToolNames.FUNCTION_CALLER_FINDER} - non_reach_tools = [t for t in tools if t.name not in reachability_tool_names] - tool_guidance_local, descriptions_local = _build_tool_guidance_for_ecosystem(ecosystem, non_reach_tools, is_reachability="no") - runtime_prompt = build_system_prompt( - descriptions_local, tool_guidance_local, - instructions=AGENT_THOUGHT_INSTRUCTIONS_NON_REACHABILITY, - sys_prompt=AGENT_SYS_PROMPT_NON_REACHABILITY, - ) - active_tool_names = [t.name for t in non_reach_tools] - logger.info("Non-reachability question detected; removed reachability tools from prompt") - rules_tracker = state.get("rules_tracker") - app_package = app_package if selected_package else None - rules_tracker.set_target_package(app_package) - rules_tracker.set_allowed_tools(active_tool_names) - if is_reachability == "yes": - rules_tracker.set_target_functions(vulnerable_functions) - return { - "ecosystem": ecosystem, - "runtime_prompt": runtime_prompt, - "is_reachability": is_reachability, - "observation": Observation(memory=critical_context, results=[]), - "critical_context": critical_context, - "app_package": app_package if selected_package else None, - } - except Exception as e: - logger.exception("pre_process_node failed") - span.set_output({"error": str(e), "exception_type": type(e).__name__}) - raise - - - async def thought_node(state: AgentState) -> AgentState: - step_num = state.get("step", 0) - with AGENT_TRACER.push_active_function("thought node", input_data=f"step:{step_num}") as span: - try: - active_prompt = state.get("runtime_prompt") or default_system_prompt - messages = [SystemMessage(content=active_prompt)] + state["messages"] - obs = state.get("observation", None) - if obs is not None: - memory_list = obs.memory if obs.memory else ["No prior knowledge."] - recent_findings = obs.results if obs.results else ["No recent findings."] - memory_context = "\n".join(f"- {m}" for m in memory_list) - findings_context = "\n".join(f"- {f}" for f in recent_findings) - context_block = f"KNOWLEDGE:\n{memory_context}\nLATEST FINDINGS:\n{findings_context}" - messages.append(SystemMessage(content=context_block)) - response: Thought = await thought_llm.ainvoke(messages) - - final_answer = "waiting for the agent to respond" - if response.mode == "finish": - ai_message = AIMessage(content=response.final_answer) - final_answer = response.final_answer - elif response.actions is None: - logger.warning("LLM returned mode='act' but actions is None, forcing finish") - ai_message = AIMessage(content=response.thought or "No actions provided, finishing.") - response = Thought( - thought=response.thought or "No actions provided", - mode="finish", - actions=None, - final_answer=response.thought or "Insufficient evidence to provide a definitive answer." - ) - final_answer = response.final_answer - else: - tool_name = response.actions.tool - arguments = _build_tool_arguments(response.actions) - tool_call_id = str(uuid.uuid4()) - ai_message = AIMessage( - content=response.thought, - tool_calls=[{ - "name": tool_name, - "args": arguments, - "id": tool_call_id - }] - ) - - span.set_output({"mode": response.mode, "step": step_num + 1}) - return { - "messages": [ai_message], - "thought": response, - "step": step_num + 1, - "max_steps": config.max_iterations, - "output": final_answer - } - except Exception as e: - logger.exception("thought_node failed at step %d", step_num) - span.set_output({"error": str(e), "exception_type": type(e).__name__, "step": step_num}) - raise - - async def should_continue(state: AgentState) -> str: - thought = state.get("thought", None) - if thought is not None and thought.mode == "finish": - return END - if state.get("step", 0) >= state.get("max_steps", config.max_iterations): - return FORCED_FINISH_NODE - return TOOL_NODE - - async def forced_finish_node(state: AgentState) -> AgentState: - step_num = state.get("step", 0) - with AGENT_TRACER.push_active_function("forced_finish node", input_data=f"step:{step_num}") as span: - try: - active_prompt = state.get("runtime_prompt") or default_system_prompt - messages = [SystemMessage(content=active_prompt)] + state["messages"] - messages.append(HumanMessage(content=FORCED_FINISH_PROMPT)) - obs = state.get("observation", None) - if obs is not None and obs.memory: - memory_context = "\n".join(f"- {m}" for m in obs.memory) - messages.append(SystemMessage(content=f"KNOWLEDGE:\n{memory_context}")) - response: Thought = await thought_llm.ainvoke(messages) - if response.mode == "finish" and response.final_answer: - ai_message = AIMessage(content=response.final_answer) - final_answer = response.final_answer - else: - final_answer = "Failed to generate a final answer within the maximum allowed steps." - ai_message = AIMessage(content=final_answer) - response = Thought( - thought=response.thought or "Max steps exceeded", - mode="finish", - actions=None, - final_answer=final_answer - ) - span.set_output({"final_answer_length": len(final_answer), "step": step_num}) - return { - "messages": [ai_message], - "thought": response, - "step": step_num, - "max_steps": state.get("max_steps", config.max_iterations), - "observation": state.get("observation", None), - "output": final_answer - } - except Exception as e: - logger.exception("forced_finish_node failed at step %d", step_num) - span.set_output({"error": str(e), "exception_type": type(e).__name__, "step": step_num}) - raise - - async def observation_node(state: AgentState) -> AgentState: - tool_message = state["messages"][-1] - last_thought_text = state["thought"].thought if state.get("thought") else "No previous thought." - tool_used = state["thought"].actions.tool if state.get("thought") and state["thought"].actions else "Unknown" - tool_input_detail = "" - if state.get("thought") and state["thought"].actions: - actions = state["thought"].actions - if actions.package_name and actions.function_name: - tool_input_detail = f"{actions.package_name},{actions.function_name}" - elif actions.query: - tool_input_detail = actions.query - elif actions.tool_input: - tool_input_detail = actions.tool_input - previous_memory = state.get("observation").memory if state.get("observation") else ["No data gathered yet."] - rules_tracker = state.get("rules_tracker") - with AGENT_TRACER.push_active_function("observation node", input_data=f"tool used:{tool_used}") as span: - try: - tool_output_for_llm = tool_message.content - result, error_message = rules_tracker.check_thought_behavior(tool_used, tool_input_detail, tool_output_for_llm) - if result: - span.set_output({"rule_error": error_message}) - return {"messages": [HumanMessage(content=error_message)]} - - if state.get("ecosystem", "").lower() == "java": - truncated_output = _truncate_tool_output(tool_output_for_llm, tool_used) - else: - truncated_output = tool_output_for_llm - - # Step 1: Comprehension -- reads raw tool output, produces compact findings - ctx_lines = state.get("critical_context", []) - critical_context_text = "\n".join(ctx_lines) if ctx_lines else "N/A" - comp_prompt = COMPREHENSION_PROMPT.format( - goal=state.get('input'), - selected_package=state.get('app_package') or "N/A", - critical_context=critical_context_text, - tool_used=tool_used, - tool_input_detail=tool_input_detail, - last_thought_text=last_thought_text, - tool_output=truncated_output, - ) - code_findings: CodeFindings = await comprehension_llm.ainvoke([SystemMessage(content=comp_prompt)]) - - findings_text = "\n".join(f"- {f}" for f in code_findings.findings) - - # Step 2: Memory update -- merges compressed findings into cumulative memory - mem_prompt = MEMORY_UPDATE_PROMPT.format( - goal=state.get('input'), - selected_package=state.get('app_package') or "N/A", - previous_memory=previous_memory, - findings=findings_text, - tool_outcome=code_findings.tool_outcome, - ) - new_observation: Observation = await observation_llm.ainvoke([SystemMessage(content=mem_prompt)]) - - messages = state["messages"] - active_prompt = state.get("runtime_prompt") or default_system_prompt - estimated = _estimate_tokens(active_prompt, messages, new_observation) - prune_messages = [] - orig_estimated = estimated - - span_trace_dict = {"comprehension_findings": code_findings.findings, "tool_outcome": code_findings.tool_outcome} - - if estimated > config.context_window_token_limit and len(messages) > 3: - prunable = messages[1:-2] - for msg in prunable: - prune_messages.append(RemoveMessage(id=msg.id)) - estimated -= _count_tokens(msg.content) if hasattr(msg, "content") and isinstance(msg.content, str) else 0 - if estimated <= config.context_window_token_limit: - break - logger.info( - "Context pruning: removed %d messages, estimated tokens now ~%d (limit %d)", - len(prune_messages), estimated, config.context_window_token_limit, - ) - span_trace_dict["orig_estimated"] = orig_estimated - span_trace_dict["estimated"] = estimated - span.set_output(span_trace_dict) - cca_results = list(state.get("cca_results", [])) - if tool_used == ToolNames.CALL_CHAIN_ANALYZER: - stripped = tool_output_for_llm.strip().lstrip("([") - first_token = stripped.split(",", 1)[0].strip().lower() - if first_token == "true": - cca_results.append(True) - elif first_token == "false": - cca_results.append(False) - package_validated = state.get("package_validated") - if tool_used == ToolNames.FUNCTION_LOCATOR and state.get("is_reachability") == "yes": - input_pkg = package_name_from_locator_query(tool_input_detail) - target_pkg = (state.get("app_package") or "").strip().lower() - if target_pkg and input_pkg == target_pkg: - if "Package is valid" in tool_output_for_llm: - package_validated = True - elif "Package is not valid" in tool_output_for_llm and package_validated is None: - package_validated = False - return { - "messages": prune_messages, - "observation": new_observation, - "step": state.get("step", 0), - "cca_results": cca_results, - "package_validated": package_validated, - } - except Exception as e: - logger.exception("observation_node failed") - span.set_output({"error": str(e), "exception_type": type(e).__name__}) - raise - - async def create_graph(): - flow = StateGraph(AgentState) - flow.add_node(THOUGHT_NODE, thought_node) - flow.add_node(TOOL_NODE, tool_node) - flow.add_node(FORCED_FINISH_NODE, forced_finish_node) - flow.add_node(PRE_PROCESS_NODE, pre_process_node) - flow.add_node(OBSERVATION_NODE, observation_node) - flow.add_edge(START, PRE_PROCESS_NODE) - flow.add_edge(PRE_PROCESS_NODE, THOUGHT_NODE) - flow.add_conditional_edges( - THOUGHT_NODE, - should_continue, - {END: END, TOOL_NODE: TOOL_NODE, FORCED_FINISH_NODE: FORCED_FINISH_NODE} - ) - flow.add_edge(TOOL_NODE, OBSERVATION_NODE) - flow.add_edge(OBSERVATION_NODE, THOUGHT_NODE) - flow.add_edge(FORCED_FINISH_NODE, END) - - app = flow.compile() - if config.verbose: - app.get_graph().draw_mermaid_png(output_file_path="flow.png") - return app - return await create_graph() -async def _create_agent(config: CVEAgentExecutorToolConfig, builder: Builder, - state: AgentMorpheusEngineState) -> AgentExecutor: - - tools, tool_descriptions,_ = await common_build_tools(config, builder, state) - - llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) - tool_guidance = "\n".join(tool_descriptions) - - # Get prompt template - prompt_template_str = get_agent_prompt(config.prompt, config.prompt_examples) - - # Create prompt with tool_selection_strategy as partial variable - prompt = PromptTemplate.from_template( - prompt_template_str, - partial_variables={ - 'tool_selection_strategy': tool_guidance if tool_guidance else "Use available tools as appropriate." - } - ) - # using langchain create_react_agent to create the agent - agent = create_react_agent(llm=llm, - tools=tools, - prompt=prompt, - output_parser=MRKLOutputParser(), - stop_sequence=["\nObservation:", "\n\tObservation:"]) - - agent_executor = AgentExecutor( - agent=agent, - tools=tools, - early_stopping_method="force", - handle_parsing_errors="Check your output and make sure it conforms, use the Action/Action Input syntax", - max_iterations=config.max_iterations, - return_intermediate_steps=config.return_intermediate_steps, - verbose=config.verbose) - - # Disable streaming for accurate token counts - if isinstance(agent_executor.agent, RunnableAgent): - agent_executor.agent.stream_runnable = False - - return agent_executor - -async def _process_steps(agent, steps, semaphore, max_iterations: int = 10): - +async def _process_steps(agents: dict, routing_llm, steps, semaphore, max_iterations: int = 10): + workflow_state = ctx_state.get() + critical_context, candidate_packages, vulnerable_functions = build_critical_context(workflow_state.cve_intel) + context_block = "\n".join(critical_context) + precomputed_intel = (critical_context, candidate_packages, vulnerable_functions) async def _process_step(step): - async def call_agent(initial_state,config=None): - if config: - return await agent.ainvoke(initial_state,config=config) - else: - return await agent.ainvoke(initial_state) + routing = await dispatch_question(routing_llm, step, context_block) + agent_type = routing.agent_type + + actual_type = agent_type if agent_type in agents else "reachability" + compiled_graph = agents[actual_type] + tracker = get_agent_class(actual_type).create_rules_tracker() + + initial_state = { + "input": step, + "messages": [HumanMessage(content=step)], + "step": 0, + "max_steps": max_iterations, + "thought": None, + "observation": None, + "output": "waiting for the agent to respond", + "rules_tracker": tracker, + "precomputed_intel": precomputed_intel, + } + graph_config = {"recursion_limit": 50} - initial_state = {"input": step} - config = None - if not isinstance(agent, AgentExecutor): - initial_state = { - "input": step, - "messages": [HumanMessage(content=step)], - "step": 0, - "max_steps": max_iterations, - "thought": None, - "observation": None, - "output": "waiting for the agent to respond", - "rules_tracker": SystemRulesTracker(), - } - config = { - "recursion_limit": 50 - } - with AGENT_TRACER.push_active_function("checklist_question", input_data=step[:80]): + with AGENT_TRACER.push_active_function( + "checklist_question", + input_data=f"[{actual_type}] {step[:80]}", + ): if semaphore: async with semaphore: - return await call_agent(initial_state, config) + return await compiled_graph.ainvoke(initial_state, config=graph_config) else: - return await call_agent(initial_state, config) + return await compiled_graph.ainvoke(initial_state, config=graph_config) return await asyncio.gather(*(_process_step(step) for step in steps), return_exceptions=True) -def _parse_intermediate_step(step: tuple[typing.Any, typing.Any]) -> dict[str, typing.Any]: - """ - Parse an agent intermediate step into an AgentIntermediateStep object. Return the dictionary representation for - compatibility with cudf. - """ - if len(step) != 2: - raise ValueError(f"Expected 2 values in each intermediate step but got {len(step)}.") - - action, output = step - - return {"tool_name": action.tool, "action_log": action.log, "tool_input": action.tool_input, "tool_output": output} - - def _postprocess_results(results: list[list[dict]], replace_exceptions: bool, replace_exceptions_value: str | None, - checklist_questions: list[list]) -> tuple[list[list[str]], list[list[list]]]: - """ - Post-process results into lists of outputs and intermediate steps. Replace exceptions with placholder values if - config.replace_exceptions = True. - :param values: + checklist_questions: list[list]) -> list[list[dict]]: + """Post-process graph agent results into a uniform list of output dicts. + + Replaces exceptions with placeholder values if replace_exceptions is True. """ outputs = [[] for _ in range(len(results))] for i, answer_list in enumerate(results): for j, answer in enumerate(answer_list): - # Handle exceptions returned by the agent - # OutputParserException is not a subclass of Exception, so we need to check for it separately if isinstance(answer, (ToolRaisedException, OutputParserException, Exception)): if replace_exceptions: - # If the agent encounters a parsing error or a server error after retries, replace the error - # with default values to prevent the pipeline from crashing outputs[i].append({"input": checklist_questions[i][j], "output": replace_exceptions_value, "intermediate_steps": None, "cca_results": [], "package_validated": None}) if isinstance(answer, ToolRaisedException): - tool_raised_exception: ToolRaisedException = answer - logger.warning(f"An exception encountered during tool execution, in result [{i}][{j}]. for " - f"question : {checklist_questions[i][j]}" - f".tool raised exception details=> {tool_raised_exception}" - f", replacing with default output -> {replace_exceptions_value}") + logger.warning( + "Tool execution exception in result[%d][%d], replacing with default output: %s", + i, j, type(answer).__name__) else: logger.warning( - "General Exception encountered in result[%d][%d]: %s, for question -> %s, " - "Replacing with default output: \"%s\" and intermediate_steps: None", - i, - j, - str(answer), - checklist_questions[i][j], - replace_exceptions_value) + "General exception in result[%d][%d]: %s, replacing with default output", + i, j, type(answer).__name__) - # For successful agent responses, extract the output, and intermediate steps if available else: - # intermediate_steps availability depends on config.return_intermediate_steps - if "intermediate_steps" in answer: - results[i][j]["intermediate_steps"] = [ - _parse_intermediate_step(step) for step in answer["intermediate_steps"] - ] - else: - results[i][j]["intermediate_steps"] = None - outputs[i].append({"input": answer["input"], "output": answer["output"], - "intermediate_steps": results[i][j]["intermediate_steps"], + "intermediate_steps": None, "cca_results": answer.get("cca_results", []), "package_validated": answer.get("package_validated")}) return outputs @@ -768,13 +157,38 @@ async def _arun(state: AgentMorpheusEngineState) -> AgentMorpheusEngineState: ctx_state.set(state) checklist_plans = state.checklist_plans - - agent = await _create_graph_agent(config, builder, state) - results = await asyncio.gather(*(_process_steps(agent, steps, semaphore, config.max_iterations) - for steps in checklist_plans.values()), return_exceptions=True) - results = _postprocess_results(results, config.replace_exceptions, config.replace_exceptions_value, - list(checklist_plans.values())) + llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) + + agents = {} + for agent_type in get_all_agent_types(): + agent_cls = get_agent_class(agent_type) + tools = agent_cls.get_tools(builder, config, state) + if tools: + agent_instance = agent_cls(tools=tools, llm=llm, config=config) + agents[agent_type] = await agent_instance.build_graph() + + routing_llm = llm.with_structured_output(QuestionRouting) + + results = await asyncio.gather( + *( + _process_steps( + agents, + routing_llm, + steps, + semaphore, + config.max_iterations, + ) + for steps in checklist_plans.values() + ), + return_exceptions=True, + ) + results = _postprocess_results( + results, + config.replace_exceptions, + config.replace_exceptions_value, + list(checklist_plans.values()), + ) state.checklist_results = dict(zip(checklist_plans.keys(), results)) with AGENT_TRACER.push_active_function("agent_finish", input_data={ diff --git a/src/vuln_analysis/functions/dispatcher.py b/src/vuln_analysis/functions/dispatcher.py new file mode 100644 index 00000000..0ab1240b --- /dev/null +++ b/src/vuln_analysis/functions/dispatcher.py @@ -0,0 +1,107 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal + +from langchain_core.messages import HumanMessage +from pydantic import BaseModel, Field + +from exploit_iq_commons.logging.loggers_factory import LoggingFactory +from nat.builder.context import Context + +logger = LoggingFactory.get_agent_logger(__name__) +AGENT_TRACER = Context.get() + + +class QuestionRouting(BaseModel): + """Structured output for LLM-based question routing classification. + + The dispatcher classifies each checklist question into one of two agent types: + - reachability: call chain tracing (is vulnerable code reachable from app code?) + - code_understanding: configuration/version/presence/usage investigation + """ + agent_type: Literal["reachability", "code_understanding"] = Field( + description="Route to 'reachability' if the question asks about call chains, " + "code paths, whether vulnerable code is used/called/reachable, " + "or whether untrusted data can reach a function. " + "Route to 'code_understanding' for configuration, version, " + "or application-level settings questions." + ) + reason: str = Field( + description="One-sentence justification for the routing decision." + ) + + +ROUTING_PROMPT_TEMPLATE = """You are classifying a CVE investigation question to route it to the correct sub-agent. + +Two sub-agents are available: +- **reachability**: Traces call chains and code paths. Use when the question asks whether vulnerable code is CALLED, USED, or REACHABLE from application code, or whether untrusted data can reach a specific function. IMPORTANT: If the question asks whether a class/function/module from the VULNERABLE PACKAGE (listed in the context below) is used, imported, or called — route to reachability. The reachability agent has Call Chain Analyzer and Function Locator to verify actual code-level usage and call paths, which is needed to distinguish "imported but never called on vulnerable path" from "actively used." +- **code_understanding**: Investigates configuration, version, presence, and general application behavior. Use when the question asks about application-level configuration (e.g., XML parsing settings, TLS options), environment setup, or properties — things that do NOT require tracing whether a specific vulnerable function is called. + +Context (CVE / vulnerable packages): +{context_block} + +Question: {question} + +Examples: +- "Is the vulnerable function XStream.fromXML() called from application code?" → reachability +- "Is the application configured to use the affected XML parser?" → code_understanding +- "Can untrusted data reach BeanUtils.populate() through the call chain?" → reachability +- "Is the vulnerable version of commons-beanutils installed?" → code_understanding +- "Is the function parseXML() reachable from any HTTP handler?" → reachability +- "Does the application enable external entity processing in its XML configuration?" → code_understanding +- "Is SslHandler used in the application?" → reachability (SslHandler is from the vulnerable package) +- "Is HttpPostStandardRequestDecoder.offer() called by application code?" → reachability +- "Is the vulnerable newTransformer() method invoked anywhere in the application's XSLT processing?" → reachability +- "Does application code call deserialize() or readObject() on the affected library?" → reachability +- "Does the code sanitize input against path traversal characters and command injection metacharacters?" → code_understanding +- "Are there input validation mechanisms to prevent malicious Markdown input from reaching the paragraph function?" → code_understanding +- "Can a single request force the application server to load unbounded data into memory?" → code_understanding +- "Can malformed input cause an unhandled exception that crashes or restarts the process?" → code_understanding + +Classify the question above.""" + + +def build_routing_prompt(context_block: str, question: str) -> str: + """Format the routing prompt template with CVE context and question text.""" + return ROUTING_PROMPT_TEMPLATE.format( + context_block=context_block, + question=question, + ) + + +async def dispatch_question( + routing_llm, + question: str, + context_block: str, +) -> QuestionRouting: + """Classify a checklist question and return routing decision. + + Uses the routing LLM (with structured output) to decide whether the question + should go to the reachability agent (call chain tracing) or the code understanding + agent (configuration/presence/usage investigation). + """ + prompt = build_routing_prompt(context_block, question) + with AGENT_TRACER.push_active_function("dispatch_question", input_data=f"question:{question[:80]}") as span: + logger.debug("Dispatching question to routing LLM: %.80s", question) + result: QuestionRouting = await routing_llm.ainvoke([HumanMessage(content=prompt)]) + logger.info( + "Question routed to '%s' (reason: %s): %.80s", + result.agent_type, + result.reason, + question, + ) + span.set_output({"agent_type": result.agent_type, "reason": result.reason}) + return result \ No newline at end of file diff --git a/src/vuln_analysis/functions/reachability_agent.py b/src/vuln_analysis/functions/reachability_agent.py new file mode 100644 index 00000000..6462939e --- /dev/null +++ b/src/vuln_analysis/functions/reachability_agent.py @@ -0,0 +1,337 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from langchain_core.messages import HumanMessage, SystemMessage, AIMessage + +from exploit_iq_commons.logging.loggers_factory import LoggingFactory +from nat.builder.context import Context +from vuln_analysis.functions.agent_registry import register_agent +from vuln_analysis.functions.base_graph_agent import BaseGraphAgent, _is_tool_available +from vuln_analysis.functions.react_internals import ( + AgentState, + Classification, + Observation, + Thought, + ReachabilityRulesTracker, + build_reachability_system_prompt, + build_classification_prompt, + FORCED_FINISH_PROMPT, + REACHABILITY_AGENT_THOUGHT_INSTRUCTIONS_GO, + REACHABILITY_AGENT_NON_REACH_SYS_PROMPT, + REACHABILITY_AGENT_NON_REACH_THOUGHT_INSTRUCTIONS, +) +from vuln_analysis.runtime_context import ctx_state +from vuln_analysis.tools.tool_names import ToolNames +from vuln_analysis.tools.transitive_code_search import package_name_from_locator_query +from vuln_analysis.utils.intel_utils import build_critical_context, enrich_go_from_osv +from vuln_analysis.utils.prompting import build_tool_descriptions +from vuln_analysis.utils.prompt_factory import ( + TOOL_SELECTION_STRATEGY, + TOOL_SELECTION_STRATEGY_NON_REACHABILITY, + FEW_SHOT_EXAMPLES, +) +from pathlib import Path +from exploit_iq_commons.utils.git_utils import sanitize_git_url_for_path +from exploit_iq_commons.utils.data_utils import DEFAULT_GIT_DIRECTORY + +logger = LoggingFactory.get_agent_logger(__name__) +AGENT_TRACER = Context.get() + + +def _validate_go_vendor_packages(source_info, candidate_packages): + code_si = next((si for si in source_info if si.type == "code"), None) + if code_si is None: + return candidate_packages, [] + repo_path = Path(DEFAULT_GIT_DIRECTORY) / sanitize_git_url_for_path(code_si.git_repo) + vendor_path = repo_path / "vendor" + if not vendor_path.is_dir(): + return candidate_packages, [] + validated = [] + removed = [] + for pkg in candidate_packages: + pkg_name = pkg.get("name", "") + if (vendor_path / pkg_name).is_dir(): + validated.append(pkg) + else: + removed.append(pkg_name) + if validated: + return validated, removed + return candidate_packages, [] + + +async def _enrich_go_candidates(cve_intel, source_info, critical_context, candidate_packages, vulnerable_functions_set): + ghsa_has_packages = any(c.get("source") == "ghsa" for c in candidate_packages) + if not ghsa_has_packages or not vulnerable_functions_set: + intel = cve_intel[0] if cve_intel else None + if intel: + await enrich_go_from_osv(intel, critical_context, candidate_packages, vulnerable_functions_set) + if candidate_packages: + candidate_packages, removed_pkgs = _validate_go_vendor_packages(source_info, candidate_packages) + if removed_pkgs: + logger.info("Go vendor validation removed %d packages not in vendor/: %s", len(removed_pkgs), removed_pkgs) + return candidate_packages, sorted(vulnerable_functions_set) + + +@register_agent("reachability") +class ReachabilityAgent(BaseGraphAgent): + + def __init__(self, tools, llm, config): + super().__init__(tools, llm, config) + self._classification_llm = llm.with_structured_output(Classification) + + @property + def agent_type(self) -> str: + return "reachability" + + _REACHABILITY_TOOLS = frozenset({ + ToolNames.FUNCTION_LOCATOR, + ToolNames.CALL_CHAIN_ANALYZER, + ToolNames.FUNCTION_LIBRARY_VERSION_FINDER, + ToolNames.FUNCTION_CALLER_FINDER, + ToolNames.CODE_KEYWORD_SEARCH, + ToolNames.CVE_WEB_SEARCH, + ToolNames.CODE_SEMANTIC_SEARCH, + ToolNames.DOCS_SEMANTIC_SEARCH, + }) + + @staticmethod + def get_tools(builder, config, state) -> list: + all_tools = BaseGraphAgent._load_all_tools(builder, config) + return [ + t for t in all_tools + if t.name in ReachabilityAgent._REACHABILITY_TOOLS + and _is_tool_available(t.name, config, state) + ] + + @staticmethod + def create_rules_tracker() -> ReachabilityRulesTracker: + return ReachabilityRulesTracker() + + def should_truncate_tool_output(self, state: AgentState, tool_used: str) -> bool: + return state.get("ecosystem", "").lower() == "java" + + def _build_tool_guidance_for_ecosystem(self, ecosystem: str, available_tools: list, + is_reachability: str = "yes") -> tuple[str, str]: + filtered_tools = [ + t for t in available_tools + if (t.name != ToolNames.FUNCTION_CALLER_FINDER or ecosystem == "go") and + (t.name != ToolNames.FUNCTION_LIBRARY_VERSION_FINDER or ecosystem == "java") + ] + list_of_tool_names = [t.name for t in filtered_tools] + list_of_tool_descriptions = [t.name + ": " + t.description for t in filtered_tools] + + strategy = TOOL_SELECTION_STRATEGY if is_reachability == "yes" else TOOL_SELECTION_STRATEGY_NON_REACHABILITY + lang = ecosystem.lower() if ecosystem else "" + if lang in strategy: + tool_guidance_local = strategy[lang] + if lang == "java": + tool_guidance_local += ( + " Use Function Library Version Finder to verify the installed version of a library " + "before concluding exploitability (e.g., input 'commons-beanutils')." + ) + if is_reachability == "yes": + hint = FEW_SHOT_EXAMPLES.get(lang, "") + if hint: + tool_guidance_local += f"\nHint: {hint}" + else: + tool_guidance_list_local = build_tool_descriptions(list_of_tool_names) + tool_guidance_local = "\n".join(tool_guidance_list_local) + + descriptions_local = "\n".join(list_of_tool_descriptions) + return tool_guidance_local, descriptions_local + + async def pre_process_node(self, state: AgentState) -> AgentState: + workflow_state = ctx_state.get() + ecosystem = workflow_state.original_input.input.image.ecosystem.value if workflow_state.original_input.input.image.ecosystem else "" + with AGENT_TRACER.push_active_function("pre_process node", input_data=f"ecosystem:{ecosystem}") as span: + try: + precomputed = state.get("precomputed_intel") + if precomputed is not None: + critical_context = list(precomputed[0]) + candidate_packages = [dict(p) for p in precomputed[1]] + vulnerable_functions = list(precomputed[2]) + else: + critical_context, candidate_packages, vulnerable_functions = build_critical_context(workflow_state.cve_intel) + vulnerable_functions_set = set(vulnerable_functions) + + if ecosystem == "go": + candidate_packages, vulnerable_functions = await _enrich_go_candidates( + workflow_state.cve_intel, + workflow_state.original_input.input.image.source_info, + critical_context, + candidate_packages, + vulnerable_functions_set, + ) + + critical_context, selected_package = await self._select_package( + ecosystem, candidate_packages, critical_context, workflow_state, + ) + app_package = selected_package + + critical_context.append( + "TASK: Investigate usage and reachability of the vulnerable function/module in the container. " + "Use the vulnerable module name from GHSA as primary investigation target." + ) + + question = state.get("input") or "" + context_block = "\n".join(critical_context) + classification_prompt = build_classification_prompt(context_block, question) + classification_result: Classification = await self._classification_llm.ainvoke([HumanMessage(content=classification_prompt)]) + span.set_output({ + "critical_context": critical_context, + "candidate_packages": candidate_packages, + "selected_package": selected_package, + "app_package": app_package if selected_package else None, + "reachability_question": classification_result.is_reachability, + }) + + is_reachability = classification_result.is_reachability + + if is_reachability == "yes": + tool_guidance_local, descriptions_local = self._build_tool_guidance_for_ecosystem(ecosystem, self.tools) + go_instructions = {"instructions": REACHABILITY_AGENT_THOUGHT_INSTRUCTIONS_GO} if ecosystem == "go" else {} + runtime_prompt = build_reachability_system_prompt(descriptions_local, tool_guidance_local, **go_instructions) + active_tool_names = [t.name for t in self.tools] + else: + reachability_tool_names = {ToolNames.CALL_CHAIN_ANALYZER, ToolNames.FUNCTION_CALLER_FINDER} + non_reach_tools = [t for t in self.tools if t.name not in reachability_tool_names] + tool_guidance_local, descriptions_local = self._build_tool_guidance_for_ecosystem(ecosystem, non_reach_tools, is_reachability="no") + runtime_prompt = build_reachability_system_prompt( + descriptions_local, tool_guidance_local, + instructions=REACHABILITY_AGENT_NON_REACH_THOUGHT_INSTRUCTIONS, + sys_prompt=REACHABILITY_AGENT_NON_REACH_SYS_PROMPT, + ) + active_tool_names = [t.name for t in non_reach_tools] + logger.info("Non-reachability question detected; removed reachability tools from prompt") + + rules_tracker = state.get("rules_tracker") + app_package = app_package if selected_package else None + rules_tracker.set_ecosystem(ecosystem) + rules_tracker.set_target_package(app_package) + rules_tracker.set_allowed_tools(active_tool_names) + if is_reachability == "yes": + rules_tracker.set_target_functions(vulnerable_functions) + return { + "ecosystem": ecosystem, + "runtime_prompt": runtime_prompt, + "is_reachability": is_reachability, + "observation": Observation(memory=critical_context, results=[]), + "critical_context": critical_context, + "app_package": app_package, + } + except Exception as e: + logger.exception("pre_process_node failed") + span.set_output({"error": str(e), "exception_type": type(e).__name__}) + raise + + def check_finish_allowed(self, state: AgentState) -> tuple[bool, str]: + if state.get("is_reachability") != "yes": + return True, "" + rules_tracker = state.get("rules_tracker") + cca_results = state.get("cca_results", []) + return rules_tracker.check_finish_allowed(cca_results) + + async def forced_finish_node(self, state: AgentState) -> AgentState: + """Override forced_finish to inject a no-CCA warning for reachability questions. + + When the reachability agent exhausts all iterations without ever calling CCA, + the LLM has no reachability evidence. Without this override the LLM can + hallucinate exploitability based on library presence alone. The injected + prompt steers it toward 'insufficient evidence / not exploitable'. + """ + cca_results = state.get("cca_results", []) + is_reachability = state.get("is_reachability") + if is_reachability == "yes" and not cca_results: + step_num = state.get("step", 0) + with AGENT_TRACER.push_active_function( + f"{self.agent_type}_forced_finish", input_data=f"step:{step_num}" + ) as span: + try: + active_prompt = state.get("runtime_prompt") + messages = [SystemMessage(content=active_prompt)] + context_block = self._build_observation_context(state.get("observation", None)) + if context_block: + messages.append(SystemMessage(content=context_block)) + question = state.get("input", "") + no_cca_prompt = ( + (f"QUESTION: {question}\n\n" if question else "") + + FORCED_FINISH_PROMPT + "\n\n" + "CRITICAL: Call Chain Analyzer was NEVER called during this investigation. " + "Without Call Chain Analyzer verification, there is NO evidence that the " + "vulnerable function is reachable from application code. Library presence " + "in dependencies alone does NOT constitute exploitability. " + "You MUST conclude that there is insufficient evidence to confirm " + "exploitability — the function was NOT confirmed reachable." + ) + messages.append(HumanMessage(content=no_cca_prompt)) + + response: Thought = await self.thought_llm.ainvoke(messages) + if response.mode == "finish" and response.final_answer: + final_answer = response.final_answer + else: + final_answer = ( + "Insufficient evidence: Call Chain Analyzer was never invoked. " + "Cannot confirm the vulnerable function is reachable from application code." + ) + response = Thought( + thought="Max steps exceeded without CCA verification", + mode="finish", + actions=None, + final_answer=final_answer, + ) + ai_message = AIMessage(content=final_answer) + logger.info("Reachability forced_finish: no CCA calls, injected no-CCA warning") + span.set_output({"final_answer_length": len(final_answer), "step": step_num, "no_cca_warning": True}) + return { + "messages": [ai_message], + "thought": response, + "step": step_num, + "max_steps": state.get("max_steps", self.config.max_iterations), + "observation": state.get("observation", None), + "output": final_answer, + } + except Exception as e: + logger.exception("%s forced_finish_node failed at step %d", self.agent_type, step_num) + span.set_output({"error": str(e), "exception_type": type(e).__name__, "step": step_num}) + raise + return await super().forced_finish_node(state) + + def post_observation(self, state: AgentState, tool_used: str, + tool_output: str, tool_input_detail: str) -> dict: + extra = {} + cca_results = list(state.get("cca_results", [])) + if tool_used == ToolNames.CALL_CHAIN_ANALYZER: + stripped = tool_output.strip().lstrip("([") + first_token = stripped.split(",", 1)[0].strip().lower() + if first_token == "true": + cca_results.append(True) + elif first_token == "false": + cca_results.append(False) + extra["cca_results"] = cca_results + + package_validated = state.get("package_validated") + rules_tracker = state.get("rules_tracker") + if tool_used == ToolNames.FUNCTION_LOCATOR and state.get("is_reachability") == "yes": + input_pkg = package_name_from_locator_query(tool_input_detail) + target_pkg = (state.get("app_package") or "").strip().lower() + if input_pkg and "Package is valid" in tool_output: + rules_tracker.add_validated_package(input_pkg) + if target_pkg and input_pkg == target_pkg: + package_validated = True + elif input_pkg and "Package is not valid" in tool_output: + if target_pkg and input_pkg == target_pkg and package_validated is None: + package_validated = False + extra["package_validated"] = package_validated + return extra diff --git a/src/vuln_analysis/functions/react_internals.py b/src/vuln_analysis/functions/react_internals.py index 9561c349..80340067 100644 --- a/src/vuln_analysis/functions/react_internals.py +++ b/src/vuln_analysis/functions/react_internals.py @@ -17,11 +17,16 @@ from typing import Any from typing import Literal from langgraph.graph import MessagesState -#---- REACT Schemas ----# + +from exploit_iq_commons.logging.loggers_factory import LoggingFactory +from vuln_analysis.tools.tool_names import ToolNames + +logger = LoggingFactory.get_agent_logger(__name__) + +# ---- Pydantic Schemas ---- # class ToolCall(BaseModel): tool: str = Field(description="Exact tool name from AVAILABLE_TOOLS") - #tool_input: str = Field(description="The input for the tool. Example: Code Keyword Search: PQescapeLiteral") package_name: str | None = Field( default=None, description="Package/module name. REQUIRED when using Function Locator, Function Caller Finder, or Call Chain Analyzer. E.g. libpq, urllib, github.com/org/pkg" @@ -30,12 +35,12 @@ class ToolCall(BaseModel): default=None, description="Function or method name with optional args. REQUIRED with package_name for code path tools. E.g. PQescapeLiteral(), parse(), errors.New(\"x\")" ) - # For search tools (Code Keyword Search, Code Semantic Search, Docs Semantic Search, CVE Web Search) + # For search tools: Code Keyword Search, Code Semantic Search, Docs Semantic Search, CVE Web Search query: str | None = Field( default=None, description="Search query. Use for search tools when package_name/function_name don't apply" ) - # Fallback: if LLM uses tool_input for simple query-only tools + # Fallback for when LLM uses tool_input for simple query-only tools tool_input: str | None = Field( default=None, description="Legacy/fallback input. Prefer package_name+function_name or query." @@ -94,18 +99,18 @@ class PackageSelection(BaseModel): description="One-sentence justification for why this package is the best investigation target." ) -class SystemRulesTracker: +class BaseRulesTracker: + """Shared behavioral rule infrastructure for all sub-agents.""" def __init__(self): self.action_history = {} self.target_package = None self.allowed_tools = [] - self.target_functions: dict[str, bool] = {} + def set_allowed_tools(self, allowed_tools: list[str]): self.allowed_tools = allowed_tools + def set_target_package(self, target_package: str): self.target_package = target_package - def set_target_functions(self, functions: list[str]): - self.target_functions = {f: False for f in functions} @staticmethod def _is_empty_result(output) -> bool: @@ -127,6 +132,11 @@ def add_action(self, action: str, action_input: str, output): else: self.action_history[action].append(entry) + def _rule_duplicate_call(self, action: str, action_input: str) -> bool: + if action not in self.action_history: + return False + return any(prev["input"] == action_input for prev in self.action_history[action]) + def _rule_number_7(self, action: str, action_input: str, output) -> bool: if action != "Code Keyword Search": return False @@ -141,10 +151,73 @@ def _rule_number_7(self, action: str, action_input: str, output) -> bool: return True return False + def _rule_use_allowed_tools(self, action: str) -> bool: + if action not in self.allowed_tools: + return True + return False + + def check_thought_behavior(self, action: str, action_input: str, output) -> tuple[bool, str]: + if self._rule_duplicate_call(action, action_input): + return True, ( + f"You already called {action} with this exact input. " + "You MUST use a DIFFERENT tool or a DIFFERENT input query. " + "Check KNOWLEDGE for what was already tried." + ) + if self._rule_number_7(action, action_input, output): + return True, ("You are NOT following Rule 7. Your query contains dots and returned " + "no results. You MUST retry with just the final component. Follow the rules.") + if self._rule_use_allowed_tools(action): + return True, (f"You are NOT following AVAILABLE_TOOLS. You MUST use the allowed tools {self.allowed_tools}. Follow the rules.") + self.add_action(action, action_input, output) + return False, "" + + +class ReachabilityRulesTracker(BaseRulesTracker): + """Behavioral rules for the reachability sub-agent.""" + def __init__(self): + super().__init__() + self.target_functions: dict[str, bool] = {} + self.ecosystem: str = "" + self.validated_packages: set[str] = set() + + def set_ecosystem(self, ecosystem: str): + self.ecosystem = ecosystem.lower() if ecosystem else "" + + def check_finish_allowed(self, cca_results: list[bool]) -> tuple[bool, str]: + """Block finish when mandatory reachability tools were not used. + + Two rules enforced: + 1. CCA must be called before finishing any reachability question. + 2. (Java only) FLVF must be called when CCA found reachability. + """ + if not cca_results and ToolNames.CALL_CHAIN_ANALYZER not in self.action_history: + return False, ( + "You MUST use Function Locator and Call Chain Analyzer before concluding. " + "Code Keyword Search alone is NOT sufficient to determine reachability. " + "Call Function Locator to validate the package, then Call Chain Analyzer " + "to check if the vulnerable function is reachable from application code." + ) + if self.ecosystem == "java" and any(cca_results): + if ToolNames.FUNCTION_LIBRARY_VERSION_FINDER not in self.action_history: + return False, ( + "MANDATORY VERSION CHECK: Call Chain Analyzer found the vulnerable function is reachable, " + "but you have NOT verified the installed library version. " + "You MUST call Function Library Version Finder (e.g., input 'commons-beanutils') " + "to check whether the installed version is within the vulnerable range before concluding." + ) + return True, "" + + def set_target_functions(self, functions: list[str]): + self.target_functions = {f: False for f in functions} + @staticmethod def _normalize_package_name(name: str) -> str: return name.strip().lower().replace("-", "_") + def add_validated_package(self, package_name: str): + """Register a package that FL confirmed as valid (e.g., uber-jar alternative).""" + self.validated_packages.add(self._normalize_package_name(package_name)) + def _rule_number_8(self, action: str, action_input: str, output) -> bool: if self.target_package is None: return False @@ -155,11 +228,12 @@ def _rule_number_8(self, action: str, action_input: str, output) -> bool: target_pkg = self._normalize_package_name(self.target_package) # Allow Java GAV with version suffix: input "group:artifact:version" # should match target "group:artifact" - if input_pkg != target_pkg and not input_pkg.startswith(target_pkg + ":"): - return True - return False - def _rule_use_allowed_tools(self, action: str) -> bool: - if action not in self.allowed_tools: + if input_pkg == target_pkg or input_pkg.startswith(target_pkg + ":"): + return False + # Allow packages that FL already validated (handles uber-jars like + # netty-all containing netty-codec-http classes) + if any(input_pkg == vp or input_pkg.startswith(vp + ":") for vp in self.validated_packages): + return False return True return False @@ -186,24 +260,35 @@ def _rule_number_9(self, action: str, action_input: str) -> tuple[bool, str]: return False, "" def check_thought_behavior(self, action: str, action_input: str, output) -> tuple[bool, str]: + if self._rule_duplicate_call(action, action_input): + logger.debug("Duplicate call rule triggered: '%s' with same input", action) + return True, ( + f"You already called {action} with this exact input. " + "You MUST use a DIFFERENT tool or a DIFFERENT input query. " + "Check KNOWLEDGE for what was already tried." + ) if self._rule_number_7(action, action_input, output): + logger.debug("Reachability Rule 7 triggered: dotted query with empty results for tool '%s'", action) return True, ("You are NOT following Rule 7. Your query contains dots and returned " "no results. You MUST retry with just the final component. Follow the rules.") if self._rule_number_8(action, action_input, output): + logger.debug("Reachability Rule 8 triggered: wrong package for tool '%s', expected '%s'", action, self.target_package) return True, (f"You are NOT following Rule 8. You are using the wrong package name. You MUST use the target package name {self.target_package} see KNOWLEDGE as the package_name before trying alternative packages. Follow the rules.") if self._rule_use_allowed_tools(action): + logger.debug("Reachability allowed-tools rule triggered: '%s' not in %s", action, self.allowed_tools) return True, (f"You are NOT following AVAILABLE_TOOLS. You MUST use the allowed tools {self.allowed_tools}. Follow the rules.") rule9, msg9 = self._rule_number_9(action, action_input) if rule9: + logger.debug("Reachability Rule 9 triggered: must investigate target functions first for tool '%s'", action) return True, msg9 self.add_action(action, action_input, output) return False, "" - + + class AgentState(MessagesState): input: str = "" step: int = 0 max_steps: int = 10 - #memory: str | None = None thought: Thought | None = None observation: Observation | None = None output: str = "" @@ -211,14 +296,19 @@ class AgentState(MessagesState): runtime_prompt: str | None = None app_package: str | None = None is_reachability: str = "yes" - rules_tracker: SystemRulesTracker = SystemRulesTracker() + rules_tracker: BaseRulesTracker = ReachabilityRulesTracker() critical_context: list[str] = [] cca_results: list[bool] = [] package_validated: bool | None = None - -### --- End of REACT Schemas ----# -#---- REACT Prompt Templates ----# -AGENT_SYS_PROMPT = ( + precomputed_intel: tuple | None = None + + +# ---- Reachability Sub-Agent Prompts ---- # +# The following prompts (REACHABILITY_AGENT_SYS_PROMPT, REACHABILITY_AGENT_THOUGHT_INSTRUCTIONS, +# REACHABILITY_REACHABILITY_AGENT_THOUGHT_INSTRUCTIONS_GO) are used exclusively by the reachability sub-agent, +# which traces call chains to determine if vulnerable code is reachable from application code. + +REACHABILITY_AGENT_SYS_PROMPT = ( "You are a security analyst investigating CVE exploitability in container images.\n" "MANDATORY STEPS (follow in order, do NOT skip any):\n" "1. IDENTIFY the vulnerable component/function from the CVE description.\n" @@ -249,7 +339,6 @@ class AgentState(MessagesState): "- When citing evidence, explain HOW it relates to the question -- do not just state that something was found." ) -# Update LANGGRAPH_SYSTEM_PROMPT_TEMPLATE in react_internals.py LANGGRAPH_SYSTEM_PROMPT_TEMPLATE = """{sys_prompt} @@ -265,7 +354,7 @@ class AgentState(MessagesState): RESPONSE: {{""" -AGENT_THOUGHT_INSTRUCTIONS = """ +REACHABILITY_AGENT_THOUGHT_INSTRUCTIONS = """ 1. Output valid JSON only. thought < 100 words. final_answer < 150 words. 2. mode="act" REQUIRES actions. mode="finish" REQUIRES final_answer. 3. Function Locator, Function Caller Finder, Call Chain Analyzer: MUST set package_name AND function_name. Do NOT use query. @@ -286,7 +375,7 @@ class AgentState(MessagesState): {{"thought": "Function Locator confirmed the package. Now trace reachability with Call Chain Analyzer", "mode": "act", "actions": {{"tool": "Call Chain Analyzer", "package_name": "", "function_name": "", "query": null, "tool_input": null, "reason": "Check if function is reachable from application code"}}, "final_answer": null}} """ -AGENT_THOUGHT_INSTRUCTIONS_GO = """ +REACHABILITY_AGENT_THOUGHT_INSTRUCTIONS_GO = """ 1. Output valid JSON only. thought < 100 words. final_answer < 150 words. 2. mode="act" REQUIRES actions. mode="finish" REQUIRES final_answer. 3. Function Locator, Function Caller Finder, Call Chain Analyzer: MUST set package_name AND function_name. Do NOT use query. @@ -312,7 +401,7 @@ class AgentState(MessagesState): {{"thought": "Function Caller Finder returned no callers, but this does not prove unreachable. MUST call Call Chain Analyzer to confirm reachability", "mode": "act", "actions": {{"tool": "Call Chain Analyzer", "package_name": "", "function_name": "", "query": null, "tool_input": null, "reason": "Check if function is reachable from application code"}}, "final_answer": null}} """ -AGENT_SYS_PROMPT_NON_REACHABILITY = ( +REACHABILITY_AGENT_NON_REACH_SYS_PROMPT = ( "You are a security analyst investigating CVE exploitability in container images.\n" "This is NOT a reachability question -- do NOT trace call chains.\n" "MANDATORY STEPS (follow in order, do NOT skip any):\n" @@ -345,7 +434,7 @@ class AgentState(MessagesState): "- When citing evidence, explain HOW it relates to the question -- do not just state that something was found." ) -AGENT_THOUGHT_INSTRUCTIONS_NON_REACHABILITY = """ +REACHABILITY_AGENT_NON_REACH_THOUGHT_INSTRUCTIONS = """ 1. Output valid JSON only. thought < 100 words. final_answer < 150 words. 2. mode="act" REQUIRES actions. mode="finish" REQUIRES final_answer. 3. Function Locator: MUST set package_name AND function_name. Do NOT use query. @@ -404,6 +493,12 @@ class AgentState(MessagesState): NEW OUTPUT: {tool_output} +CRITICAL — CALL CHAIN ANALYZER REACHABILITY: +When TOOL USED is "Call Chain Analyzer": +- If the result is POSITIVE (reachable): your first finding MUST be "REACHABLE via [package] - sufficient evidence." Do NOT hedge, qualify, or say "further investigation required." +- If the result is NEGATIVE (not reachable): your first finding MUST be "NOT reachable via [package]." +These override all other rules for Call Chain Analyzer results. + CODE COMPREHENSION RULES: 1. READ the actual code snippets in NEW OUTPUT. Do NOT just check whether something was "found" or "not found." 2. For each code snippet or function returned, determine: @@ -452,17 +547,13 @@ class AgentState(MessagesState): RESPONSE: {{""" -# Legacy prompt kept for backwards compatibility with older traces -OBSERVATION_NODE_PROMPT = COMPREHENSION_PROMPT - -### --- End of REACT Prompt Templates ----# -def build_system_prompt( +def build_reachability_system_prompt( tool_descriptions: str, tool_guidance: str, - instructions: str = AGENT_THOUGHT_INSTRUCTIONS, + instructions: str = REACHABILITY_AGENT_THOUGHT_INSTRUCTIONS, sys_prompt: str | None = None, ) -> str: - sys_prompt = sys_prompt or AGENT_SYS_PROMPT + sys_prompt = sys_prompt or REACHABILITY_AGENT_SYS_PROMPT return LANGGRAPH_SYSTEM_PROMPT_TEMPLATE.format( sys_prompt=sys_prompt, tools=tool_descriptions, @@ -545,6 +636,8 @@ def _build_tool_arguments(actions: ToolCall)->dict[str, Any]: return {"query": actions.query} if actions.tool_input: return {"query": actions.tool_input} # fallback + logger.warning("Tool '%s' called without required arguments (package_name=%s, function_name=%s, query=%s)", + actions.tool, actions.package_name, actions.function_name, actions.query) raise ValueError(f"Tool {actions.tool} requires package_name+function_name or query/tool_input") diff --git a/src/vuln_analysis/register.py b/src/vuln_analysis/register.py index dc847a87..fe3b4b75 100644 --- a/src/vuln_analysis/register.py +++ b/src/vuln_analysis/register.py @@ -47,6 +47,8 @@ from vuln_analysis.tools import container_image_analysis_data from vuln_analysis.tools import local_vdb from vuln_analysis.tools import serp +from vuln_analysis.tools import configuration_scanner +from vuln_analysis.tools import import_usage_analyzer from vuln_analysis.utils.error_handling_decorator import catch_pipeline_errors_async # pylint: enable=unused-import from vuln_analysis.utils.llm_engine_utils import postprocess_engine_output, finalize_preprocess_engine_input @@ -255,7 +257,7 @@ async def call_llm_engine_subgraph_node(message: AgentMorpheusEngineInput): graph = graph_builder.compile() def convert_str_to_agent_morpheus_input(input: str) -> AgentMorpheusInput: - logger.debug("Converting input to AgentMorpheusInput: %s", input) + logger.debug("Converting JSON string input to AgentMorpheusInput (length: %d)", len(input)) try: return AgentMorpheusInput.model_validate_json(input) except Exception as e: @@ -263,7 +265,7 @@ def convert_str_to_agent_morpheus_input(input: str) -> AgentMorpheusInput: raise e def convert_textio_to_agent_morpheus_input(input: TextIOWrapper) -> AgentMorpheusInput: - logger.debug("Converting input to AgentMorpheusInput: %s", input) + logger.debug("Converting TextIOWrapper input to AgentMorpheusInput") try: data = input.read() return AgentMorpheusInput.model_validate_json(data) @@ -273,7 +275,7 @@ def convert_textio_to_agent_morpheus_input(input: TextIOWrapper) -> AgentMorpheu raise e def convert_agent_morpheus_output_to_str(output: AgentMorpheusOutput) -> str: - logger.debug("Converting AgentMorpheusOutput to str: %s", output) + logger.debug("Converting AgentMorpheusOutput to JSON string") try: return output.model_dump_json() except Exception as e: diff --git a/src/vuln_analysis/runtime_context.py b/src/vuln_analysis/runtime_context.py index fceb67aa..adca2360 100644 --- a/src/vuln_analysis/runtime_context.py +++ b/src/vuln_analysis/runtime_context.py @@ -19,4 +19,9 @@ import contextvars # Holds the current AgentMorpheusEngineState for the active task -ctx_state = contextvars.ContextVar("ctx_state", default="default_value") \ No newline at end of file +ctx_state = contextvars.ContextVar("ctx_state", default="default_value") + +# Source scope for CU agent tools (Docs Semantic Search, Code Keyword Search). +# List of path substrings; tools keep results whose path matches any entry. +# None means no scoping (search all sources). +cu_source_scope = contextvars.ContextVar("cu_source_scope", default=None) \ No newline at end of file diff --git a/src/vuln_analysis/tools/configuration_scanner.py b/src/vuln_analysis/tools/configuration_scanner.py new file mode 100644 index 00000000..0ed1ada3 --- /dev/null +++ b/src/vuln_analysis/tools/configuration_scanner.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +import re +from collections import OrderedDict +from pathlib import Path + +from aiq.builder.builder import Builder +from aiq.builder.framework_enum import LLMFrameworkEnum +from aiq.builder.function_info import FunctionInfo +from aiq.cli.register_workflow import register_function +from aiq.data_models.function import FunctionBaseConfig +from pydantic import Field + +from exploit_iq_commons.logging.loggers_factory import LoggingFactory +from exploit_iq_commons.utils.git_utils import sanitize_git_url_for_path +from exploit_iq_commons.utils.data_utils import DEFAULT_GIT_DIRECTORY +from vuln_analysis.utils.error_handling_decorator import catch_tool_errors +from vuln_analysis.utils.source_classification import is_dependency_path, filter_by_source_scope, format_app_dep_output + +CONFIGURATION_SCANNER = "configuration_scanner" + +logger = LoggingFactory.get_agent_logger(__name__) + + +def format_context_snippet(lines: list[str], match_line: int, context_lines: int) -> str: + """Format source lines around a match, marking the match line with '>'.""" + start = max(0, match_line - context_lines) + end = min(len(lines), match_line + context_lines + 1) + return "\n".join( + f"{'>' if i == match_line else ' '} {i+1}: {lines[i]}" + for i in range(start, end) + ) + +# File patterns and directory names considered as configuration files +CONFIG_FILE_PATTERNS = [ + # Named config files — matched anywhere in the repo + "application.yml", "application.yaml", "application.properties", + "config.yaml", "config.yml", "config.xml", + "settings.toml", "settings.yaml", "settings.yml", + "web.xml", "beans.xml", + "Dockerfile", "Dockerfile.*", "docker-compose*.yml", + # Config-specific extensions — safe to match anywhere + "*.properties", "*.env", "*.conf", "*.ini", +] + +# Directory names that typically contain configuration files +CONFIG_DIR_PATTERNS = ["config", "conf", "conf.d", "etc", "resources"] + +# Avoids per-file regex compilation +_CONFIG_EXTENSIONS = [] +_CONFIG_EXACT_NAMES = [] +_CONFIG_WILDCARD_PATTERNS = [] + +for _pat in CONFIG_FILE_PATTERNS: + if _pat.startswith("*."): + _CONFIG_EXTENSIONS.append(_pat[1:]) + elif "*" in _pat: + _CONFIG_WILDCARD_PATTERNS.append( + re.compile(_pat.replace(".", r"\.").replace("*", ".*"), re.IGNORECASE) + ) + else: + _CONFIG_EXACT_NAMES.append(_pat.lower()) + + +class ConfigurationScannerToolConfig(FunctionBaseConfig, name=CONFIGURATION_SCANNER): + """Scans configuration files for vulnerability-relevant patterns.""" + max_results: int = Field(default=15, description="Maximum config entries to return") + context_lines: int = Field(default=5, description="Lines of context around matches") + + +def _is_config_file(file_path: str) -> bool: + """Check if a file path matches any known configuration file pattern.""" + name = os.path.basename(file_path) + lower_name = name.lower() + + # Check in order of speed: extension → exact name → wildcard regex + if any(lower_name.endswith(ext) for ext in _CONFIG_EXTENSIONS): + return True + if lower_name in _CONFIG_EXACT_NAMES: + return True + if any(p.match(lower_name) for p in _CONFIG_WILDCARD_PATTERNS): + return True + + return False + + +def _is_in_config_dir(file_path: str) -> bool: + """Check if a file resides under a known configuration directory.""" + parts = Path(file_path).parts + return any(p.lower() in CONFIG_DIR_PATTERNS for p in parts) + + +def _collect_config_files(repo_path: str) -> list[tuple[str, str]]: + """Walk repo and collect config file paths and contents. + + Skips .git, __pycache__, node_modules, .tox directories. + Skips files larger than 500KB to avoid memory issues with large generated configs. + """ + config_files = [] + for root, dirs, files in os.walk(repo_path): + dirs[:] = [d for d in dirs if d not in (".git", "__pycache__", "node_modules", ".tox")] + for fname in files: + full_path = os.path.join(root, fname) + rel_path = os.path.relpath(full_path, repo_path) + if _is_config_file(rel_path) or _is_in_config_dir(rel_path): + try: + with open(full_path, "r", errors="ignore") as f: + content = f.read() + if len(content) < 500_000: + config_files.append((rel_path, content)) + else: + logger.warning("Skipping config file %s: size %d exceeds 500KB limit", rel_path, len(content)) + except Exception: + continue + logger.debug("Collected %d config files from %s", len(config_files), repo_path) + return config_files + + +def _count_config_matches(config_files: list[tuple[str, str]], keywords: list[str]) -> int: + """Count total keyword-matching lines across config files (no formatting, no cap).""" + count = 0 + for _, content in config_files: + for line in content.split("\n"): + if any(kw in line.lower() for kw in keywords): + count += 1 + return count + + +def search_config_content( + config_files: list[tuple[str, str]], + keywords: list[str], + max_results: int = 15, + context_lines: int = 5, + source_label: str = "unknown", +) -> list[str]: + """Match keywords against config file contents, returning formatted snippets.""" + matches = [] + for rel_path, content in config_files: + lines = content.split("\n") + for line_num, line in enumerate(lines): + line_lower = line.lower() + if any(kw in line_lower for kw in keywords): + snippet = format_context_snippet(lines, line_num, context_lines) + matches.append(f"--- {rel_path} (source: {source_label}) line {line_num+1} ---\n{snippet}") + if len(matches) >= max_results: + break + if len(matches) >= max_results: + break + return matches + + +@register_function(config_type=ConfigurationScannerToolConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) +async def configuration_scanner(config: ConfigurationScannerToolConfig, builder: Builder): + from vuln_analysis.runtime_context import ctx_state, cu_source_scope + + _CONFIG_CACHE_MAX_SIZE = 20 + _config_files_cache: OrderedDict[tuple, list[tuple[str, str]]] = OrderedDict() + _repo_locks: dict[tuple, asyncio.Lock] = {} + _repo_locks_guard = asyncio.Lock() + + @catch_tool_errors(CONFIGURATION_SCANNER) + async def _arun(query: str) -> str: + workflow_state = ctx_state.get() + source_infos = workflow_state.original_input.input.image.source_info + source_scope = cu_source_scope.get() + logger.debug("Configuration scanner: searching for '%s' across %d source(s), scope=%s", + query, len(source_infos), source_scope) + + keywords = [w.strip().lower() for w in re.split(r"[,\s]+", query) if len(w.strip()) >= 2] + all_app_configs = [] + all_dep_configs = [] + + for si in source_infos: + if not hasattr(si, "git_repo"): + continue + repo_path = Path(DEFAULT_GIT_DIRECTORY) / sanitize_git_url_for_path(si.git_repo) + if not repo_path.is_dir(): + continue + + repo_key = (si.git_repo, si.ref) + if repo_key in _config_files_cache: + async with _repo_locks_guard: + _config_files_cache.move_to_end(repo_key) + else: + async with _repo_locks_guard: + if repo_key not in _repo_locks: + _repo_locks[repo_key] = asyncio.Lock() + repo_lock = _repo_locks[repo_key] + async with repo_lock: + if repo_key not in _config_files_cache: + _config_files_cache[repo_key] = _collect_config_files(str(repo_path)) + if len(_config_files_cache) > _CONFIG_CACHE_MAX_SIZE: + _config_files_cache.popitem(last=False) + + for cfg in _config_files_cache[repo_key]: + if is_dependency_path(cfg[0]): + all_dep_configs.append(cfg) + else: + all_app_configs.append(cfg) + + all_dep_configs = filter_by_source_scope(all_dep_configs, source_scope, lambda x: x[0]) + + source_label = source_infos[0].git_repo if source_infos and hasattr(source_infos[0], "git_repo") else "unknown" + app_matches = search_config_content( + all_app_configs, keywords, + max_results=config.max_results, + context_lines=config.context_lines, + source_label=source_label, + ) + remaining = config.max_results - len(app_matches) + dep_matches = search_config_content( + all_dep_configs, keywords, + max_results=max(remaining, 0), + context_lines=config.context_lines, + source_label=source_label, + ) if remaining > 0 else [] + + no_results_msg = f"No configuration entries found matching: {query}" + total_app = _count_config_matches(all_app_configs, keywords) + total_dep = _count_config_matches(all_dep_configs, keywords) + logger.debug("Configuration scanner: %d app + %d dep match(es) for '%s'", + total_app, total_dep, query) + return format_app_dep_output(app_matches, dep_matches, total_app, total_dep, no_results_msg) + + yield FunctionInfo.from_fn( + _arun, + description=( + "Scans configuration files (YAML, XML, properties, build files, Dockerfiles) " + "for patterns related to the vulnerability. Input: keywords describing what " + "configuration to look for (e.g. 'XML external entity processing', " + "'deserialization enabled', 'xstream allowTypes'). " + "Searches all indexed sources including framework dependencies." + ), + ) \ No newline at end of file diff --git a/src/vuln_analysis/tools/import_usage_analyzer.py b/src/vuln_analysis/tools/import_usage_analyzer.py new file mode 100644 index 00000000..755c82f2 --- /dev/null +++ b/src/vuln_analysis/tools/import_usage_analyzer.py @@ -0,0 +1,158 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from aiq.builder.builder import Builder +from aiq.builder.framework_enum import LLMFrameworkEnum +from aiq.builder.function_info import FunctionInfo +from aiq.cli.register_workflow import register_function +from aiq.data_models.function import FunctionBaseConfig +from pydantic import Field + +from exploit_iq_commons.logging.loggers_factory import LoggingFactory +from exploit_iq_commons.utils.functions_parsers.lang_functions_parsers_factory import get_language_function_parser +from vuln_analysis.utils.error_handling_decorator import catch_tool_errors +from vuln_analysis.utils.source_classification import is_dependency_path, filter_by_source_scope, format_app_dep_output + +IMPORT_USAGE_ANALYZER = "import_usage_analyzer" + +logger = LoggingFactory.get_agent_logger(__name__) + + +def _find_usage_in_file(content: str, imported_names: list[str], max_usages: int = 5) -> list[str]: + """Find usage sites of imported names, excluding import lines themselves.""" + usages = [] + lines = content.split("\n") + for line_num, line in enumerate(lines): + for name in imported_names: + short_name = name.rsplit(".", 1)[-1] if "." in name else name + if re.search(rf'\b{re.escape(short_name)}\b', line) and not line.strip().startswith(("import ", "from ", "#include")): + usages.append(f" L{line_num+1}: {line.strip()}") + if len(usages) >= max_usages: + return usages + return usages + + +class ImportUsageAnalyzerToolConfig(FunctionBaseConfig, name=IMPORT_USAGE_ANALYZER): + """Analyzes imports and usage patterns of a package across indexed sources.""" + max_files: int = Field(default=20, description="Maximum files to report") + + +def analyze_imports(searcher, import_patterns: list[re.Pattern], package_name: str, + max_files: int = 20, source_scope: list[str] | None = None, + ecosystem_label: str = "") -> str: + """Scan a Tantivy searcher for import patterns, with app/dep source awareness. + + Results are separated into application code and dependency code sections, + with dependency results filtered by source_scope when provided. + """ + num_docs = searcher.num_docs + app_results = [] + dep_results = [] + + for doc_id in range(num_docs): + try: + raw = searcher.doc(doc_id) + file_path = raw["file_path"][0] + content = raw["content"][0] + except Exception: + continue + + found_imports = [] + for pattern in import_patterns: + for match in pattern.finditer(content): + found_imports.append(match.group(0)) + + if not found_imports: + continue + + imported_names = [] + for imp in found_imports: + parts = imp.split() + for p in parts: + cleaned = p.strip(";'\",<>(){}").strip() + if package_name.lower() in cleaned.lower() and len(cleaned) > 2: + imported_names.append(cleaned) + + usages = _find_usage_in_file(content, imported_names) + + entry = f"--- {file_path} ---\n" + entry += f" Imports: {'; '.join(found_imports[:3])}\n" + if usages: + entry += f" Usages ({len(usages)}):\n" + "\n".join(usages) + else: + entry += " No usage sites found beyond import." + + if is_dependency_path(file_path): + dep_results.append((file_path, entry)) + else: + app_results.append(entry) + + dep_results = filter_by_source_scope(dep_results, source_scope, lambda x: x[0]) + dep_entries = [entry for _, entry in dep_results] + + no_results_msg = f"No imports of '{package_name}' found in indexed sources (ecosystem: {ecosystem_label})." + total_app = len(app_results) + total_dep = len(dep_entries) + + trimmed_app = app_results[:max_files] + remaining = max_files - len(trimmed_app) + trimmed_dep = dep_entries[:max(remaining, 0)] + + return format_app_dep_output(trimmed_app, trimmed_dep, total_app, total_dep, no_results_msg) + + +@register_function(config_type=ImportUsageAnalyzerToolConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) +async def import_usage_analyzer(config: ImportUsageAnalyzerToolConfig, builder: Builder): + from vuln_analysis.runtime_context import ctx_state, cu_source_scope + from vuln_analysis.utils.full_text_search import FullTextSearch + + @catch_tool_errors(IMPORT_USAGE_ANALYZER) + async def _arun(query: str) -> str: + workflow_state = ctx_state.get() + code_index_path = workflow_state.code_index_path + ecosystem = workflow_state.original_input.input.image.ecosystem + source_scope = cu_source_scope.get() + + fts = FullTextSearch.get_instance(cache_path=code_index_path) + if fts.is_empty(): + logger.debug("Import usage analyzer: no source code indexed at %s", code_index_path) + return "No source code indexed." + + package_name = query.strip() + parser = get_language_function_parser(ecosystem, tree=None) if ecosystem else None + import_patterns = parser.get_import_search_patterns(package_name) if parser else [re.compile(re.escape(package_name), re.IGNORECASE)] + ecosystem_label = ecosystem.value if ecosystem else "" + logger.debug("Import usage analyzer: searching for '%s' (ecosystem: %s), scope=%s", + package_name, ecosystem_label, source_scope) + + result = analyze_imports( + fts.index.searcher(), import_patterns, package_name, + max_files=config.max_files, source_scope=source_scope, + ecosystem_label=ecosystem_label, + ) + logger.debug("Import usage analyzer: %s", result.split("\n", 1)[0]) + return result + + yield FunctionInfo.from_fn( + _arun, + description=( + "Finds all imports and usage patterns of a specific package/module across indexed sources. " + "Input: package or module name (e.g. 'encoding/xml', 'com.thoughtworks.xstream', " + "'urllib.parse'). Reports which files import it, how many, and how the package is used. " + "Searches all sources including framework dependencies." + ), + ) \ No newline at end of file diff --git a/src/vuln_analysis/tools/lexical_full_search.py b/src/vuln_analysis/tools/lexical_full_search.py index 0b24fcc1..8117e990 100644 --- a/src/vuln_analysis/tools/lexical_full_search.py +++ b/src/vuln_analysis/tools/lexical_full_search.py @@ -37,19 +37,25 @@ class LexicalSearchToolConfig(FunctionBaseConfig, name=LEXICAL_CODE_SEARCH): @register_function(config_type=LexicalSearchToolConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) async def lexical_search(config: LexicalSearchToolConfig, builder: Builder): # pylint: disable=unused-argument - from vuln_analysis.runtime_context import ctx_state + from vuln_analysis.runtime_context import ctx_state, cu_source_scope from vuln_analysis.utils.full_text_search import FullTextSearch @catch_tool_errors(LEXICAL_CODE_SEARCH) async def _arun(query: str) -> str: workflow_state = ctx_state.get() code_index_path = workflow_state.code_index_path - full_text_search = FullTextSearch(cache_path=code_index_path) + full_text_search = FullTextSearch.get_instance(cache_path=code_index_path) if full_text_search.is_empty(): + logger.debug("Lexical search: index is empty at %s", code_index_path) raise ValueError(f"Invalid code index at: {code_index_path}, index is empty") - result = full_text_search.search_index(query, config.top_k) + # Pass source_scope from ContextVar to restrict dependency results + # to the current question's target package (set by CU pre_process_node) + source_scope = cu_source_scope.get() + result = full_text_search.search_index(query, config.top_k, source_scope=source_scope) + logger.debug("Lexical search: query='%.80s', top_k=%d, source_scope=%s", + query, config.top_k, source_scope) return result diff --git a/src/vuln_analysis/tools/local_vdb.py b/src/vuln_analysis/tools/local_vdb.py index 7549032f..98b701d6 100644 --- a/src/vuln_analysis/tools/local_vdb.py +++ b/src/vuln_analysis/tools/local_vdb.py @@ -47,15 +47,22 @@ async def load_vectordb_asretriever(config: LocalVDBRetrieverToolConfig, builder from langchain_community.vectorstores import FAISS from langchain_core.prompts import PromptTemplate - from vuln_analysis.runtime_context import ctx_state + from vuln_analysis.runtime_context import ctx_state, cu_source_scope embedder = await builder.get_embedder(embedder_name=config.embedder_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) + qa_prompt = PromptTemplate(template=("Use the following pieces of context to answer the question at the end. " + "If you don't know the answer, just say that you don't know, " + "don't try to make up an answer.\n\n{context}\n\n" + "Question: {question}\nHelpful Answer:"), + input_variables=['context', 'question']) + # FAISS deserialization is expensive — cache per db_source path + _retrieval_cache: dict[str, RetrievalQA] = {} + @catch_tool_errors(LOCAL_VDB_RETRIEVER) async def _arun(query: str) -> str | dict: - # workaround since the agent executor only accepts strings. workflow_state = ctx_state.get() if config.vdb_type == VdbType.CODE: @@ -65,21 +72,32 @@ async def _arun(query: str) -> str | dict: else: raise ValueError(f"Invalid VDB type: {config.vdb_type}. Must be one of {VdbType.CODE} or {VdbType.DOC}.") - qa_prompt = PromptTemplate(template=("Use the following pieces of context to answer the question at the end. " - "If you don't know the answer, just say that you don't know, " - "don't try to make up an answer.\n\n{context}\n\n" - "Question: {question}\nHelpful Answer:"), - input_variables=['context', 'question']) - - vector_db = FAISS.load_local(db_source, embedder, allow_dangerous_deserialization=True) - retrieval_qa_tool = RetrievalQA.from_chain_type(llm=llm, - chain_type="stuff", - chain_type_kwargs={"prompt": qa_prompt}, - retriever=vector_db.as_retriever(), - return_source_documents=config.return_source_documents) + if db_source not in _retrieval_cache: + logger.info("Loading FAISS index from %s (type: %s)", db_source, config.vdb_type) + vector_db = FAISS.load_local(db_source, embedder, allow_dangerous_deserialization=True) + _retrieval_cache[db_source] = RetrievalQA.from_chain_type( + llm=llm, + chain_type="stuff", + chain_type_kwargs={"prompt": qa_prompt}, + retriever=vector_db.as_retriever(), + return_source_documents=config.return_source_documents, + ) + retrieval_qa_tool = _retrieval_cache[db_source] output_dict = await retrieval_qa_tool.ainvoke(query) + if config.return_source_documents and "source_documents" in output_dict: + source_scope = cu_source_scope.get() + if source_scope: + pre_filter = len(output_dict["source_documents"]) + output_dict["source_documents"] = [ + doc for doc in output_dict["source_documents"] + if any(scope in doc.metadata.get("source", "") for scope in source_scope) + ] + if pre_filter != len(output_dict["source_documents"]): + logger.debug("Source scope %s filtered doc results from %d to %d", + source_scope, pre_filter, len(output_dict["source_documents"])) + # If returning source documents, include the result and source_documents keys in the output if config.return_source_documents: return {k: v for k, v in output_dict.items() if k in ["result", "source_documents"]} @@ -96,7 +114,9 @@ async def _arun(query: str) -> str | dict: elif config.vdb_type == VdbType.DOC: description = ( "Searches container documentation using semantic search. " - "Answers questions about application purpose, architecture, and features." + "Answers questions about application purpose, architecture, and features. " + "Queries must be specific — include a concrete term (library name, config property, protocol). " + "Vague queries like 'configuration' or 'security settings' return unhelpful results." ) else: raise ValueError(f"Invalid VDB type: {config.vdb_type}. Must be one of {VdbType.CODE} or {VdbType.DOC}.") diff --git a/src/vuln_analysis/tools/tests/test_transitive_code_search.py b/src/vuln_analysis/tools/tests/test_transitive_code_search.py index b432fc25..65037fe5 100644 --- a/src/vuln_analysis/tools/tests/test_transitive_code_search.py +++ b/src/vuln_analysis/tools/tests/test_transitive_code_search.py @@ -1075,3 +1075,27 @@ def test_query_cleaning_strips_unicode_quotes(raw_query, expected_cleaned): """Test that query cleaning handles both ASCII and Unicode smart quotes.""" cleaned = raw_query.strip().split("\n")[0].strip().strip("'\"\u2018\u2019\u201c\u201d").strip() assert cleaned == expected_cleaned, f"For input {repr(raw_query)}: got {repr(cleaned)}, expected {repr(expected_cleaned)}" + + +@pytest.mark.parametrize("raw_query, expected_cleaned", [ + # Standard ASCII quotes + ("'commons-beanutils:commons-beanutils:1.9.4'", "commons-beanutils:commons-beanutils:1.9.4"), + ('"commons-beanutils:commons-beanutils:1.9.4"', "commons-beanutils:commons-beanutils:1.9.4"), + # Unicode smart quotes (left/right single) + ("\u2018commons-beanutils:commons-beanutils:1.9.4\u2019", "commons-beanutils:commons-beanutils:1.9.4"), + # Unicode smart quotes (left/right double) + ("\u201ccommons-beanutils:commons-beanutils:1.9.4\u201d", "commons-beanutils:commons-beanutils:1.9.4"), + # Mixed: ASCII left, unicode right + ("'commons-beanutils:commons-beanutils:1.9.4\u2019", "commons-beanutils:commons-beanutils:1.9.4"), + ("\"commons-beanutils:commons-beanutils:1.9.4\u201d", "commons-beanutils:commons-beanutils:1.9.4"), + # No quotes + ("commons-beanutils:commons-beanutils:1.9.4", "commons-beanutils:commons-beanutils:1.9.4"), + # Whitespace + quotes + (" 'commons-beanutils:commons-beanutils:1.9.4' ", "commons-beanutils:commons-beanutils:1.9.4"), + # Trailing newline junk from LLM + ("'commons-beanutils:commons-beanutils:1.9.4'\nPlease wait...", "commons-beanutils:commons-beanutils:1.9.4"), +]) +def test_query_cleaning_strips_unicode_quotes(raw_query, expected_cleaned): + """Test that query cleaning handles both ASCII and Unicode smart quotes.""" + cleaned = raw_query.strip().split("\n")[0].strip().strip("'\"\u2018\u2019\u201c\u201d").strip() + assert cleaned == expected_cleaned, f"For input {repr(raw_query)}: got {repr(cleaned)}, expected {repr(expected_cleaned)}" diff --git a/src/vuln_analysis/tools/tool_names.py b/src/vuln_analysis/tools/tool_names.py index fa1ddf72..f07d40d5 100644 --- a/src/vuln_analysis/tools/tool_names.py +++ b/src/vuln_analysis/tools/tool_names.py @@ -22,45 +22,51 @@ class ToolNames: - """ - Constants for agent tool names matching YAML configuration keys. - + """Constants for agent tool names matching YAML configuration keys. + These names correspond to the function keys in the YAML config files - (e.g., config.yml, config-http-openai-local.yml). + (e.g., config.yml, config-http-nim.yml). """ - + # Code analysis tools CODE_SEMANTIC_SEARCH = "Code Semantic Search" - """Searches container source code using semantic vector embeddings""" - + """Searches container source code using semantic vector embeddings.""" + DOCS_SEMANTIC_SEARCH = "Docs Semantic Search" - """Searches container documentation using semantic vector embeddings""" - + """Searches container documentation using semantic vector embeddings.""" + CODE_KEYWORD_SEARCH = "Code Keyword Search" - """Searches container source code for exact keyword matches""" - - # Code path analysis + """Searches container source code for exact keyword matches.""" + + # Code path analysis tools FUNCTION_LOCATOR = "Function Locator" - """Mandatory first step for code path analysis. Validates package names, locates functions using fuzzy matching, provides ecosystem type.""" + """Mandatory first step for code path analysis. Validates package names and locates functions using fuzzy matching.""" CALL_CHAIN_ANALYZER = "Call Chain Analyzer" - """Checks if a function is reachable from application code""" - + """Checks if a function is reachable from application code.""" + FUNCTION_CALLER_FINDER = "Function Caller Finder" - """Golang only. Finds functions calling specific standard library methods.""" - + """Go only. Finds functions calling specific standard library methods.""" + # External and cached data sources CVE_WEB_SEARCH = "CVE Web Search" - """Searches the web for CVE and vulnerability information""" - + """Searches the web for CVE and vulnerability information.""" + CONTAINER_ANALYSIS_DATA = "Container Analysis Data" - """Retrieves pre-analyzed data from earlier container scan steps""" + """Retrieves pre-analyzed data from earlier container scan steps.""" FUNCTION_LIBRARY_VERSION_FINDER = "Function Library Version Finder" - """Checks in which library version the function is used""" + """Java only. Checks in which library version the function is used.""" + + # Code Understanding agent tools + CONFIGURATION_SCANNER = "Configuration Scanner" + """Scans configuration files (YAML, XML, properties, build files) for vulnerability-relevant patterns.""" + + IMPORT_USAGE_ANALYZER = "Import Usage Analyzer" + """Finds all imports and usage patterns of a specific package across indexed sources.""" -# Export as module-level constants +# Module-level constants for convenience imports CODE_SEMANTIC_SEARCH = ToolNames.CODE_SEMANTIC_SEARCH DOCS_SEMANTIC_SEARCH = ToolNames.DOCS_SEMANTIC_SEARCH CODE_KEYWORD_SEARCH = ToolNames.CODE_KEYWORD_SEARCH @@ -70,6 +76,8 @@ class ToolNames: CVE_WEB_SEARCH = ToolNames.CVE_WEB_SEARCH CONTAINER_ANALYSIS_DATA = ToolNames.CONTAINER_ANALYSIS_DATA FUNCTION_LIBRARY_VERSION_FINDER = ToolNames.FUNCTION_LIBRARY_VERSION_FINDER +CONFIGURATION_SCANNER = ToolNames.CONFIGURATION_SCANNER +IMPORT_USAGE_ANALYZER = ToolNames.IMPORT_USAGE_ANALYZER @@ -82,6 +90,8 @@ class ToolNames: 'FUNCTION_CALLER_FINDER', 'CVE_WEB_SEARCH', 'CONTAINER_ANALYSIS_DATA', - 'FUNCTION_LOCATOR', + 'PACKAGE_FUNCTION_LOCATOR', 'FUNCTION_LIBRARY_VERSION_FINDER', -] \ No newline at end of file + 'CONFIGURATION_SCANNER', + 'IMPORT_USAGE_ANALYZER', +] diff --git a/src/vuln_analysis/utils/code_understanding_prompt_factory.py b/src/vuln_analysis/utils/code_understanding_prompt_factory.py new file mode 100644 index 00000000..8063b183 --- /dev/null +++ b/src/vuln_analysis/utils/code_understanding_prompt_factory.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Per-ecosystem tool strategies and generic tool descriptions for the Code Understanding agent.""" + +from vuln_analysis.tools.tool_names import ToolNames + +_DOCS_SEARCH_STEP = ( + "4. Use Docs Semantic Search for broader architectural questions about the application " + "— queries must include a specific term (library name, config property, protocol)." +) + +# Generic tool descriptions used when no ecosystem-specific strategy is available +CU_TOOL_GENERAL_DESCRIPTIONS: dict[str, str] = { + ToolNames.DOCS_SEMANTIC_SEARCH: ( + "Searches container documentation using semantic search. " + "Answers questions about application purpose, architecture, features, and dependencies. " + "IMPORTANT: Queries must be specific — include at least one concrete term " + "(class name, config property, library name, or protocol). " + "GOOD queries: 'deserialization whitelist configuration', " + "'SSL/TLS certificate validation settings', 'input sanitization before parsing'. " + "BAD queries: 'security settings', 'configuration', 'error handling'. " + "Vague queries return unhelpful results." + ), + ToolNames.CODE_KEYWORD_SEARCH: ( + "Performs keyword search on container source code for exact text matches. " + "Input should be a function name, class name, import statement, or code pattern." + ), + ToolNames.CONFIGURATION_SCANNER: ( + "Scans configuration files (YAML, XML, properties, build files, Dockerfiles) " + "for feature flags, enabled components, and security settings. " + "Searches all sources including framework dependencies." + ), + ToolNames.IMPORT_USAGE_ANALYZER: ( + "Finds all imports and usage patterns of a package/module across sources. " + "Reports which files import it, usage count, and how it's used. " + "Searches all sources including framework dependencies." + ), +} + +# Per-ecosystem tool selection strategies guiding the agent's investigation order +CU_TOOL_SELECTION_STRATEGY: dict[str, str] = { + "python": ( + "1. Start with Configuration Scanner to check if affected features are enabled or configured. " + "2. Use Code Keyword Search for exact import/function lookups (e.g. 'import xml.etree'). " + "3. Use Import Usage Analyzer to understand how the affected package is imported and used. " + + _DOCS_SEARCH_STEP + ), + "go": ( + "1. Start with Configuration Scanner to check configs for feature flags and security settings. " + "2. Use Code Keyword Search for import paths (e.g. 'encoding/xml'). " + "3. Use Import Usage Analyzer to trace imports of the affected package. " + + _DOCS_SEARCH_STEP + ), + "java": ( + "1. Start with Configuration Scanner to check configs, pom.xml, and build.gradle for feature flags and security settings. " + "2. Use Code Keyword Search for import statements (e.g. 'import javax.xml'). " + "3. Use Import Usage Analyzer to trace how the affected library is imported across sources. " + + _DOCS_SEARCH_STEP + ), + "javascript": ( + "1. Start with Configuration Scanner to check package.json, .env, and config files. " + "2. Use Code Keyword Search for require/import patterns (e.g. 'require(\"xml2js\")'). " + "3. Use Import Usage Analyzer to find all imports of the affected package. " + + _DOCS_SEARCH_STEP + ), + "c": ( + "1. Start with Configuration Scanner to check Makefiles, CMakeLists.txt, and config headers. " + "2. Use Code Keyword Search for function names and #include patterns. " + "3. Use Import Usage Analyzer to trace #include usage of the affected library. " + + _DOCS_SEARCH_STEP + ), +} diff --git a/src/vuln_analysis/utils/full_text_search.py b/src/vuln_analysis/utils/full_text_search.py index 0a02ab99..6fa0e8e5 100644 --- a/src/vuln_analysis/utils/full_text_search.py +++ b/src/vuln_analysis/utils/full_text_search.py @@ -16,6 +16,8 @@ import json import os import re +import threading +from collections import OrderedDict from pathlib import Path from typing import Iterable @@ -35,9 +37,7 @@ def tokenize_code(code: str) -> str: - """ - Tokenize code into simple readable document format - """ + """Normalize source code identifiers for search: split camelCase/snake_case, lowercase, filter noise.""" matches = re.finditer(r"\b\w{2,}\b", code) tokens = [] for m in matches: @@ -47,6 +47,7 @@ def tokenize_code(code: str) -> str: for part in variable_pattern.findall(section): if len(part) < 2: continue + # Exclude hex-like strings and repetitive patterns if (sum(1 for c in part if "a" <= c <= "z" or "A" <= c <= "Z" or "0" <= c <= "9") > len(part) // 2 and len(part) / len(set(part)) < 4): tokens.append(part.lower()) @@ -55,8 +56,9 @@ def tokenize_code(code: str) -> str: def clean_query(input_query) -> str: - """ - Parse a query string with OR/AND operations and fix quotes issues using regex. + """Parse a query string with OR/AND operations and fix quote issues. + + Sanitizes LLM-generated queries for Tantivy — unmatched quotes crash its parser. """ # Remove anything after newline characters if '\n' in input_query: @@ -85,16 +87,21 @@ def replace_quoted(match): return input_query -ECOSYSTEM_DEP_DIRS = { - "c": "rpm_libs/", - "go": "vendor/", - "javascript": "node_modules/", - "python": "transitive_env/", - "java": "dependencies-sources/", -} +from vuln_analysis.utils.source_classification import is_dependency_path, filter_by_source_scope class FullTextSearch: + """Tantivy-based full text search index for source code and documentation. + + Wraps a Tantivy index with document management, query sanitization, and + result separation into app code vs. dependency code. Instances are cached + via get_instance() with LRU eviction to avoid holding too many open indexes. + """ INDEX_TYPE = "tantivy" + # LRU cache of FullTextSearch instances keyed by cache_path. + # Prevents unbounded growth of open Tantivy index handles. + _instances: "OrderedDict[str, FullTextSearch]" = None + _INSTANCE_CACHE_MAX = 10 + _instances_lock = threading.Lock() def __init__(self, cache_path: str = None, tokenizer=False): if cache_path: @@ -104,13 +111,37 @@ def __init__(self, cache_path: str = None, tokenizer=False): self.index.reload() self.tokenizer = tokenizer + @classmethod + def get_instance(cls, cache_path: str, tokenizer=False) -> "FullTextSearch": + """Return a cached instance for the given path, with LRU eviction. + + On cache hit, moves the entry to the end (most recently used). + On cache miss, creates a new instance and evicts the least recently used + if the cache exceeds _INSTANCE_CACHE_MAX. + """ + with cls._instances_lock: + if cls._instances is None: + cls._instances = OrderedDict() + if cache_path in cls._instances: + cls._instances.move_to_end(cache_path) + return cls._instances[cache_path] + instance = cls(cache_path=cache_path, tokenizer=tokenizer) + with cls._instances_lock: + if cache_path in cls._instances: + return cls._instances[cache_path] + cls._instances[cache_path] = instance + if len(cls._instances) > cls._INSTANCE_CACHE_MAX: + evicted = cls._instances.popitem(last=False) + logger.debug("Evicted LRU FullTextSearch instance: %s", evicted[0]) + return instance + @classmethod def get_index_directory(cls, base_path: str, hash_value: str) -> Path: """Returns the path where the index should be stored""" return Path(base_path) / cls.INDEX_TYPE / hash_value def _build_schema(self): - """Build schema for the code index""" + """Build Tantivy schema: file_path + content (both stored for retrieval) + doc_id.""" schema_builder = SchemaBuilder() schema_builder.add_text_field("file_path", stored=True) schema_builder.add_text_field("content", stored=True) @@ -119,17 +150,28 @@ def _build_schema(self): return schema def add_documents(self, documents: Iterable): - + """Index an iterable of (file_path, content) tuples into Tantivy.""" writer = self.index.writer() for doc_id, (title, text) in enumerate(documents): writer.add_document(TantivyDocument(file_path=title, content=text, doc_id=doc_id)) writer.commit() def is_empty(self): + """Check if the index contains any documents.""" return self.index.searcher().num_docs == 0 - def search_index(self, query: str, top_k: int = 10) -> str: + def search_index(self, query: str, top_k: int = 10, source_scope: list[str] | None = None) -> str: + """Search the index using Tantivy's ranked query engine. + + Returns results separated into application code and dependency code sections. + When source_scope is provided, dependency results are filtered to only + include files matching the scope (e.g. specific package names). + Args: + query: Search query string (sanitized internally) + top_k: Maximum total results to return (app results prioritized) + source_scope: Optional list of path substrings to filter dependency results + """ self.index.reload() try: if self.tokenizer: @@ -141,16 +183,20 @@ def search_index(self, query: str, top_k: int = 10) -> str: app_docs = [] dep_docs = [] - - vendors_list = list(ECOSYSTEM_DEP_DIRS.values()) for _, doc_id in results: raw = searcher.doc(doc_id) doc = {"source": raw["file_path"][0], "content": raw["content"][0]} - if any(doc["source"].startswith(vendor) for vendor in vendors_list): + if is_dependency_path(doc["source"]): dep_docs.append(doc) else: app_docs.append(doc) + pre_filter_count = len(dep_docs) + dep_docs = filter_by_source_scope(dep_docs, source_scope, lambda d: d["source"]) + if pre_filter_count != len(dep_docs): + logger.debug("Source scope %s filtered dep results from %d to %d", + source_scope, pre_filter_count, len(dep_docs)) + total_app = len(app_docs) total_dep = len(dep_docs) @@ -174,11 +220,11 @@ def search_index(self, query: str, top_k: int = 10) -> str: return "\n".join(parts) except Exception as e: - logger.exception(e) + logger.exception("Search failed for query") return "There was an error searching for the query." def add_documents_from_langchain_chunks(self, documents: list[Document]): - """Create an index from langchain chunked documents""" + """Index pre-chunked LangChain documents (used by document_embedding.py for VDB build).""" try: documents = [(doc.metadata['source'], doc.page_content) for doc in documents] @@ -188,14 +234,18 @@ def add_documents_from_langchain_chunks(self, documents: list[Document]): self.add_documents(tqdm(documents, total=len(documents), desc="Indexing")) except Exception as e: - logger.warning(f"Unable to add documents to the index {e}") + logger.warning("Unable to add documents to the index: %s", e) def add_documents_from_code_path(self, code_path: str, include_extensions: list[str], use_langparser=True, splitter=True): - """Create an index from raw files.""" + """Index raw source files from a directory, optionally using LangChain's LanguageParser. + + When use_langparser=True, files are parsed with language-aware chunking. + When False, files are read as raw text (optionally tokenized for search). + """ doc_content = [] @@ -227,6 +277,6 @@ def add_documents_from_code_path(self, content = tokenize_code(content) doc_content.append((file_path, content)) except Exception as e: - logger.warning(f"Error reading {file_path}: {e}") + logger.warning("Error reading %s: %s", file_path, e) self.add_documents(doc_content) diff --git a/src/vuln_analysis/utils/intel_utils.py b/src/vuln_analysis/utils/intel_utils.py index 32e6942f..d8dee99c 100644 --- a/src/vuln_analysis/utils/intel_utils.py +++ b/src/vuln_analysis/utils/intel_utils.py @@ -26,6 +26,8 @@ logger = LoggingFactory.get_agent_logger(__name__) +_MAX_RHSA_CANDIDATES = 20 + def update_version(incoming_version, current_version, compare): """ @@ -200,12 +202,19 @@ def build_critical_context(cve_intel_list) -> tuple[list[str], list[dict], list[ critical_context.append(f"Search keywords: {', '.join(short_names)}") for f in vf: vulnerable_functions.add(f.rsplit('.', 1)[-1]) + ver_range = vuln.get('vulnerable_version_range', '') if isinstance(vuln, dict) else getattr(v, 'vulnerable_version_range', '') + patched = vuln.get('first_patched_version', '') if isinstance(vuln, dict) else getattr(v, 'first_patched_version', '') if pkg: if isinstance(pkg, dict): pkg_name = pkg.get("name", "") pkg_eco = pkg.get("ecosystem", "") if pkg_name: critical_context.append(f"Vulnerable module ({pkg_eco}): {pkg_name}") + if ver_range: + ver_ctx = f"Vulnerable version range ({pkg_name}): {ver_range}" + if patched: + ver_ctx += f" | First patched version: {patched}" + critical_context.append(ver_ctx) if pkg_name not in seen_packages: seen_packages.add(pkg_name) candidate_packages.append({"name": pkg_name, "source": "ghsa", "ecosystem": pkg_eco}) @@ -227,10 +236,14 @@ def build_critical_context(cve_intel_list) -> tuple[list[str], list[dict], list[ critical_context.append(f"KNOWN MITIGATIONS: {mit_text[:500]}") if cve_intel.rhsa.package_state: pkgs = list(set(p.package_name for p in cve_intel.rhsa.package_state if p.package_name)) + rhsa_added = 0 for p in pkgs: if p not in seen_packages: seen_packages.add(p) candidate_packages.append({"name": p, "source": "rhsa"}) + rhsa_added += 1 + if rhsa_added >= _MAX_RHSA_CANDIDATES: + break if len(pkgs) > 5: critical_context.append( f"Affected across {len(pkgs)} Red Hat products (sample: {', '.join(pkgs[:5])}). " diff --git a/src/vuln_analysis/utils/source_classification.py b/src/vuln_analysis/utils/source_classification.py new file mode 100644 index 00000000..6a74fd12 --- /dev/null +++ b/src/vuln_analysis/utils/source_classification.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, TypeVar + +from exploit_iq_commons.utils.dep_tree import ECOSYSTEM_DEP_DIRS + +_VENDORS = list(ECOSYSTEM_DEP_DIRS.values()) + +T = TypeVar("T") + + +def is_dependency_path(file_path: str) -> bool: + """Return True if file_path starts with any ecosystem dependency directory prefix.""" + return any(file_path.startswith(vendor) for vendor in _VENDORS) + + +def filter_by_source_scope(dep_items: list[T], source_scope: list[str] | None, + path_fn: Callable[[T], str]) -> list[T]: + """Filter dependency items to those whose path contains any source_scope substring. + + Args: + dep_items: Dependency result items to filter. + source_scope: Path substrings to match. None means no filtering. + path_fn: Extracts the file path string from an item. + """ + if not source_scope: + return dep_items + return [item for item in dep_items if any(scope in path_fn(item) for scope in source_scope)] + + +def format_app_dep_output(app_items: list[str], dep_items: list[str], + total_app: int, total_dep: int, + no_results_msg: str) -> str: + """Format results with app/dep section headers matching Code Keyword Search output. + + Output format: + Main application (N of M results) + + Application library dependencies (N of M results) + + """ + if total_app == 0 and total_dep == 0: + return no_results_msg + + app_header = f"Main application ({len(app_items)} of {total_app} results)" + dep_header = f"Application library dependencies ({len(dep_items)} of {total_dep} results)" + + parts = [app_header] + if app_items: + parts.append("\n\n".join(app_items)) + parts.append(dep_header) + if dep_items: + parts.append("\n\n".join(dep_items)) + return "\n".join(parts) diff --git a/src/vuln_analysis/utils/token_utils.py b/src/vuln_analysis/utils/token_utils.py new file mode 100644 index 00000000..c3e43589 --- /dev/null +++ b/src/vuln_analysis/utils/token_utils.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Token counting and truncation utilities for agent context management.""" + +import tiktoken + +from vuln_analysis.tools.tool_names import ToolNames + +_tiktoken_enc = tiktoken.get_encoding("cl100k_base") + + +def count_tokens(text: str) -> int: + """Count the number of tokens in a text string using the cl100k_base encoding. + + Falls back to a rough character-based estimate if encoding fails. + """ + try: + return len(_tiktoken_enc.encode(text)) + except Exception: + return len(text) // 4 + + +def estimate_tokens(runtime_prompt: str, messages: list, observation) -> int: + """Estimate total token count of the current agent context (prompt + messages + observation).""" + parts = [runtime_prompt] + for msg in messages: + if hasattr(msg, "content") and isinstance(msg.content, str): + parts.append(msg.content) + if observation is not None: + for item in (observation.memory or []): + parts.append(item) + for item in (observation.results or []): + parts.append(item) + return count_tokens("\n".join(parts)) + + +def truncate_tool_output(tool_output: str, tool_name: str, max_tokens: int = 400) -> str: + """Truncate tool output to fit within a token budget. + + Uses tool-specific strategies: Call Chain Analyzer keeps the header, + Code Keyword Search preserves section headers, and the default strategy + keeps head (70%) and tail (30%) of the output. + """ + token_count = count_tokens(tool_output) + if token_count <= max_tokens: + return tool_output + + lines = tool_output.split('\n') + + if tool_name == ToolNames.CALL_CHAIN_ANALYZER: + head = '\n'.join(lines[:3]) + remaining = token_count - count_tokens(head) + return f"{head}\n[... truncated {remaining} tokens ...]" + + if tool_name == ToolNames.CODE_KEYWORD_SEARCH: + kept_lines = [] + kept_tokens = 0 + for line in lines: + is_header = line.startswith("---") or "Main application" in line or "library dependencies" in line or line.strip() == "" + line_tokens = count_tokens(line) + if is_header or kept_tokens + line_tokens <= max_tokens: + kept_lines.append(line) + kept_tokens += line_tokens + if kept_tokens >= max_tokens: + break + kept_lines.append(f"[... truncated {token_count - kept_tokens} tokens ...]") + return '\n'.join(kept_lines) + + head_budget = int(max_tokens * 0.7) + tail_budget = max_tokens - head_budget + head_lines, head_tokens = [], 0 + for line in lines: + lt = count_tokens(line) + if head_tokens + lt > head_budget: + break + head_lines.append(line) + head_tokens += lt + tail_lines, tail_tokens = [], 0 + for line in reversed(lines): + lt = count_tokens(line) + if tail_tokens + lt > tail_budget: + break + tail_lines.insert(0, line) + tail_tokens += lt + truncated = token_count - head_tokens - tail_tokens + return '\n'.join(head_lines) + f"\n[... truncated {truncated} tokens ...]\n" + '\n'.join(tail_lines) diff --git a/tests/agent_test_helpers.py b/tests/agent_test_helpers.py new file mode 100644 index 00000000..2d60aed2 --- /dev/null +++ b/tests/agent_test_helpers.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +from vuln_analysis.tools.tool_names import ToolNames + + +class MockTool: + def __init__(self, name: str): + self.name = name + + +ALL_TOOLS = [ + MockTool(ToolNames.CODE_SEMANTIC_SEARCH), + MockTool(ToolNames.DOCS_SEMANTIC_SEARCH), + MockTool(ToolNames.CODE_KEYWORD_SEARCH), + MockTool(ToolNames.FUNCTION_LOCATOR), + MockTool(ToolNames.CALL_CHAIN_ANALYZER), + MockTool(ToolNames.FUNCTION_CALLER_FINDER), + MockTool(ToolNames.CVE_WEB_SEARCH), + MockTool(ToolNames.CONTAINER_ANALYSIS_DATA), + MockTool(ToolNames.FUNCTION_LIBRARY_VERSION_FINDER), + MockTool(ToolNames.CONFIGURATION_SCANNER), + MockTool(ToolNames.IMPORT_USAGE_ANALYZER), +] + + +def make_builder(tools=None): + builder = MagicMock() + builder.get_tools = MagicMock(return_value=list(tools if tools is not None else ALL_TOOLS)) + return builder + + +def make_config(**overrides): + config = MagicMock() + config.tool_names = overrides.get("tool_names", []) + config.transitive_search_tool_enabled = overrides.get("transitive_search_tool_enabled", True) + config.cve_web_search_enabled = overrides.get("cve_web_search_enabled", True) + config.max_iterations = 10 + return config + + +def make_state(code_vdb_path="/path", doc_vdb_path="/path", code_index_path="/path"): + state = MagicMock() + state.code_vdb_path = code_vdb_path + state.doc_vdb_path = doc_vdb_path + state.code_index_path = code_index_path + return state diff --git a/tests/test_agent_registry.py b/tests/test_agent_registry.py new file mode 100644 index 00000000..6260495d --- /dev/null +++ b/tests/test_agent_registry.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for agent_registry: register_agent decorator, get_agent_class, get_all_agent_types. +""" + +import pytest + +from vuln_analysis.functions.agent_registry import ( + _AGENT_REGISTRY, + register_agent, + get_agent_class, + get_all_agent_types, +) + + +class TestRealAgentRegistration: + """Verify that importing the agent modules registers both agents.""" + + def test_reachability_registered(self): + import vuln_analysis.functions.reachability_agent # noqa: F401 + assert "reachability" in _AGENT_REGISTRY + + def test_code_understanding_registered(self): + import vuln_analysis.functions.code_understanding_agent # noqa: F401 + assert "code_understanding" in _AGENT_REGISTRY + + def test_get_all_agent_types_contains_both(self): + import vuln_analysis.functions.reachability_agent # noqa: F401 + import vuln_analysis.functions.code_understanding_agent # noqa: F401 + types = get_all_agent_types() + assert "reachability" in types + assert "code_understanding" in types + + def test_get_agent_class_reachability(self): + from vuln_analysis.functions.reachability_agent import ReachabilityAgent + assert get_agent_class("reachability") is ReachabilityAgent + + def test_get_agent_class_code_understanding(self): + from vuln_analysis.functions.code_understanding_agent import CodeUnderstandingAgent + assert get_agent_class("code_understanding") is CodeUnderstandingAgent + + +class TestGetAgentClassErrors: + """Test error cases for get_agent_class.""" + + def test_unknown_type_raises_key_error(self): + with pytest.raises(KeyError, match="Unknown agent type 'nonexistent'"): + get_agent_class("nonexistent") + + def test_error_message_lists_registered_types(self): + import vuln_analysis.functions.reachability_agent # noqa: F401 + import vuln_analysis.functions.code_understanding_agent # noqa: F401 + with pytest.raises(KeyError) as exc_info: + get_agent_class("bogus") + msg = str(exc_info.value) + assert "reachability" in msg + assert "code_understanding" in msg + + +class TestRegisterAgentDecorator: + """Test the register_agent decorator mechanics using a dummy class.""" + + def setup_method(self): + self._saved = dict(_AGENT_REGISTRY) + + def teardown_method(self): + _AGENT_REGISTRY.clear() + _AGENT_REGISTRY.update(self._saved) + + def test_decorator_registers_class(self): + @register_agent("test_dummy") + class DummyAgent: + pass + + assert get_agent_class("test_dummy") is DummyAgent + + def test_decorator_returns_original_class(self): + @register_agent("test_dummy2") + class DummyAgent: + pass + + assert DummyAgent.__name__ == "DummyAgent" + + def test_re_registration_replaces_class(self): + @register_agent("test_replace") + class First: + pass + + @register_agent("test_replace") + class Second: + pass + + assert get_agent_class("test_replace") is Second + + def test_registered_type_appears_in_all_types(self): + @register_agent("test_listed") + class Listed: + pass + + assert "test_listed" in get_all_agent_types() diff --git a/tests/test_base_graph_agent.py b/tests/test_base_graph_agent.py new file mode 100644 index 00000000..2ea1bcf4 --- /dev/null +++ b/tests/test_base_graph_agent.py @@ -0,0 +1,805 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for BaseGraphAgent: should_continue routing, default hooks, agent_type property, +thought_node context pruning. +""" + +import pytest +from unittest.mock import MagicMock, AsyncMock, patch +from langchain_core.messages import SystemMessage, AIMessage, HumanMessage, ToolMessage + +from vuln_analysis.functions.base_graph_agent import BaseGraphAgent +from vuln_analysis.functions.react_internals import Thought, ToolCall, Observation + + +class _ConcreteAgent(BaseGraphAgent): + """Minimal concrete subclass for testing base class behavior.""" + + async def pre_process_node(self, state): + return state + + @staticmethod + def get_tools(builder, config, state): + return [] + + @staticmethod + def create_rules_tracker(): + return MagicMock() + + +def _make_agent(): + mock_llm = MagicMock() + mock_llm.with_structured_output = MagicMock(return_value=MagicMock()) + config = MagicMock() + config.max_iterations = 10 + return _ConcreteAgent(tools=[], llm=mock_llm, config=config) + + +class TestShouldContinue: + """Test should_continue routing logic.""" + + @pytest.mark.asyncio + async def test_returns_end_on_finish_mode(self): + agent = _make_agent() + thought = Thought(thought="done", mode="finish", actions=None, final_answer="answer") + state = {"thought": thought, "step": 3, "max_steps": 10} + result = await agent.should_continue(state) + assert result == "__end__" + + @pytest.mark.asyncio + async def test_returns_forced_finish_at_max_steps(self): + agent = _make_agent() + thought = Thought( + thought="still working", + mode="act", + actions=ToolCall(tool="some_tool", query="q", reason="testing"), + final_answer=None, + ) + state = {"thought": thought, "step": 10, "max_steps": 10} + result = await agent.should_continue(state) + assert result == "forced_finish_node" + + @pytest.mark.asyncio + async def test_returns_forced_finish_beyond_max_steps(self): + agent = _make_agent() + thought = Thought( + thought="still working", + mode="act", + actions=ToolCall(tool="some_tool", query="q", reason="testing"), + final_answer=None, + ) + state = {"thought": thought, "step": 15, "max_steps": 10} + result = await agent.should_continue(state) + assert result == "forced_finish_node" + + @pytest.mark.asyncio + async def test_returns_tool_node_when_continuing(self): + agent = _make_agent() + thought = Thought( + thought="need more info", + mode="act", + actions=ToolCall(tool="some_tool", query="q", reason="testing"), + final_answer=None, + ) + state = {"thought": thought, "step": 3, "max_steps": 10} + result = await agent.should_continue(state) + assert result == "tool_node" + + @pytest.mark.asyncio + async def test_returns_thought_node_when_thought_is_none(self): + agent = _make_agent() + state = {"thought": None, "step": 0, "max_steps": 10} + result = await agent.should_continue(state) + assert result == "thought_node" + + @pytest.mark.asyncio + async def test_forced_finish_when_thought_none_at_max_steps(self): + """Step limit must be enforced even when thought is None (e.g. after + check_finish_allowed repeatedly blocks). Without this, the agent + self-loops thought_node→thought_node until GraphRecursionError.""" + agent = _make_agent() + state = {"thought": None, "step": 10, "max_steps": 10} + result = await agent.should_continue(state) + assert result == "forced_finish_node" + + @pytest.mark.asyncio + async def test_forced_finish_when_thought_none_beyond_max_steps(self): + agent = _make_agent() + state = {"thought": None, "step": 15, "max_steps": 10} + result = await agent.should_continue(state) + assert result == "forced_finish_node" + + @pytest.mark.asyncio + async def test_uses_config_max_iterations_as_fallback(self): + agent = _make_agent() + thought = Thought( + thought="working", + mode="act", + actions=ToolCall(tool="some_tool", query="q", reason="testing"), + final_answer=None, + ) + state = {"thought": thought, "step": 10} + result = await agent.should_continue(state) + assert result == "forced_finish_node" + + +class TestDefaultHooks: + """Test default hook implementations on BaseGraphAgent.""" + + def test_post_observation_returns_empty_dict(self): + agent = _make_agent() + result = agent.post_observation(state={}, tool_used="X", tool_output="Y", tool_input_detail="Z") + assert result == {} + + def test_should_truncate_returns_false(self): + agent = _make_agent() + result = agent.should_truncate_tool_output(state={}, tool_used="X") + assert result is False + + def test_agent_type_property(self): + agent = _make_agent() + assert agent.agent_type == "base" + + def test_build_comprehension_context_returns_full_context(self): + agent = _make_agent() + state = {"critical_context": ["CVE Description: test vuln", "Vulnerable module: xstream"]} + result = agent.build_comprehension_context(state) + assert "CVE Description: test vuln" in result + assert "Vulnerable module: xstream" in result + + def test_build_comprehension_context_empty_state(self): + agent = _make_agent() + assert agent.build_comprehension_context({}) == "N/A" + + def test_build_comprehension_context_empty_list(self): + agent = _make_agent() + assert agent.build_comprehension_context({"critical_context": []}) == "N/A" + + @patch("vuln_analysis.functions.base_graph_agent.ctx_state") + def test_sanitize_findings_replaces_wrong_cve(self, mock_ctx): + intel = MagicMock() + intel.vuln_id = "CVE-2021-43859" + ws = MagicMock() + ws.cve_intel = [intel] + mock_ctx.get.return_value = ws + + agent = _make_agent() + findings = ["Found CVE-2020-26217 in code", "Package present"] + result = agent.sanitize_findings(findings, {}) + assert "CVE-2020-26217" not in result[0] + assert "the investigated vulnerability" in result[0] + assert result[1] == "Package present" + + @patch("vuln_analysis.functions.base_graph_agent.ctx_state") + def test_sanitize_findings_keeps_correct_cve(self, mock_ctx): + intel = MagicMock() + intel.vuln_id = "CVE-2021-43859" + ws = MagicMock() + ws.cve_intel = [intel] + mock_ctx.get.return_value = ws + + agent = _make_agent() + findings = ["Affects CVE-2021-43859"] + result = agent.sanitize_findings(findings, {}) + assert result == ["Affects CVE-2021-43859"] + + @patch("vuln_analysis.functions.base_graph_agent.ctx_state") + def test_sanitize_findings_empty_list(self, mock_ctx): + ws = MagicMock() + ws.cve_intel = [] + mock_ctx.get.return_value = ws + + agent = _make_agent() + assert agent.sanitize_findings([], {}) == [] + + +class TestInit: + """Test BaseGraphAgent constructor wires up LLM wrappers.""" + + def test_creates_four_structured_output_llms(self): + mock_llm = MagicMock() + config = MagicMock() + config.max_iterations = 10 + agent = _ConcreteAgent(tools=["t1", "t2"], llm=mock_llm, config=config) + + assert mock_llm.with_structured_output.call_count == 4 + assert agent.tools == ["t1", "t2"] + assert agent.config is config + + +def _make_thought_response(mode="finish", final_answer="done"): + return Thought(thought="thinking", mode=mode, actions=None, final_answer=final_answer) + + +def _long_content(n_words=500): + return " ".join(["word"] * n_words) + + +class TestThoughtNodePruning: + """Test that thought_node prunes messages when tokens exceed the limit.""" + + @pytest.mark.asyncio + @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER") + async def test_prunes_middle_messages_when_over_limit(self, mock_tracer): + agent = _make_agent() + agent.config.context_window_token_limit = 100 + agent.thought_llm.ainvoke = AsyncMock(return_value=_make_thought_response()) + + long = _long_content(200) + state = { + "runtime_prompt": "system prompt", + "messages": [ + HumanMessage(content=long), + AIMessage(content=long), + ToolMessage(content=long, tool_call_id="tc1"), + AIMessage(content=long), + ToolMessage(content="recent tool output", tool_call_id="tc2"), + HumanMessage(content="recent question"), + ], + "observation": None, + "step": 2, + } + + await agent.thought_node(state) + + invoked_messages = agent.thought_llm.ainvoke.call_args[0][0] + num_original = 1 + 6 # system prompt + 6 state messages + assert len(invoked_messages) < num_original + + @pytest.mark.asyncio + @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER") + async def test_no_pruning_when_under_limit(self, mock_tracer): + agent = _make_agent() + agent.config.context_window_token_limit = 50000 + agent.thought_llm.ainvoke = AsyncMock(return_value=_make_thought_response()) + + state = { + "runtime_prompt": "short prompt", + "messages": [ + HumanMessage(content="hello"), + AIMessage(content="response"), + ], + "observation": None, + "step": 1, + } + + await agent.thought_node(state) + + invoked_messages = agent.thought_llm.ainvoke.call_args[0][0] + contents = [m.content for m in invoked_messages if hasattr(m, "content")] + assert "hello" in contents + assert "response" in contents + + @pytest.mark.asyncio + @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER") + async def test_preserves_system_prompt_and_last_message(self, mock_tracer): + agent = _make_agent() + agent.config.context_window_token_limit = 50 + agent.thought_llm.ainvoke = AsyncMock(return_value=_make_thought_response()) + + long = _long_content(200) + state = { + "runtime_prompt": "system prompt", + "messages": [ + HumanMessage(content=long), + AIMessage(content=long), + ToolMessage(content=long, tool_call_id="tc1"), + HumanMessage(content="latest question"), + ], + "observation": None, + "step": 3, + } + + await agent.thought_node(state) + + invoked_messages = agent.thought_llm.ainvoke.call_args[0][0] + contents = [m.content for m in invoked_messages if hasattr(m, "content")] + assert "system prompt" in contents + assert "latest question" in contents + assert long not in contents + + @pytest.mark.asyncio + @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER") + async def test_pruning_includes_observation_context_in_count(self, mock_tracer): + agent = _make_agent() + agent.config.context_window_token_limit = 200 + agent.thought_llm.ainvoke = AsyncMock(return_value=_make_thought_response()) + + long = _long_content(100) + obs = Observation( + memory=[_long_content(50)], + results=[_long_content(50)], + ) + state = { + "runtime_prompt": "system prompt", + "messages": [ + HumanMessage(content=long), + AIMessage(content=long), + ToolMessage(content="tool out", tool_call_id="tc1"), + HumanMessage(content="question"), + ], + "observation": obs, + "step": 2, + } + + await agent.thought_node(state) + + invoked_messages = agent.thought_llm.ainvoke.call_args[0][0] + assert any("KNOWLEDGE" in m.content for m in invoked_messages if hasattr(m, "content") and isinstance(m.content, str)) + contents = [m.content for m in invoked_messages if hasattr(m, "content")] + assert long not in contents + + +class TestThoughtNodeBadToolArguments: + """Test that thought_node recovers from bad tool arguments instead of crashing. + + Mirrors the old AgentExecutor's handle_parsing_errors behavior: when the LLM + produces a ToolCall with missing required fields, the agent should get an error + message and retry rather than killing the entire graph. + """ + + @pytest.mark.asyncio + @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER") + async def test_recovers_from_missing_arguments(self, mock_tracer): + """When all ToolCall fields are None, thought_node returns an error + HumanMessage with thought=None so should_continue loops back.""" + agent = _make_agent() + agent.config.context_window_token_limit = 50000 + + bad_actions = ToolCall( + tool="Function Library Version Finder", + package_name=None, + function_name=None, + query=None, + tool_input=None, + reason="check version", + ) + bad_response = Thought( + thought="check the version", + mode="act", + actions=bad_actions, + final_answer=None, + ) + agent.thought_llm.ainvoke = AsyncMock(return_value=bad_response) + + state = { + "runtime_prompt": "system prompt", + "messages": [HumanMessage(content="Is SslHandler used?")], + "observation": None, + "step": 2, + } + + result = await agent.thought_node(state) + + assert result["thought"] is None + assert result["step"] == 3 + assert result["output"] == "waiting for the agent to respond" + assert len(result["messages"]) == 2 + ai_msg = result["messages"][0] + assert isinstance(ai_msg, AIMessage) + assert "check the version" in ai_msg.content + error_msg = result["messages"][1] + assert isinstance(error_msg, HumanMessage) + assert "ERROR" in error_msg.content + assert "Function Library Version Finder" in error_msg.content + + @pytest.mark.asyncio + @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER") + async def test_recovery_routes_back_to_thought_node_bad_args(self, mock_tracer): + """After a bad-arguments recovery, should_continue returns 'thought_node' + because thought is None — the agent gets another chance.""" + agent = _make_agent() + + state = {"thought": None, "step": 3, "max_steps": 10} + route = await agent.should_continue(state) + assert route == "thought_node" + + @pytest.mark.asyncio + @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER") + async def test_recovery_still_counts_toward_max_steps(self, mock_tracer): + """A bad-arguments iteration increments step, so the agent hits + forced_finish_node when step reaches max_steps.""" + agent = _make_agent() + agent.config.context_window_token_limit = 50000 + + bad_actions = ToolCall( + tool="Some Tool", + reason="testing", + ) + bad_response = Thought( + thought="trying something", + mode="act", + actions=bad_actions, + final_answer=None, + ) + agent.thought_llm.ainvoke = AsyncMock(return_value=bad_response) + + state = { + "runtime_prompt": "prompt", + "messages": [HumanMessage(content="question")], + "observation": None, + "step": 9, + "max_steps": 10, + } + + result = await agent.thought_node(state) + + assert result["step"] == 10 + assert result["thought"] is None + + route = await agent.should_continue(result) + assert route == "forced_finish_node" + + @pytest.mark.asyncio + @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER") + async def test_valid_tool_call_still_works(self, mock_tracer): + """Verify that valid tool calls are not affected by the ValueError handling.""" + agent = _make_agent() + agent.config.context_window_token_limit = 50000 + + good_actions = ToolCall( + tool="Configuration Scanner", + query="netty SSL settings", + reason="check config", + ) + good_response = Thought( + thought="scan for config", + mode="act", + actions=good_actions, + final_answer=None, + ) + agent.thought_llm.ainvoke = AsyncMock(return_value=good_response) + + state = { + "runtime_prompt": "system prompt", + "messages": [HumanMessage(content="question")], + "observation": None, + "step": 1, + } + + result = await agent.thought_node(state) + + assert result["thought"] is good_response + assert result["step"] == 2 + ai_msg = result["messages"][0] + assert isinstance(ai_msg, AIMessage) + assert ai_msg.tool_calls[0]["name"] == "Configuration Scanner" + assert ai_msg.tool_calls[0]["args"] == {"query": "netty SSL settings"} + + +class TestCheckFinishAllowedBlocking: + """Test that blocked finishes include AIMessage and respect step limits.""" + + @pytest.mark.asyncio + @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER") + async def test_blocked_finish_includes_ai_message(self, mock_tracer): + """When check_finish_allowed blocks, the LLM's finish attempt must be + recorded as an AIMessage so the chat model sees its own response and + the rejection in proper Human/AI alternation.""" + agent = _make_agent() + agent.config.context_window_token_limit = 50000 + agent.check_finish_allowed = MagicMock( + return_value=(False, "You MUST use Function Locator first.") + ) + + finish_response = Thought( + thought="I have enough info", + mode="finish", + actions=None, + final_answer="The function is not reachable.", + ) + agent.thought_llm.ainvoke = AsyncMock(return_value=finish_response) + + state = { + "runtime_prompt": "system prompt", + "messages": [HumanMessage(content="Is the function reachable?")], + "observation": None, + "step": 2, + } + + result = await agent.thought_node(state) + + assert result["thought"] is None + assert result["step"] == 3 + assert len(result["messages"]) == 2 + ai_msg = result["messages"][0] + assert isinstance(ai_msg, AIMessage) + assert "The function is not reachable." in ai_msg.content + human_msg = result["messages"][1] + assert isinstance(human_msg, HumanMessage) + assert "Function Locator" in human_msg.content + + @pytest.mark.asyncio + @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER") + async def test_blocked_finish_ai_message_falls_back_to_thought(self, mock_tracer): + """When final_answer is None, the AIMessage should use the thought text.""" + agent = _make_agent() + agent.config.context_window_token_limit = 50000 + agent.check_finish_allowed = MagicMock( + return_value=(False, "Call CCA first.") + ) + + finish_response = Thought( + thought="seems done", + mode="finish", + actions=None, + final_answer=None, + ) + agent.thought_llm.ainvoke = AsyncMock(return_value=finish_response) + + state = { + "runtime_prompt": "prompt", + "messages": [HumanMessage(content="question")], + "observation": None, + "step": 0, + } + + result = await agent.thought_node(state) + + ai_msg = result["messages"][0] + assert isinstance(ai_msg, AIMessage) + assert "seems done" in ai_msg.content + + @pytest.mark.asyncio + @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER") + async def test_blocked_finish_at_max_steps_routes_to_forced_finish(self, mock_tracer): + """If check_finish_allowed blocks at step 9 (incrementing to 10), + should_continue must route to forced_finish_node, not self-loop.""" + agent = _make_agent() + agent.config.context_window_token_limit = 50000 + agent.check_finish_allowed = MagicMock( + return_value=(False, "Call FL and CCA first.") + ) + + finish_response = Thought( + thought="done", + mode="finish", + actions=None, + final_answer="answer", + ) + agent.thought_llm.ainvoke = AsyncMock(return_value=finish_response) + + state = { + "runtime_prompt": "prompt", + "messages": [HumanMessage(content="question")], + "observation": None, + "step": 9, + "max_steps": 10, + } + + result = await agent.thought_node(state) + assert result["step"] == 10 + assert result["thought"] is None + + route = await agent.should_continue(result) + assert route == "forced_finish_node" + + @pytest.mark.asyncio + @patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER") + async def test_bad_args_includes_ai_message(self, mock_tracer): + """Bad tool arguments error must include an AIMessage with the LLM's + original thought for proper chat alternation.""" + agent = _make_agent() + agent.config.context_window_token_limit = 50000 + + bad_actions = ToolCall( + tool="Function Locator", + reason="locate function", + ) + bad_response = Thought( + thought="Let me find the function", + mode="act", + actions=bad_actions, + final_answer=None, + ) + agent.thought_llm.ainvoke = AsyncMock(return_value=bad_response) + + state = { + "runtime_prompt": "prompt", + "messages": [HumanMessage(content="question")], + "observation": None, + "step": 3, + } + + result = await agent.thought_node(state) + + assert len(result["messages"]) == 2 + ai_msg = result["messages"][0] + assert isinstance(ai_msg, AIMessage) + assert "Let me find the function" in ai_msg.content + error_msg = result["messages"][1] + assert isinstance(error_msg, HumanMessage) + assert "ERROR" in error_msg.content + + +class TestSelectPackage: + """Tests for _select_package image-match fast path and LLM fallback.""" + + def _make_workflow_state(self, image_name="registry.redhat.io/openshift4/ose-docker-builder", + git_repo="https://github.com/openshift/builder"): + si = MagicMock() + si.git_repo = git_repo + image = MagicMock() + image.name = image_name + image.source_info = [si] + ws = MagicMock() + ws.original_input.input.image = image + return ws + + @pytest.mark.asyncio + async def test_image_match_skips_llm(self): + """When a candidate name matches the image/repo, LLM is not called.""" + agent = _make_agent() + candidates = [ + {"name": "builder", "source": "rhsa"}, + {"name": "kernel", "source": "rhsa"}, + {"name": "glibc", "source": "rhsa"}, + ] + ws = self._make_workflow_state() + + with patch("vuln_analysis.utils.intel_utils.filter_context_to_package", + side_effect=lambda ctx, pkg, cands: ctx): + ctx, selected = await agent._select_package( + "go", candidates, ["CVE desc"], ws, + ) + + assert selected == "builder" + agent.package_filter_llm.ainvoke.assert_not_called() + + @pytest.mark.asyncio + async def test_no_match_calls_llm(self): + """When no candidate matches the image, LLM is called.""" + agent = _make_agent() + mock_selection = MagicMock() + mock_selection.selected_package = "xstream" + mock_selection.reason = "ecosystem match" + agent.package_filter_llm.ainvoke = AsyncMock(return_value=mock_selection) + + candidates = [ + {"name": "xstream", "source": "ghsa", "ecosystem": "Maven"}, + {"name": "kernel", "source": "rhsa"}, + ] + ws = self._make_workflow_state(image_name="registry.redhat.io/infinispan/server", + git_repo="https://github.com/infinispan/infinispan") + + with patch("vuln_analysis.utils.intel_utils.filter_context_to_package", + side_effect=lambda ctx, pkg, cands: ctx): + ctx, selected = await agent._select_package( + "java", candidates, ["CVE desc"], ws, + ) + + assert selected == "xstream" + agent.package_filter_llm.ainvoke.assert_called_once() + + @pytest.mark.asyncio + async def test_image_match_with_many_candidates(self): + """1000+ candidates with image match -> LLM skipped, no overflow.""" + agent = _make_agent() + candidates = [{"name": f"rhsa-product-{i}", "source": "rhsa"} for i in range(1200)] + candidates.append({"name": "builder", "source": "rhsa"}) + ws = self._make_workflow_state() + + with patch("vuln_analysis.utils.intel_utils.filter_context_to_package", + side_effect=lambda ctx, pkg, cands: ctx): + ctx, selected = await agent._select_package( + "go", candidates, ["CVE desc"], ws, + ) + + assert selected == "builder" + agent.package_filter_llm.ainvoke.assert_not_called() + + @pytest.mark.asyncio + async def test_single_candidate_no_llm(self): + """Single candidate is used directly without LLM.""" + agent = _make_agent() + candidates = [{"name": "jinja2", "source": "ghsa"}] + ws = self._make_workflow_state() + + with patch("vuln_analysis.utils.intel_utils.filter_context_to_package", + side_effect=lambda ctx, pkg, cands: ctx): + ctx, selected = await agent._select_package( + "python", candidates, ["CVE desc"], ws, + ) + + assert selected == "jinja2" + agent.package_filter_llm.ainvoke.assert_not_called() + + +class TestForcedFinishNode: + """Tests for forced_finish_node using observation memory instead of full history.""" + + @pytest.mark.asyncio + async def test_uses_observation_memory_not_full_history(self): + """forced_finish_node should build prompt from observation memory, + not the full conversation history, to avoid token overflow.""" + agent = _make_agent() + mock_response = Thought( + thought="summarizing", mode="finish", actions=None, + final_answer="Based on evidence, the function is not reachable.", + ) + agent.thought_llm.ainvoke = AsyncMock(return_value=mock_response) + + obs = Observation( + results=["CCA returned (False, [])"], + memory=["Package validated: commons-beanutils:1.9.4", + "FL found PropertyUtilsBean.getProperty", + "CCA: function not reachable from app code"], + ) + big_messages = [ + HumanMessage(content=f"tool output {i}" * 200) for i in range(10) + ] + state = { + "step": 10, "max_steps": 10, + "runtime_prompt": "You are a security analyst.", + "messages": big_messages, + "observation": obs, + "thought": None, + } + + with patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER"): + result = await agent.forced_finish_node(state) + + call_args = agent.thought_llm.ainvoke.call_args[0][0] + assert not any( + msg.content.startswith("tool output") for msg in call_args + if hasattr(msg, "content") + ), "Full conversation history should not be in the prompt" + knowledge_msg = [ + msg for msg in call_args + if hasattr(msg, "content") and "KNOWLEDGE" in msg.content + ] + assert len(knowledge_msg) == 1, "Observation memory should be in the prompt as KNOWLEDGE block" + assert "LATEST FINDINGS" in knowledge_msg[0].content, "Recent findings should also be included" + assert "CCA returned (False, [])" in knowledge_msg[0].content + assert result["output"] == "Based on evidence, the function is not reachable." + + @pytest.mark.asyncio + async def test_works_without_observation(self): + """forced_finish_node should work even when no observations exist.""" + agent = _make_agent() + mock_response = Thought( + thought="no evidence", mode="finish", actions=None, + final_answer="No evidence found.", + ) + agent.thought_llm.ainvoke = AsyncMock(return_value=mock_response) + + state = { + "step": 10, "max_steps": 10, + "runtime_prompt": "You are a security analyst.", + "messages": [HumanMessage(content="some message")], + "observation": None, + "thought": None, + } + + with patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER"): + result = await agent.forced_finish_node(state) + + call_args = agent.thought_llm.ainvoke.call_args[0][0] + assert len(call_args) == 2 # system prompt + forced finish prompt + assert result["output"] == "No evidence found." + + @pytest.mark.asyncio + async def test_fallback_on_non_finish_response(self): + """forced_finish_node returns default message when LLM doesn't finish.""" + agent = _make_agent() + mock_response = Thought( + thought="I want to call another tool", mode="act", + actions=ToolCall(tool="Function Locator", package_name="pkg", function_name="fn", reason="test"), + final_answer=None, + ) + agent.thought_llm.ainvoke = AsyncMock(return_value=mock_response) + + state = { + "step": 10, "max_steps": 10, + "runtime_prompt": "You are a security analyst.", + "messages": [], + "observation": None, + "thought": None, + } + + with patch("vuln_analysis.functions.base_graph_agent.AGENT_TRACER"): + result = await agent.forced_finish_node(state) + + assert "Failed to generate a final answer" in result["output"] diff --git a/tests/test_base_tool_descriptions.py b/tests/test_base_tool_descriptions.py index 37929060..f457f59f 100644 --- a/tests/test_base_tool_descriptions.py +++ b/tests/test_base_tool_descriptions.py @@ -13,10 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Tests for the base build_tool_descriptions() function. +"""Tests for the base build_tool_descriptions() function. -Tests that the consolidated base function provides simple tool descriptions +Verifies that the consolidated base function provides simple tool descriptions that can be formatted by specialized functions for different contexts. """ @@ -76,7 +75,7 @@ def test_base_all_tools(): result = build_tool_descriptions(tool_names) - # Should have 7 descriptions + # Should have 7 descriptions (CONTAINER_ANALYSIS_DATA has no description in the base function) assert len(result) == 7 # Verify all tools are present @@ -155,36 +154,6 @@ def test_mod_few_shot_structure(): print("✓ MOD_FEW_SHOT structure validated") -def test_cve_web_search_description_warns_about_versions(): - """Test that CVE Web Search description includes version warning.""" - # Without Version Finder: generic version warning - tool_names = [ToolNames.CVE_WEB_SEARCH] - result = build_tool_descriptions(tool_names) - assert len(result) == 1 - desc = result[0] - assert "MULTIPLE versions" in desc - assert ToolNames.FUNCTION_LIBRARY_VERSION_FINDER not in desc - - # With Version Finder (Java): references the tool - tool_names = [ToolNames.CVE_WEB_SEARCH, ToolNames.FUNCTION_LIBRARY_VERSION_FINDER] - result = build_tool_descriptions(tool_names) - descs = " ".join(result) - assert "MULTIPLE versions" in descs - assert ToolNames.FUNCTION_LIBRARY_VERSION_FINDER in descs - - print("✓ CVE Web Search description includes version warning") - - -def test_agent_prompt_contains_version_awareness_instructions(): - """Test that agent prompt template contains version awareness instructions.""" - from vuln_analysis.utils.prompting import get_agent_prompt - - prompt = get_agent_prompt() - assert "VERSION" in prompt - - print("✓ Agent prompt contains version awareness instructions") - - if __name__ == "__main__": print("Running Base Tool Descriptions tests...\n") @@ -194,7 +163,5 @@ def test_agent_prompt_contains_version_awareness_instructions(): test_base_empty_list() test_checklist_formats_descriptions() test_mod_few_shot_structure() - test_cve_web_search_description_warns_about_versions() - test_agent_prompt_contains_version_awareness_instructions() print("\n✅ All base tool descriptions tests passed!") diff --git a/tests/test_build_code_understanding_tools.py b/tests/test_build_code_understanding_tools.py new file mode 100644 index 00000000..1eb88038 --- /dev/null +++ b/tests/test_build_code_understanding_tools.py @@ -0,0 +1,276 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for CodeUnderstandingAgent: tool selection, availability, and comprehension hooks.""" + +from unittest.mock import MagicMock, patch + +from agent_test_helpers import MockTool, make_builder, make_config, make_state +from vuln_analysis.functions.code_understanding_agent import CodeUnderstandingAgent +from vuln_analysis.functions.code_understanding_internals import CodeUnderstandingRulesTracker +from vuln_analysis.tools.tool_names import ToolNames + + +def _get_tools(builder=None, config=None, state=None): + return CodeUnderstandingAgent.get_tools( + builder or make_builder(), + config or make_config(), + state or make_state(), + ) + + +class TestGetTools: + """Test CodeUnderstandingAgent.get_tools selection and availability logic.""" + + def test_filters_to_exactly_4_cu_tools(self): + result = _get_tools() + assert len(result) == 4 + + def test_output_tool_names(self): + result = _get_tools() + result_names = {t.name for t in result} + expected_names = { + ToolNames.DOCS_SEMANTIC_SEARCH, + ToolNames.CODE_KEYWORD_SEARCH, + ToolNames.CONFIGURATION_SCANNER, + ToolNames.IMPORT_USAGE_ANALYZER, + } + assert result_names == expected_names + + def test_excludes_reachability_tools(self): + tools = [ + MockTool(ToolNames.FUNCTION_LOCATOR), + MockTool(ToolNames.CALL_CHAIN_ANALYZER), + MockTool(ToolNames.FUNCTION_CALLER_FINDER), + MockTool(ToolNames.DOCS_SEMANTIC_SEARCH), + MockTool(ToolNames.CODE_KEYWORD_SEARCH), + ] + result = _get_tools(builder=make_builder(tools)) + result_names = {t.name for t in result} + assert ToolNames.FUNCTION_LOCATOR not in result_names + assert ToolNames.CALL_CHAIN_ANALYZER not in result_names + assert ToolNames.FUNCTION_CALLER_FINDER not in result_names + assert len(result) == 2 + + def test_excludes_web_and_container_tools(self): + tools = [ + MockTool(ToolNames.CVE_WEB_SEARCH), + MockTool(ToolNames.CONTAINER_ANALYSIS_DATA), + MockTool(ToolNames.DOCS_SEMANTIC_SEARCH), + MockTool(ToolNames.CODE_KEYWORD_SEARCH), + ] + result = _get_tools(builder=make_builder(tools)) + result_names = {t.name for t in result} + assert ToolNames.CVE_WEB_SEARCH not in result_names + assert ToolNames.CONTAINER_ANALYSIS_DATA not in result_names + assert len(result) == 2 + + def test_excludes_version_finder(self): + tools = [ + MockTool(ToolNames.FUNCTION_LIBRARY_VERSION_FINDER), + MockTool(ToolNames.DOCS_SEMANTIC_SEARCH), + MockTool(ToolNames.CODE_KEYWORD_SEARCH), + ] + result = _get_tools(builder=make_builder(tools)) + result_names = {t.name for t in result} + assert ToolNames.FUNCTION_LIBRARY_VERSION_FINDER not in result_names + assert len(result) == 2 + + def test_empty_builder_returns_empty(self): + result = _get_tools(builder=make_builder(tools=[])) + assert result == [] + + def test_no_matching_tools_returns_empty(self): + tools = [ + MockTool(ToolNames.FUNCTION_LOCATOR), + MockTool(ToolNames.CALL_CHAIN_ANALYZER), + MockTool(ToolNames.FUNCTION_CALLER_FINDER), + MockTool(ToolNames.CVE_WEB_SEARCH), + ] + result = _get_tools(builder=make_builder(tools)) + assert result == [] + + def test_preserves_tool_object_identity(self): + docs_tool = MockTool(ToolNames.DOCS_SEMANTIC_SEARCH) + keyword_tool = MockTool(ToolNames.CODE_KEYWORD_SEARCH) + locator_tool = MockTool(ToolNames.FUNCTION_LOCATOR) + builder = make_builder(tools=[docs_tool, keyword_tool, locator_tool]) + result = _get_tools(builder=builder) + assert len(result) == 2 + assert docs_tool in result + assert keyword_tool in result + assert locator_tool not in result + + def test_partial_overlap(self): + tools = [ + MockTool(ToolNames.DOCS_SEMANTIC_SEARCH), + MockTool(ToolNames.CODE_KEYWORD_SEARCH), + MockTool(ToolNames.FUNCTION_LOCATOR), + MockTool(ToolNames.CVE_WEB_SEARCH), + ] + result = _get_tools(builder=make_builder(tools)) + result_names = {t.name for t in result} + expected_names = { + ToolNames.DOCS_SEMANTIC_SEARCH, + ToolNames.CODE_KEYWORD_SEARCH, + } + assert len(result) == 2 + assert result_names == expected_names + + +class TestGetToolsAvailability: + """get_tools filters out tools whose infrastructure prerequisites are not met.""" + + def test_filters_docs_semantic_search_when_no_vdb(self): + state = make_state(doc_vdb_path=None) + result = _get_tools(state=state) + assert ToolNames.DOCS_SEMANTIC_SEARCH not in {t.name for t in result} + + def test_filters_code_keyword_search_when_no_index(self): + state = make_state(code_index_path=None) + result = _get_tools(state=state) + assert ToolNames.CODE_KEYWORD_SEARCH not in {t.name for t in result} + + def test_cu_only_tools_always_kept(self): + state = make_state(code_vdb_path=None, doc_vdb_path=None, code_index_path=None) + result = _get_tools(state=state) + result_names = {t.name for t in result} + assert ToolNames.CONFIGURATION_SCANNER in result_names + + def test_filters_import_usage_analyzer_when_no_index(self): + state = make_state(code_index_path=None) + result = _get_tools(state=state) + assert ToolNames.IMPORT_USAGE_ANALYZER not in {t.name for t in result} + + def test_import_usage_analyzer_available_with_index(self): + state = make_state(code_index_path="/some/path") + result = _get_tools(state=state) + assert ToolNames.IMPORT_USAGE_ANALYZER in {t.name for t in result} + + +class TestCodeUnderstandingAgentMeta: + """Test create_rules_tracker and agent_type for CodeUnderstandingAgent.""" + + def test_create_rules_tracker_returns_cu_tracker(self): + tracker = CodeUnderstandingAgent.create_rules_tracker() + assert isinstance(tracker, CodeUnderstandingRulesTracker) + + def test_create_rules_tracker_returns_fresh_instance(self): + t1 = CodeUnderstandingAgent.create_rules_tracker() + t2 = CodeUnderstandingAgent.create_rules_tracker() + assert t1 is not t2 + + def test_agent_type_is_cu(self): + mock_llm = MagicMock() + mock_llm.with_structured_output = MagicMock(return_value=MagicMock()) + config = MagicMock() + config.max_iterations = 10 + agent = CodeUnderstandingAgent(tools=[], llm=mock_llm, config=config) + assert agent.agent_type == "cu" + + +def _make_cu_agent(): + mock_llm = MagicMock() + mock_llm.with_structured_output = MagicMock(return_value=MagicMock()) + config = MagicMock() + config.max_iterations = 10 + return CodeUnderstandingAgent(tools=[], llm=mock_llm, config=config) + + +def _mock_ctx_state(*vuln_ids): + """Return a mock workflow state with cve_intel entries for the given vuln IDs.""" + intel_list = [] + for vid in vuln_ids: + intel = MagicMock() + intel.vuln_id = vid + intel_list.append(intel) + ws = MagicMock() + ws.cve_intel = intel_list + return ws + + +class TestCUComprehensionHooks: + """Test CU agent comprehension context reduction and CVE sanitization.""" + + @patch("vuln_analysis.functions.code_understanding_agent.ctx_state") + def test_build_comprehension_context_contains_vuln_id_and_package(self, mock_ctx): + mock_ctx.get.return_value = _mock_ctx_state("CVE-2021-43859") + agent = _make_cu_agent() + state = {"app_package": "com.thoughtworks.xstream:xstream"} + result = agent.build_comprehension_context(state) + assert "CVE-2021-43859" in result + assert "com.thoughtworks.xstream:xstream" in result + + @patch("vuln_analysis.functions.code_understanding_agent.ctx_state") + def test_build_comprehension_context_includes_grounding_instruction(self, mock_ctx): + mock_ctx.get.return_value = _mock_ctx_state("CVE-2021-43859") + agent = _make_cu_agent() + result = agent.build_comprehension_context({"app_package": "pkg"}) + assert "Only extract facts explicitly stated in the tool output" in result + assert "Do not add CVE IDs" in result + + @patch("vuln_analysis.functions.code_understanding_agent.ctx_state") + def test_build_comprehension_context_excludes_critical_context(self, mock_ctx): + mock_ctx.get.return_value = _mock_ctx_state("CVE-2021-43859") + agent = _make_cu_agent() + state = { + "app_package": "xstream", + "critical_context": ["GHSA description: XStream can cause DoS", "NVD: high severity"], + } + result = agent.build_comprehension_context(state) + assert "GHSA description" not in result + assert "NVD" not in result + + @patch("vuln_analysis.functions.code_understanding_agent.ctx_state") + def test_build_comprehension_context_unknown_fallbacks(self, mock_ctx): + ws = MagicMock() + ws.cve_intel = [] + mock_ctx.get.return_value = ws + agent = _make_cu_agent() + result = agent.build_comprehension_context({}) + assert "unknown" in result + + @patch("vuln_analysis.functions.base_graph_agent.ctx_state") + def test_sanitize_findings_replaces_wrong_cve(self, mock_ctx): + mock_ctx.get.return_value = _mock_ctx_state("CVE-2021-43859") + agent = _make_cu_agent() + findings = ["XStream 1.4.18 is vulnerable to CVE-2020-26217"] + result = agent.sanitize_findings(findings, {}) + assert result == ["XStream 1.4.18 is vulnerable to the investigated vulnerability"] + + @patch("vuln_analysis.functions.base_graph_agent.ctx_state") + def test_sanitize_findings_keeps_correct_cve(self, mock_ctx): + mock_ctx.get.return_value = _mock_ctx_state("CVE-2021-43859") + agent = _make_cu_agent() + findings = ["Affects CVE-2021-43859"] + result = agent.sanitize_findings(findings, {}) + assert result == ["Affects CVE-2021-43859"] + + @patch("vuln_analysis.functions.base_graph_agent.ctx_state") + def test_sanitize_findings_replaces_multiple_wrong_cves(self, mock_ctx): + mock_ctx.get.return_value = _mock_ctx_state("CVE-2021-43859") + agent = _make_cu_agent() + findings = ["CVE-2020-26217 and CVE-2019-10086 affect this, also CVE-2021-43859"] + result = agent.sanitize_findings(findings, {}) + assert "CVE-2020-26217" not in result[0] + assert "CVE-2019-10086" not in result[0] + assert "CVE-2021-43859" in result[0] + assert result[0].count("the investigated vulnerability") == 2 + + @patch("vuln_analysis.functions.base_graph_agent.ctx_state") + def test_sanitize_findings_no_cve_ids_unchanged(self, mock_ctx): + mock_ctx.get.return_value = _mock_ctx_state("CVE-2021-43859") + agent = _make_cu_agent() + findings = ["XStream 1.4.18 found in dependencies", "Package is present"] + result = agent.sanitize_findings(findings, {}) + assert result == findings + + @patch("vuln_analysis.functions.base_graph_agent.ctx_state") + def test_sanitize_findings_multi_cve_intel(self, mock_ctx): + mock_ctx.get.return_value = _mock_ctx_state("CVE-2021-43859", "CVE-2021-39144") + agent = _make_cu_agent() + findings = ["CVE-2021-43859 and CVE-2021-39144 and CVE-2020-26217"] + result = agent.sanitize_findings(findings, {}) + assert "CVE-2021-43859" in result[0] + assert "CVE-2021-39144" in result[0] + assert "CVE-2020-26217" not in result[0] diff --git a/tests/test_code_understanding_internals.py b/tests/test_code_understanding_internals.py new file mode 100644 index 00000000..ad52a8c5 --- /dev/null +++ b/tests/test_code_understanding_internals.py @@ -0,0 +1,176 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from vuln_analysis.functions.code_understanding_internals import ( + CodeUnderstandingRulesTracker, + CU_AGENT_SYS_PROMPT, + CU_AGENT_THOUGHT_INSTRUCTIONS, +) + + +class TestCodeUnderstandingRulesTracker: + def test_iua_allowed_without_survey(self): + """IUA can be called at any time — no gating.""" + tracker = CodeUnderstandingRulesTracker() + tracker.set_allowed_tools(["Import Usage Analyzer"]) + + violated, msg = tracker.check_thought_behavior( + "Import Usage Analyzer", + "com.example.package", + ["imports found"] + ) + + assert violated is False + assert msg == "" + + def test_allowed_tool_passes(self): + tracker = CodeUnderstandingRulesTracker() + tracker.set_allowed_tools(["Code Keyword Search"]) + + violated, msg = tracker.check_thought_behavior( + "Code Keyword Search", + "some query", + ["results"] + ) + + assert violated is False + + def test_docs_semantic_search_passes(self): + tracker = CodeUnderstandingRulesTracker() + tracker.set_allowed_tools(["Docs Semantic Search"]) + + violated, msg = tracker.check_thought_behavior( + "Docs Semantic Search", + "query", + ["docs"] + ) + + assert violated is False + + def test_check_rule7_fires(self): + tracker = CodeUnderstandingRulesTracker() + tracker.set_allowed_tools(["Code Keyword Search"]) + + tracker.check_thought_behavior( + "Code Keyword Search", + "com.example.Class", + [] + ) + + violated, msg = tracker.check_thought_behavior( + "Code Keyword Search", + "com.example.Another", + [] + ) + + assert violated is True + assert "Rule 7" in msg + assert "dots" in msg + + def test_check_allowed_tools_rejects_unknown(self): + tracker = CodeUnderstandingRulesTracker() + tracker.set_allowed_tools(["Code Keyword Search"]) + + violated, msg = tracker.check_thought_behavior( + "Unknown Tool", + "query", + ["results"] + ) + + assert violated is True + assert "AVAILABLE_TOOLS" in msg + assert "Code Keyword Search" in msg + + def test_check_passes_and_adds_action(self): + tracker = CodeUnderstandingRulesTracker() + tracker.set_allowed_tools(["Code Keyword Search"]) + + assert "Code Keyword Search" not in tracker.action_history + + violated, msg = tracker.check_thought_behavior( + "Code Keyword Search", + "import xstream", + ["result1", "result2"] + ) + + assert violated is False + assert msg == "" + assert "Code Keyword Search" in tracker.action_history + assert len(tracker.action_history["Code Keyword Search"]) == 1 + assert tracker.action_history["Code Keyword Search"][0]["input"] == "import xstream" + + def test_duplicate_config_scanner_blocked(self): + tracker = CodeUnderstandingRulesTracker() + tracker.set_allowed_tools(["Configuration Scanner"]) + tracker.check_thought_behavior( + "Configuration Scanner", "deserialization beanutils", ["config found"] + ) + violated, msg = tracker.check_thought_behavior( + "Configuration Scanner", "deserialization beanutils", ["config found"] + ) + assert violated is True + assert "already called" in msg + + def test_config_scanner_different_query_allowed(self): + tracker = CodeUnderstandingRulesTracker() + tracker.set_allowed_tools(["Configuration Scanner"]) + tracker.check_thought_behavior( + "Configuration Scanner", "deserialization beanutils", ["config found"] + ) + violated, msg = tracker.check_thought_behavior( + "Configuration Scanner", "security allowlist", ["other config"] + ) + assert violated is False + assert msg == "" + + def test_check_failing_does_not_add_action(self): + tracker = CodeUnderstandingRulesTracker() + tracker.set_allowed_tools(["Configuration Scanner"]) + + violated, msg = tracker.check_thought_behavior( + "Unknown Tool", + "query", + ["results"] + ) + + assert violated is True + assert "Unknown Tool" not in tracker.action_history + + +class TestCUConstants: + def test_sys_prompt_is_nonempty_string(self): + assert isinstance(CU_AGENT_SYS_PROMPT, str) + assert len(CU_AGENT_SYS_PROMPT) > 0 + assert len(CU_AGENT_SYS_PROMPT.strip()) > 0 + + def test_sys_prompt_mentions_code_understanding(self): + assert "code understanding" in CU_AGENT_SYS_PROMPT.lower() + + def test_sys_prompt_does_not_mention_call_chain_analyzer(self): + assert "Call Chain Analyzer" not in CU_AGENT_SYS_PROMPT + assert "call chain analyzer" not in CU_AGENT_SYS_PROMPT.lower() + + def test_thought_instructions_have_rules_tags(self): + assert "" in CU_AGENT_THOUGHT_INSTRUCTIONS + assert "" in CU_AGENT_THOUGHT_INSTRUCTIONS + + def test_thought_instructions_have_three_examples(self): + assert "" in CU_AGENT_THOUGHT_INSTRUCTIONS + assert "" in CU_AGENT_THOUGHT_INSTRUCTIONS + assert "" in CU_AGENT_THOUGHT_INSTRUCTIONS + + def test_thought_instructions_rules_numbered_1_to_9(self): + for i in range(1, 10): + assert f"{i}." in CU_AGENT_THOUGHT_INSTRUCTIONS diff --git a/tests/test_code_understanding_prompt_factory.py b/tests/test_code_understanding_prompt_factory.py new file mode 100644 index 00000000..e05484b3 --- /dev/null +++ b/tests/test_code_understanding_prompt_factory.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from vuln_analysis.tools.tool_names import ToolNames +from vuln_analysis.utils.code_understanding_prompt_factory import ( + CU_TOOL_GENERAL_DESCRIPTIONS, + CU_TOOL_SELECTION_STRATEGY, +) + + +class TestCUToolGeneralDescriptions: + def test_has_4_entries(self): + assert len(CU_TOOL_GENERAL_DESCRIPTIONS) == 4 + + def test_keys_match_cu_tool_names(self): + expected_keys = { + ToolNames.DOCS_SEMANTIC_SEARCH, + ToolNames.CODE_KEYWORD_SEARCH, + ToolNames.CONFIGURATION_SCANNER, + ToolNames.IMPORT_USAGE_ANALYZER, + } + assert set(CU_TOOL_GENERAL_DESCRIPTIONS.keys()) == expected_keys + + def test_values_non_empty(self): + for key, value in CU_TOOL_GENERAL_DESCRIPTIONS.items(): + assert isinstance(value, str), f"Value for '{key}' is not a string" + assert len(value.strip()) > 0, f"Value for '{key}' is empty" + + +class TestCUToolSelectionStrategy: + def test_has_5_ecosystems(self): + expected_ecosystems = {"python", "go", "java", "javascript", "c"} + assert set(CU_TOOL_SELECTION_STRATEGY.keys()) == expected_ecosystems + + def test_strategies_non_empty(self): + for ecosystem, strategy in CU_TOOL_SELECTION_STRATEGY.items(): + assert isinstance(strategy, str), f"Strategy for '{ecosystem}' is not a string" + assert len(strategy.strip()) > 0, f"Strategy for '{ecosystem}' is empty" + + def test_each_mentions_at_least_3_tool_names(self): + cu_tool_names = [ + ToolNames.DOCS_SEMANTIC_SEARCH, + ToolNames.CODE_KEYWORD_SEARCH, + ToolNames.CONFIGURATION_SCANNER, + ToolNames.IMPORT_USAGE_ANALYZER, + ] + + for ecosystem, strategy in CU_TOOL_SELECTION_STRATEGY.items(): + mentioned_tools = sum(1 for tool_name in cu_tool_names if tool_name in strategy) + assert mentioned_tools >= 3, f"Strategy for '{ecosystem}' mentions only {mentioned_tools} tool(s)" diff --git a/tests/test_configuration_scanner.py b/tests/test_configuration_scanner.py new file mode 100644 index 00000000..ca9a462c --- /dev/null +++ b/tests/test_configuration_scanner.py @@ -0,0 +1,420 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from vuln_analysis.tools.configuration_scanner import ( + _is_config_file, + _is_in_config_dir, + _collect_config_files, + search_config_content, +) + + +class TestIsConfigFile: + @pytest.mark.parametrize( + "file_path,expected", + [ + # Named config files — matched anywhere + ("config.yaml", True), + ("config.yml", True), + ("config.xml", True), + ("application.properties", True), + ("application.yml", True), + ("application.yaml", True), + ("settings.toml", True), + ("settings.yaml", True), + ("settings.yml", True), + ("beans.xml", True), + ("web.xml", True), + ("Dockerfile", True), + ("Dockerfile.prod", True), + ("Dockerfile.dev", True), + ("docker-compose.yml", True), + ("docker-compose.prod.yml", True), + # Config-specific extensions — matched anywhere + ("app.env", True), + (".env", True), + ("nginx.conf", True), + ("config.ini", True), + ("app.properties", True), + # Removed: build/dependency files (handled by other tools) + ("pom.xml", False), + ("build.gradle", False), + ("build.gradle.kts", False), + ("go.mod", False), + ("go.sum", False), + ("package.json", False), + ("requirements.txt", False), + ("setup.py", False), + ("setup.cfg", False), + ("pyproject.toml", False), + # Removed: catch-all extensions (only matched inside config dirs) + ("app.yml", False), + ("app.cfg", False), + ("data.yaml", False), + ("docker-compose-dev.yaml", False), + # Non-config files + ("main.py", False), + ("utils.go", False), + ("README.md", False), + ("app.js", False), + ("test.txt", False), + ("data.json", False), + ], + ) + def test_config_file_patterns(self, file_path, expected): + assert _is_config_file(file_path) == expected + + @pytest.mark.parametrize( + "file_path,expected", + [ + ("DOCKERFILE", True), + ("CONFIG.YAML", True), + ("APPLICATION.PROPERTIES", True), + ("DOCKER-COMPOSE.YML", True), + ("BEANS.XML", True), + ("SETTINGS.TOML", True), + ], + ) + def test_case_insensitive(self, file_path, expected): + assert _is_config_file(file_path) == expected + + def test_file_in_subdirectory(self): + assert _is_config_file("path/to/config.yaml") is True + assert _is_config_file("deep/nested/path/application.properties") is True + + +class TestIsInConfigDir: + @pytest.mark.parametrize( + "file_path,expected", + [ + ("config/app.txt", True), + ("conf/settings.txt", True), + ("conf.d/default.conf", True), + ("etc/nginx/nginx.conf", True), + ("src/main/resources/app.yml", True), + ("app/config/database.yml", True), + ("project/conf/server.xml", True), + ("src/main/java/App.java", False), + ("lib/utils.py", False), + ("tests/test_app.py", False), + ("data/sample.csv", False), + ], + ) + def test_config_dir_patterns(self, file_path, expected): + assert _is_in_config_dir(file_path) == expected + + @pytest.mark.parametrize( + "file_path,expected", + [ + ("Config/app.txt", True), + ("RESOURCES/app.yml", True), + ("ETC/nginx.conf", True), + ("CONF.D/default.conf", True), + ], + ) + def test_case_insensitive_dir(self, file_path, expected): + assert _is_in_config_dir(file_path) == expected + + +class TestCollectConfigFiles: + def test_finds_config_files(self, tmp_path): + (tmp_path / "config.yaml").write_text("key: value") + (tmp_path / "application.properties").write_text("app.name=test") + (tmp_path / "nginx.conf").write_text("server {}") + + result = _collect_config_files(str(tmp_path)) + + assert len(result) == 3 + paths = {path for path, _ in result} + assert "config.yaml" in paths + assert "application.properties" in paths + assert "nginx.conf" in paths + + for path, content in result: + assert len(content) > 0 + + def test_finds_files_in_config_dir(self, tmp_path): + config_dir = tmp_path / "config" + config_dir.mkdir() + (config_dir / "database.txt").write_text("db_config") + (config_dir / "server.txt").write_text("server_config") + + result = _collect_config_files(str(tmp_path)) + + assert len(result) == 2 + paths = {path for path, _ in result} + assert "config/database.txt" in paths + assert "config/server.txt" in paths + + def test_excludes_git_dir(self, tmp_path): + git_dir = tmp_path / ".git" + git_dir.mkdir() + (git_dir / "config").write_text("git config") + (tmp_path / "application.yml").write_text("app: config") + + result = _collect_config_files(str(tmp_path)) + + assert len(result) == 1 + paths = {path for path, _ in result} + assert "application.yml" in paths + assert ".git/config" not in paths + + def test_excludes_pycache(self, tmp_path): + pycache_dir = tmp_path / "__pycache__" + pycache_dir.mkdir() + (pycache_dir / "config.yaml").write_text("cached") + (tmp_path / "application.yml").write_text("app: config") + + result = _collect_config_files(str(tmp_path)) + + assert len(result) == 1 + paths = {path for path, _ in result} + assert "application.yml" in paths + assert "__pycache__/config.yaml" not in paths + + def test_excludes_node_modules(self, tmp_path): + node_modules_dir = tmp_path / "node_modules" + node_modules_dir.mkdir() + (node_modules_dir / "package.json").write_text("{}") + (tmp_path / "application.yml").write_text("app: config") + + result = _collect_config_files(str(tmp_path)) + + assert len(result) == 1 + paths = {path for path, _ in result} + assert "application.yml" in paths + assert "node_modules/package.json" not in paths + + def test_excludes_tox(self, tmp_path): + tox_dir = tmp_path / ".tox" + tox_dir.mkdir() + (tox_dir / "config.ini").write_text("[tox]") + (tmp_path / "application.yml").write_text("app: config") + + result = _collect_config_files(str(tmp_path)) + + assert len(result) == 1 + paths = {path for path, _ in result} + assert "application.yml" in paths + assert ".tox/config.ini" not in paths + + def test_skips_large_files(self, tmp_path): + large_content = "x" * 500_001 + (tmp_path / "large.properties").write_text(large_content) + (tmp_path / "small.properties").write_text("key=value") + + result = _collect_config_files(str(tmp_path)) + + assert len(result) == 1 + paths = {path for path, _ in result} + assert "small.properties" in paths + assert "large.properties" not in paths + + def test_empty_repo_returns_empty(self, tmp_path): + result = _collect_config_files(str(tmp_path)) + assert result == [] + + def test_nested_directory_structure(self, tmp_path): + src_dir = tmp_path / "src" / "main" / "resources" + src_dir.mkdir(parents=True) + (src_dir / "application.yml").write_text("app: test") + + conf_dir = tmp_path / "conf" + conf_dir.mkdir() + (conf_dir / "server.txt").write_text("server config") + + (tmp_path / "config.yaml").write_text("key: value") + + result = _collect_config_files(str(tmp_path)) + + assert len(result) == 3 + paths = {path for path, _ in result} + assert "src/main/resources/application.yml" in paths + assert "conf/server.txt" in paths + assert "config.yaml" in paths + + def test_handles_read_errors_gracefully(self, tmp_path): + (tmp_path / "good.properties").write_text("key=value") + bad_file = tmp_path / "bad.properties" + bad_file.write_bytes(b"\xff\xfe invalid utf-8") + bad_file.chmod(0o000) + + result = _collect_config_files(str(tmp_path)) + + paths = {path for path, _ in result} + assert "good.properties" in paths + + bad_file.chmod(0o644) + + def test_returns_relative_paths(self, tmp_path): + subdir = tmp_path / "config" + subdir.mkdir() + (subdir / "app.yml").write_text("app: config") + + result = _collect_config_files(str(tmp_path)) + + assert len(result) == 1 + path, _ = result[0] + assert path == "config/app.yml" + assert not path.startswith("/") + assert str(tmp_path) not in path + + def test_returns_file_contents(self, tmp_path): + expected_content = "database:\n host: localhost\n port: 5432" + (tmp_path / "config.yaml").write_text(expected_content) + + result = _collect_config_files(str(tmp_path)) + + assert len(result) == 1 + path, content = result[0] + assert path == "config.yaml" + assert content == expected_content + + +class TestSearchConfigContent: + def test_empty_config_files(self): + result = search_config_content([], ["keyword"]) + assert result == [] + + def test_no_matches(self): + config_files = [("app.yml", "database:\n host: localhost")] + result = search_config_content(config_files, ["xstream"]) + assert result == [] + + def test_single_match(self): + config_files = [("app.yml", "line1\nxstream_enabled: true\nline3")] + result = search_config_content(config_files, ["xstream"], context_lines=0) + assert len(result) == 1 + assert "app.yml" in result[0] + assert "xstream_enabled: true" in result[0] + + def test_source_label(self): + config_files = [("app.yml", "match_keyword")] + result = search_config_content(config_files, ["match"], source_label="git://repo", context_lines=0) + assert "source: git://repo" in result[0] + + def test_multiple_files(self): + config_files = [ + ("a.yml", "xstream: yes"), + ("b.yml", "no match here"), + ("c.xml", "xstream config"), + ] + result = search_config_content(config_files, ["xstream"], context_lines=0) + assert len(result) == 2 + assert "a.yml" in result[0] + assert "c.xml" in result[1] + + def test_max_results_cap(self): + config_files = [("big.yml", "\n".join(f"keyword_{i}" for i in range(50)))] + result = search_config_content(config_files, ["keyword"], max_results=5, context_lines=0) + assert len(result) == 5 + + def test_context_lines(self): + content = "line1\nline2\nMATCH_LINE\nline4\nline5" + config_files = [("app.yml", content)] + result = search_config_content(config_files, ["match_line"], context_lines=1) + assert len(result) == 1 + assert "line2" in result[0] + assert "> 3: MATCH_LINE" in result[0] + assert "line4" in result[0] + assert "line1" not in result[0] + assert "line5" not in result[0] + + def test_case_insensitive_keywords(self): + config_files = [("app.yml", "XStream_Config: enabled")] + result = search_config_content(config_files, ["xstream"], context_lines=0) + assert len(result) == 1 + + def test_multiple_keywords_match_any(self): + config_files = [ + ("a.yml", "xstream: true"), + ("b.yml", "deserialization: false"), + ("c.yml", "no relevant content"), + ] + result = search_config_content(config_files, ["xstream", "deserialization"], context_lines=0) + assert len(result) == 2 + + def test_line_number_in_output(self): + config_files = [("app.yml", "line1\nline2\nxstream: true")] + result = search_config_content(config_files, ["xstream"], context_lines=0) + assert "line 3" in result[0] + + def test_default_source_label(self): + config_files = [("app.yml", "match_keyword")] + result = search_config_content(config_files, ["match"], context_lines=0) + assert "source: unknown" in result[0] + + +class TestConfigScannerSourceAwareness: + """Tests for app/dep classification and source scoping in Configuration Scanner.""" + + def _make_configs(self): + return [ + ("config.yaml", "xstream: app-level"), + ("src/main/resources/application.yml", "xstream: app-config"), + ("dependencies-sources/xstream-1.4/config.properties", "xstream: dep-xstream"), + ("dependencies-sources/commons-io-2.6/config.properties", "xstream: dep-commons"), + ] + + def test_app_config_prioritized_over_dependency(self): + from vuln_analysis.utils.source_classification import ( + is_dependency_path, format_app_dep_output, + ) + configs = self._make_configs() + app_configs = [c for c in configs if not is_dependency_path(c[0])] + dep_configs = [c for c in configs if is_dependency_path(c[0])] + + app_matches = search_config_content(app_configs, ["xstream"], max_results=10, context_lines=0) + dep_matches = search_config_content(dep_configs, ["xstream"], max_results=10, context_lines=0) + + result = format_app_dep_output(app_matches, dep_matches, len(app_matches), len(dep_matches), + "No matches") + assert "Main application (2 of 2 results)" in result + assert "Application library dependencies (2 of 2 results)" in result + app_pos = result.index("config.yaml") + dep_pos = result.index("Application library dependencies") + assert app_pos < dep_pos + + def test_source_scope_filters_dependency_configs(self): + from vuln_analysis.utils.source_classification import ( + is_dependency_path, filter_by_source_scope, + ) + configs = self._make_configs() + dep_configs = [c for c in configs if is_dependency_path(c[0])] + + filtered = filter_by_source_scope(dep_configs, ["xstream"], lambda x: x[0]) + assert len(filtered) == 1 + assert "xstream-1.4" in filtered[0][0] + + def test_app_configs_unaffected_by_source_scope(self): + from vuln_analysis.utils.source_classification import is_dependency_path + configs = self._make_configs() + app_configs = [c for c in configs if not is_dependency_path(c[0])] + + app_matches = search_config_content(app_configs, ["xstream"], max_results=10, context_lines=0) + assert len(app_matches) == 2 + + def test_app_dep_headers_in_output(self): + from vuln_analysis.utils.source_classification import format_app_dep_output + result = format_app_dep_output(["match1"], ["match2"], 1, 1, "No matches") + assert "Main application" in result + assert "Application library dependencies" in result + + def test_no_matches_preserves_message(self): + from vuln_analysis.utils.source_classification import format_app_dep_output + result = format_app_dep_output([], [], 0, 0, "No configuration entries found matching: xstream") + assert result == "No configuration entries found matching: xstream" diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py new file mode 100644 index 00000000..bdbeef34 --- /dev/null +++ b/tests/test_dispatcher.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from unittest.mock import AsyncMock +from pydantic import ValidationError +from langchain_core.messages import HumanMessage + +from vuln_analysis.functions.dispatcher import ( + QuestionRouting, + build_routing_prompt, + dispatch_question, +) + + +class TestQuestionRouting: + def test_valid_reachability_type(self): + routing = QuestionRouting(agent_type="reachability", reason="test reason") + assert routing.agent_type == "reachability" + assert routing.reason == "test reason" + + def test_valid_code_understanding_type(self): + routing = QuestionRouting(agent_type="code_understanding", reason="test reason") + assert routing.agent_type == "code_understanding" + assert routing.reason == "test reason" + + def test_invalid_type_rejected(self): + with pytest.raises(ValidationError) as exc_info: + QuestionRouting(agent_type="invalid", reason="test") + assert "agent_type" in str(exc_info.value) + + def test_missing_agent_type_rejected(self): + with pytest.raises(ValidationError) as exc_info: + QuestionRouting(reason="test") + assert "agent_type" in str(exc_info.value) + + def test_missing_reason_rejected(self): + with pytest.raises(ValidationError) as exc_info: + QuestionRouting(agent_type="reachability") + assert "reason" in str(exc_info.value) + + +class TestBuildRoutingPrompt: + def test_context_and_question_inserted(self): + context = "CVE-2024-1234, package: foo:bar" + question = "Is the function vulnerable()?" + result = build_routing_prompt(context, question) + + assert context in result + assert question in result + + def test_all_few_shot_examples_present(self): + result = build_routing_prompt("context", "question") + + expected_examples = [ + "XStream.fromXML()", + "XML parser", + "BeanUtils.populate()", + "commons-beanutils", + "parseXML()", + "external entity processing", + "newTransformer()", + "deserialize()", + ] + + for example in expected_examples: + assert example in result, f"Missing example: {example}" + + def test_both_agent_types_mentioned(self): + result = build_routing_prompt("context", "question") + + assert "reachability" in result + assert "code_understanding" in result + + def test_empty_inputs(self): + result = build_routing_prompt("", "") + assert isinstance(result, str) + assert len(result) > 0 + + def test_special_chars_in_inputs(self): + context = "CVE {test} with % and {{nested}}" + question = "Is {function}() reachable? % complete" + + result = build_routing_prompt(context, question) + + assert context in result + assert question in result + + +class TestDispatchQuestion: + @pytest.mark.asyncio + async def test_dispatch_returns_routing_result(self): + expected_result = QuestionRouting( + agent_type="reachability", + reason="Test reason" + ) + + mock_llm = AsyncMock() + mock_llm.ainvoke.return_value = expected_result + + result = await dispatch_question( + routing_llm=mock_llm, + question="Test question?", + context_block="Test context", + ) + + assert result == expected_result + assert result.agent_type == "reachability" + assert result.reason == "Test reason" + + @pytest.mark.asyncio + async def test_dispatch_passes_human_message(self): + mock_result = QuestionRouting( + agent_type="code_understanding", + reason="Config check" + ) + + mock_llm = AsyncMock() + mock_llm.ainvoke.return_value = mock_result + + question = "Is the XML parser configured?" + context = "CVE-2024-5678" + + await dispatch_question( + routing_llm=mock_llm, + question=question, + context_block=context, + ) + + mock_llm.ainvoke.assert_called_once() + + call_args = mock_llm.ainvoke.call_args[0][0] + assert len(call_args) == 1 + assert isinstance(call_args[0], HumanMessage) + assert question in call_args[0].content + assert context in call_args[0].content + + @pytest.mark.asyncio + async def test_dispatch_propagates_exception(self): + mock_llm = AsyncMock() + mock_llm.ainvoke.side_effect = ValueError("LLM error") + + with pytest.raises(ValueError, match="LLM error"): + await dispatch_question( + routing_llm=mock_llm, + question="Test question", + context_block="Test context", + ) diff --git a/tests/test_import_usage_analyzer.py b/tests/test_import_usage_analyzer.py new file mode 100644 index 00000000..98d77001 --- /dev/null +++ b/tests/test_import_usage_analyzer.py @@ -0,0 +1,278 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from exploit_iq_commons.utils.dep_tree import Ecosystem +from exploit_iq_commons.utils.functions_parsers.lang_functions_parsers_factory import get_language_function_parser +from vuln_analysis.tools.import_usage_analyzer import ( + _find_usage_in_file, + analyze_imports, +) + + +def _get_patterns(ecosystem: Ecosystem, package_name: str) -> list[re.Pattern]: + parser = get_language_function_parser(ecosystem, tree=None) + return parser.get_import_search_patterns(package_name) + + +class TestGetImportSearchPatterns: + def test_python_import_pattern(self): + patterns = _get_patterns(Ecosystem.PYTHON, "xml.etree") + + assert any(p.search("import xml.etree.ElementTree") for p in patterns) + assert any(p.search("from xml.etree import ElementTree") for p in patterns) + + def test_java_import_pattern(self): + patterns = _get_patterns(Ecosystem.JAVA, "com.thoughtworks.xstream") + + assert any(p.search("import com.thoughtworks.xstream.XStream;") for p in patterns) + assert any(p.search("import static com.thoughtworks.xstream.XStream;") for p in patterns) + + def test_go_import_pattern(self): + patterns = _get_patterns(Ecosystem.GO, "encoding/xml") + + assert any(p.search('import "encoding/xml"') for p in patterns) + assert any(p.search('import (\n "encoding/xml"\n)') for p in patterns) + + def test_javascript_require_pattern(self): + patterns = _get_patterns(Ecosystem.JAVASCRIPT, "xml2js") + + assert any(p.search("const xml = require('xml2js')") for p in patterns) + assert any(p.search('const xml = require("xml2js")') for p in patterns) + + def test_javascript_import_from_pattern(self): + patterns = _get_patterns(Ecosystem.JAVASCRIPT, "xml2js") + + assert any(p.search("import xml from 'xml2js'") for p in patterns) + assert any(p.search('import { parseString } from "xml2js"') for p in patterns) + + def test_c_include_pattern(self): + patterns = _get_patterns(Ecosystem.C_CPP, "libxml") + + assert any(p.search('#include ') for p in patterns) + assert any(p.search('#include "libxml/tree.h"') for p in patterns) + + def test_special_chars_escaped(self): + patterns = _get_patterns(Ecosystem.PYTHON, "foo.bar[baz]") + + assert len(patterns) > 0 + for p in patterns: + assert isinstance(p, re.Pattern) + + def test_case_insensitive(self): + patterns = _get_patterns(Ecosystem.PYTHON, "XmlParser") + + assert any(p.search("import xmlparser") for p in patterns) + assert any(p.search("import XMLPARSER") for p in patterns) + assert any(p.search("import XmlParser") for p in patterns) + + +class TestFindUsageInFile: + def test_finds_usage_sites(self): + content = """import xml.etree.ElementTree +tree = ElementTree.parse('data.xml') +root = tree.getroot() +for child in root: + print(child.tag) +""" + usages = _find_usage_in_file(content, ["xml.etree.ElementTree"]) + + assert len(usages) > 0 + assert any("ElementTree.parse" in u for u in usages) + + def test_skips_import_lines(self): + content = """import xml.etree.ElementTree +from xml.etree import ElementTree +tree = ElementTree.parse('data.xml') +""" + usages = _find_usage_in_file(content, ["ElementTree"]) + + assert len(usages) == 1 + assert "L3:" in usages[0] + assert "parse" in usages[0] + + def test_skips_from_lines(self): + content = """from xml.etree import ElementTree +tree = ElementTree.parse('data.xml') +""" + usages = _find_usage_in_file(content, ["ElementTree"]) + + assert len(usages) == 1 + assert "L2:" in usages[0] + + def test_skips_include_lines(self): + content = """#include +xmlDocPtr doc = xmlParseFile("data.xml"); +""" + usages = _find_usage_in_file(content, ["xmlParseFile"]) + + assert len(usages) == 1 + assert "L2:" in usages[0] + + def test_max_usages_cap(self): + content = "\n".join([f"var x{i} = XStream();" for i in range(20)]) + + usages = _find_usage_in_file(content, ["XStream"], max_usages=5) + + assert len(usages) == 5 + + def test_word_boundary_matching(self): + content = """class MyXStreamHelper: + pass +xs = XStream() +""" + usages = _find_usage_in_file(content, ["XStream"]) + + assert len(usages) == 1 + assert any("L3:" in u and "XStream()" in u for u in usages) + + def test_dotted_name_uses_short_component(self): + content = """import com.thoughtworks.xstream.XStream; +XStream xs = new XStream(); +""" + usages = _find_usage_in_file(content, ["com.thoughtworks.xstream.XStream"]) + + assert len(usages) == 1 + assert "L2:" in usages[0] + assert "XStream xs" in usages[0] + + def test_no_usages_returns_empty(self): + content = """import xml.etree.ElementTree +from xml.etree import ElementTree +""" + usages = _find_usage_in_file(content, ["ElementTree"]) + + assert usages == [] + + def test_empty_content(self): + usages = _find_usage_in_file("", ["XStream"]) + + assert usages == [] + + def test_empty_names(self): + content = "xs = XStream();" + usages = _find_usage_in_file(content, []) + + assert usages == [] + + def test_multiple_names_same_line(self): + content = """import foo, bar +result = foo.process(bar.data) +""" + usages = _find_usage_in_file(content, ["foo", "bar"]) + + assert len(usages) == 2 + assert all("L2:" in u for u in usages) + assert any("foo.process" in u for u in usages) + assert any("bar.data" in u for u in usages) + + def test_line_number_format(self): + content = """import XStream +xs = XStream() +""" + usages = _find_usage_in_file(content, ["XStream"]) + + assert usages[0].startswith(" L2:") + + def test_strips_line_content(self): + content = """import XStream + xs = XStream() +""" + usages = _find_usage_in_file(content, ["XStream"]) + + assert "xs = XStream()" in usages[0] + assert usages[0].count(" " * 8) == 0 + + +class _MockSearcher: + def __init__(self, docs): + self.docs = docs + self.num_docs = len(docs) + + def doc(self, doc_id): + if doc_id >= len(self.docs): + raise IndexError(f"doc_id {doc_id} out of range") + return self.docs[doc_id] + + +def _import_searcher(*docs): + return _MockSearcher(list(docs)) + + +class TestAnalyzeImportsSourceAwareness: + + def _java_patterns(self, package_name): + return _get_patterns(Ecosystem.JAVA, package_name) + + def test_app_imports_prioritized(self): + docs = [ + {"file_path": ["dependencies-sources/xstream-1.4/XStreamConverter.java"], + "content": ["import com.thoughtworks.xstream.XStream;\nXStream xs = new XStream();"]}, + {"file_path": ["src/main/java/App.java"], + "content": ["import com.thoughtworks.xstream.XStream;\nXStream converter = new XStream();"]}, + ] + patterns = self._java_patterns("com.thoughtworks.xstream") + result = analyze_imports(_import_searcher(*docs), patterns, "com.thoughtworks.xstream") + assert "Main application" in result + assert "Application library dependencies" in result + app_pos = result.index("Main application") + dep_pos = result.index("Application library dependencies") + assert app_pos < dep_pos + assert "src/main/java/App.java" in result + + def test_source_scope_filters_dep_imports(self): + docs = [ + {"file_path": ["dependencies-sources/xstream-1.4/Converter.java"], + "content": ["import com.thoughtworks.xstream.XStream;"]}, + {"file_path": ["dependencies-sources/commons-io/IOUtils.java"], + "content": ["import com.thoughtworks.xstream.XStream;"]}, + ] + patterns = self._java_patterns("com.thoughtworks.xstream") + result = analyze_imports(_import_searcher(*docs), patterns, "com.thoughtworks.xstream", + source_scope=["xstream"]) + assert "xstream-1.4" in result + assert "commons-io" not in result + + def test_no_imports_unchanged(self): + docs = [ + {"file_path": ["src/App.java"], "content": ["public class App {}"]}, + ] + patterns = self._java_patterns("com.thoughtworks.xstream") + result = analyze_imports(_import_searcher(*docs), patterns, "com.thoughtworks.xstream", + ecosystem_label="java") + assert "No imports of 'com.thoughtworks.xstream' found" in result + + def test_app_dep_headers_in_output(self): + docs = [ + {"file_path": ["src/App.java"], + "content": ["import com.thoughtworks.xstream.XStream;"]}, + {"file_path": ["vendor/lib/Helper.java"], + "content": ["import com.thoughtworks.xstream.XStream;"]}, + ] + patterns = self._java_patterns("com.thoughtworks.xstream") + result = analyze_imports(_import_searcher(*docs), patterns, "com.thoughtworks.xstream") + assert "Main application" in result + assert "Application library dependencies" in result + + def test_only_dep_imports(self): + docs = [ + {"file_path": ["dependencies-sources/lib/Foo.java"], + "content": ["import com.thoughtworks.xstream.XStream;"]}, + ] + patterns = self._java_patterns("com.thoughtworks.xstream") + result = analyze_imports(_import_searcher(*docs), patterns, "com.thoughtworks.xstream") + assert "Main application (0 of 0 results)" in result + assert "Application library dependencies (1 of 1 results)" in result diff --git a/tests/test_intel_utils.py b/tests/test_intel_utils.py new file mode 100644 index 00000000..831cea72 --- /dev/null +++ b/tests/test_intel_utils.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for intel_utils: build_critical_context RHSA candidate capping.""" + +from exploit_iq_commons.data_models.cve_intel import CveIntel, CveIntelRhsa +from vuln_analysis.utils.intel_utils import build_critical_context, _MAX_RHSA_CANDIDATES + + +class TestBuildCriticalContextRhsaCap: + """Tests for RHSA candidate package capping in build_critical_context.""" + + def _make_cve_intel_with_rhsa_packages(self, count): + """Create a CveIntel with `count` RHSA package_state entries.""" + package_states = [ + CveIntelRhsa.PackageState(package_name=f"rhsa-product-{i}") + for i in range(count) + ] + rhsa = CveIntelRhsa(package_state=package_states) + return CveIntel(vuln_id="CVE-2026-99999", rhsa=rhsa) + + def test_rhsa_cap_limits_candidates(self): + """RHSA with 100+ packages should only add _MAX_RHSA_CANDIDATES to candidates.""" + cve_intel = self._make_cve_intel_with_rhsa_packages(100) + _, candidates, _ = build_critical_context([cve_intel]) + + rhsa_candidates = [c for c in candidates if c["source"] == "rhsa"] + assert len(rhsa_candidates) == _MAX_RHSA_CANDIDATES + + def test_rhsa_below_cap_all_included(self): + """RHSA with fewer than _MAX_RHSA_CANDIDATES packages includes all.""" + cve_intel = self._make_cve_intel_with_rhsa_packages(5) + _, candidates, _ = build_critical_context([cve_intel]) + + rhsa_candidates = [c for c in candidates if c["source"] == "rhsa"] + assert len(rhsa_candidates) == 5 + + def test_rhsa_cap_with_1000_packages(self): + """Reproduces production scenario: 1000+ RHSA packages capped at limit.""" + cve_intel = self._make_cve_intel_with_rhsa_packages(1234) + _, candidates, _ = build_critical_context([cve_intel]) + + rhsa_candidates = [c for c in candidates if c["source"] == "rhsa"] + assert len(rhsa_candidates) == _MAX_RHSA_CANDIDATES + + def test_rhsa_cap_does_not_affect_ghsa(self): + """GHSA candidates are not affected by the RHSA cap.""" + cve_intel = self._make_cve_intel_with_rhsa_packages(100) + # Manually add GHSA data + from exploit_iq_commons.data_models.cve_intel import CveIntelGhsa + cve_intel.ghsa = CveIntelGhsa( + ghsa_id="GHSA-test-0001", + vulnerabilities=[{"package": {"name": "xstream", "ecosystem": "Maven"}}], + ) + _, candidates, _ = build_critical_context([cve_intel]) + + ghsa_candidates = [c for c in candidates if c["source"] == "ghsa"] + rhsa_candidates = [c for c in candidates if c["source"] == "rhsa"] + assert len(ghsa_candidates) == 1 + assert len(rhsa_candidates) == _MAX_RHSA_CANDIDATES + + def test_context_note_still_shows_total_count(self): + """The critical_context note should still mention the full package count.""" + cve_intel = self._make_cve_intel_with_rhsa_packages(50) + context, _, _ = build_critical_context([cve_intel]) + + affected_notes = [c for c in context if "Affected across" in c] + assert len(affected_notes) == 1 + assert "50" in affected_notes[0] \ No newline at end of file diff --git a/tests/test_process_steps.py b/tests/test_process_steps.py new file mode 100644 index 00000000..381d4492 --- /dev/null +++ b/tests/test_process_steps.py @@ -0,0 +1,294 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for _process_steps in cve_agent.py. + +Focus: tracker/graph alignment on fallback, initial_state construction, +semaphore usage, concurrent step processing. +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from langchain_core.messages import HumanMessage + +from vuln_analysis.functions.cve_agent import _process_steps +from vuln_analysis.functions.dispatcher import QuestionRouting +from vuln_analysis.functions.react_internals import ReachabilityRulesTracker +from vuln_analysis.functions.code_understanding_internals import CodeUnderstandingRulesTracker + + +@pytest.fixture +def mock_workflow_state(): + state = MagicMock() + state.cve_intel = [] + return state + + +@pytest.fixture +def mock_graph(): + graph = AsyncMock() + graph.ainvoke = AsyncMock(return_value={"input": "q", "output": "answer"}) + return graph + + +@pytest.fixture +def patch_externals(mock_workflow_state): + """Patch ctx_state, build_critical_context, dispatch_question, and AGENT_TRACER.""" + with ( + patch("vuln_analysis.functions.cve_agent.ctx_state") as mock_ctx, + patch("vuln_analysis.functions.cve_agent.build_critical_context") as mock_bcc, + patch("vuln_analysis.functions.cve_agent.dispatch_question") as mock_dispatch, + patch("vuln_analysis.functions.cve_agent.AGENT_TRACER") as mock_tracer, + ): + mock_ctx.get.return_value = mock_workflow_state + mock_bcc.return_value = (["ctx_line"], [{"name": "pkg"}], ["vuln_func"]) + mock_tracer.push_active_function.return_value.__enter__ = MagicMock() + mock_tracer.push_active_function.return_value.__exit__ = MagicMock(return_value=False) + yield { + "ctx_state": mock_ctx, + "build_critical_context": mock_bcc, + "dispatch_question": mock_dispatch, + "tracer": mock_tracer, + } + + +class TestFallbackAlignment: + """When routed agent_type is not in agents dict, both graph AND tracker + must fall back to reachability.""" + + @pytest.mark.asyncio + async def test_cu_fallback_uses_reachability_graph(self, patch_externals, mock_graph): + patch_externals["dispatch_question"].return_value = QuestionRouting( + agent_type="code_understanding", reason="config question" + ) + agents = {"reachability": mock_graph} + + await _process_steps(agents, MagicMock(), ["Is it configured?"], None) + + mock_graph.ainvoke.assert_called_once() + + @pytest.mark.asyncio + async def test_cu_fallback_uses_reachability_tracker(self, patch_externals, mock_graph): + patch_externals["dispatch_question"].return_value = QuestionRouting( + agent_type="code_understanding", reason="config question" + ) + agents = {"reachability": mock_graph} + + await _process_steps(agents, MagicMock(), ["Is it configured?"], None) + + call_args = mock_graph.ainvoke.call_args[0][0] + tracker = call_args["rules_tracker"] + assert isinstance(tracker, ReachabilityRulesTracker) + assert not isinstance(tracker, CodeUnderstandingRulesTracker) + + @pytest.mark.asyncio + async def test_direct_route_uses_correct_tracker(self, patch_externals): + patch_externals["dispatch_question"].return_value = QuestionRouting( + agent_type="code_understanding", reason="config question" + ) + cu_graph = AsyncMock() + cu_graph.ainvoke = AsyncMock(return_value={"input": "q", "output": "a"}) + reach_graph = AsyncMock() + agents = {"reachability": reach_graph, "code_understanding": cu_graph} + + await _process_steps(agents, MagicMock(), ["Is it configured?"], None) + + call_args = cu_graph.ainvoke.call_args[0][0] + tracker = call_args["rules_tracker"] + assert isinstance(tracker, CodeUnderstandingRulesTracker) + reach_graph.ainvoke.assert_not_called() + + @pytest.mark.asyncio + async def test_reachability_route_uses_reachability_tracker(self, patch_externals, mock_graph): + patch_externals["dispatch_question"].return_value = QuestionRouting( + agent_type="reachability", reason="function call check" + ) + agents = {"reachability": mock_graph} + + await _process_steps(agents, MagicMock(), ["Is func reachable?"], None) + + call_args = mock_graph.ainvoke.call_args[0][0] + tracker = call_args["rules_tracker"] + assert isinstance(tracker, ReachabilityRulesTracker) + + +class TestInitialState: + """Verify initial_state dict has correct keys and values.""" + + @pytest.mark.asyncio + async def test_initial_state_keys(self, patch_externals, mock_graph): + patch_externals["dispatch_question"].return_value = QuestionRouting( + agent_type="reachability", reason="test" + ) + agents = {"reachability": mock_graph} + + await _process_steps(agents, MagicMock(), ["test question"], None) + + call_args = mock_graph.ainvoke.call_args[0][0] + assert call_args["input"] == "test question" + assert call_args["step"] == 0 + assert call_args["max_steps"] == 10 + assert call_args["thought"] is None + assert call_args["observation"] is None + assert call_args["output"] == "waiting for the agent to respond" + assert isinstance(call_args["messages"][0], HumanMessage) + assert call_args["messages"][0].content == "test question" + + @pytest.mark.asyncio + async def test_precomputed_intel_passed(self, patch_externals, mock_graph): + patch_externals["dispatch_question"].return_value = QuestionRouting( + agent_type="reachability", reason="test" + ) + agents = {"reachability": mock_graph} + + await _process_steps(agents, MagicMock(), ["q"], None) + + call_args = mock_graph.ainvoke.call_args[0][0] + intel = call_args["precomputed_intel"] + assert intel == (["ctx_line"], [{"name": "pkg"}], ["vuln_func"]) + + @pytest.mark.asyncio + async def test_custom_max_iterations(self, patch_externals, mock_graph): + patch_externals["dispatch_question"].return_value = QuestionRouting( + agent_type="reachability", reason="test" + ) + agents = {"reachability": mock_graph} + + await _process_steps(agents, MagicMock(), ["q"], None, max_iterations=25) + + call_args = mock_graph.ainvoke.call_args[0][0] + assert call_args["max_steps"] == 25 + + @pytest.mark.asyncio + async def test_graph_config_recursion_limit(self, patch_externals, mock_graph): + patch_externals["dispatch_question"].return_value = QuestionRouting( + agent_type="reachability", reason="test" + ) + agents = {"reachability": mock_graph} + + await _process_steps(agents, MagicMock(), ["q"], None) + + config_arg = mock_graph.ainvoke.call_args[1]["config"] + assert config_arg == {"recursion_limit": 50} + + +class TestConcurrency: + """Test concurrent step processing and semaphore behavior.""" + + @pytest.mark.asyncio + async def test_multiple_steps_all_processed(self, patch_externals, mock_graph): + patch_externals["dispatch_question"].return_value = QuestionRouting( + agent_type="reachability", reason="test" + ) + agents = {"reachability": mock_graph} + steps = ["q1", "q2", "q3"] + + results = await _process_steps(agents, MagicMock(), steps, None) + + assert len(results) == 3 + assert mock_graph.ainvoke.call_count == 3 + + @pytest.mark.asyncio + async def test_semaphore_limits_concurrency(self, patch_externals): + max_concurrent = 0 + current_concurrent = 0 + + async def slow_invoke(state, config=None): + nonlocal max_concurrent, current_concurrent + current_concurrent += 1 + max_concurrent = max(max_concurrent, current_concurrent) + await asyncio.sleep(0.01) + current_concurrent -= 1 + return {"input": state["input"], "output": "done"} + + mock_graph = AsyncMock() + mock_graph.ainvoke = slow_invoke + + patch_externals["dispatch_question"].return_value = QuestionRouting( + agent_type="reachability", reason="test" + ) + agents = {"reachability": mock_graph} + semaphore = asyncio.Semaphore(2) + + await _process_steps(agents, MagicMock(), ["q1", "q2", "q3", "q4"], semaphore) + + assert max_concurrent <= 2 + + @pytest.mark.asyncio + async def test_no_semaphore_allows_full_concurrency(self, patch_externals): + max_concurrent = 0 + current_concurrent = 0 + + async def slow_invoke(state, config=None): + nonlocal max_concurrent, current_concurrent + current_concurrent += 1 + max_concurrent = max(max_concurrent, current_concurrent) + await asyncio.sleep(0.01) + current_concurrent -= 1 + return {"input": state["input"], "output": "done"} + + mock_graph = AsyncMock() + mock_graph.ainvoke = slow_invoke + + patch_externals["dispatch_question"].return_value = QuestionRouting( + agent_type="reachability", reason="test" + ) + agents = {"reachability": mock_graph} + + await _process_steps(agents, MagicMock(), ["q1", "q2", "q3", "q4"], None) + + assert max_concurrent == 4 + + @pytest.mark.asyncio + async def test_exception_in_step_returned_not_raised(self, patch_externals): + mock_graph = AsyncMock() + mock_graph.ainvoke = AsyncMock(side_effect=RuntimeError("graph failed")) + + patch_externals["dispatch_question"].return_value = QuestionRouting( + agent_type="reachability", reason="test" + ) + agents = {"reachability": mock_graph} + + results = await _process_steps(agents, MagicMock(), ["q1"], None) + + assert len(results) == 1 + assert isinstance(results[0], RuntimeError) + + @pytest.mark.asyncio + async def test_empty_steps_returns_empty(self, patch_externals, mock_graph): + agents = {"reachability": mock_graph} + results = await _process_steps(agents, MagicMock(), [], None) + assert results == [] + + +class TestTracingSpan: + """Test that the tracing span receives actual_type, not routed type.""" + + @pytest.mark.asyncio + async def test_span_logs_actual_type_on_fallback(self, patch_externals, mock_graph): + patch_externals["dispatch_question"].return_value = QuestionRouting( + agent_type="code_understanding", reason="config" + ) + agents = {"reachability": mock_graph} + + await _process_steps(agents, MagicMock(), ["q"], None) + + call_args = patch_externals["tracer"].push_active_function.call_args + assert "[reachability]" in call_args[1]["input_data"] + + @pytest.mark.asyncio + async def test_span_logs_routed_type_when_available(self, patch_externals): + patch_externals["dispatch_question"].return_value = QuestionRouting( + agent_type="code_understanding", reason="config" + ) + cu_graph = AsyncMock() + cu_graph.ainvoke = AsyncMock(return_value={"input": "q", "output": "a"}) + agents = {"reachability": MagicMock(), "code_understanding": cu_graph} + + await _process_steps(agents, MagicMock(), ["q"], None) + + call_args = patch_externals["tracer"].push_active_function.call_args + assert "[code_understanding]" in call_args[1]["input_data"] diff --git a/tests/test_python_segmenter.py b/tests/test_python_segmenter.py index c909d0c0..d2158caf 100644 --- a/tests/test_python_segmenter.py +++ b/tests/test_python_segmenter.py @@ -89,7 +89,7 @@ ("import os", False, "py3 import statement"), ], ) -def test_is_python2_code(self, code: str, expected: bool, description: str): +def test_is_python2_code(code: str, expected: bool, description: str): """Test that Python 2/3 patterns are correctly detected.""" result = is_python2_code(code) assert result is expected, f"Expected {expected} for {description}, got {result}" \ No newline at end of file diff --git a/tests/test_reachability_agent.py b/tests/test_reachability_agent.py new file mode 100644 index 00000000..590a8042 --- /dev/null +++ b/tests/test_reachability_agent.py @@ -0,0 +1,221 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for ReachabilityAgent: get_tools, create_rules_tracker, +agent_type, should_truncate_tool_output.""" + +import pytest +from unittest.mock import MagicMock + +from agent_test_helpers import MockTool, ALL_TOOLS, make_builder, make_config, make_state +from vuln_analysis.functions.reachability_agent import ReachabilityAgent +from vuln_analysis.functions.react_internals import ReachabilityRulesTracker +from vuln_analysis.tools.tool_names import ToolNames + + +def _make_reachability_agent(tools=None): + mock_llm = MagicMock() + mock_llm.with_structured_output = MagicMock(return_value=MagicMock()) + config = MagicMock() + config.max_iterations = 10 + return ReachabilityAgent(tools=tools or [], llm=mock_llm, config=config) + + +class TestGetTools: + """ReachabilityAgent.get_tools selects reachability tools and filters by availability.""" + + def test_keeps_reachability_tools(self): + builder = make_builder() + config = make_config() + state = make_state() + result = ReachabilityAgent.get_tools(builder, config, state) + result_names = {t.name for t in result} + assert ToolNames.FUNCTION_LOCATOR in result_names + assert ToolNames.CALL_CHAIN_ANALYZER in result_names + assert ToolNames.CODE_KEYWORD_SEARCH in result_names + assert ToolNames.CVE_WEB_SEARCH in result_names + + def test_excludes_cu_only_tools(self): + builder = make_builder() + config = make_config() + state = make_state() + result = ReachabilityAgent.get_tools(builder, config, state) + result_names = {t.name for t in result} + assert ToolNames.CONFIGURATION_SCANNER not in result_names + assert ToolNames.IMPORT_USAGE_ANALYZER not in result_names + + def test_excludes_container_analysis_data(self): + builder = make_builder() + config = make_config() + state = make_state() + result = ReachabilityAgent.get_tools(builder, config, state) + result_names = {t.name for t in result} + assert ToolNames.CONTAINER_ANALYSIS_DATA not in result_names + + def test_keeps_all_8_reachability_tools(self): + builder = make_builder() + config = make_config() + state = make_state() + result = ReachabilityAgent.get_tools(builder, config, state) + assert len(result) == 8 + + def test_empty_builder_returns_empty(self): + builder = make_builder(tools=[]) + config = make_config() + state = make_state() + result = ReachabilityAgent.get_tools(builder, config, state) + assert result == [] + + def test_preserves_tool_order(self): + ordered_tools = [ + MockTool(ToolNames.CVE_WEB_SEARCH), + MockTool(ToolNames.FUNCTION_LOCATOR), + MockTool(ToolNames.CALL_CHAIN_ANALYZER), + ] + builder = make_builder(tools=ordered_tools) + config = make_config() + state = make_state() + result = ReachabilityAgent.get_tools(builder, config, state) + assert [t.name for t in result] == [ + ToolNames.CVE_WEB_SEARCH, + ToolNames.FUNCTION_LOCATOR, + ToolNames.CALL_CHAIN_ANALYZER, + ] + + def test_unknown_tools_excluded(self): + tools = [MockTool(ToolNames.FUNCTION_LOCATOR), MockTool("Some Future Tool")] + builder = make_builder(tools=tools) + config = make_config() + state = make_state() + result = ReachabilityAgent.get_tools(builder, config, state) + assert len(result) == 1 + assert result[0].name == ToolNames.FUNCTION_LOCATOR + + +class TestGetToolsAvailability: + """get_tools filters out tools whose infrastructure prerequisites are not met.""" + + def test_filters_code_semantic_search_when_no_vdb(self): + builder = make_builder() + config = make_config() + state = make_state(code_vdb_path=None) + result = ReachabilityAgent.get_tools(builder, config, state) + assert ToolNames.CODE_SEMANTIC_SEARCH not in {t.name for t in result} + + def test_filters_docs_semantic_search_when_no_vdb(self): + builder = make_builder() + config = make_config() + state = make_state(doc_vdb_path=None) + result = ReachabilityAgent.get_tools(builder, config, state) + assert ToolNames.DOCS_SEMANTIC_SEARCH not in {t.name for t in result} + + def test_filters_code_keyword_search_when_no_index(self): + builder = make_builder() + config = make_config() + state = make_state(code_index_path=None) + result = ReachabilityAgent.get_tools(builder, config, state) + assert ToolNames.CODE_KEYWORD_SEARCH not in {t.name for t in result} + + def test_filters_transitive_tools_when_no_index(self): + builder = make_builder() + config = make_config() + state = make_state(code_index_path=None) + result = ReachabilityAgent.get_tools(builder, config, state) + result_names = {t.name for t in result} + assert ToolNames.CALL_CHAIN_ANALYZER not in result_names + assert ToolNames.FUNCTION_CALLER_FINDER not in result_names + assert ToolNames.FUNCTION_LOCATOR not in result_names + + def test_filters_cve_web_search_when_disabled(self): + builder = make_builder() + config = make_config(cve_web_search_enabled=False) + state = make_state() + result = ReachabilityAgent.get_tools(builder, config, state) + assert ToolNames.CVE_WEB_SEARCH not in {t.name for t in result} + + def test_filters_transitive_tools_when_disabled(self): + builder = make_builder() + config = make_config(transitive_search_tool_enabled=False) + state = make_state() + result = ReachabilityAgent.get_tools(builder, config, state) + result_names = {t.name for t in result} + assert ToolNames.CALL_CHAIN_ANALYZER not in result_names + assert ToolNames.FUNCTION_CALLER_FINDER not in result_names + assert ToolNames.FUNCTION_LOCATOR not in result_names + + def test_version_finder_always_kept(self): + builder = make_builder() + config = make_config() + state = make_state(code_vdb_path=None, doc_vdb_path=None, code_index_path=None) + result = ReachabilityAgent.get_tools(builder, config, state) + assert ToolNames.FUNCTION_LIBRARY_VERSION_FINDER in {t.name for t in result} + + +class TestReachabilityDuplicateCall: + + def test_duplicate_call_blocked(self): + tracker = ReachabilityRulesTracker() + tracker.set_allowed_tools(["Function Locator"]) + tracker.set_target_package("commons-beanutils") + tracker.check_thought_behavior("Function Locator", "commons-beanutils,getProperty", ["result"]) + violated, msg = tracker.check_thought_behavior("Function Locator", "commons-beanutils,getProperty", ["result"]) + assert violated is True + assert "already called" in msg + + +class TestCreateRulesTracker: + + def test_returns_reachability_rules_tracker(self): + tracker = ReachabilityAgent.create_rules_tracker() + assert isinstance(tracker, ReachabilityRulesTracker) + + def test_returns_fresh_instance_each_call(self): + t1 = ReachabilityAgent.create_rules_tracker() + t2 = ReachabilityAgent.create_rules_tracker() + assert t1 is not t2 + + +class TestAgentType: + + def test_agent_type_is_reachability(self): + agent = _make_reachability_agent() + assert agent.agent_type == "reachability" + + +class TestShouldTruncateToolOutput: + + def test_true_for_java(self): + agent = _make_reachability_agent() + assert agent.should_truncate_tool_output({"ecosystem": "java"}, "any_tool") is True + + def test_true_for_java_uppercase(self): + agent = _make_reachability_agent() + assert agent.should_truncate_tool_output({"ecosystem": "Java"}, "any_tool") is True + + def test_false_for_go(self): + agent = _make_reachability_agent() + assert agent.should_truncate_tool_output({"ecosystem": "go"}, "any_tool") is False + + def test_false_for_python(self): + agent = _make_reachability_agent() + assert agent.should_truncate_tool_output({"ecosystem": "python"}, "any_tool") is False + + def test_false_for_empty_ecosystem(self): + agent = _make_reachability_agent() + assert agent.should_truncate_tool_output({"ecosystem": ""}, "any_tool") is False + + def test_false_when_ecosystem_missing(self): + agent = _make_reachability_agent() + assert agent.should_truncate_tool_output({}, "any_tool") is False + + +class TestInit: + + def test_creates_fifth_classification_llm(self): + mock_llm = MagicMock() + mock_llm.with_structured_output = MagicMock(return_value=MagicMock()) + config = MagicMock() + config.max_iterations = 10 + agent = ReachabilityAgent(tools=[], llm=mock_llm, config=config) + assert mock_llm.with_structured_output.call_count == 5 + assert hasattr(agent, "_classification_llm") diff --git a/tests/test_react_internals_rules.py b/tests/test_react_internals_rules.py new file mode 100644 index 00000000..2440aecd --- /dev/null +++ b/tests/test_react_internals_rules.py @@ -0,0 +1,418 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from vuln_analysis.functions.react_internals import ( + BaseRulesTracker, + ReachabilityRulesTracker, + AgentState, + _find_image_matching_candidate, +) + + +class TestBaseRulesTracker: + def test_init_defaults(self): + tracker = BaseRulesTracker() + assert tracker.action_history == {} + assert tracker.target_package is None + assert tracker.allowed_tools == [] + + def test_set_allowed_tools(self): + tracker = BaseRulesTracker() + tools = ["Function Locator", "Code Keyword Search"] + tracker.set_allowed_tools(tools) + assert tracker.allowed_tools == tools + + def test_set_target_package(self): + tracker = BaseRulesTracker() + tracker.set_target_package("commons-beanutils") + assert tracker.target_package == "commons-beanutils" + + def test_is_empty_result_empty_list(self): + assert BaseRulesTracker._is_empty_result([]) is True + + def test_is_empty_result_bracket_string(self): + assert BaseRulesTracker._is_empty_result("[]") is True + + def test_is_empty_result_empty_string(self): + assert BaseRulesTracker._is_empty_result("") is True + + def test_is_empty_result_whitespace(self): + assert BaseRulesTracker._is_empty_result(" ") is True + + def test_is_empty_result_none(self): + assert BaseRulesTracker._is_empty_result(None) is False + + def test_is_empty_result_int(self): + assert BaseRulesTracker._is_empty_result(0) is False + + def test_is_empty_result_nonempty_list(self): + assert BaseRulesTracker._is_empty_result(["item"]) is False + + def test_add_action_accumulates(self): + tracker = BaseRulesTracker() + tracker.add_action("Function Locator", "pkg,fn1", ["result1"]) + tracker.add_action("Function Locator", "pkg,fn2", ["result2"]) + assert len(tracker.action_history["Function Locator"]) == 2 + assert tracker.action_history["Function Locator"][0]["input"] == "pkg,fn1" + assert tracker.action_history["Function Locator"][1]["input"] == "pkg,fn2" + + +class TestDuplicateCallRule: + def test_duplicate_call_blocked(self): + tracker = BaseRulesTracker() + tracker.set_allowed_tools(["Function Locator"]) + tracker.check_thought_behavior("Function Locator", "pkg,fn", ["result"]) + violated, msg = tracker.check_thought_behavior("Function Locator", "pkg,fn", ["result"]) + assert violated is True + assert "already called" in msg + + def test_different_input_allowed(self): + tracker = BaseRulesTracker() + tracker.set_allowed_tools(["Function Locator"]) + tracker.check_thought_behavior("Function Locator", "pkg,fn1", ["result"]) + violated, msg = tracker.check_thought_behavior("Function Locator", "pkg,fn2", ["result"]) + assert violated is False + assert msg == "" + + def test_different_tool_same_input_allowed(self): + tracker = BaseRulesTracker() + tracker.set_allowed_tools(["Function Locator", "Code Keyword Search"]) + tracker.check_thought_behavior("Function Locator", "pkg,fn", ["result"]) + violated, msg = tracker.check_thought_behavior("Code Keyword Search", "pkg,fn", ["result"]) + assert violated is False + assert msg == "" + + def test_first_call_always_passes(self): + tracker = BaseRulesTracker() + tracker.set_allowed_tools(["Function Locator"]) + violated, msg = tracker.check_thought_behavior("Function Locator", "pkg,fn", ["result"]) + assert violated is False + assert msg == "" + + +class TestBaseRule7: + def test_rule7_non_cks_tool_ignored(self): + tracker = BaseRulesTracker() + tracker.add_action("Code Keyword Search", "org.Class", []) + result = tracker._rule_number_7("Function Locator", "org.Class", []) + assert result is False + + def test_rule7_no_dot_in_query_ignored(self): + tracker = BaseRulesTracker() + tracker.add_action("Code Keyword Search", "ClassName", []) + result = tracker._rule_number_7("Code Keyword Search", "ClassName", []) + assert result is False + + def test_rule7_nonempty_output_ignored(self): + tracker = BaseRulesTracker() + tracker.add_action("Code Keyword Search", "org.Class", []) + result = tracker._rule_number_7("Code Keyword Search", "org.Class", ["match"]) + assert result is False + + def test_rule7_no_prior_history_ignored(self): + tracker = BaseRulesTracker() + result = tracker._rule_number_7("Code Keyword Search", "org.Class", []) + assert result is False + + def test_rule7_prior_non_dotted_ignored(self): + tracker = BaseRulesTracker() + tracker.add_action("Code Keyword Search", "ClassName", []) + result = tracker._rule_number_7("Code Keyword Search", "org.Class", []) + assert result is False + + def test_rule7_prior_had_results_ignored(self): + tracker = BaseRulesTracker() + tracker.add_action("Code Keyword Search", "org.Class", ["match"]) + result = tracker._rule_number_7("Code Keyword Search", "org.Class", []) + assert result is False + + def test_rule7_consecutive_dotted_empty_fires(self): + tracker = BaseRulesTracker() + tracker.add_action("Code Keyword Search", "org.Class", []) + result = tracker._rule_number_7("Code Keyword Search", "org.Another", []) + assert result is True + + +class TestBaseRuleAllowedTools: + def test_allowed_tools_in_list_passes(self): + tracker = BaseRulesTracker() + tracker.set_allowed_tools(["Function Locator", "Code Keyword Search"]) + result = tracker._rule_use_allowed_tools("Function Locator") + assert result is False + + def test_allowed_tools_not_in_list_fails(self): + tracker = BaseRulesTracker() + tracker.set_allowed_tools(["Function Locator"]) + result = tracker._rule_use_allowed_tools("CVE Web Search") + assert result is True + + def test_allowed_tools_empty_list_rejects_all(self): + tracker = BaseRulesTracker() + tracker.set_allowed_tools([]) + result = tracker._rule_use_allowed_tools("Function Locator") + assert result is True + + +class TestBaseCheckThoughtBehavior: + def test_happy_path_returns_false_and_adds_to_history(self): + tracker = BaseRulesTracker() + tracker.set_allowed_tools(["Function Locator"]) + violated, msg = tracker.check_thought_behavior("Function Locator", "pkg,fn", ["result"]) + assert violated is False + assert msg == "" + assert "Function Locator" in tracker.action_history + assert tracker.action_history["Function Locator"][0]["input"] == "pkg,fn" + + def test_duplicate_priority_over_rule7(self): + tracker = BaseRulesTracker() + tracker.set_allowed_tools(["Code Keyword Search"]) + tracker.add_action("Code Keyword Search", "org.Class", []) + violated, msg = tracker.check_thought_behavior("Code Keyword Search", "org.Class", []) + assert violated is True + assert "already called" in msg + assert "Rule 7" not in msg + + def test_rule7_priority_over_allowed_tools(self): + tracker = BaseRulesTracker() + tracker.set_allowed_tools(["Other Tool"]) + tracker.add_action("Code Keyword Search", "org.Class", []) + violated, msg = tracker.check_thought_behavior("Code Keyword Search", "org.Another", []) + assert violated is True + assert "Rule 7" in msg + assert "AVAILABLE_TOOLS" not in msg + + def test_allowed_tools_error_message_format(self): + tracker = BaseRulesTracker() + tracker.set_allowed_tools(["Function Locator", "Code Keyword Search"]) + violated, msg = tracker.check_thought_behavior("CVE Web Search", "query", []) + assert violated is True + assert "AVAILABLE_TOOLS" in msg + assert "['Function Locator', 'Code Keyword Search']" in msg + + +class TestReachabilityRulesTracker: + def test_extends_base(self): + tracker = ReachabilityRulesTracker() + assert isinstance(tracker, BaseRulesTracker) + + def test_set_target_functions_creates_dict(self): + tracker = ReachabilityRulesTracker() + tracker.set_target_functions(["fn1", "fn2"]) + assert tracker.target_functions == {"fn1": False, "fn2": False} + + def test_rule8_target_package_enforcement(self): + tracker = ReachabilityRulesTracker() + tracker.set_target_package("commons-beanutils") + result = tracker._rule_number_8("Function Locator", "wrong-package,fn", []) + assert result is True + + def test_rule8_prefix_matching_for_java_gav(self): + tracker = ReachabilityRulesTracker() + tracker.set_target_package("commons-beanutils:commons-beanutils") + result = tracker._rule_number_8("Function Locator", "commons-beanutils:commons-beanutils:1.9.4,fn", []) + assert result is False + + def test_rule8_skipped_after_first_call(self): + tracker = ReachabilityRulesTracker() + tracker.set_target_package("commons-beanutils") + tracker.set_allowed_tools(["Function Locator"]) + tracker.add_action("Function Locator", "commons-beanutils,fn", ["result"]) + result = tracker._rule_number_8("Function Locator", "wrong-package,fn", []) + assert result is False + + def test_rule8_no_target_package_skipped(self): + tracker = ReachabilityRulesTracker() + result = tracker._rule_number_8("Function Locator", "any-package,fn", []) + assert result is False + + def test_rule9_vulnerable_functions_first(self): + tracker = ReachabilityRulesTracker() + tracker.set_target_functions(["getProperty", "setProperty"]) + violated, msg = tracker._rule_number_9("Call Chain Analyzer", "pkg,someOtherFunction") + assert violated is True + assert "Rule 9" in msg + assert "getProperty" in msg or "setProperty" in msg + + def test_rule9_passes_after_checking_vulnerable(self): + tracker = ReachabilityRulesTracker() + tracker.set_target_functions(["getProperty"]) + violated, msg = tracker._rule_number_9("Call Chain Analyzer", "pkg,PropertyUtilsBean.getProperty") + assert violated is False + assert msg == "" + violated, msg = tracker._rule_number_9("Call Chain Analyzer", "pkg,someOtherFunction") + assert violated is False + assert msg == "" + + def test_reachability_check_order(self): + tracker = ReachabilityRulesTracker() + tracker.set_allowed_tools(["Other Tool"]) + tracker.set_target_package("commons-beanutils") + tracker.set_target_functions(["getProperty"]) + + tracker.add_action("Code Keyword Search", "org.Class", []) + violated, msg = tracker.check_thought_behavior("Code Keyword Search", "org.Another", []) + assert violated is True + assert "Rule 7" in msg + + tracker2 = ReachabilityRulesTracker() + tracker2.set_allowed_tools(["Call Chain Analyzer"]) + tracker2.set_target_package("commons-beanutils") + violated, msg = tracker2.check_thought_behavior("Call Chain Analyzer", "wrong-package,fn", []) + assert violated is True + assert "Rule 8" in msg + + tracker3 = ReachabilityRulesTracker() + tracker3.set_allowed_tools(["Other Tool"]) + violated, msg = tracker3.check_thought_behavior("Call Chain Analyzer", "pkg,fn", []) + assert violated is True + assert "AVAILABLE_TOOLS" in msg + + tracker4 = ReachabilityRulesTracker() + tracker4.set_allowed_tools(["Call Chain Analyzer"]) + tracker4.set_target_package("pkg") + tracker4.set_target_functions(["getProperty"]) + violated, msg = tracker4.check_thought_behavior("Call Chain Analyzer", "pkg,otherFunction", []) + assert violated is True + assert "Rule 9" in msg + + +class TestCheckFinishAllowed: + """Tests for ReachabilityRulesTracker.check_finish_allowed.""" + + def test_blocks_finish_when_cca_never_called(self): + """Reachability agent must call CCA before finishing.""" + tracker = ReachabilityRulesTracker() + allowed, msg = tracker.check_finish_allowed(cca_results=[]) + assert allowed is False + assert "Function Locator" in msg + assert "Call Chain Analyzer" in msg + + def test_allows_finish_when_cca_returned_false(self): + """CCA was called and returned not-reachable — finish is allowed.""" + tracker = ReachabilityRulesTracker() + tracker.add_action("Call Chain Analyzer", "pkg,fn", "(False, [])") + allowed, msg = tracker.check_finish_allowed(cca_results=[False]) + assert allowed is True + + def test_allows_finish_when_cca_returned_true_non_java(self): + """Non-Java: CCA found reachable — finish allowed (no FLVF requirement).""" + tracker = ReachabilityRulesTracker() + tracker.set_ecosystem("go") + tracker.add_action("Call Chain Analyzer", "pkg,fn", "(True, [path])") + allowed, msg = tracker.check_finish_allowed(cca_results=[True]) + assert allowed is True + + def test_blocks_finish_java_cca_true_no_flvf(self): + """Java: CCA found reachable but FLVF not called — block finish.""" + tracker = ReachabilityRulesTracker() + tracker.set_ecosystem("java") + tracker.add_action("Call Chain Analyzer", "pkg,fn", "(True, [path])") + allowed, msg = tracker.check_finish_allowed(cca_results=[True]) + assert allowed is False + assert "VERSION CHECK" in msg + assert "Function Library Version Finder" in msg + + def test_allows_finish_java_cca_true_with_flvf(self): + """Java: CCA found reachable and FLVF was called — finish allowed.""" + tracker = ReachabilityRulesTracker() + tracker.set_ecosystem("java") + tracker.add_action("Call Chain Analyzer", "pkg,fn", "(True, [path])") + tracker.add_action("Function Library Version Finder", "pkg", "1.9.4") + allowed, msg = tracker.check_finish_allowed(cca_results=[True]) + assert allowed is True + + def test_allows_finish_java_cca_false(self): + """Java: CCA returned not-reachable — no FLVF requirement.""" + tracker = ReachabilityRulesTracker() + tracker.set_ecosystem("java") + tracker.add_action("Call Chain Analyzer", "pkg,fn", "(False, [])") + allowed, msg = tracker.check_finish_allowed(cca_results=[False]) + assert allowed is True + + def test_allows_finish_when_cca_in_history_but_results_empty(self): + """Edge case: CCA was called (in action_history) but cca_results is + empty — e.g., output parsing failed. Don't block since CCA was attempted.""" + tracker = ReachabilityRulesTracker() + tracker.add_action("Call Chain Analyzer", "pkg,fn", "some unparseable output") + allowed, msg = tracker.check_finish_allowed(cca_results=[]) + assert allowed is True + + +class TestAgentState: + def test_default_tracker_type_annotation(self): + assert "rules_tracker" in AgentState.__annotations__ + + def test_default_is_reachability_field(self): + assert "is_reachability" in AgentState.__annotations__ + + +class TestFindImageMatchingCandidate: + """Tests for _find_image_matching_candidate used by package filter fast path.""" + + def test_match_in_image_name(self): + candidates = [ + {"name": "builder", "source": "rhsa"}, + {"name": "kernel", "source": "rhsa"}, + ] + result = _find_image_matching_candidate( + candidates, "registry.redhat.io/openshift4/ose-docker-builder", None, + ) + assert result == "builder" + + def test_match_in_repo(self): + candidates = [ + {"name": "kernel", "source": "rhsa"}, + {"name": "infinispan", "source": "rhsa"}, + ] + result = _find_image_matching_candidate( + candidates, "registry.redhat.io/some-image", "https://github.com/infinispan/infinispan", + ) + assert result == "infinispan" + + def test_no_match(self): + candidates = [ + {"name": "kernel", "source": "rhsa"}, + {"name": "glibc", "source": "rhsa"}, + ] + result = _find_image_matching_candidate( + candidates, "registry.redhat.io/openshift4/ose-docker-builder", + "https://github.com/openshift/builder", + ) + assert result is None + + def test_short_name_skipped(self): + """Candidate names shorter than 3 chars are not matched.""" + candidates = [{"name": "go", "source": "rhsa"}] + result = _find_image_matching_candidate( + candidates, "registry.redhat.io/golang-builder", None, + ) + assert result is None + + def test_no_image_no_repo(self): + candidates = [{"name": "builder", "source": "rhsa"}] + result = _find_image_matching_candidate(candidates, None, None) + assert result is None + + def test_first_match_wins(self): + """When multiple candidates could match, the first one wins.""" + candidates = [ + {"name": "openshift", "source": "rhsa"}, + {"name": "builder", "source": "rhsa"}, + ] + result = _find_image_matching_candidate( + candidates, "registry.redhat.io/openshift4/ose-docker-builder", + "https://github.com/openshift/builder", + ) + assert result == "openshift" diff --git a/tests/test_source_classification.py b/tests/test_source_classification.py new file mode 100644 index 00000000..53f07212 --- /dev/null +++ b/tests/test_source_classification.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from vuln_analysis.utils.source_classification import ( + is_dependency_path, + filter_by_source_scope, + format_app_dep_output, +) + + +class TestIsDependencyPath: + @pytest.mark.parametrize("path", [ + "dependencies-sources/commons-beanutils/PropertyUtilsBean.java", + "vendor/github.com/foo/bar/bar.go", + "transitive_env/lib/requests/models.py", + "node_modules/express/index.js", + "rpm_libs/libxml2/parser.c", + ]) + def test_dependency_paths(self, path): + assert is_dependency_path(path) is True + + @pytest.mark.parametrize("path", [ + "src/main/java/com/example/App.java", + "cmd/main.go", + "app.py", + "pom.xml", + "src/main/resources/application.yml", + ]) + def test_app_paths(self, path): + assert is_dependency_path(path) is False + + def test_empty_path(self): + assert is_dependency_path("") is False + + +class TestFilterBySourceScope: + def test_filters_by_scope(self): + items = [ + ("dependencies-sources/xstream/XStream.java", "entry1"), + ("dependencies-sources/commons-io/IOUtils.java", "entry2"), + ("dependencies-sources/xstream/Converter.java", "entry3"), + ] + result = filter_by_source_scope(items, ["xstream"], lambda x: x[0]) + assert len(result) == 2 + assert result[0][1] == "entry1" + assert result[1][1] == "entry3" + + def test_none_scope_returns_all(self): + items = [("a", "x"), ("b", "y")] + result = filter_by_source_scope(items, None, lambda x: x[0]) + assert result == items + + def test_empty_scope_returns_all(self): + items = [("a", "x"), ("b", "y")] + result = filter_by_source_scope(items, [], lambda x: x[0]) + assert result == items + + def test_multiple_scope_terms(self): + items = [ + ("vendor/foo/main.go", "e1"), + ("vendor/bar/main.go", "e2"), + ("vendor/baz/main.go", "e3"), + ] + result = filter_by_source_scope(items, ["foo", "baz"], lambda x: x[0]) + assert len(result) == 2 + assert result[0][1] == "e1" + assert result[1][1] == "e3" + + def test_no_matches_returns_empty(self): + items = [("vendor/foo/main.go", "e1")] + result = filter_by_source_scope(items, ["nonexistent"], lambda x: x[0]) + assert result == [] + + +class TestFormatAppDepOutput: + def test_both_sections(self): + result = format_app_dep_output( + ["app_match_1", "app_match_2"], + ["dep_match_1"], + total_app=2, total_dep=1, + no_results_msg="No results", + ) + assert "Main application (2 of 2 results)" in result + assert "Application library dependencies (1 of 1 results)" in result + assert "app_match_1" in result + assert "dep_match_1" in result + + def test_empty_returns_no_results_msg(self): + result = format_app_dep_output([], [], 0, 0, "Nothing found") + assert result == "Nothing found" + + def test_app_only(self): + result = format_app_dep_output( + ["app1"], [], total_app=1, total_dep=0, + no_results_msg="No results", + ) + assert "Main application (1 of 1 results)" in result + assert "Application library dependencies (0 of 0 results)" in result + assert "app1" in result + + def test_dep_only(self): + result = format_app_dep_output( + [], ["dep1"], total_app=0, total_dep=1, + no_results_msg="No results", + ) + assert "Main application (0 of 0 results)" in result + assert "Application library dependencies (1 of 1 results)" in result + assert "dep1" in result + + def test_trimmed_counts(self): + result = format_app_dep_output( + ["a1", "a2"], ["d1"], + total_app=5, total_dep=3, + no_results_msg="No results", + ) + assert "Main application (2 of 5 results)" in result + assert "Application library dependencies (1 of 3 results)" in result + + def test_app_before_dep(self): + result = format_app_dep_output( + ["APP_SECTION"], ["DEP_SECTION"], + total_app=1, total_dep=1, + no_results_msg="No results", + ) + app_pos = result.index("APP_SECTION") + dep_pos = result.index("DEP_SECTION") + assert app_pos < dep_pos diff --git a/tests/test_transitive_detection.py b/tests/test_transitive_detection.py index 5a5fa5a0..cd65acb9 100644 --- a/tests/test_transitive_detection.py +++ b/tests/test_transitive_detection.py @@ -1,28 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for detect_ecosystem in dep_tree.py.""" + import pytest -from exploit_iq_commons.utils.transitive_code_searcher_tool import TransitiveCodeSearcher - -C_CPP_DETECTION_SCENARIOS = [ - # Positive cases - (["src/main.c"], True), - (["src/utils.cpp"], True), - (["include/lib.h"], True), - (["src/deep/nested/logic.cxx"], True), - # Negative cases - (["Makefile", "app.py"], False), - (["README.md", "LICENSE"], False), - (["main.cs"], False), - (["style.css"], False), - ([], False), - # Ignored directories - ([".git/objects/obj.c"], False), - ([".venv/lib/site/test.c"], False), - ([".config/settings.h"], False), +from exploit_iq_commons.utils.dep_tree import Ecosystem, detect_ecosystem + + +# --- C/C++ detection: requires manifest + source files --- + +C_CPP_POSITIVE_SCENARIOS = [ + # manifest + source file combinations + ("CMakeLists.txt", ["src/main.c"]), + ("CMakeLists.txt", ["src/utils.cpp"]), + ("Makefile", ["include/lib.h"]), + ("meson.build", ["src/deep/nested/logic.cxx"]), + ("CMakeLists.txt", ["lib.cc"]), + ("CMakeLists.txt", ["api.hpp"]), ] -@pytest.mark.parametrize("file_paths, expected", C_CPP_DETECTION_SCENARIOS) -def test_has_c_cpp_sources_scenarios(tmp_path, file_paths, expected): + +C_CPP_NEGATIVE_SCENARIOS = [ + # manifest present but no C/C++ source files + ("Makefile", ["app.py"]), + ("CMakeLists.txt", ["README.md", "LICENSE"]), + ("Makefile", ["main.cs"]), + ("CMakeLists.txt", ["style.css"]), + ("Makefile", []), + # C/C++ sources but NO manifest — should not detect + (None, ["src/main.c"]), + (None, ["src/utils.cpp", "include/lib.h"]), + # manifest + source files in ignored dotfile directories + ("CMakeLists.txt", [".git/objects/obj.c"]), + ("CMakeLists.txt", [".venv/lib/site/test.c"]), + ("CMakeLists.txt", [".config/settings.h"]), +] + + +@pytest.mark.parametrize("manifest, file_paths", C_CPP_POSITIVE_SCENARIOS) +def test_c_cpp_detected(tmp_path, manifest, file_paths): + (tmp_path / manifest).touch() for path in file_paths: full_path = tmp_path / path full_path.parent.mkdir(parents=True, exist_ok=True) full_path.touch() - result = TransitiveCodeSearcher._has_c_cpp_sources(tmp_path) - assert result is expected, f"Failed for scenario: {file_paths}" \ No newline at end of file + assert detect_ecosystem(tmp_path) == Ecosystem.C_CPP + + +@pytest.mark.parametrize("manifest, file_paths", C_CPP_NEGATIVE_SCENARIOS) +def test_c_cpp_not_detected(tmp_path, manifest, file_paths): + if manifest: + (tmp_path / manifest).touch() + for path in file_paths: + full_path = tmp_path / path + full_path.parent.mkdir(parents=True, exist_ok=True) + full_path.touch() + assert detect_ecosystem(tmp_path) != Ecosystem.C_CPP + + +# --- Other ecosystem detection --- + +OTHER_ECOSYSTEM_SCENARIOS = [ + ("go.mod", Ecosystem.GO), + ("requirements.txt", Ecosystem.PYTHON), + ("pyproject.toml", Ecosystem.PYTHON), + ("setup.py", Ecosystem.PYTHON), + ("package.json", Ecosystem.JAVASCRIPT), + ("pom.xml", Ecosystem.JAVA), +] + + +@pytest.mark.parametrize("manifest, expected", OTHER_ECOSYSTEM_SCENARIOS) +def test_ecosystem_detected_by_manifest(tmp_path, manifest, expected): + (tmp_path / manifest).touch() + assert detect_ecosystem(tmp_path) == expected + + +def test_no_manifest_returns_none(tmp_path): + (tmp_path / "README.md").touch() + assert detect_ecosystem(tmp_path) is None + + +def test_empty_dir_returns_none(tmp_path): + assert detect_ecosystem(tmp_path) is None + + +# --- Priority order: Go > Python > JS > Java > C/C++ --- + +def test_go_takes_priority_over_python(tmp_path): + (tmp_path / "go.mod").touch() + (tmp_path / "requirements.txt").touch() + assert detect_ecosystem(tmp_path) == Ecosystem.GO + + +def test_python_takes_priority_over_java(tmp_path): + (tmp_path / "requirements.txt").touch() + (tmp_path / "pom.xml").touch() + assert detect_ecosystem(tmp_path) == Ecosystem.PYTHON + + +def test_java_takes_priority_over_c_cpp(tmp_path): + (tmp_path / "pom.xml").touch() + (tmp_path / "CMakeLists.txt").touch() + (tmp_path / "main.c").touch() + assert detect_ecosystem(tmp_path) == Ecosystem.JAVA + + +def test_configure_in_auto_subdir(tmp_path): + """detect_ecosystem checks for auto/configure as a C/C++ manifest.""" + (tmp_path / "auto").mkdir() + (tmp_path / "auto" / "configure").touch() + (tmp_path / "src").mkdir() + (tmp_path / "src" / "main.c").touch() + assert detect_ecosystem(tmp_path) == Ecosystem.C_CPP