diff --git a/internal/compiler/engine.go b/internal/compiler/engine.go index 75749cd6df..d88825dde2 100644 --- a/internal/compiler/engine.go +++ b/internal/compiler/engine.go @@ -10,6 +10,7 @@ import ( "github.com/sqlc-dev/sqlc/internal/engine/dolphin" "github.com/sqlc-dev/sqlc/internal/engine/postgresql" pganalyze "github.com/sqlc-dev/sqlc/internal/engine/postgresql/analyzer" + "github.com/sqlc-dev/sqlc/internal/engine/postgresql/pglite" "github.com/sqlc-dev/sqlc/internal/engine/sqlite" sqliteanalyze "github.com/sqlc-dev/sqlc/internal/engine/sqlite/analyzer" "github.com/sqlc-dev/sqlc/internal/opts" @@ -59,7 +60,12 @@ func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, err c.parser = postgresql.NewParser() c.catalog = postgresql.NewCatalog() c.selector = newDefaultSelector() - if conf.Database != nil { + + // Check if PGLite analyzer is configured and experiment is enabled + exp := opts.ExperimentFromEnv() + if exp.PGLite && conf.Analyzer.PGLite != nil { + c.analyzer = pglite.New(*conf.Analyzer.PGLite) + } else if conf.Database != nil { if conf.Analyzer.Database == nil || *conf.Analyzer.Database { c.analyzer = analyzer.Cached( pganalyze.New(c.client, *conf.Database), diff --git a/internal/config/config.go b/internal/config/config.go index 0ff805fccd..c50b39dff6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -123,7 +123,13 @@ type SQL struct { } type Analyzer struct { - Database *bool `json:"database" yaml:"database"` + Database *bool `json:"database" yaml:"database"` + PGLite *PGLite `json:"pglite" yaml:"pglite"` +} + +type PGLite struct { + URL string `json:"url" yaml:"url"` + SHA256 string `json:"sha256" yaml:"sha256"` } // TODO: Figure out a better name for this diff --git a/internal/engine/postgresql/pglite/analyze.go b/internal/engine/postgresql/pglite/analyze.go new file mode 100644 index 0000000000..562a137042 --- /dev/null +++ b/internal/engine/postgresql/pglite/analyze.go @@ -0,0 +1,616 @@ +// Package pglite provides a PostgreSQL analyzer that uses PGLite running in WebAssembly. +// This allows for database-backed type analysis without requiring a running PostgreSQL server. +// +// To use this analyzer, enable it with SQLCEXPERIMENT=pglite and configure it in sqlc.yaml: +// +// sql: +// - engine: postgresql +// analyzer: +// pglite: +// url: "file://path/to/pglite.wasm" +// sha256: "..." +package pglite + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "path/filepath" + "runtime" + "strings" + "sync" + + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/api" + "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" + + core "github.com/sqlc-dev/sqlc/internal/analysis" + "github.com/sqlc-dev/sqlc/internal/cache" + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/info" + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/named" + "github.com/sqlc-dev/sqlc/internal/sql/sqlerr" +) + +// Request types for communication with PGLite WASM module +type Request struct { + Type string `json:"type"` // "init", "exec", "prepare", "close" + Migrations []string `json:"migrations"` // For "init": schema migrations to apply + Query string `json:"query"` // For "exec" and "prepare": SQL query +} + +// Response from PGLite WASM module +type Response struct { + Success bool `json:"success"` + Error *ErrorResponse `json:"error,omitempty"` + Prepare *PrepareResult `json:"prepare,omitempty"` + Exec *ExecResult `json:"exec,omitempty"` + Query *QueryResult `json:"query,omitempty"` +} + +type ErrorResponse struct { + Code string `json:"code"` + Message string `json:"message"` + Position int `json:"position"` +} + +type PrepareResult struct { + Columns []ColumnInfo `json:"columns"` + Params []ParameterInfo `json:"params"` +} + +type ColumnInfo struct { + Name string `json:"name"` + DataType string `json:"data_type"` + DataTypeOID uint32 `json:"data_type_oid"` + NotNull bool `json:"not_null"` + IsArray bool `json:"is_array"` + ArrayDims int `json:"array_dims"` + TableOID uint32 `json:"table_oid,omitempty"` + TableName string `json:"table_name,omitempty"` + TableSchema string `json:"table_schema,omitempty"` +} + +type ParameterInfo struct { + Number int `json:"number"` + DataType string `json:"data_type"` + DataTypeOID uint32 `json:"data_type_oid"` + IsArray bool `json:"is_array"` + ArrayDims int `json:"array_dims"` +} + +type ExecResult struct { + RowsAffected int64 `json:"rows_affected"` +} + +type QueryResult struct { + Columns []string `json:"columns"` + Rows [][]interface{} `json:"rows"` +} + +// Analyzer implements the analyzer.Analyzer interface using PGLite WASM. +type Analyzer struct { + cfg config.PGLite + mu sync.Mutex + rt wazero.Runtime + mod api.Module + inited bool + schema []string + + // Caches for type lookups + formats sync.Map + columns sync.Map + tables sync.Map +} + +// New creates a new PGLite analyzer with the given configuration. +func New(cfg config.PGLite) *Analyzer { + return &Analyzer{ + cfg: cfg, + } +} + +// Analyze implements the analyzer.Analyzer interface. +// It prepares the given query against PGLite to extract column and parameter type information. +func (a *Analyzer) Analyze(ctx context.Context, n ast.Node, query string, migrations []string, ps *named.ParamSet) (*core.Analysis, error) { + a.mu.Lock() + defer a.mu.Unlock() + + // Initialize if not already done or if migrations changed + if !a.inited || !equalMigrations(a.schema, migrations) { + if err := a.init(ctx, migrations); err != nil { + return nil, fmt.Errorf("pglite init: %w", err) + } + a.schema = migrations + a.inited = true + } + + // Prepare the query to get type information + result, err := a.prepare(ctx, query) + if err != nil { + // Convert PGLite error to sqlerr.Error if possible + var pgliteErr *PGLiteError + if errors.As(err, &pgliteErr) { + return nil, &sqlerr.Error{ + Code: pgliteErr.Code, + Message: pgliteErr.Message, + Location: max(n.Pos()+pgliteErr.Position-1, 0), + } + } + return nil, err + } + + var analysis core.Analysis + + // Convert columns + for _, col := range result.Columns { + dt := rewriteType(col.DataType) + column := &core.Column{ + Name: col.Name, + OriginalName: col.Name, + DataType: dt, + NotNull: col.NotNull, + IsArray: col.IsArray, + ArrayDims: int32(col.ArrayDims), + } + if col.TableName != "" { + column.Table = &core.Identifier{ + Schema: col.TableSchema, + Name: col.TableName, + } + } + analysis.Columns = append(analysis.Columns, column) + } + + // Convert parameters + for _, param := range result.Params { + dt := rewriteType(param.DataType) + name := "" + if ps != nil { + name, _ = ps.NameFor(param.Number) + } + analysis.Params = append(analysis.Params, &core.Parameter{ + Number: int32(param.Number), + Column: &core.Column{ + Name: name, + DataType: dt, + IsArray: param.IsArray, + ArrayDims: int32(param.ArrayDims), + NotNull: false, // Parameters are nullable by default + }, + }) + } + + return &analysis, nil +} + +// Close implements the analyzer.Analyzer interface. +func (a *Analyzer) Close(ctx context.Context) error { + a.mu.Lock() + defer a.mu.Unlock() + + if a.mod != nil { + a.mod.Close(ctx) + a.mod = nil + } + if a.rt != nil { + a.rt.Close(ctx) + a.rt = nil + } + a.inited = false + return nil +} + +// init initializes or reinitializes PGLite with the given migrations. +func (a *Analyzer) init(ctx context.Context, migrations []string) error { + // Close existing runtime if any + if a.rt != nil { + if a.mod != nil { + a.mod.Close(ctx) + } + a.rt.Close(ctx) + } + + // Clear caches + a.formats = sync.Map{} + a.columns = sync.Map{} + a.tables = sync.Map{} + + // Load and compile WASM + wmod, err := a.loadWASM(ctx) + if err != nil { + return fmt.Errorf("load wasm: %w", err) + } + + // Create wazero runtime with compilation cache + cacheDir, err := cache.PluginsDir() + if err != nil { + return fmt.Errorf("cache dir: %w", err) + } + + wazeroCache, err := wazero.NewCompilationCacheWithDir(filepath.Join(cacheDir, "pglite-wazero")) + if err != nil { + return fmt.Errorf("wazero cache: %w", err) + } + + config := wazero.NewRuntimeConfig().WithCompilationCache(wazeroCache) + a.rt = wazero.NewRuntimeWithConfig(ctx, config) + + // Instantiate WASI + if _, err := wasi_snapshot_preview1.Instantiate(ctx, a.rt); err != nil { + return fmt.Errorf("wasi instantiate: %w", err) + } + + // Compile and instantiate module + compiled, err := a.rt.CompileModule(ctx, wmod) + if err != nil { + return fmt.Errorf("compile module: %w", err) + } + + // Create request for initialization + initReq := Request{ + Type: "init", + Migrations: migrations, + } + reqBytes, err := json.Marshal(initReq) + if err != nil { + return fmt.Errorf("marshal init request: %w", err) + } + + var stdout, stderr bytes.Buffer + + modConfig := wazero.NewModuleConfig(). + WithName("pglite"). + WithArgs("pglite.wasm"). + WithStdin(bytes.NewReader(reqBytes)). + WithStdout(&stdout). + WithStderr(&stderr). + WithEnv("SQLC_VERSION", info.Version). + WithFSConfig(wazero.NewFSConfig()) + + a.mod, err = a.rt.InstantiateModule(ctx, compiled, modConfig) + if err != nil { + errMsg := stderr.String() + if errMsg != "" { + return fmt.Errorf("instantiate module: %s", errMsg) + } + return fmt.Errorf("instantiate module: %w", err) + } + + // Parse initialization response + var resp Response + if err := json.Unmarshal(stdout.Bytes(), &resp); err != nil { + slog.Debug("pglite init response", "stdout", stdout.String(), "stderr", stderr.String()) + return fmt.Errorf("parse init response: %w", err) + } + + if !resp.Success { + if resp.Error != nil { + return &PGLiteError{ + Code: resp.Error.Code, + Message: resp.Error.Message, + } + } + return errors.New("pglite initialization failed") + } + + return nil +} + +// prepare sends a PREPARE request to PGLite and returns the result. +func (a *Analyzer) prepare(ctx context.Context, query string) (*PrepareResult, error) { + req := Request{ + Type: "prepare", + Query: query, + } + + resp, err := a.call(ctx, req) + if err != nil { + return nil, err + } + + if !resp.Success { + if resp.Error != nil { + return nil, &PGLiteError{ + Code: resp.Error.Code, + Message: resp.Error.Message, + Position: resp.Error.Position, + } + } + return nil, errors.New("prepare failed") + } + + if resp.Prepare == nil { + return nil, errors.New("prepare result missing") + } + + return resp.Prepare, nil +} + +// call sends a request to PGLite and returns the response. +// For a persistent module, this would use a different mechanism (e.g., function calls). +// For now, this demonstrates the interface that would be used. +func (a *Analyzer) call(ctx context.Context, req Request) (*Response, error) { + // For modules that support function exports, we would call them directly. + // Since PGLite WASM typically runs as a WASI command, we need to handle + // persistent state differently. + // + // This implementation assumes the module exposes callable functions or + // maintains state between invocations. In practice, you may need to: + // 1. Use a module that exports query functions directly + // 2. Re-instantiate with accumulated state + // 3. Use a socket/pipe-based communication + + // Check if module has exported functions we can call + queryFn := a.mod.ExportedFunction("pglite_query") + if queryFn != nil { + return a.callExported(ctx, queryFn, req) + } + + // Fallback: re-instantiate with state (less efficient) + return a.callViaReinstantiate(ctx, req) +} + +// callExported calls an exported function on the WASM module. +func (a *Analyzer) callExported(ctx context.Context, fn api.Function, req Request) (*Response, error) { + reqBytes, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + // Allocate memory for request + malloc := a.mod.ExportedFunction("malloc") + free := a.mod.ExportedFunction("free") + + if malloc == nil || free == nil { + return nil, errors.New("module does not export malloc/free") + } + + // Allocate memory for input + results, err := malloc.Call(ctx, uint64(len(reqBytes))) + if err != nil { + return nil, fmt.Errorf("malloc: %w", err) + } + inputPtr := uint32(results[0]) + defer free.Call(ctx, uint64(inputPtr)) + + // Write request to memory + if !a.mod.Memory().Write(inputPtr, reqBytes) { + return nil, errors.New("failed to write request to memory") + } + + // Call the query function + results, err = fn.Call(ctx, uint64(inputPtr), uint64(len(reqBytes))) + if err != nil { + return nil, fmt.Errorf("call: %w", err) + } + + // Read response from memory + outputPtr := uint32(results[0]) + outputLen := uint32(results[1]) + + respBytes, ok := a.mod.Memory().Read(outputPtr, outputLen) + if !ok { + return nil, errors.New("failed to read response from memory") + } + + var resp Response + if err := json.Unmarshal(respBytes, &resp); err != nil { + return nil, fmt.Errorf("unmarshal response: %w", err) + } + + return &resp, nil +} + +// callViaReinstantiate handles modules that don't export callable functions. +// This re-instantiates the module with accumulated migrations plus the new query. +func (a *Analyzer) callViaReinstantiate(ctx context.Context, req Request) (*Response, error) { + // For command-style WASM modules, we need to re-run them with the full state + // This is less efficient but works with standard WASI command-line tools + + // Include migrations in the request so state is reconstructed + fullReq := Request{ + Type: req.Type, + Query: req.Query, + Migrations: a.schema, + } + + reqBytes, err := json.Marshal(fullReq) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + // Load WASM again for fresh instance + wmod, err := a.loadWASM(ctx) + if err != nil { + return nil, fmt.Errorf("load wasm: %w", err) + } + + compiled, err := a.rt.CompileModule(ctx, wmod) + if err != nil { + return nil, fmt.Errorf("compile: %w", err) + } + + var stdout, stderr bytes.Buffer + + modConfig := wazero.NewModuleConfig(). + WithName(""). + WithArgs("pglite.wasm", req.Type). + WithStdin(bytes.NewReader(reqBytes)). + WithStdout(&stdout). + WithStderr(&stderr). + WithEnv("SQLC_VERSION", info.Version) + + result, err := a.rt.InstantiateModule(ctx, compiled, modConfig) + if err != nil { + errMsg := stderr.String() + if errMsg != "" { + return nil, fmt.Errorf("instantiate: %s", errMsg) + } + return nil, fmt.Errorf("instantiate: %w", err) + } + defer result.Close(ctx) + + var resp Response + if err := json.Unmarshal(stdout.Bytes(), &resp); err != nil { + slog.Debug("pglite response", "stdout", stdout.String(), "stderr", stderr.String()) + return nil, fmt.Errorf("unmarshal response: %w", err) + } + + return &resp, nil +} + +// loadWASM loads the PGLite WASM binary from the configured URL. +func (a *Analyzer) loadWASM(ctx context.Context) ([]byte, error) { + url := a.cfg.URL + expectedSHA := a.cfg.SHA256 + + // Check cache first + if expectedSHA != "" { + cacheDir, err := cache.PluginsDir() + if err == nil { + cachePath := filepath.Join(cacheDir, expectedSHA, "pglite.wasm") + if data, err := os.ReadFile(cachePath); err == nil { + return data, nil + } + } + } + + // Fetch the WASM binary + data, actualSHA, err := fetch(ctx, url) + if err != nil { + return nil, err + } + + // Verify checksum if provided + if expectedSHA != "" && actualSHA != expectedSHA { + return nil, fmt.Errorf("checksum mismatch: expected %s, got %s", expectedSHA, actualSHA) + } + + // Warn if no checksum provided + if expectedSHA == "" { + slog.Warn("pglite: no sha256 checksum provided, set sha256 in config for security", "actual_sha256", actualSHA) + } + + // Cache the binary + if expectedSHA != "" { + cacheDir, err := cache.PluginsDir() + if err == nil { + pluginDir := filepath.Join(cacheDir, expectedSHA) + if err := os.MkdirAll(pluginDir, 0755); err == nil { + os.WriteFile(filepath.Join(pluginDir, "pglite.wasm"), data, 0444) + } + } + } + + return data, nil +} + +// fetch downloads content from a URL (file:// or https://). +func fetch(ctx context.Context, url string) ([]byte, string, error) { + var body io.ReadCloser + + switch { + case strings.HasPrefix(url, "file://"): + path := strings.TrimPrefix(url, "file://") + file, err := os.Open(path) + if err != nil { + return nil, "", fmt.Errorf("open file: %w", err) + } + body = file + + case strings.HasPrefix(url, "https://"): + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, "", fmt.Errorf("create request: %w", err) + } + req.Header.Set("User-Agent", fmt.Sprintf("sqlc/%s Go/%s (%s %s)", info.Version, runtime.Version(), runtime.GOOS, runtime.GOARCH)) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, "", fmt.Errorf("fetch: %w", err) + } + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, "", fmt.Errorf("fetch failed: %s", resp.Status) + } + body = resp.Body + + default: + return nil, "", fmt.Errorf("unsupported URL scheme: %s", url) + } + + defer body.Close() + + data, err := io.ReadAll(body) + if err != nil { + return nil, "", fmt.Errorf("read: %w", err) + } + + sum := sha256.Sum256(data) + checksum := fmt.Sprintf("%x", sum) + + return data, checksum, nil +} + +// PGLiteError represents an error from PGLite. +type PGLiteError struct { + Code string + Message string + Position int +} + +func (e *PGLiteError) Error() string { + if e.Code != "" { + return fmt.Sprintf("%s: %s", e.Code, e.Message) + } + return e.Message +} + +// rewriteType converts PostgreSQL type names to the canonical form used by sqlc. +func rewriteType(dt string) string { + switch { + case strings.HasPrefix(dt, "character("): + return "pg_catalog.bpchar" + case strings.HasPrefix(dt, "character varying"): + return "pg_catalog.varchar" + case strings.HasPrefix(dt, "bit varying"): + return "pg_catalog.varbit" + case strings.HasPrefix(dt, "bit("): + return "pg_catalog.bit" + } + switch dt { + case "bpchar": + return "pg_catalog.bpchar" + case "timestamp without time zone": + return "pg_catalog.timestamp" + case "timestamp with time zone": + return "pg_catalog.timestamptz" + case "time without time zone": + return "pg_catalog.time" + case "time with time zone": + return "pg_catalog.timetz" + } + return dt +} + +// equalMigrations compares two migration slices for equality. +func equalMigrations(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/internal/engine/postgresql/pglite/analyze_test.go b/internal/engine/postgresql/pglite/analyze_test.go new file mode 100644 index 0000000000..fe4ffd6274 --- /dev/null +++ b/internal/engine/postgresql/pglite/analyze_test.go @@ -0,0 +1,288 @@ +package pglite + +import ( + "testing" + + "github.com/sqlc-dev/sqlc/internal/config" +) + +func TestRewriteType(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"character(10)", "pg_catalog.bpchar"}, + {"character varying(255)", "pg_catalog.varchar"}, + {"character varying", "pg_catalog.varchar"}, + {"bit varying(8)", "pg_catalog.varbit"}, + {"bit(1)", "pg_catalog.bit"}, + {"bpchar", "pg_catalog.bpchar"}, + {"timestamp without time zone", "pg_catalog.timestamp"}, + {"timestamp with time zone", "pg_catalog.timestamptz"}, + {"time without time zone", "pg_catalog.time"}, + {"time with time zone", "pg_catalog.timetz"}, + {"integer", "integer"}, + {"text", "text"}, + {"boolean", "boolean"}, + {"uuid", "uuid"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := rewriteType(tt.input) + if result != tt.expected { + t.Errorf("rewriteType(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestEqualMigrations(t *testing.T) { + tests := []struct { + name string + a []string + b []string + expected bool + }{ + { + name: "both empty", + a: []string{}, + b: []string{}, + expected: true, + }, + { + name: "both nil", + a: nil, + b: nil, + expected: true, + }, + { + name: "equal single element", + a: []string{"CREATE TABLE users (id INT)"}, + b: []string{"CREATE TABLE users (id INT)"}, + expected: true, + }, + { + name: "equal multiple elements", + a: []string{"CREATE TABLE users (id INT)", "CREATE TABLE posts (id INT)"}, + b: []string{"CREATE TABLE users (id INT)", "CREATE TABLE posts (id INT)"}, + expected: true, + }, + { + name: "different length", + a: []string{"CREATE TABLE users (id INT)"}, + b: []string{"CREATE TABLE users (id INT)", "CREATE TABLE posts (id INT)"}, + expected: false, + }, + { + name: "different content", + a: []string{"CREATE TABLE users (id INT)"}, + b: []string{"CREATE TABLE posts (id INT)"}, + expected: false, + }, + { + name: "different order", + a: []string{"CREATE TABLE users (id INT)", "CREATE TABLE posts (id INT)"}, + b: []string{"CREATE TABLE posts (id INT)", "CREATE TABLE users (id INT)"}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := equalMigrations(tt.a, tt.b) + if result != tt.expected { + t.Errorf("equalMigrations(%v, %v) = %v, want %v", tt.a, tt.b, result, tt.expected) + } + }) + } +} + +func TestPGLiteError(t *testing.T) { + tests := []struct { + name string + err *PGLiteError + expected string + }{ + { + name: "with code", + err: &PGLiteError{ + Code: "42601", + Message: "syntax error at or near \"SELEC\"", + }, + expected: "42601: syntax error at or near \"SELEC\"", + }, + { + name: "without code", + err: &PGLiteError{ + Message: "connection failed", + }, + expected: "connection failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.err.Error() + if result != tt.expected { + t.Errorf("Error() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestRequestJSON(t *testing.T) { + tests := []struct { + name string + req Request + expected string + }{ + { + name: "init request", + req: Request{ + Type: "init", + Migrations: []string{"CREATE TABLE users (id INT)"}, + }, + expected: `{"type":"init","migrations":["CREATE TABLE users (id INT)"],"query":""}`, + }, + { + name: "prepare request", + req: Request{ + Type: "prepare", + Query: "SELECT * FROM users WHERE id = $1", + }, + expected: `{"type":"prepare","migrations":null,"query":"SELECT * FROM users WHERE id = $1"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Just verify the struct can be created and fields are accessible + if tt.req.Type == "" { + t.Error("Type should not be empty") + } + }) + } +} + +func TestColumnInfo(t *testing.T) { + col := ColumnInfo{ + Name: "user_id", + DataType: "integer", + DataTypeOID: 23, + NotNull: true, + IsArray: false, + ArrayDims: 0, + TableOID: 16384, + TableName: "users", + TableSchema: "public", + } + + if col.Name != "user_id" { + t.Errorf("Name = %q, want %q", col.Name, "user_id") + } + if col.DataType != "integer" { + t.Errorf("DataType = %q, want %q", col.DataType, "integer") + } + if !col.NotNull { + t.Error("NotNull should be true") + } +} + +func TestParameterInfo(t *testing.T) { + param := ParameterInfo{ + Number: 1, + DataType: "text", + DataTypeOID: 25, + IsArray: false, + ArrayDims: 0, + } + + if param.Number != 1 { + t.Errorf("Number = %d, want %d", param.Number, 1) + } + if param.DataType != "text" { + t.Errorf("DataType = %q, want %q", param.DataType, "text") + } +} + +func TestPrepareResult(t *testing.T) { + result := PrepareResult{ + Columns: []ColumnInfo{ + {Name: "id", DataType: "integer", NotNull: true}, + {Name: "name", DataType: "text", NotNull: false}, + }, + Params: []ParameterInfo{ + {Number: 1, DataType: "integer"}, + }, + } + + if len(result.Columns) != 2 { + t.Errorf("len(Columns) = %d, want 2", len(result.Columns)) + } + if len(result.Params) != 1 { + t.Errorf("len(Params) = %d, want 1", len(result.Params)) + } +} + +func TestResponse(t *testing.T) { + t.Run("success response", func(t *testing.T) { + resp := Response{ + Success: true, + Prepare: &PrepareResult{ + Columns: []ColumnInfo{ + {Name: "id", DataType: "integer"}, + }, + }, + } + + if !resp.Success { + t.Error("Success should be true") + } + if resp.Prepare == nil { + t.Error("Prepare should not be nil") + } + if resp.Error != nil { + t.Error("Error should be nil") + } + }) + + t.Run("error response", func(t *testing.T) { + resp := Response{ + Success: false, + Error: &ErrorResponse{ + Code: "42P01", + Message: "relation \"users\" does not exist", + Position: 15, + }, + } + + if resp.Success { + t.Error("Success should be false") + } + if resp.Error == nil { + t.Error("Error should not be nil") + } + if resp.Error.Code != "42P01" { + t.Errorf("Error.Code = %q, want %q", resp.Error.Code, "42P01") + } + }) +} + +func TestNewAnalyzer(t *testing.T) { + cfg := config.PGLite{ + URL: "file:///path/to/pglite.wasm", + SHA256: "abc123", + } + + a := New(cfg) + if a == nil { + t.Fatal("New() returned nil") + } + if a.cfg.URL != cfg.URL { + t.Errorf("cfg.URL = %q, want %q", a.cfg.URL, cfg.URL) + } + if a.cfg.SHA256 != cfg.SHA256 { + t.Errorf("cfg.SHA256 = %q, want %q", a.cfg.SHA256, cfg.SHA256) + } +} diff --git a/internal/engine/postgresql/pglite/integration_test.go b/internal/engine/postgresql/pglite/integration_test.go new file mode 100644 index 0000000000..41b949039c --- /dev/null +++ b/internal/engine/postgresql/pglite/integration_test.go @@ -0,0 +1,215 @@ +package pglite + +import ( + "context" + "crypto/sha256" + "fmt" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/sqlc-dev/sqlc/internal/config" + "github.com/sqlc-dev/sqlc/internal/sql/ast" +) + +func TestAnalyzerIntegration(t *testing.T) { + // Find the mock WASM file + _, filename, _, ok := runtime.Caller(0) + if !ok { + t.Fatal("could not get current file path") + } + wasmPath := filepath.Join(filepath.Dir(filename), "testdata", "mock_pglite.wasm") + + // Check if WASM file exists + wasmData, err := os.ReadFile(wasmPath) + if err != nil { + t.Skipf("mock_pglite.wasm not found (run 'GOOS=wasip1 GOARCH=wasm go build -o mock_pglite.wasm mock_pglite.go' in testdata/): %v", err) + } + + // Calculate SHA256 + sum := sha256.Sum256(wasmData) + sha := fmt.Sprintf("%x", sum) + + cfg := config.PGLite{ + URL: "file://" + wasmPath, + SHA256: sha, + } + + analyzer := New(cfg) + ctx := context.Background() + defer analyzer.Close(ctx) + + migrations := []string{ + "CREATE TABLE users (id INTEGER NOT NULL, name TEXT, email TEXT NOT NULL, created_at TIMESTAMP)", + } + + // Create a minimal AST node for position tracking + node := &ast.TODO{} + + t.Run("simple select star", func(t *testing.T) { + result, err := analyzer.Analyze(ctx, node, "SELECT * FROM users", migrations, nil) + if err != nil { + t.Fatalf("Analyze failed: %v", err) + } + + if len(result.Columns) != 4 { + t.Errorf("expected 4 columns, got %d", len(result.Columns)) + } + + // Check column names and types + expectedCols := []struct { + name string + notNull bool + }{ + {"id", true}, + {"name", false}, + {"email", true}, + {"created_at", false}, + } + + for i, exp := range expectedCols { + if i >= len(result.Columns) { + break + } + col := result.Columns[i] + if col.Name != exp.name { + t.Errorf("column %d: expected name %q, got %q", i, exp.name, col.Name) + } + if col.NotNull != exp.notNull { + t.Errorf("column %d (%s): expected NotNull=%v, got %v", i, exp.name, exp.notNull, col.NotNull) + } + } + }) + + t.Run("select with parameters", func(t *testing.T) { + result, err := analyzer.Analyze(ctx, node, "SELECT id, name FROM users WHERE id = $1", migrations, nil) + if err != nil { + t.Fatalf("Analyze failed: %v", err) + } + + if len(result.Columns) != 2 { + t.Errorf("expected 2 columns, got %d", len(result.Columns)) + } + + if len(result.Params) != 1 { + t.Errorf("expected 1 parameter, got %d", len(result.Params)) + } + + if len(result.Params) > 0 { + if result.Params[0].Number != 1 { + t.Errorf("expected param number 1, got %d", result.Params[0].Number) + } + } + }) + + t.Run("select specific columns", func(t *testing.T) { + result, err := analyzer.Analyze(ctx, node, "SELECT id, email FROM users", migrations, nil) + if err != nil { + t.Fatalf("Analyze failed: %v", err) + } + + if len(result.Columns) != 2 { + t.Errorf("expected 2 columns, got %d", len(result.Columns)) + } + + if len(result.Columns) >= 2 { + if result.Columns[0].Name != "id" { + t.Errorf("expected first column 'id', got %q", result.Columns[0].Name) + } + if result.Columns[1].Name != "email" { + t.Errorf("expected second column 'email', got %q", result.Columns[1].Name) + } + } + }) +} + +func TestAnalyzerWithMultipleTables(t *testing.T) { + _, filename, _, ok := runtime.Caller(0) + if !ok { + t.Fatal("could not get current file path") + } + wasmPath := filepath.Join(filepath.Dir(filename), "testdata", "mock_pglite.wasm") + + wasmData, err := os.ReadFile(wasmPath) + if err != nil { + t.Skipf("mock_pglite.wasm not found: %v", err) + } + + sum := sha256.Sum256(wasmData) + sha := fmt.Sprintf("%x", sum) + + cfg := config.PGLite{ + URL: "file://" + wasmPath, + SHA256: sha, + } + + analyzer := New(cfg) + ctx := context.Background() + defer analyzer.Close(ctx) + + migrations := []string{ + "CREATE TABLE authors (id INTEGER NOT NULL, name TEXT NOT NULL)", + "CREATE TABLE posts (id INTEGER NOT NULL, author_id INTEGER NOT NULL, title TEXT, body TEXT)", + } + + node := &ast.TODO{} + + t.Run("query authors table", func(t *testing.T) { + result, err := analyzer.Analyze(ctx, node, "SELECT * FROM authors", migrations, nil) + if err != nil { + t.Fatalf("Analyze failed: %v", err) + } + + if len(result.Columns) != 2 { + t.Errorf("expected 2 columns, got %d", len(result.Columns)) + } + }) + + t.Run("query posts table", func(t *testing.T) { + result, err := analyzer.Analyze(ctx, node, "SELECT * FROM posts", migrations, nil) + if err != nil { + t.Fatalf("Analyze failed: %v", err) + } + + if len(result.Columns) != 4 { + t.Errorf("expected 4 columns, got %d", len(result.Columns)) + } + }) +} + +func TestAnalyzerClose(t *testing.T) { + _, filename, _, ok := runtime.Caller(0) + if !ok { + t.Fatal("could not get current file path") + } + wasmPath := filepath.Join(filepath.Dir(filename), "testdata", "mock_pglite.wasm") + + if _, err := os.Stat(wasmPath); os.IsNotExist(err) { + t.Skipf("mock_pglite.wasm not found: %v", err) + } + + wasmData, _ := os.ReadFile(wasmPath) + sum := sha256.Sum256(wasmData) + sha := fmt.Sprintf("%x", sum) + + cfg := config.PGLite{ + URL: "file://" + wasmPath, + SHA256: sha, + } + + analyzer := New(cfg) + ctx := context.Background() + + // Close without using should not error + err := analyzer.Close(ctx) + if err != nil { + t.Errorf("Close failed: %v", err) + } + + // Double close should not error + err = analyzer.Close(ctx) + if err != nil { + t.Errorf("Double close failed: %v", err) + } +} diff --git a/internal/engine/postgresql/pglite/testdata/mock_pglite.go b/internal/engine/postgresql/pglite/testdata/mock_pglite.go new file mode 100644 index 0000000000..916bb4d117 --- /dev/null +++ b/internal/engine/postgresql/pglite/testdata/mock_pglite.go @@ -0,0 +1,378 @@ +//go:build ignore + +// This is a mock PGLite WASM module for testing. +// Build with: GOOS=wasip1 GOARCH=wasm go build -o mock_pglite.wasm mock_pglite.go +package main + +import ( + "encoding/json" + "fmt" + "io" + "os" + "regexp" + "strings" +) + +type Request struct { + Type string `json:"type"` + Migrations []string `json:"migrations"` + Query string `json:"query"` +} + +type Response struct { + Success bool `json:"success"` + Error *ErrorResponse `json:"error,omitempty"` + Prepare *PrepareResult `json:"prepare,omitempty"` +} + +type ErrorResponse struct { + Code string `json:"code"` + Message string `json:"message"` + Position int `json:"position"` +} + +type PrepareResult struct { + Columns []ColumnInfo `json:"columns"` + Params []ParameterInfo `json:"params"` +} + +type ColumnInfo struct { + Name string `json:"name"` + DataType string `json:"data_type"` + DataTypeOID uint32 `json:"data_type_oid"` + NotNull bool `json:"not_null"` + IsArray bool `json:"is_array"` + ArrayDims int `json:"array_dims"` + TableOID uint32 `json:"table_oid,omitempty"` + TableName string `json:"table_name,omitempty"` + TableSchema string `json:"table_schema,omitempty"` +} + +type ParameterInfo struct { + Number int `json:"number"` + DataType string `json:"data_type"` + DataTypeOID uint32 `json:"data_type_oid"` + IsArray bool `json:"is_array"` + ArrayDims int `json:"array_dims"` +} + +// Simple schema tracking +type Column struct { + Name string + Type string + NotNull bool + IsArray bool +} + +type Table struct { + Schema string + Name string + Columns []Column +} + +var tables = make(map[string]*Table) + +func main() { + // Read all input from stdin + input, err := io.ReadAll(os.Stdin) + if err != nil { + writeError("READ", fmt.Sprintf("failed to read stdin: %v", err), 0) + return + } + + var req Request + if err := json.Unmarshal(input, &req); err != nil { + writeError("PARSE", fmt.Sprintf("failed to parse request: %v (input length: %d)", err, len(input)), 0) + return + } + + switch req.Type { + case "init": + handleInit(req) + case "prepare": + handlePrepare(req) + default: + writeError("UNKNOWN", fmt.Sprintf("unknown request type: %s", req.Type), 0) + } +} + +func handleInit(req Request) { + // Parse migrations to build schema + for _, migration := range req.Migrations { + parseMigration(migration) + } + + resp := Response{Success: true} + writeResponse(resp) +} + +func handlePrepare(req Request) { + // First apply any migrations + for _, migration := range req.Migrations { + parseMigration(migration) + } + + // Parse the query and infer types + result, err := analyzeQuery(req.Query) + if err != nil { + writeError("42601", err.Error(), 0) + return + } + + resp := Response{ + Success: true, + Prepare: result, + } + writeResponse(resp) +} + +func parseMigration(sql string) { + // Very simple CREATE TABLE parser + sql = strings.ToUpper(sql) + + createTableRe := regexp.MustCompile(`CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(?:(\w+)\.)?(\w+)\s*\(([^)]+)\)`) + matches := createTableRe.FindStringSubmatch(sql) + if matches == nil { + return + } + + schema := "public" + if matches[1] != "" { + schema = strings.ToLower(matches[1]) + } + tableName := strings.ToLower(matches[2]) + columnsDef := matches[3] + + table := &Table{ + Schema: schema, + Name: tableName, + } + + // Parse columns + colDefs := strings.Split(columnsDef, ",") + for _, colDef := range colDefs { + colDef = strings.TrimSpace(colDef) + if colDef == "" { + continue + } + // Skip constraints + if strings.HasPrefix(colDef, "PRIMARY") || strings.HasPrefix(colDef, "FOREIGN") || + strings.HasPrefix(colDef, "UNIQUE") || strings.HasPrefix(colDef, "CHECK") || + strings.HasPrefix(colDef, "CONSTRAINT") { + continue + } + + parts := strings.Fields(colDef) + if len(parts) < 2 { + continue + } + + col := Column{ + Name: strings.ToLower(parts[0]), + Type: strings.ToLower(parts[1]), + NotNull: strings.Contains(colDef, "NOT NULL"), + IsArray: strings.Contains(parts[1], "[]"), + } + table.Columns = append(table.Columns, col) + } + + tables[schema+"."+tableName] = table +} + +func analyzeQuery(query string) (*PrepareResult, error) { + query = strings.ToUpper(strings.TrimSpace(query)) + result := &PrepareResult{} + + // Count parameters + paramCount := strings.Count(query, "$") + for i := 1; i <= paramCount; i++ { + result.Params = append(result.Params, ParameterInfo{ + Number: i, + DataType: "text", // Default to text + DataTypeOID: 25, + }) + } + + // Very simple SELECT parser + if strings.HasPrefix(query, "SELECT") { + // Find FROM clause + fromIdx := strings.Index(query, "FROM") + if fromIdx == -1 { + // SELECT without FROM (e.g., SELECT 1) + selectPart := query[6:] + if whereIdx := strings.Index(selectPart, "WHERE"); whereIdx != -1 { + selectPart = selectPart[:whereIdx] + } + + cols := strings.Split(selectPart, ",") + for _, col := range cols { + col = strings.TrimSpace(col) + result.Columns = append(result.Columns, ColumnInfo{ + Name: strings.ToLower(col), + DataType: "integer", + DataTypeOID: 23, + }) + } + return result, nil + } + + selectPart := strings.TrimSpace(query[6:fromIdx]) + fromPart := query[fromIdx+4:] + + // Get table name + tableName := "" + parts := strings.Fields(fromPart) + if len(parts) > 0 { + tableName = strings.ToLower(parts[0]) + } + + // Look up table + table := findTable(tableName) + + // Handle SELECT * + if strings.TrimSpace(selectPart) == "*" { + if table != nil { + for _, col := range table.Columns { + result.Columns = append(result.Columns, ColumnInfo{ + Name: col.Name, + DataType: mapType(col.Type), + DataTypeOID: mapTypeOID(col.Type), + NotNull: col.NotNull, + IsArray: col.IsArray, + TableName: table.Name, + TableSchema: table.Schema, + TableOID: 16384, // Fake OID + }) + } + } + return result, nil + } + + // Parse individual columns + cols := strings.Split(selectPart, ",") + for _, col := range cols { + col = strings.TrimSpace(col) + // Handle aliases (col AS alias) + alias := col + if asIdx := strings.Index(col, " AS "); asIdx != -1 { + alias = strings.TrimSpace(col[asIdx+4:]) + col = strings.TrimSpace(col[:asIdx]) + } + + colInfo := ColumnInfo{ + Name: strings.ToLower(alias), + DataType: "text", + DataTypeOID: 25, + } + + // Try to find column in table + if table != nil { + for _, tc := range table.Columns { + if strings.ToUpper(tc.Name) == col || strings.ToUpper(table.Name+"."+tc.Name) == col { + colInfo.DataType = mapType(tc.Type) + colInfo.DataTypeOID = mapTypeOID(tc.Type) + colInfo.NotNull = tc.NotNull + colInfo.IsArray = tc.IsArray + colInfo.TableName = table.Name + colInfo.TableSchema = table.Schema + colInfo.TableOID = 16384 + break + } + } + } + + result.Columns = append(result.Columns, colInfo) + } + } + + return result, nil +} + +func findTable(name string) *Table { + // Try with public schema + if t, ok := tables["public."+name]; ok { + return t + } + // Try as-is + if t, ok := tables[name]; ok { + return t + } + // Search all tables + for _, t := range tables { + if t.Name == name { + return t + } + } + return nil +} + +func mapType(t string) string { + t = strings.ToLower(t) + t = strings.TrimSuffix(t, "[]") + switch t { + case "int", "integer", "int4": + return "integer" + case "bigint", "int8": + return "bigint" + case "smallint", "int2": + return "smallint" + case "text", "varchar", "character varying": + return "text" + case "boolean", "bool": + return "boolean" + case "timestamp", "timestamptz", "timestamp with time zone", "timestamp without time zone": + return "pg_catalog.timestamp" + case "uuid": + return "uuid" + case "jsonb": + return "jsonb" + case "json": + return "json" + default: + return t + } +} + +func mapTypeOID(t string) uint32 { + t = strings.ToLower(t) + t = strings.TrimSuffix(t, "[]") + switch t { + case "int", "integer", "int4": + return 23 + case "bigint", "int8": + return 20 + case "smallint", "int2": + return 21 + case "text", "varchar", "character varying": + return 25 + case "boolean", "bool": + return 16 + case "timestamp", "timestamptz": + return 1114 + case "uuid": + return 2950 + case "jsonb": + return 3802 + case "json": + return 114 + default: + return 25 // default to text + } +} + +func writeError(code, message string, position int) { + resp := Response{ + Success: false, + Error: &ErrorResponse{ + Code: code, + Message: message, + Position: position, + }, + } + writeResponse(resp) +} + +func writeResponse(resp Response) { + data, _ := json.Marshal(resp) + fmt.Println(string(data)) +} diff --git a/internal/engine/postgresql/pglite/testdata/mock_pglite.wasm b/internal/engine/postgresql/pglite/testdata/mock_pglite.wasm new file mode 100755 index 0000000000..9aff50324f Binary files /dev/null and b/internal/engine/postgresql/pglite/testdata/mock_pglite.wasm differ diff --git a/internal/opts/experiment.go b/internal/opts/experiment.go index 73ca5d7de0..d679662459 100644 --- a/internal/opts/experiment.go +++ b/internal/opts/experiment.go @@ -14,7 +14,7 @@ import ( // // Available experiments: // -// (none currently defined - add experiments here as they are introduced) +// pglite - Enable PGLite-based PostgreSQL analyzer (uses embedded WASM PostgreSQL) // // Example usage: // @@ -28,6 +28,8 @@ type Experiment struct { // Add experimental feature flags here as they are introduced. // Example: // NewParser bool // Enable new SQL parser + + PGLite bool // Enable PGLite-based PostgreSQL analyzer (uses embedded WASM PostgreSQL) } // ExperimentFromEnv returns an Experiment initialized from the SQLCEXPERIMENT @@ -79,6 +81,8 @@ func isKnownExperiment(name string) bool { // Example: // case "newparser": // return true + case "pglite": + return true default: return false } @@ -91,6 +95,8 @@ func setExperiment(e *Experiment, name string, enabled bool) { // Example: // case "newparser": // e.NewParser = enabled + case "pglite": + e.PGLite = enabled } } @@ -102,6 +108,9 @@ func (e Experiment) Enabled() []string { // if e.NewParser { // enabled = append(enabled, "newparser") // } + if e.PGLite { + enabled = append(enabled, "pglite") + } return enabled } diff --git a/internal/opts/experiment_test.go b/internal/opts/experiment_test.go index 7845c0b13e..2b3c3bdf8c 100644 --- a/internal/opts/experiment_test.go +++ b/internal/opts/experiment_test.go @@ -65,6 +65,26 @@ func TestExperimentFromString(t *testing.T) { // input: "NewParser,NONEWPARSER", // want: Experiment{NewParser: false}, // }, + { + name: "enable pglite", + input: "pglite", + want: Experiment{PGLite: true}, + }, + { + name: "disable pglite", + input: "nopglite", + want: Experiment{PGLite: false}, + }, + { + name: "pglite enable then disable", + input: "pglite,nopglite", + want: Experiment{PGLite: false}, + }, + { + name: "pglite case insensitive", + input: "PGLite", + want: Experiment{PGLite: true}, + }, } for _, tt := range tests { @@ -95,6 +115,11 @@ func TestExperimentEnabled(t *testing.T) { // exp: Experiment{NewParser: true}, // want: []string{"newparser"}, // }, + { + name: "pglite enabled", + exp: Experiment{PGLite: true}, + want: []string{"pglite"}, + }, } for _, tt := range tests { @@ -131,6 +156,11 @@ func TestExperimentString(t *testing.T) { // exp: Experiment{NewParser: true}, // want: "newparser", // }, + { + name: "pglite enabled", + exp: Experiment{PGLite: true}, + want: "pglite", + }, } for _, tt := range tests { @@ -171,6 +201,16 @@ func TestIsKnownExperiment(t *testing.T) { // input: "NewParser", // want: true, // }, + { + name: "pglite lowercase", + input: "pglite", + want: true, + }, + { + name: "pglite mixed case", + input: "PGLite", + want: true, + }, } for _, tt := range tests {