From 9a58f0b7d6e13074dc43043e20f31539f340c9e7 Mon Sep 17 00:00:00 2001 From: Vivek Anand <78247712+viveka1302@users.noreply.github.com> Date: Mon, 24 Feb 2025 19:52:49 +0530 Subject: [PATCH 01/10] SQL injection checker python --- checkers/python/avoid-unsanitized-sql.go | 293 +++++++++++++++++++++++ 1 file changed, 293 insertions(+) create mode 100644 checkers/python/avoid-unsanitized-sql.go diff --git a/checkers/python/avoid-unsanitized-sql.go b/checkers/python/avoid-unsanitized-sql.go new file mode 100644 index 00000000..e316f2e0 --- /dev/null +++ b/checkers/python/avoid-unsanitized-sql.go @@ -0,0 +1,293 @@ +package python + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + sitter "github.com/smacker/go-tree-sitter" + "globstar.dev/analysis" +) + +// SQLInjection creates an analyzer that detects unsafe SQL query construction. +func SQLInjection() *analysis.Analyzer { + return &analysis.Analyzer{ + Name: "sql-injection", + Description: "Detects unsafe SQL query construction", + Category: analysis.CategorySecurity, + Severity: analysis.SeverityCritical, + Language: analysis.LangPy, // assuming LangPy is defined elsewhere + Run: func(pass *analysis.Pass) (interface{}, error) { + // Define an inner function that captures pass. + fn := func(n *sitter.Node) { + visitCall(pass, n) + } + analysis.Preorder(pass, fn) + return nil, nil + }, + } +} + +// visitCall checks if a call node is a SQL execution call and if its arguments are unsafe. +func visitCall(pass *analysis.Pass, node *sitter.Node) { + source := pass.FileContext.Source + + // Only process call nodes. + if node.Type() != "call" { + return + } + + // Extract the function part (e.g. cursor.execute). + functionNode := node.ChildByFieldName("function") + if functionNode == nil { + return + } + + // Proceed only if the function is one of our recognized SQL methods. + if !isSQLExecuteMethod(functionNode, source) { + return + } + + // Check the first argument. + argsNode := node.ChildByFieldName("arguments") + if argsNode == nil { + return + } + firstArg := getNthChild(argsNode, 0) + if firstArg == nil { + return + } + + // If the query string is built unsafely, report an issue. + if isUnsafeString(firstArg, source) { + pass.Report(pass, node, "Direct use of unsafe string in SQL query") + return + } + + // If the argument is an identifier, trace its origin. + if firstArg.Type() == "identifier" { + varName := firstArg.Content(source) + traceVariableOrigin(pass, varName, node, make(map[string]bool), make(map[string]bool), source) + } +} + +// isSQLExecuteMethod returns true if the function node is one of the SQL execution methods. +func isSQLExecuteMethod(node *sitter.Node, source []byte) bool { + var funcName string + switch node.Type() { + case "identifier": + funcName = node.Content(source) + case "attribute": + attr := node.ChildByFieldName("attribute") + if attr != nil { + funcName = attr.Content(source) + } + } + + sqlMethods := map[string]bool{ + "execute": true, + "executemany": true, + "executescript": true, + } + return sqlMethods[funcName] +} + +// isUnsafeString returns true if the node represents an unsafely built string (e.g. f-string interpolation or unsafe concatenation). +func isUnsafeString(node *sitter.Node, source []byte) bool { + // Check for f-strings with interpolation. + if node.Type() == "fstring" { + for i := 0; i < int(node.ChildCount()); i++ { + if node.Child(i).Type() == "interpolation" { + return true + } + } + } + + // Check for unsafe binary concatenation. + if node.Type() == "binary_operator" { + op := node.ChildByFieldName("operator") + if op != nil && op.Content(source) == "+" { + return containsVariable(node.ChildByFieldName("left"), source) || + containsVariable(node.ChildByFieldName("right"), source) + } + } + + return false +} + +// traceVariableOrigin recursively traces the origin of a variable through local assignments and cross‐file imports. +func traceVariableOrigin(pass *analysis.Pass, varName string, originalNode *sitter.Node, + visitedVars map[string]bool, visitedFiles map[string]bool, source []byte) { + + if visitedVars[varName] { + return + } + visitedVars[varName] = true + + if traceLocalAssignments(pass, varName, originalNode, visitedVars, visitedFiles, source) { + return + } + + traceCrossFileImports(pass, varName, originalNode, visitedVars, visitedFiles, source) +} + +// traceLocalAssignments looks for local assignments to the variable and reports if it originates from an unsafe string. +func traceLocalAssignments(pass *analysis.Pass, varName string, originalNode *sitter.Node, + visitedVars map[string]bool, visitedFiles map[string]bool, source []byte) bool { + + query := `(assignment left: (identifier) @var right: (_) @value)` + q, err := sitter.NewQuery([]byte(query), pass.Analyzer.Language.Parser()) + if err != nil { + return false + } + defer q.Close() + + cursor := sitter.NewQueryCursor() + defer cursor.Close() + cursor.Exec(q, pass.FileContext.Ast) + + for { + match, ok := cursor.NextMatch() + if !ok { + break + } + + var varNode, valueNode *sitter.Node + for _, capture := range match.Captures { + switch capture.Name { + case "var": + varNode = capture.Node + case "value": + valueNode = capture.Node + } + } + + if varNode != nil && varNode.Content(source) == varName { + if isUnsafeString(valueNode, source) { + pass.Report(pass, originalNode, fmt.Sprintf("Variable '%s' originates from an unsafe string", varName)) + return true + } + + if valueNode.Type() == "identifier" { + newVar := valueNode.Content(source) + traceVariableOrigin(pass, newVar, originalNode, visitedVars, visitedFiles, source) + return true + } + } + } + return false +} + +// traceCrossFileImports looks for import-from statements to trace the variable across files. +func traceCrossFileImports(pass *analysis.Pass, varName string, originalNode *sitter.Node, + visitedVars map[string]bool, visitedFiles map[string]bool, source []byte) { + + query := `( + (import_from_statement + module_name: (dotted_name) @module + name: (dotted_name) @imported_var + ) @import + )` + q, err := sitter.NewQuery([]byte(query), pass.Analyzer.Language.Parser()) + if err != nil { + return + } + defer q.Close() + + cursor := sitter.NewQueryCursor() + defer cursor.Close() + cursor.Exec(q, pass.FileContext.Ast) + + for { + match, ok := cursor.NextMatch() + if !ok { + break + } + + var moduleNode, varNode *sitter.Node + for _, capture := range match.Captures { + switch capture.Name { + case "module": + moduleNode = capture.Node + case "imported_var": + varNode = capture.Node + } + } + + if varNode != nil && varNode.Content(source) == varName && moduleNode != nil { + modulePath := convertImportToPath(moduleNode.Content(source)) + if visitedFiles[modulePath] { + continue + } + visitedFiles[modulePath] = true + + for _, file := range pass.Files { + if strings.HasSuffix(file.FilePath, modulePath) { + newPass := &analysis.Pass{ + Analyzer: pass.Analyzer, + FileContext: file, + Files: pass.Files, + Report: pass.Report, + } + traceVariableOrigin(newPass, varName, originalNode, visitedVars, visitedFiles, file.Source) + } + } + } + } +} + +// containsVariable returns true if the node (or any of its subnodes) is an identifier or attribute. +func containsVariable(node *sitter.Node, source []byte) bool { + if node == nil { + return false + } + switch node.Type() { + case "identifier", "attribute": + return true + case "binary_operator": + return containsVariable(node.ChildByFieldName("left"), source) || + containsVariable(node.ChildByFieldName("right"), source) + case "parenthesized_expression": + return containsVariable(node.NamedChild(0), source) + default: + return false + } +} + +// getNthChild returns the nth child of a node or nil if out of bounds. +func getNthChild(node *sitter.Node, n int) *sitter.Node { + if n < int(node.ChildCount()) { + return node.Child(n) + } + return nil +} + +// convertImportToPath converts a dotted module name to a file path (e.g. "a.b" -> "a/b.py"). +func convertImportToPath(importStr string) string { + return strings.ReplaceAll(importStr, ".", string(filepath.Separator)) + ".py" +} + +//everything beyond this point is a test + +func main() { + // Create a slice of analyzers. In this case, we include the SQLInjection analyzer. + analyzers := []*analysis.Analyzer{ + SQLInjection(), + } + + // Run the analyzers on the specified directory. + issues, err := analysis.RunAnalyzers("./afile.py", analyzers) + if err != nil { + fmt.Println("Error running analyzers:", err) + os.Exit(1) + } + + // Optionally, report the issues in text format. + output, err := analysis.ReportIssues(issues, "text") + if err != nil { + fmt.Println("Error reporting issues:", err) + os.Exit(1) + } + fmt.Println(string(output)) +} From 2cdb4bfe93da1fbccbed6ab66d5b4c658f7fa52e Mon Sep 17 00:00:00 2001 From: Vivek Anand <78247712+viveka1302@users.noreply.github.com> Date: Tue, 25 Feb 2025 01:06:25 +0530 Subject: [PATCH 02/10] Update avoid-unsanitized-sql.go --- checkers/python/avoid-unsanitized-sql.go | 1 + 1 file changed, 1 insertion(+) diff --git a/checkers/python/avoid-unsanitized-sql.go b/checkers/python/avoid-unsanitized-sql.go index e316f2e0..1e764ab4 100644 --- a/checkers/python/avoid-unsanitized-sql.go +++ b/checkers/python/avoid-unsanitized-sql.go @@ -272,6 +272,7 @@ func convertImportToPath(importStr string) string { func main() { // Create a slice of analyzers. In this case, we include the SQLInjection analyzer. + fmt.Println("print if ever executed") analyzers := []*analysis.Analyzer{ SQLInjection(), } From 134b832247f4be28e348a28923bef1906676590f Mon Sep 17 00:00:00 2001 From: Vivek Anand <78247712+viveka1302@users.noreply.github.com> Date: Tue, 25 Feb 2025 01:37:18 +0530 Subject: [PATCH 03/10] Update avoid-unsanitized-sql.go --- checkers/python/avoid-unsanitized-sql.go | 26 ------------------------ 1 file changed, 26 deletions(-) diff --git a/checkers/python/avoid-unsanitized-sql.go b/checkers/python/avoid-unsanitized-sql.go index 1e764ab4..d237b2c1 100644 --- a/checkers/python/avoid-unsanitized-sql.go +++ b/checkers/python/avoid-unsanitized-sql.go @@ -2,7 +2,6 @@ package python import ( "fmt" - "os" "path/filepath" "strings" @@ -267,28 +266,3 @@ func getNthChild(node *sitter.Node, n int) *sitter.Node { func convertImportToPath(importStr string) string { return strings.ReplaceAll(importStr, ".", string(filepath.Separator)) + ".py" } - -//everything beyond this point is a test - -func main() { - // Create a slice of analyzers. In this case, we include the SQLInjection analyzer. - fmt.Println("print if ever executed") - analyzers := []*analysis.Analyzer{ - SQLInjection(), - } - - // Run the analyzers on the specified directory. - issues, err := analysis.RunAnalyzers("./afile.py", analyzers) - if err != nil { - fmt.Println("Error running analyzers:", err) - os.Exit(1) - } - - // Optionally, report the issues in text format. - output, err := analysis.ReportIssues(issues, "text") - if err != nil { - fmt.Println("Error reporting issues:", err) - os.Exit(1) - } - fmt.Println(string(output)) -} From dfa2e0b344713f929f7060610a5fba45bcc43e9f Mon Sep 17 00:00:00 2001 From: Vivek Anand <78247712+viveka1302@users.noreply.github.com> Date: Tue, 25 Feb 2025 02:35:07 +0530 Subject: [PATCH 04/10] Update avoid-unsanitized-sql.go --- checkers/python/avoid-unsanitized-sql.go | 86 ++++++++++-------------- 1 file changed, 37 insertions(+), 49 deletions(-) diff --git a/checkers/python/avoid-unsanitized-sql.go b/checkers/python/avoid-unsanitized-sql.go index d237b2c1..a23e475d 100644 --- a/checkers/python/avoid-unsanitized-sql.go +++ b/checkers/python/avoid-unsanitized-sql.go @@ -9,28 +9,9 @@ import ( "globstar.dev/analysis" ) -// SQLInjection creates an analyzer that detects unsafe SQL query construction. -func SQLInjection() *analysis.Analyzer { - return &analysis.Analyzer{ - Name: "sql-injection", - Description: "Detects unsafe SQL query construction", - Category: analysis.CategorySecurity, - Severity: analysis.SeverityCritical, - Language: analysis.LangPy, // assuming LangPy is defined elsewhere - Run: func(pass *analysis.Pass) (interface{}, error) { - // Define an inner function that captures pass. - fn := func(n *sitter.Node) { - visitCall(pass, n) - } - analysis.Preorder(pass, fn) - return nil, nil - }, - } -} - -// visitCall checks if a call node is a SQL execution call and if its arguments are unsafe. -func visitCall(pass *analysis.Pass, node *sitter.Node) { - source := pass.FileContext.Source +// checkSQLInjection is the rule callback that inspects each call node. +func checkSQLInjection(r analysis.Rule, ana *analysis.Analyzer, node *sitter.Node) { + source := ana.FileContext.Source // Only process call nodes. if node.Type() != "call" { @@ -60,18 +41,28 @@ func visitCall(pass *analysis.Pass, node *sitter.Node) { // If the query string is built unsafely, report an issue. if isUnsafeString(firstArg, source) { - pass.Report(pass, node, "Direct use of unsafe string in SQL query") + ana.Report(&analysis.Issue{ + Message: "Direct use of unsafe string in SQL query", + Range: node.Range(), + }) return } // If the argument is an identifier, trace its origin. if firstArg.Type() == "identifier" { varName := firstArg.Content(source) - traceVariableOrigin(pass, varName, node, make(map[string]bool), make(map[string]bool), source) + traceVariableOrigin(r, ana, varName, node, make(map[string]bool), make(map[string]bool), source) } } -// isSQLExecuteMethod returns true if the function node is one of the SQL execution methods. +// SQLInjection registers the SQL injection rule. +func SQLInjection() analysis.Rule { + var entry analysis.VisitFn = checkSQLInjection + return analysis.CreateRule("call", analysis.LangPy, &entry, nil) +} + +// --- Helper Functions --- + func isSQLExecuteMethod(node *sitter.Node, source []byte) bool { var funcName string switch node.Type() { @@ -92,7 +83,6 @@ func isSQLExecuteMethod(node *sitter.Node, source []byte) bool { return sqlMethods[funcName] } -// isUnsafeString returns true if the node represents an unsafely built string (e.g. f-string interpolation or unsafe concatenation). func isUnsafeString(node *sitter.Node, source []byte) bool { // Check for f-strings with interpolation. if node.Type() == "fstring" { @@ -115,8 +105,7 @@ func isUnsafeString(node *sitter.Node, source []byte) bool { return false } -// traceVariableOrigin recursively traces the origin of a variable through local assignments and cross‐file imports. -func traceVariableOrigin(pass *analysis.Pass, varName string, originalNode *sitter.Node, +func traceVariableOrigin(r analysis.Rule, ana *analysis.Analyzer, varName string, originalNode *sitter.Node, visitedVars map[string]bool, visitedFiles map[string]bool, source []byte) { if visitedVars[varName] { @@ -124,19 +113,18 @@ func traceVariableOrigin(pass *analysis.Pass, varName string, originalNode *sitt } visitedVars[varName] = true - if traceLocalAssignments(pass, varName, originalNode, visitedVars, visitedFiles, source) { + if traceLocalAssignments(r, ana, varName, originalNode, visitedVars, visitedFiles, source) { return } - traceCrossFileImports(pass, varName, originalNode, visitedVars, visitedFiles, source) + traceCrossFileImports(r, ana, varName, originalNode, visitedVars, visitedFiles, source) } -// traceLocalAssignments looks for local assignments to the variable and reports if it originates from an unsafe string. -func traceLocalAssignments(pass *analysis.Pass, varName string, originalNode *sitter.Node, +func traceLocalAssignments(r analysis.Rule, ana *analysis.Analyzer, varName string, originalNode *sitter.Node, visitedVars map[string]bool, visitedFiles map[string]bool, source []byte) bool { query := `(assignment left: (identifier) @var right: (_) @value)` - q, err := sitter.NewQuery([]byte(query), pass.Analyzer.Language.Parser()) + q, err := sitter.NewQuery([]byte(query), ana.Language.Parser()) if err != nil { return false } @@ -144,7 +132,7 @@ func traceLocalAssignments(pass *analysis.Pass, varName string, originalNode *si cursor := sitter.NewQueryCursor() defer cursor.Close() - cursor.Exec(q, pass.FileContext.Ast) + cursor.Exec(q, ana.FileContext.Ast) for { match, ok := cursor.NextMatch() @@ -164,13 +152,16 @@ func traceLocalAssignments(pass *analysis.Pass, varName string, originalNode *si if varNode != nil && varNode.Content(source) == varName { if isUnsafeString(valueNode, source) { - pass.Report(pass, originalNode, fmt.Sprintf("Variable '%s' originates from an unsafe string", varName)) + ana.Report(&analysis.Issue{ + Message: fmt.Sprintf("Variable '%s' originates from an unsafe string", varName), + Range: originalNode.Range(), + }) return true } if valueNode.Type() == "identifier" { newVar := valueNode.Content(source) - traceVariableOrigin(pass, newVar, originalNode, visitedVars, visitedFiles, source) + traceVariableOrigin(r, ana, newVar, originalNode, visitedVars, visitedFiles, source) return true } } @@ -178,8 +169,7 @@ func traceLocalAssignments(pass *analysis.Pass, varName string, originalNode *si return false } -// traceCrossFileImports looks for import-from statements to trace the variable across files. -func traceCrossFileImports(pass *analysis.Pass, varName string, originalNode *sitter.Node, +func traceCrossFileImports(r analysis.Rule, ana *analysis.Analyzer, varName string, originalNode *sitter.Node, visitedVars map[string]bool, visitedFiles map[string]bool, source []byte) { query := `( @@ -188,7 +178,7 @@ func traceCrossFileImports(pass *analysis.Pass, varName string, originalNode *si name: (dotted_name) @imported_var ) @import )` - q, err := sitter.NewQuery([]byte(query), pass.Analyzer.Language.Parser()) + q, err := sitter.NewQuery([]byte(query), ana.Language.Parser()) if err != nil { return } @@ -196,7 +186,7 @@ func traceCrossFileImports(pass *analysis.Pass, varName string, originalNode *si cursor := sitter.NewQueryCursor() defer cursor.Close() - cursor.Exec(q, pass.FileContext.Ast) + cursor.Exec(q, ana.FileContext.Ast) for { match, ok := cursor.NextMatch() @@ -221,22 +211,22 @@ func traceCrossFileImports(pass *analysis.Pass, varName string, originalNode *si } visitedFiles[modulePath] = true - for _, file := range pass.Files { + for _, file := range ana.Files { if strings.HasSuffix(file.FilePath, modulePath) { - newPass := &analysis.Pass{ - Analyzer: pass.Analyzer, + // Create a temporary analyzer context for the imported file. + tempAna := &analysis.Analyzer{ + Language: ana.Language, FileContext: file, - Files: pass.Files, - Report: pass.Report, + Files: ana.Files, + Report: ana.Report, // Reuse the report function. } - traceVariableOrigin(newPass, varName, originalNode, visitedVars, visitedFiles, file.Source) + traceVariableOrigin(r, tempAna, varName, originalNode, visitedVars, visitedFiles, file.Source) } } } } } -// containsVariable returns true if the node (or any of its subnodes) is an identifier or attribute. func containsVariable(node *sitter.Node, source []byte) bool { if node == nil { return false @@ -254,7 +244,6 @@ func containsVariable(node *sitter.Node, source []byte) bool { } } -// getNthChild returns the nth child of a node or nil if out of bounds. func getNthChild(node *sitter.Node, n int) *sitter.Node { if n < int(node.ChildCount()) { return node.Child(n) @@ -262,7 +251,6 @@ func getNthChild(node *sitter.Node, n int) *sitter.Node { return nil } -// convertImportToPath converts a dotted module name to a file path (e.g. "a.b" -> "a/b.py"). func convertImportToPath(importStr string) string { return strings.ReplaceAll(importStr, ".", string(filepath.Separator)) + ".py" } From 153b83213ab74f08b250a5b0bc8bea818ac633a7 Mon Sep 17 00:00:00 2001 From: Vivek Anand <78247712+viveka1302@users.noreply.github.com> Date: Tue, 25 Feb 2025 02:45:43 +0530 Subject: [PATCH 05/10] Create avoid-unsanitized-sql.test.py --- checkers/python/avoid-unsanitized-sql.test.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 checkers/python/avoid-unsanitized-sql.test.py diff --git a/checkers/python/avoid-unsanitized-sql.test.py b/checkers/python/avoid-unsanitized-sql.test.py new file mode 100644 index 00000000..1bbc889b --- /dev/null +++ b/checkers/python/avoid-unsanitized-sql.test.py @@ -0,0 +1,23 @@ + +import sqlite3 +from fastapi import FastAPI, Query +import sqlite3 + +app = FastAPI() + +def execute_unsafe_query(query: str): + conn = sqlite3.connect("test.db") + cursor = conn.cursor() + cursor.execute(query) # ✅ Uses parameterized query + result = cursor.fetchall() + conn.commit() + conn.close() + return result + +@app.get("/unsafe_query/") +def unsafe_query(user_input: str): + query = f"SELECT * FROM users WHERE name = {user_input}" + query2= "SELECT * FROM users WHERE name ="+ user_input + result = execute_unsafe_query(query) + result2= execute_unsafe_query(query=query2) + return {"result": result, "result2": result2} From 8e2d0bab47b51f61ad95ff4ed792b3d60b308efc Mon Sep 17 00:00:00 2001 From: Vivek Anand <78247712+viveka1302@users.noreply.github.com> Date: Tue, 25 Feb 2025 04:04:07 +0530 Subject: [PATCH 06/10] Update avoid-unsanitized-sql.go --- checkers/python/avoid-unsanitized-sql.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/checkers/python/avoid-unsanitized-sql.go b/checkers/python/avoid-unsanitized-sql.go index a23e475d..5fbd67f7 100644 --- a/checkers/python/avoid-unsanitized-sql.go +++ b/checkers/python/avoid-unsanitized-sql.go @@ -42,8 +42,10 @@ func checkSQLInjection(r analysis.Rule, ana *analysis.Analyzer, node *sitter.Nod // If the query string is built unsafely, report an issue. if isUnsafeString(firstArg, source) { ana.Report(&analysis.Issue{ - Message: "Direct use of unsafe string in SQL query", - Range: node.Range(), + Message: "Concatenated string in SQL query is an SQL injection threat!", + Category: analysis.CategorySecurity, + Severity: analysis.SeverityCritical, + Range: node.Range(), }) return } From a8ba78ab14316ceacc842dcf3c108384082a0e87 Mon Sep 17 00:00:00 2001 From: Vivek Anand <78247712+viveka1302@users.noreply.github.com> Date: Tue, 25 Feb 2025 23:32:14 +0530 Subject: [PATCH 07/10] test-directives added --- checkers/python/avoid-unsanitized-sql.test.py | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/checkers/python/avoid-unsanitized-sql.test.py b/checkers/python/avoid-unsanitized-sql.test.py index 1bbc889b..8f3d8cf2 100644 --- a/checkers/python/avoid-unsanitized-sql.test.py +++ b/checkers/python/avoid-unsanitized-sql.test.py @@ -8,16 +8,35 @@ def execute_unsafe_query(query: str): conn = sqlite3.connect("test.db") cursor = conn.cursor() - cursor.execute(query) # ✅ Uses parameterized query + cursor.execute(query) #unsafe with user input result = cursor.fetchall() conn.commit() conn.close() return result +def better_query(query: str, params): + conn = sqlite3.connect("test.db") + cursor = conn.cursor() + cursor.execute(query, params) #safe to execute with user input + result = cursor.fetchall() + conn.commit() + conn.close() + return result + + @app.get("/unsafe_query/") def unsafe_query(user_input: str): + #f-string case + # query = f"SELECT * FROM users WHERE name = {user_input}" + #binary operator case + # query2= "SELECT * FROM users WHERE name ="+ user_input + #should not identify this as an error + query3= "SELECT * FROM user WHERE name= ?" result = execute_unsafe_query(query) result2= execute_unsafe_query(query=query2) - return {"result": result, "result2": result2} + + result3= better_query(query=query3, params=(user_input,)) + + return {"result": result, "result2": result2, "result3": result3} From c57caa6ace24cca089461fa51293a9c6a807c30e Mon Sep 17 00:00:00 2001 From: Sourya Vatsyayan Date: Sun, 9 Mar 2025 18:03:31 +0530 Subject: [PATCH 08/10] chore: migrate to the new runner Signed-off-by: Sourya Vatsyayan --- checkers/checker.go | 5 + checkers/python/avoid-unsanitized-sql.go | 135 +++++++++--------- .../testdata/avoid-unsanitized-sql.test.py | 41 ++++++ 3 files changed, 113 insertions(+), 68 deletions(-) create mode 100644 checkers/python/testdata/avoid-unsanitized-sql.test.py diff --git a/checkers/checker.go b/checkers/checker.go index 20fe6e98..9e9b878f 100644 --- a/checkers/checker.go +++ b/checkers/checker.go @@ -9,6 +9,7 @@ import ( goAnalysis "globstar.dev/analysis" "globstar.dev/checkers/javascript" + "globstar.dev/checkers/python" "globstar.dev/pkg/analysis" ) @@ -69,6 +70,10 @@ var AnalyzerRegistry = []Analyzer{ TestDir: "checkers/javascript/testdata", // relative to the repository root Analyzers: []*goAnalysis.Analyzer{javascript.NoDoubleEq, javascript.SQLInjection}, }, + { + TestDir: "checkers/python/testdata", // relative to the repository root + Analyzers: []*goAnalysis.Analyzer{python.AvoidUnsanitizedSQL}, + }, } func LoadGoCheckers() []*goAnalysis.Analyzer { diff --git a/checkers/python/avoid-unsanitized-sql.go b/checkers/python/avoid-unsanitized-sql.go index 5fbd67f7..b3052382 100644 --- a/checkers/python/avoid-unsanitized-sql.go +++ b/checkers/python/avoid-unsanitized-sql.go @@ -9,58 +9,60 @@ import ( "globstar.dev/analysis" ) +var AvoidUnsanitizedSQL = &analysis.Analyzer{ + Name: "avoid-unsanitized-sql", + Language: analysis.LangPy, + Description: "Check if SQL query is sanitized", + Category: analysis.CategorySecurity, + Severity: analysis.SeverityCritical, + Run: checkSQLInjection, +} + // checkSQLInjection is the rule callback that inspects each call node. -func checkSQLInjection(r analysis.Rule, ana *analysis.Analyzer, node *sitter.Node) { - source := ana.FileContext.Source +func checkSQLInjection(pass *analysis.Pass) (interface{}, error) { + analysis.Preorder(pass, func(node *sitter.Node) { + source := pass.FileContext.Source - // Only process call nodes. - if node.Type() != "call" { - return - } + // Only process call nodes. + if node.Type() != "call" { + return + } - // Extract the function part (e.g. cursor.execute). - functionNode := node.ChildByFieldName("function") - if functionNode == nil { - return - } + // Extract the function part (e.g. cursor.execute). + functionNode := node.ChildByFieldName("function") + if functionNode == nil { + return + } - // Proceed only if the function is one of our recognized SQL methods. - if !isSQLExecuteMethod(functionNode, source) { - return - } + // Proceed only if the function is one of our recognized SQL methods. + if !isSQLExecuteMethod(functionNode, source) { + return + } - // Check the first argument. - argsNode := node.ChildByFieldName("arguments") - if argsNode == nil { - return - } - firstArg := getNthChild(argsNode, 0) - if firstArg == nil { - return - } + // Check the first argument. + argsNode := node.ChildByFieldName("arguments") + if argsNode == nil { + return + } + firstArg := getNthChild(argsNode, 0) + if firstArg == nil { + return + } - // If the query string is built unsafely, report an issue. - if isUnsafeString(firstArg, source) { - ana.Report(&analysis.Issue{ - Message: "Concatenated string in SQL query is an SQL injection threat!", - Category: analysis.CategorySecurity, - Severity: analysis.SeverityCritical, - Range: node.Range(), - }) - return - } + // If the query string is built unsafely, report an issue. + if isUnsafeString(firstArg, source) { + pass.Report(pass, node, "Concatenated string in SQL query is an SQL injection threat") + return + } - // If the argument is an identifier, trace its origin. - if firstArg.Type() == "identifier" { - varName := firstArg.Content(source) - traceVariableOrigin(r, ana, varName, node, make(map[string]bool), make(map[string]bool), source) - } -} + // If the argument is an identifier, trace its origin. + if firstArg.Type() == "identifier" { + varName := firstArg.Content(source) + traceVariableOrigin(pass, varName, node, make(map[string]bool), make(map[string]bool), source) + } + }) -// SQLInjection registers the SQL injection rule. -func SQLInjection() analysis.Rule { - var entry analysis.VisitFn = checkSQLInjection - return analysis.CreateRule("call", analysis.LangPy, &entry, nil) + return nil, nil } // --- Helper Functions --- @@ -107,7 +109,7 @@ func isUnsafeString(node *sitter.Node, source []byte) bool { return false } -func traceVariableOrigin(r analysis.Rule, ana *analysis.Analyzer, varName string, originalNode *sitter.Node, +func traceVariableOrigin(pass *analysis.Pass, varName string, originalNode *sitter.Node, visitedVars map[string]bool, visitedFiles map[string]bool, source []byte) { if visitedVars[varName] { @@ -115,18 +117,18 @@ func traceVariableOrigin(r analysis.Rule, ana *analysis.Analyzer, varName string } visitedVars[varName] = true - if traceLocalAssignments(r, ana, varName, originalNode, visitedVars, visitedFiles, source) { + if traceLocalAssignments(pass, varName, originalNode, visitedVars, visitedFiles, source) { return } - traceCrossFileImports(r, ana, varName, originalNode, visitedVars, visitedFiles, source) + traceCrossFileImports(pass, varName, originalNode, visitedVars, visitedFiles, source) } -func traceLocalAssignments(r analysis.Rule, ana *analysis.Analyzer, varName string, originalNode *sitter.Node, +func traceLocalAssignments(pass *analysis.Pass, varName string, originalNode *sitter.Node, visitedVars map[string]bool, visitedFiles map[string]bool, source []byte) bool { query := `(assignment left: (identifier) @var right: (_) @value)` - q, err := sitter.NewQuery([]byte(query), ana.Language.Parser()) + q, err := sitter.NewQuery([]byte(query), pass.Analyzer.Language.Grammar()) if err != nil { return false } @@ -134,7 +136,7 @@ func traceLocalAssignments(r analysis.Rule, ana *analysis.Analyzer, varName stri cursor := sitter.NewQueryCursor() defer cursor.Close() - cursor.Exec(q, ana.FileContext.Ast) + cursor.Exec(q, pass.FileContext.Ast) for { match, ok := cursor.NextMatch() @@ -143,8 +145,8 @@ func traceLocalAssignments(r analysis.Rule, ana *analysis.Analyzer, varName stri } var varNode, valueNode *sitter.Node - for _, capture := range match.Captures { - switch capture.Name { + for idx, capture := range match.Captures { + switch q.CaptureNameForId(uint32(idx)) { case "var": varNode = capture.Node case "value": @@ -154,16 +156,13 @@ func traceLocalAssignments(r analysis.Rule, ana *analysis.Analyzer, varName stri if varNode != nil && varNode.Content(source) == varName { if isUnsafeString(valueNode, source) { - ana.Report(&analysis.Issue{ - Message: fmt.Sprintf("Variable '%s' originates from an unsafe string", varName), - Range: originalNode.Range(), - }) + pass.Report(pass, originalNode, fmt.Sprintf("Variable '%s' originates from an unsafe string", varName)) return true } if valueNode.Type() == "identifier" { newVar := valueNode.Content(source) - traceVariableOrigin(r, ana, newVar, originalNode, visitedVars, visitedFiles, source) + traceVariableOrigin(pass, newVar, originalNode, visitedVars, visitedFiles, source) return true } } @@ -171,7 +170,7 @@ func traceLocalAssignments(r analysis.Rule, ana *analysis.Analyzer, varName stri return false } -func traceCrossFileImports(r analysis.Rule, ana *analysis.Analyzer, varName string, originalNode *sitter.Node, +func traceCrossFileImports(pass *analysis.Pass, varName string, originalNode *sitter.Node, visitedVars map[string]bool, visitedFiles map[string]bool, source []byte) { query := `( @@ -180,7 +179,7 @@ func traceCrossFileImports(r analysis.Rule, ana *analysis.Analyzer, varName stri name: (dotted_name) @imported_var ) @import )` - q, err := sitter.NewQuery([]byte(query), ana.Language.Parser()) + q, err := sitter.NewQuery([]byte(query), pass.Analyzer.Language.Grammar()) if err != nil { return } @@ -188,7 +187,7 @@ func traceCrossFileImports(r analysis.Rule, ana *analysis.Analyzer, varName stri cursor := sitter.NewQueryCursor() defer cursor.Close() - cursor.Exec(q, ana.FileContext.Ast) + cursor.Exec(q, pass.FileContext.Ast) for { match, ok := cursor.NextMatch() @@ -197,8 +196,8 @@ func traceCrossFileImports(r analysis.Rule, ana *analysis.Analyzer, varName stri } var moduleNode, varNode *sitter.Node - for _, capture := range match.Captures { - switch capture.Name { + for idx, capture := range match.Captures { + switch q.CaptureNameForId(uint32(idx)) { case "module": moduleNode = capture.Node case "imported_var": @@ -213,16 +212,16 @@ func traceCrossFileImports(r analysis.Rule, ana *analysis.Analyzer, varName stri } visitedFiles[modulePath] = true - for _, file := range ana.Files { + for _, file := range pass.Files { if strings.HasSuffix(file.FilePath, modulePath) { // Create a temporary analyzer context for the imported file. - tempAna := &analysis.Analyzer{ - Language: ana.Language, + tempPass := &analysis.Pass{ + Analyzer: pass.Analyzer, FileContext: file, - Files: ana.Files, - Report: ana.Report, // Reuse the report function. + Files: pass.Files, + Report: pass.Report, // Reuse the report function. } - traceVariableOrigin(r, tempAna, varName, originalNode, visitedVars, visitedFiles, file.Source) + traceVariableOrigin(tempPass, varName, originalNode, visitedVars, visitedFiles, file.Source) } } } diff --git a/checkers/python/testdata/avoid-unsanitized-sql.test.py b/checkers/python/testdata/avoid-unsanitized-sql.test.py new file mode 100644 index 00000000..72a03e1c --- /dev/null +++ b/checkers/python/testdata/avoid-unsanitized-sql.test.py @@ -0,0 +1,41 @@ +import sqlite3 +from fastapi import FastAPI, Query +import sqlite3 + +app = FastAPI() + +def execute_unsafe_query(query: str): + conn = sqlite3.connect("test.db") + cursor = conn.cursor() + cursor.execute(query) #unsafe with user input + result = cursor.fetchall() + conn.commit() + conn.close() + return result + +def better_query(query: str, params): + conn = sqlite3.connect("test.db") + cursor = conn.cursor() + cursor.execute(query, params) #safe to execute with user input + result = cursor.fetchall() + conn.commit() + conn.close() + return result + + +@app.get("/unsafe_query/") +def unsafe_query(user_input: str): + #f-string case + # + query = f"SELECT * FROM users WHERE name = {user_input}" + #binary operator case + # + query2= "SELECT * FROM users WHERE name ="+ user_input + #should not identify this as an error + query3= "SELECT * FROM user WHERE name= ?" + result = execute_unsafe_query(query) + result2= execute_unsafe_query(query=query2) + + result3= better_query(query=query3, params=(user_input,)) + + return {"result": result, "result2": result2, "result3": result3} \ No newline at end of file From 3e8106e1852a5cee074d6354e65ef76d9424fddb Mon Sep 17 00:00:00 2001 From: Vivek Anand <78247712+viveka1302@users.noreply.github.com> Date: Thu, 13 Mar 2025 12:25:21 +0530 Subject: [PATCH 09/10] Update avoid-unsanitized-sql.test.py Signed-off-by: Vivek Anand <78247712+viveka1302@users.noreply.github.com> --- checkers/python/avoid-unsanitized-sql.test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/checkers/python/avoid-unsanitized-sql.test.py b/checkers/python/avoid-unsanitized-sql.test.py index 8f3d8cf2..d41e160a 100644 --- a/checkers/python/avoid-unsanitized-sql.test.py +++ b/checkers/python/avoid-unsanitized-sql.test.py @@ -8,6 +8,7 @@ def execute_unsafe_query(query: str): conn = sqlite3.connect("test.db") cursor = conn.cursor() + # cursor.execute(query) #unsafe with user input result = cursor.fetchall() conn.commit() @@ -27,11 +28,12 @@ def better_query(query: str, params): @app.get("/unsafe_query/") def unsafe_query(user_input: str): #f-string case - # + query = f"SELECT * FROM users WHERE name = {user_input}" #binary operator case - # + query2= "SELECT * FROM users WHERE name ="+ user_input + #should not identify this as an error query3= "SELECT * FROM user WHERE name= ?" result = execute_unsafe_query(query) From 38893962ef102e4cb79c161cf04a36fbe34d21a1 Mon Sep 17 00:00:00 2001 From: Vivek Anand <78247712+viveka1302@users.noreply.github.com> Date: Thu, 13 Mar 2025 12:36:57 +0530 Subject: [PATCH 10/10] Update avoid-unsanitized-sql.test.py Signed-off-by: Vivek Anand <78247712+viveka1302@users.noreply.github.com> --- checkers/python/testdata/avoid-unsanitized-sql.test.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/checkers/python/testdata/avoid-unsanitized-sql.test.py b/checkers/python/testdata/avoid-unsanitized-sql.test.py index 72a03e1c..32ceb70c 100644 --- a/checkers/python/testdata/avoid-unsanitized-sql.test.py +++ b/checkers/python/testdata/avoid-unsanitized-sql.test.py @@ -7,6 +7,7 @@ def execute_unsafe_query(query: str): conn = sqlite3.connect("test.db") cursor = conn.cursor() + # cursor.execute(query) #unsafe with user input result = cursor.fetchall() conn.commit() @@ -26,10 +27,10 @@ def better_query(query: str, params): @app.get("/unsafe_query/") def unsafe_query(user_input: str): #f-string case - # + query = f"SELECT * FROM users WHERE name = {user_input}" #binary operator case - # + query2= "SELECT * FROM users WHERE name ="+ user_input #should not identify this as an error query3= "SELECT * FROM user WHERE name= ?" @@ -38,4 +39,4 @@ def unsafe_query(user_input: str): result3= better_query(query=query3, params=(user_input,)) - return {"result": result, "result2": result2, "result3": result3} \ No newline at end of file + return {"result": result, "result2": result2, "result3": result3}