diff --git a/README.md b/README.md index 7eab862..678d83b 100644 --- a/README.md +++ b/README.md @@ -111,7 +111,7 @@ Mark computed properties and methods with `[Expressive]` to generate companion e | External member mapping | `[ExpressiveFor]` for BCL/third-party members | | Tuples, index/range, `with`, collection expressions | And more modern C# syntax | | Expression transformers | Built-in + custom `IExpressionTreeTransformer` pipeline | -| SQL window functions | ROW_NUMBER, RANK, DENSE_RANK, NTILE (experimental) | +| SQL window functions | ROW_NUMBER, RANK, DENSE_RANK, NTILE, PERCENT_RANK, CUME_DIST, SUM/AVG/COUNT/MIN/MAX OVER, LAG/LEAD, FIRST_VALUE/LAST_VALUE/NTH_VALUE with ROWS/RANGE frames (experimental) | See the [full documentation](https://efnext.github.io/ExpressiveSharp/guide/introduction) for detailed usage, [reference](https://efnext.github.io/ExpressiveSharp/reference/expressive-attribute), and [recipes](https://efnext.github.io/ExpressiveSharp/recipes/computed-properties). diff --git a/docs/guide/window-functions.md b/docs/guide/window-functions.md index f328d00..a22fabe 100644 --- a/docs/guide/window-functions.md +++ b/docs/guide/window-functions.md @@ -31,17 +31,77 @@ services.AddDbContext(options => .UseExpressives(o => o.UseRelationalExtensions())); ``` +::: tip Concise syntax with `using static` +Add these imports for a compact, SQL-like syntax without class prefixes: +```csharp +using static ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.WindowFunctions.WindowFunction; +using static ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.WindowFunctions.WindowFrameBound; +using ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.WindowFunctions; + +// Then in queries: +RowNumber(Window.PartitionBy(o.CustomerId).OrderBy(o.Price)) +Sum(o.Price, Window.OrderBy(o.Date).RowsBetween(UnboundedPreceding, CurrentRow)) +Lag(o.Price, 1, 0.0, Window.OrderBy(o.Date)) +``` +::: + ## Available Functions +### Ranking Functions + | Function | SQL | Description | |----------|-----|-------------| | `WindowFunction.RowNumber(window)` | `ROW_NUMBER() OVER(...)` | Sequential row number within the partition. Returns `long`. | | `WindowFunction.Rank(window)` | `RANK() OVER(...)` | Rank with gaps for ties. Returns `long`. | | `WindowFunction.DenseRank(window)` | `DENSE_RANK() OVER(...)` | Rank without gaps for ties. Returns `long`. | | `WindowFunction.Ntile(n, window)` | `NTILE(n) OVER(...)` | Distributes rows into `n` roughly equal groups. Returns `long`. | +| `WindowFunction.PercentRank(window)` | `PERCENT_RANK() OVER(...)` | Relative rank as a value between 0.0 and 1.0. Returns `double`. | +| `WindowFunction.CumeDist(window)` | `CUME_DIST() OVER(...)` | Cumulative distribution (0.0–1.0]. Returns `double`. | + +### Aggregate Functions + +Aggregate window functions compute values over a set of rows defined by the window specification. Unlike ranking functions, they support [window frame clauses](#window-frame-specification). + +| Function | SQL | Description | +|----------|-----|-------------| +| `WindowFunction.Sum(expr, window)` | `SUM(expr) OVER(...)` | Sum of values. Returns same type as input. | +| `WindowFunction.Average(expr, window)` | `AVG(expr) OVER(...)` | Average of values. Returns `T?` (or `double` for `int`/`long` input). | +| `WindowFunction.Count(window)` | `COUNT(*) OVER(...)` | Count of all rows. Returns `long`. | +| `WindowFunction.Count(expr, window)` | `COUNT(expr) OVER(...)` | Count of non-null values. Returns `long`. | +| `WindowFunction.Min(expr, window)` | `MIN(expr) OVER(...)` | Minimum value. Returns same type as input. | +| `WindowFunction.Max(expr, window)` | `MAX(expr) OVER(...)` | Maximum value. Returns same type as input. | + +### Navigation Functions + +Navigation functions access specific rows relative to the current row. LAG/LEAD do not support frame clauses; FIRST_VALUE/LAST_VALUE do. + +| Function | SQL | Frame? | Description | +|----------|-----|--------|-------------| +| `WindowFunction.Lag(expr, window)` | `LAG(expr) OVER(...)` | No | Previous row's value (offset 1). | +| `WindowFunction.Lag(expr, n, window)` | `LAG(expr, n) OVER(...)` | No | Value `n` rows back. | +| `WindowFunction.Lag(expr, n, default, window)` | `LAG(expr, n, default) OVER(...)` | No | Value `n` rows back, with default. | +| `WindowFunction.Lead(expr, window)` | `LEAD(expr) OVER(...)` | No | Next row's value (offset 1). | +| `WindowFunction.Lead(expr, n, window)` | `LEAD(expr, n) OVER(...)` | No | Value `n` rows ahead. | +| `WindowFunction.Lead(expr, n, default, window)` | `LEAD(expr, n, default) OVER(...)` | No | Value `n` rows ahead, with default. | +| `WindowFunction.FirstValue(expr, window)` | `FIRST_VALUE(expr) OVER(...)` | Yes | First value in the frame. | +| `WindowFunction.LastValue(expr, window)` | `LAST_VALUE(expr) OVER(...)` | Yes | Last value in the frame. | +| `WindowFunction.NthValue(expr, n, window)` | `NTH_VALUE(expr, n) OVER(...)` | Yes | Value at the Nth row (1-based) in the frame. | + +::: tip Nullable results from LAG/LEAD +When no row exists at the requested offset (e.g. LAG on the first row), SQL returns NULL. For value-type columns, cast to a nullable type in the projection to detect this: `(double?)WindowFunction.Lag(o.Price, window)`. When a default value is provided (3-arg overload), NULL is never returned. +::: + +::: warning LAST_VALUE needs an explicit frame +With the default frame (`RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW`), `LAST_VALUE` returns the *current row's* value — not the partition's last. Use an explicit frame: +```csharp +WindowFunction.LastValue(o.Price, + Window.OrderBy(o.Price) + .RowsBetween(WindowFrameBound.UnboundedPreceding, WindowFrameBound.UnboundedFollowing)) +``` +::: ::: tip -All window functions return `long`. When projecting into a typed DTO with `int` properties, use an explicit cast: `(int)WindowFunction.RowNumber(...)`. +Ranking functions return `long`. When projecting into a typed DTO with `int` properties, use an explicit cast: `(int)WindowFunction.RowNumber(...)`. ::: ## Window Specification API @@ -55,6 +115,8 @@ Build window specifications using the fluent `Window` API: | `Window.PartitionBy(expr)` | `PARTITION BY expr` | | `.ThenBy(expr)` | Additional `ORDER BY expr ASC` column | | `.ThenByDescending(expr)` | Additional `ORDER BY expr DESC` column | +| `.RowsBetween(start, end)` | `ROWS BETWEEN start AND end` (see [Window Frame Specification](#window-frame-specification)) | +| `.RangeBetween(start, end)` | `RANGE BETWEEN start AND end` (see [Window Frame Specification](#window-frame-specification)) | Chain these methods to build the full window specification: @@ -74,6 +136,59 @@ Window.PartitionBy(o.CustomerId) .ThenBy(o.Id) ``` +## Window Frame Specification + +Aggregate window functions support frame clauses that narrow the set of rows used for the computation. Frames use `RowsBetween` or `RangeBetween` chained onto an ordered window specification: + +```csharp +Window.OrderBy(o.Price) + .RowsBetween(WindowFrameBound.UnboundedPreceding, WindowFrameBound.CurrentRow) +``` + +The `WindowFrameBound` factory members produce the five SQL:2003 frame boundaries: + +| Bound | SQL | +|-------|-----| +| `WindowFrameBound.UnboundedPreceding` | `UNBOUNDED PRECEDING` | +| `WindowFrameBound.Preceding(n)` | `n PRECEDING` | +| `WindowFrameBound.CurrentRow` | `CURRENT ROW` | +| `WindowFrameBound.Following(n)` | `n FOLLOWING` | +| `WindowFrameBound.UnboundedFollowing` | `UNBOUNDED FOLLOWING` | + +Example — running total with `SUM`: + +```csharp +var results = db.Orders.Select(o => new +{ + o.Id, + o.Price, + RunningTotal = WindowFunction.Sum(o.Price, + Window.PartitionBy(o.CustomerId) + .OrderBy(o.Price) + .RowsBetween(WindowFrameBound.UnboundedPreceding, WindowFrameBound.CurrentRow)) +}); +``` + +Generated SQL (SQLite): + +```sql +SELECT "o"."Id", "o"."Price", + SUM("o"."Price") OVER(PARTITION BY "o"."CustomerId" ORDER BY "o"."Price" ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS "RunningTotal" +FROM "Orders" AS "o" +``` + +::: tip Default frame behavior +When no explicit frame is specified, SQL defaults to `RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW` in the presence of `ORDER BY`. This produces a running total/min/max by default. +::: + +::: warning Ranking functions don't support frames +The SQL standard forbids frame clauses on ranking functions (ROW_NUMBER, RANK, DENSE_RANK, NTILE) — SQL Server and PostgreSQL will reject the query. Aggregate functions (SUM, AVG, COUNT, MIN, MAX) and value functions (FIRST_VALUE, LAST_VALUE, NTH_VALUE) accept frames. +::: + +::: warning Literal offsets only +`Preceding(n)` and `Following(n)` accept an integer **constant**. Passing a variable or captured value will fail translation: SQL requires literal integer constants in the frame clause. +::: + ## Complete Example ```csharp diff --git a/docs/recipes/window-functions-ranking.md b/docs/recipes/window-functions-ranking.md index 819badb..5358c34 100644 --- a/docs/recipes/window-functions-ranking.md +++ b/docs/recipes/window-functions-ranking.md @@ -337,6 +337,35 @@ Window functions are supported across all major relational providers: The generated SQL uses standard window function syntax, which all these providers support. +## Aggregate Window Functions + +In addition to ranking, ExpressiveSharp supports aggregate window functions (`SUM`, `AVG`, `COUNT`, `MIN`, `MAX`) with optional frame clauses. These are useful for running totals, moving averages, and cumulative min/max: + +```csharp +using ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.WindowFunctions; + +var results = dbContext.Orders + .Select(o => new + { + o.Id, + o.Price, + RunningTotal = WindowFunction.Sum(o.Price, + Window.PartitionBy(o.CustomerId) + .OrderBy(o.Price) + .RowsBetween(WindowFrameBound.UnboundedPreceding, WindowFrameBound.CurrentRow)), + MovingAvg = WindowFunction.Average(o.Price, + Window.OrderBy(o.Price) + .RowsBetween(WindowFrameBound.Preceding(2), WindowFrameBound.CurrentRow)) + }) + .ToList(); +``` + +See the [Window Functions guide](/guide/window-functions#window-frame-specification) for the full frame specification reference. + +::: warning Frames apply to aggregate functions only +The SQL standard forbids frame clauses on ranking functions (ROW_NUMBER, RANK, DENSE_RANK, NTILE). SQL Server and PostgreSQL reject the syntax. Aggregate functions (SUM, AVG, COUNT, MIN, MAX) and value functions (FIRST_VALUE, LAST_VALUE, NTH_VALUE) support frames. +::: + ## Tips ::: tip Combine with other ExpressiveSharp features diff --git a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Abstractions/WindowDefinition.cs b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Abstractions/WindowDefinition.cs index 1bf13e9..5b3c7e6 100644 --- a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Abstractions/WindowDefinition.cs +++ b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Abstractions/WindowDefinition.cs @@ -39,4 +39,26 @@ public OrderedWindowDefinition ThenBy(TKey key) => /// Adds a subsequent ORDER BY column (descending). public OrderedWindowDefinition ThenByDescending(TKey key) => throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + /// Applies a row-based window frame: ROWS BETWEEN AND . + /// + public FramedWindowDefinition RowsBetween(WindowFrameBound start, WindowFrameBound end) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + /// Applies a range-based window frame: RANGE BETWEEN AND . + /// + public FramedWindowDefinition RangeBetween(WindowFrameBound start, WindowFrameBound end) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); +} + +/// +/// Represents a window specification after a frame clause (ROWS/RANGE BETWEEN) has been applied. +/// Terminal type — no further chaining is possible. +/// +public sealed class FramedWindowDefinition +{ + private FramedWindowDefinition() => + throw new InvalidOperationException("FramedWindowDefinition is a marker type for expression trees and cannot be instantiated."); } diff --git a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Abstractions/WindowFrameBound.cs b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Abstractions/WindowFrameBound.cs new file mode 100644 index 0000000..2cf2ab3 --- /dev/null +++ b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Abstractions/WindowFrameBound.cs @@ -0,0 +1,33 @@ +namespace ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.WindowFunctions; + +/// +/// Represents a boundary of a SQL window frame (e.g. UNBOUNDED PRECEDING, +/// 3 PRECEDING, CURRENT ROW, 5 FOLLOWING, UNBOUNDED FOLLOWING). +/// These factory members are translated to SQL by ExpressiveSharp's translators — +/// they throw at runtime if accessed directly. +/// +public sealed class WindowFrameBound +{ + private WindowFrameBound() => + throw new InvalidOperationException("WindowFrameBound is a marker type for expression trees and cannot be instantiated."); + + /// Translates to UNBOUNDED PRECEDING. + public static WindowFrameBound UnboundedPreceding => + throw new InvalidOperationException("This property is translated to SQL and cannot be accessed directly."); + + /// Translates to PRECEDING. + public static WindowFrameBound Preceding(int offset) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// Translates to CURRENT ROW. + public static WindowFrameBound CurrentRow => + throw new InvalidOperationException("This property is translated to SQL and cannot be accessed directly."); + + /// Translates to FOLLOWING. + public static WindowFrameBound Following(int offset) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// Translates to UNBOUNDED FOLLOWING. + public static WindowFrameBound UnboundedFollowing => + throw new InvalidOperationException("This property is translated to SQL and cannot be accessed directly."); +} diff --git a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Abstractions/WindowFunction.cs b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Abstractions/WindowFunction.cs index 96d3f93..e9dc202 100644 --- a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Abstractions/WindowFunction.cs +++ b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Abstractions/WindowFunction.cs @@ -1,12 +1,24 @@ namespace ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.WindowFunctions; /// -/// Provides SQL window function stubs (ROW_NUMBER, RANK, DENSE_RANK, NTILE) -/// for use in EF Core LINQ queries. These methods are translated to SQL by -/// ExpressiveSharp's method call translator — they throw at runtime if called directly. +/// Provides SQL window function stubs for use in EF Core LINQ queries. +/// These methods are translated to SQL by ExpressiveSharp's method call +/// translator — they throw at runtime if called directly. +/// +/// Ranking functions (ROW_NUMBER, RANK, DENSE_RANK, NTILE) accept +/// only — the SQL standard forbids +/// frame clauses on ranking functions. +/// +/// +/// Aggregate functions (SUM, AVG, COUNT, MIN, MAX) accept both +/// (uses SQL default frame) and +/// (explicit ROWS/RANGE BETWEEN). +/// /// public static class WindowFunction { + // ── Ranking functions ──────────────────────────────────────────────── + /// /// Translates to ROW_NUMBER() OVER(...). /// Returns a sequential number for each row within the window partition. @@ -41,4 +53,176 @@ public static long DenseRank(OrderedWindowDefinition window) => /// public static long Ntile(int buckets, OrderedWindowDefinition window) => throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + /// Translates to PERCENT_RANK() OVER(...). + /// Returns the relative rank of each row as a value between 0.0 and 1.0. + /// + public static double PercentRank(OrderedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + /// Translates to CUME_DIST() OVER(...). + /// Returns the cumulative distribution of each row as a value between 0.0 and 1.0. + /// + public static double CumeDist(OrderedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + // ── Aggregate functions ────────────────────────────────────────────── + + /// Translates to SUM(expression) OVER(...). + public static T Sum(T expression, OrderedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + public static T Sum(T expression, FramedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// Translates to AVG(expression) OVER(...). + public static T? Average(T expression, OrderedWindowDefinition window) where T : struct => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + public static T? Average(T expression, FramedWindowDefinition window) where T : struct => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + // int/long → double (matching Queryable.Average semantics) + + /// Translates to AVG(expression) OVER(...). Returns double for integer input. + public static double Average(int expression, OrderedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + public static double Average(int expression, FramedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + public static double? Average(int? expression, OrderedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + public static double? Average(int? expression, FramedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// Translates to AVG(expression) OVER(...). Returns double for long input. + public static double Average(long expression, OrderedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + public static double Average(long expression, FramedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + public static double? Average(long? expression, OrderedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + public static double? Average(long? expression, FramedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// Translates to COUNT(*) OVER(...). Counts all rows in the window. + public static int Count(OrderedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + public static int Count(FramedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// Translates to COUNT(expression) OVER(...). Counts non-null values. + public static int Count(T expression, OrderedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + public static int Count(T expression, FramedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// Translates to MIN(expression) OVER(...). + public static T Min(T expression, OrderedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + public static T Min(T expression, FramedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// Translates to MAX(expression) OVER(...). + public static T Max(T expression, OrderedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + public static T Max(T expression, FramedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + // ── Navigation functions ───────────────────────────────────────────── + // LAG/LEAD access a row at a specific offset from the current row. + // The SQL standard forbids frame clauses on these — OrderedWindowDefinition only. + // + // Without a default value, the SQL result is NULL when no row exists at the + // requested offset. For value types, project into a nullable column explicitly: + // (double?)WindowFunction.Lag(o.Price, Window.OrderBy(o.Price)) + + /// + /// Translates to LAG(expression) OVER(...). Returns the previous row's value (offset 1). + /// The result is NULL when no previous row exists; cast to a nullable type if needed. + /// + public static T Lag(T expression, OrderedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// Translates to LAG(expression, ) OVER(...). + public static T Lag(T expression, int offset, OrderedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// Translates to LAG(expression, , ) OVER(...). + public static T Lag(T expression, int offset, T defaultValue, OrderedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + /// Translates to LEAD(expression) OVER(...). Returns the next row's value (offset 1). + /// The result is NULL when no next row exists; cast to a nullable type if needed. + /// + public static T Lead(T expression, OrderedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// Translates to LEAD(expression, ) OVER(...). + public static T Lead(T expression, int offset, OrderedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// Translates to LEAD(expression, , ) OVER(...). + public static T Lead(T expression, int offset, T defaultValue, OrderedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + /// Translates to FIRST_VALUE(expression) OVER(...). + /// Returns the first value in the window frame. The result depends on the frame; + /// with the default frame this is the first row of the partition. + /// + public static T FirstValue(T expression, OrderedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + public static T FirstValue(T expression, FramedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + /// Translates to LAST_VALUE(expression) OVER(...). + /// Returns the last value in the window frame. With the default frame this returns + /// the current row's value — use an explicit frame like + /// .RowsBetween(UnboundedPreceding, UnboundedFollowing) to get the partition's last value. + /// + public static T LastValue(T expression, OrderedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + public static T LastValue(T expression, FramedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + /// Translates to NTH_VALUE(expression, ) OVER(...). + /// Returns the value at the Nth row in the window frame (1-based). + /// + public static T NthValue(T expression, int n, OrderedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); + + /// + public static T NthValue(T expression, int n, FramedWindowDefinition window) => + throw new InvalidOperationException("This method is translated to SQL and cannot be called directly."); } diff --git a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.csproj b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.csproj index 8229051..166d72f 100644 --- a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.csproj +++ b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.csproj @@ -24,4 +24,8 @@ + + + + diff --git a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFrameBoundInfo.cs b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFrameBoundInfo.cs new file mode 100644 index 0000000..74a263c --- /dev/null +++ b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFrameBoundInfo.cs @@ -0,0 +1,44 @@ +namespace ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Infrastructure.Internal; + +/// +/// Kind of a SQL window frame boundary — maps 1:1 to the five forms of +/// the SQL:2003 frame-bound grammar. +/// +internal enum WindowFrameBoundKind +{ + UnboundedPreceding, + Preceding, + CurrentRow, + Following, + UnboundedFollowing, +} + +/// +/// Fully-resolved description of a single frame boundary carried through the +/// translation pipeline. The is only populated for +/// and ; +/// it is stored as a literal integer because SQL requires literal constants for +/// frame-bound offsets (parameters are not allowed in the frame clause). +/// +internal readonly record struct WindowFrameBoundInfo(WindowFrameBoundKind Kind, int? Offset) +{ + /// Emits the SQL fragment for this boundary (e.g. 3 PRECEDING, CURRENT ROW). + public string ToSqlFragment() => Kind switch + { + WindowFrameBoundKind.UnboundedPreceding => "UNBOUNDED PRECEDING", + WindowFrameBoundKind.Preceding => $"{ValidateOffset()} PRECEDING", + WindowFrameBoundKind.CurrentRow => "CURRENT ROW", + WindowFrameBoundKind.Following => $"{ValidateOffset()} FOLLOWING", + WindowFrameBoundKind.UnboundedFollowing => "UNBOUNDED FOLLOWING", + _ => throw new InvalidOperationException($"Unknown WindowFrameBoundKind: {Kind}"), + }; + + private int ValidateOffset() + { + if (!Offset.HasValue) + throw new InvalidOperationException($"Window frame bound '{Kind}' requires a non-null offset."); + if (Offset.Value < 0) + throw new InvalidOperationException($"Window frame bound '{Kind}' requires a non-negative offset, but got {Offset.Value}."); + return Offset.Value; + } +} diff --git a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFrameBoundMemberTranslator.cs b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFrameBoundMemberTranslator.cs new file mode 100644 index 0000000..d860744 --- /dev/null +++ b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFrameBoundMemberTranslator.cs @@ -0,0 +1,43 @@ +using System.Reflection; +using ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.WindowFunctions; +using Microsoft.EntityFrameworkCore; +using Microsoft.EntityFrameworkCore.Diagnostics; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; + +namespace ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Infrastructure.Internal; + +/// +/// Translates the no-argument static property getters +/// (UnboundedPreceding, CurrentRow, UnboundedFollowing) into +/// intermediate nodes, which are then +/// consumed by when it sees +/// RowsBetween / RangeBetween. +/// +/// The offset-bearing variants (Preceding(int), Following(int)) are +/// methods and are handled by . +/// +/// +internal sealed class WindowFrameBoundMemberTranslator : IMemberTranslator +{ + public SqlExpression? Translate( + SqlExpression? instance, + MemberInfo member, + Type returnType, + IDiagnosticsLogger logger) + { + if (member.DeclaringType != typeof(WindowFrameBound)) + return null; + + return member.Name switch + { + nameof(WindowFrameBound.UnboundedPreceding) => + new WindowFrameBoundSqlExpression(new WindowFrameBoundInfo(WindowFrameBoundKind.UnboundedPreceding, null)), + nameof(WindowFrameBound.CurrentRow) => + new WindowFrameBoundSqlExpression(new WindowFrameBoundInfo(WindowFrameBoundKind.CurrentRow, null)), + nameof(WindowFrameBound.UnboundedFollowing) => + new WindowFrameBoundSqlExpression(new WindowFrameBoundInfo(WindowFrameBoundKind.UnboundedFollowing, null)), + _ => null + }; + } +} diff --git a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFrameBoundSqlExpression.cs b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFrameBoundSqlExpression.cs new file mode 100644 index 0000000..7234d22 --- /dev/null +++ b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFrameBoundSqlExpression.cs @@ -0,0 +1,37 @@ +using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; + +namespace ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Infrastructure.Internal; + +/// +/// Intermediate SQL expression that represents a single window frame boundary. +/// Produced by when it encounters +/// WindowFrameBound.* members, and consumed by the same translator when +/// it sees RowsBetween / RangeBetween. Never reaches final SQL rendering. +/// +internal sealed class WindowFrameBoundSqlExpression : SqlExpression +{ + public WindowFrameBoundInfo BoundInfo { get; } + + public WindowFrameBoundSqlExpression(WindowFrameBoundInfo boundInfo) + : base(typeof(object), typeMapping: null) + { + BoundInfo = boundInfo; + } + + protected override Expression VisitChildren(ExpressionVisitor visitor) => this; + + protected override void Print(ExpressionPrinter expressionPrinter) => + expressionPrinter.Append($"WindowFrameBound({BoundInfo.ToSqlFragment()})"); + +#if NET9_0_OR_GREATER + public override Expression Quote() => + throw new InvalidOperationException("WindowFrameBoundSqlExpression is an intermediate node and should not be quoted."); +#endif + + public override bool Equals(object? obj) => + obj is WindowFrameBoundSqlExpression other && BoundInfo == other.BoundInfo; + + public override int GetHashCode() => BoundInfo.GetHashCode(); +} diff --git a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFrameType.cs b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFrameType.cs new file mode 100644 index 0000000..6e53966 --- /dev/null +++ b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFrameType.cs @@ -0,0 +1,11 @@ +namespace ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Infrastructure.Internal; + +/// +/// SQL window frame type — determines whether the frame is row-based (ROWS) or +/// value-based (RANGE). +/// +internal enum WindowFrameType +{ + Rows, + Range, +} diff --git a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFunctionEvaluatableExpressionFilter.cs b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFunctionEvaluatableExpressionFilter.cs index 894a53b..d65a1af 100644 --- a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFunctionEvaluatableExpressionFilter.cs +++ b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFunctionEvaluatableExpressionFilter.cs @@ -5,8 +5,9 @@ namespace ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Infrastructure.Internal; /// -/// Prevents window function marker method calls from being client-evaluated. -/// These must remain as expression tree nodes for the method call translators to handle. +/// Prevents window function marker method calls and property accesses from being +/// client-evaluated. These must remain as expression tree nodes for the +/// method/member translators to handle. /// internal sealed class WindowFunctionEvaluatableExpressionFilter : IEvaluatableExpressionFilterPlugin { @@ -18,12 +19,21 @@ public bool IsEvaluatableExpression(Expression expression) if (declaringType == typeof(Window) || declaringType == typeof(PartitionedWindowDefinition) || declaringType == typeof(OrderedWindowDefinition) - || declaringType == typeof(WindowFunction)) + || declaringType == typeof(WindowFunction) + || declaringType == typeof(WindowFrameBound)) { return false; } } + // WindowFrameBound.UnboundedPreceding/CurrentRow/UnboundedFollowing are + // static property getters — surface as MemberExpression nodes, not method calls. + if (expression is MemberExpression memberAccess + && memberAccess.Member.DeclaringType == typeof(WindowFrameBound)) + { + return false; + } + return true; } } diff --git a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFunctionMemberTranslatorPlugin.cs b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFunctionMemberTranslatorPlugin.cs new file mode 100644 index 0000000..1d6ae5d --- /dev/null +++ b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFunctionMemberTranslatorPlugin.cs @@ -0,0 +1,16 @@ +using Microsoft.EntityFrameworkCore.Query; + +namespace ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Infrastructure.Internal; + +/// +/// Registers member access translators with EF Core's query pipeline. +/// Currently handles only property getters +/// (UnboundedPreceding, CurrentRow, UnboundedFollowing). +/// +internal sealed class WindowFunctionMemberTranslatorPlugin : IMemberTranslatorPlugin +{ + public IEnumerable Translators { get; } = + [ + new WindowFrameBoundMemberTranslator() + ]; +} diff --git a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFunctionMethodCallTranslator.cs b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFunctionMethodCallTranslator.cs index 0aa2b82..3e574e8 100644 --- a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFunctionMethodCallTranslator.cs +++ b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFunctionMethodCallTranslator.cs @@ -10,8 +10,14 @@ namespace ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Infrastructur /// /// Translates static methods into SQL window function expressions. -/// ROW_NUMBER uses the built-in ; -/// RANK, DENSE_RANK, and NTILE use . +/// +/// Ranking functions (ROW_NUMBER, RANK, DENSE_RANK, NTILE) never emit a frame clause — +/// the SQL standard forbids it and SQL Server / PostgreSQL reject the syntax. +/// +/// +/// Aggregate functions (SUM, AVG, COUNT, MIN, MAX) propagate the frame from the +/// into . +/// /// internal sealed class WindowFunctionMethodCallTranslator : IMethodCallTranslator { @@ -39,6 +45,8 @@ public WindowFunctionMethodCallTranslator( return method.Name switch { + // ── Ranking functions (no frame) ───────────────────────────── + nameof(WindowFunction.RowNumber) when arguments.Count == 1 && arguments[0] is WindowSpecSqlExpression spec => new RowNumberExpression(spec.Partitions, spec.Orderings, longTypeMapping), @@ -61,7 +69,153 @@ [new OrderingExpression( [_sqlExpressionFactory.ApplyDefaultTypeMapping(arguments[0])], spec.Partitions, spec.Orderings, typeof(long), longTypeMapping), + nameof(WindowFunction.PercentRank) when arguments.Count >= 1 && arguments[0] is WindowSpecSqlExpression spec + => new WindowFunctionSqlExpression("PERCENT_RANK", [], spec.Partitions, spec.Orderings, + typeof(double), _typeMappingSource.FindMapping(typeof(double))), + + nameof(WindowFunction.CumeDist) when arguments.Count >= 1 && arguments[0] is WindowSpecSqlExpression spec + => new WindowFunctionSqlExpression("CUME_DIST", [], spec.Partitions, spec.Orderings, + typeof(double), _typeMappingSource.FindMapping(typeof(double))), + + // ── Aggregate functions (with frame) ───────────────────────── + + nameof(WindowFunction.Sum) when ExtractAggregateArgs(arguments, out var expr, out var spec) + => MakeAggregate("SUM", [expr], spec, method.ReturnType), + + nameof(WindowFunction.Average) when ExtractAggregateArgs(arguments, out var expr, out var spec) + => MakeAggregate("AVG", + // When the C# return type differs from the expression type (int/long→double), + // cast the argument so SQL computes a floating-point AVG, not integer division. + [NeedsFloatCast(expr, method.ReturnType) + ? _sqlExpressionFactory.ApplyDefaultTypeMapping( + _sqlExpressionFactory.Convert(expr, method.ReturnType)) + : expr], + spec, method.ReturnType), + + nameof(WindowFunction.Count) when arguments.Count == 1 && arguments[0] is WindowSpecSqlExpression spec + => MakeAggregate("COUNT", [_sqlExpressionFactory.Fragment("*")], spec, typeof(int)), + + nameof(WindowFunction.Count) when ExtractAggregateArgs(arguments, out var expr, out var spec) + => MakeAggregate("COUNT", [expr], spec, typeof(int)), + + nameof(WindowFunction.Min) when ExtractAggregateArgs(arguments, out var expr, out var spec) + => MakeAggregate("MIN", [expr], spec, method.ReturnType), + + nameof(WindowFunction.Max) when ExtractAggregateArgs(arguments, out var expr, out var spec) + => MakeAggregate("MAX", [expr], spec, method.ReturnType), + + // ── Navigation functions (no frame) ────────────────────────── + + nameof(WindowFunction.Lag) when ExtractNavigationArgs(arguments, out var lagArgs, out var lagSpec) + => MakeNavigation("LAG", lagArgs, lagSpec, method.ReturnType), + + nameof(WindowFunction.Lead) when ExtractNavigationArgs(arguments, out var leadArgs, out var leadSpec) + => MakeNavigation("LEAD", leadArgs, leadSpec, method.ReturnType), + + // ── Value functions (with frame) ───────────────────────────── + + nameof(WindowFunction.FirstValue) when ExtractAggregateArgs(arguments, out var fvExpr, out var fvSpec) + => MakeAggregate("FIRST_VALUE", [fvExpr], fvSpec, method.ReturnType), + + nameof(WindowFunction.LastValue) when ExtractAggregateArgs(arguments, out var lvExpr, out var lvSpec) + => MakeAggregate("LAST_VALUE", [lvExpr], lvSpec, method.ReturnType), + + nameof(WindowFunction.NthValue) when arguments.Count >= 3 && arguments[^1] is WindowSpecSqlExpression nvSpec + => MakeAggregate("NTH_VALUE", + [arguments[0], _sqlExpressionFactory.ApplyDefaultTypeMapping(arguments[1])], + nvSpec, method.ReturnType), + _ => null }; } + + /// + /// Extracts the expression argument (first) and window spec (last) from a 2-argument + /// aggregate call like Sum(o.Price, window). + /// + private static bool ExtractAggregateArgs( + IReadOnlyList arguments, + out SqlExpression expression, + out WindowSpecSqlExpression spec) + { + if (arguments.Count >= 2 && arguments[^1] is WindowSpecSqlExpression s) + { + expression = arguments[0]; + spec = s; + return true; + } + + expression = null!; + spec = null!; + return false; + } + + /// + /// Extracts the function arguments (all but last) and window spec (last) from a + /// navigation call like Lag(o.Price, 2, window). Applies default type mapping + /// to all function arguments. + /// + private bool ExtractNavigationArgs( + IReadOnlyList arguments, + out List funcArgs, + out WindowSpecSqlExpression spec) + { + if (arguments.Count >= 2 && arguments[^1] is WindowSpecSqlExpression s) + { + funcArgs = []; + for (var i = 0; i < arguments.Count - 1; i++) + funcArgs.Add(_sqlExpressionFactory.ApplyDefaultTypeMapping(arguments[i])); + spec = s; + return true; + } + + funcArgs = null!; + spec = null!; + return false; + } + + private WindowFunctionSqlExpression MakeNavigation( + string functionName, + IReadOnlyList funcArgs, + WindowSpecSqlExpression spec, + Type returnType) + { + var typeMapping = _typeMappingSource.FindMapping(returnType); + return new WindowFunctionSqlExpression( + functionName, + funcArgs, + spec.Partitions, + spec.Orderings, + returnType, + typeMapping); + } + + /// + /// Returns true when the AVG expression argument's CLR type is an integer type + /// but the method's return type is floating-point — SQL Server performs integer + /// division for AVG(int), so we need to CAST the argument. + /// + private static bool NeedsFloatCast(SqlExpression expr, Type returnType) => + returnType == typeof(double) + && expr.Type is var t + && (t == typeof(int) || t == typeof(long) || t == typeof(int?) || t == typeof(long?)); + + private WindowFunctionSqlExpression MakeAggregate( + string functionName, + IReadOnlyList funcArgs, + WindowSpecSqlExpression spec, + Type returnType) + { + var typeMapping = _typeMappingSource.FindMapping(returnType); + return new WindowFunctionSqlExpression( + functionName, + funcArgs, + spec.Partitions, + spec.Orderings, + returnType, + typeMapping, + spec.FrameType, + spec.FrameStart, + spec.FrameEnd); + } } diff --git a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFunctionSqlExpression.cs b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFunctionSqlExpression.cs index cb79b13..08f1d66 100644 --- a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFunctionSqlExpression.cs +++ b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowFunctionSqlExpression.cs @@ -6,7 +6,7 @@ namespace ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Infrastructure.Internal; /// -/// SQL expression representing a window function call: FUNC_NAME(args) OVER(PARTITION BY ... ORDER BY ...). +/// SQL expression representing a window function call: FUNC_NAME(args) OVER(PARTITION BY ... ORDER BY ... [frame]). /// Used for RANK, DENSE_RANK, NTILE (ROW_NUMBER uses the built-in ). /// /// This expression is self-rendering: produces correct SQL through @@ -16,11 +16,11 @@ namespace ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Infrastructur /// /// /// SQL standard assumption: The function names (RANK, DENSE_RANK, NTILE) and the -/// OVER(PARTITION BY ... ORDER BY ...) clause syntax are hardcoded as literal SQL fragments. -/// This relies on SQL:2003 window function syntax which is consistently implemented by all -/// major databases (SQL Server 2005+, PostgreSQL 8.4+, SQLite 3.25+, MySQL 8.0+, Oracle 8i+, -/// MariaDB 10.2+). If a provider deviates from this standard syntax, a provider-specific -/// implementation would be needed. +/// OVER(PARTITION BY ... ORDER BY ... [ROWS/RANGE BETWEEN ...]) clause syntax are hardcoded as +/// literal SQL fragments. This relies on SQL:2003 window function syntax which is consistently +/// implemented by all major databases (SQL Server 2012+, PostgreSQL 8.4+, SQLite 3.25+, +/// MySQL 8.0+, Oracle 8i+, MariaDB 10.2+). If a provider deviates from this standard syntax, +/// a provider-specific implementation would be needed. /// /// internal sealed class WindowFunctionSqlExpression : SqlExpression @@ -29,6 +29,9 @@ internal sealed class WindowFunctionSqlExpression : SqlExpression public IReadOnlyList Arguments { get; } public IReadOnlyList Partitions { get; } public IReadOnlyList Orderings { get; } + public WindowFrameType? FrameType { get; } + public WindowFrameBoundInfo? FrameStart { get; } + public WindowFrameBoundInfo? FrameEnd { get; } public WindowFunctionSqlExpression( string functionName, @@ -36,19 +39,25 @@ public WindowFunctionSqlExpression( IReadOnlyList partitions, IReadOnlyList orderings, Type type, - RelationalTypeMapping? typeMapping) + RelationalTypeMapping? typeMapping, + WindowFrameType? frameType = null, + WindowFrameBoundInfo? frameStart = null, + WindowFrameBoundInfo? frameEnd = null) : base(type, typeMapping) { FunctionName = functionName; Arguments = arguments; Partitions = partitions; Orderings = orderings; + FrameType = frameType; + FrameStart = frameStart; + FrameEnd = frameEnd; } /// /// Self-rendering: when any QuerySqlGenerator visits this expression via VisitExtension, /// it calls VisitChildren, which visits SqlFragmentExpression and child SqlExpression nodes - /// in the correct order to produce FUNC(args) OVER(PARTITION BY ... ORDER BY ...). + /// in the correct order to produce FUNC(args) OVER(PARTITION BY ... ORDER BY ... [frame]). /// protected override Expression VisitChildren(ExpressionVisitor visitor) { @@ -67,7 +76,7 @@ protected override void Print(ExpressionPrinter expressionPrinter) => /// /// Shared rendering logic for both SQL generation () and /// diagnostic output (). Produces the - /// FUNC(args) OVER(PARTITION BY ... ORDER BY ...) structure. + /// FUNC(args) OVER(PARTITION BY ... ORDER BY ... [ROWS/RANGE BETWEEN ...]) structure. /// private void EmitWindowFunction(Action appendText, Action visitExpression) { @@ -79,6 +88,8 @@ private void EmitWindowFunction(Action appendText, Action vi } appendText(") OVER("); + var anyClauseEmitted = false; + if (Partitions.Count > 0) { appendText("PARTITION BY "); @@ -87,11 +98,12 @@ private void EmitWindowFunction(Action appendText, Action vi if (i > 0) appendText(", "); visitExpression(Partitions[i]); } + anyClauseEmitted = true; } if (Orderings.Count > 0) { - if (Partitions.Count > 0) appendText(" "); + if (anyClauseEmitted) appendText(" "); appendText("ORDER BY "); for (var i = 0; i < Orderings.Count; i++) { @@ -99,6 +111,16 @@ private void EmitWindowFunction(Action appendText, Action vi visitExpression(Orderings[i].Expression); appendText(Orderings[i].IsAscending ? " ASC" : " DESC"); } + anyClauseEmitted = true; + } + + if (FrameType is { } frameType) + { + if (anyClauseEmitted) appendText(" "); + appendText(frameType == WindowFrameType.Rows ? "ROWS BETWEEN " : "RANGE BETWEEN "); + appendText(FrameStart!.Value.ToSqlFragment()); + appendText(" AND "); + appendText(FrameEnd!.Value.ToSqlFragment()); } appendText(")"); @@ -114,7 +136,10 @@ obj is WindowFunctionSqlExpression other && FunctionName == other.FunctionName && Arguments.SequenceEqual(other.Arguments) && Partitions.SequenceEqual(other.Partitions) - && Orderings.SequenceEqual(other.Orderings); + && Orderings.SequenceEqual(other.Orderings) + && FrameType == other.FrameType + && FrameStart == other.FrameStart + && FrameEnd == other.FrameEnd; public override int GetHashCode() { @@ -123,6 +148,9 @@ public override int GetHashCode() foreach (var a in Arguments) hash.Add(a); foreach (var p in Partitions) hash.Add(p); foreach (var o in Orderings) hash.Add(o); + hash.Add(FrameType); + hash.Add(FrameStart); + hash.Add(FrameEnd); return hash.ToHashCode(); } } diff --git a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowSpecMethodCallTranslator.cs b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowSpecMethodCallTranslator.cs index 0e8103e..5f3f067 100644 --- a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowSpecMethodCallTranslator.cs +++ b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowSpecMethodCallTranslator.cs @@ -8,8 +8,10 @@ namespace ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Infrastructure.Internal; /// -/// Translates static methods and / -/// instance methods into intermediate nodes. +/// Translates static methods, / +/// instance methods, and +/// factory methods into intermediate SQL expression nodes +/// (, ). /// internal sealed class WindowSpecMethodCallTranslator : IMethodCallTranslator { @@ -36,6 +38,21 @@ internal sealed class WindowSpecMethodCallTranslator : IMethodCallTranslator }; } + // Static factory methods on WindowFrameBound (the no-arg variants + // UnboundedPreceding/CurrentRow/UnboundedFollowing are properties and are + // handled by WindowFrameBoundMemberTranslator instead). + if (declaringType == typeof(WindowFrameBound)) + { + return method.Name switch + { + nameof(WindowFrameBound.Preceding) when TryGetIntConstant(arguments[0], out var offset) => + new WindowFrameBoundSqlExpression(new WindowFrameBoundInfo(WindowFrameBoundKind.Preceding, offset)), + nameof(WindowFrameBound.Following) when TryGetIntConstant(arguments[0], out var offset) => + new WindowFrameBoundSqlExpression(new WindowFrameBoundInfo(WindowFrameBoundKind.Following, offset)), + _ => null + }; + } + // Instance methods on PartitionedWindowDefinition and OrderedWindowDefinition if ((declaringType == typeof(PartitionedWindowDefinition) || declaringType == typeof(OrderedWindowDefinition)) && instance is WindowSpecSqlExpression spec) @@ -47,10 +64,26 @@ internal sealed class WindowSpecMethodCallTranslator : IMethodCallTranslator spec.WithOrdering(arguments[0], ascending: true), nameof(PartitionedWindowDefinition.OrderByDescending) or nameof(OrderedWindowDefinition.ThenByDescending) => spec.WithOrdering(arguments[0], ascending: false), + nameof(OrderedWindowDefinition.RowsBetween) when arguments is [WindowFrameBoundSqlExpression start, WindowFrameBoundSqlExpression end] => + spec.WithFrame(WindowFrameType.Rows, start.BoundInfo, end.BoundInfo), + nameof(OrderedWindowDefinition.RangeBetween) when arguments is [WindowFrameBoundSqlExpression start, WindowFrameBoundSqlExpression end] => + spec.WithFrame(WindowFrameType.Range, start.BoundInfo, end.BoundInfo), _ => null }; } return null; } + + private static bool TryGetIntConstant(SqlExpression expression, out int value) + { + if (expression is SqlConstantExpression { Value: int i }) + { + value = i; + return true; + } + + value = 0; + return false; + } } diff --git a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowSpecSqlExpression.cs b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowSpecSqlExpression.cs index 69e584b..e07d65a 100644 --- a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowSpecSqlExpression.cs +++ b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/Infrastructure/Internal/WindowSpecSqlExpression.cs @@ -6,29 +6,38 @@ namespace ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Infrastructure.Internal; /// -/// Intermediate SQL expression that carries the PARTITION BY and ORDER BY clauses -/// of a window specification. This node is consumed by the window function translator -/// and should never reach final SQL rendering. +/// Intermediate SQL expression that carries the PARTITION BY, ORDER BY, and optional +/// frame (ROWS/RANGE BETWEEN) clauses of a window specification. This node is consumed +/// by the window function translator and should never reach final SQL rendering. /// internal sealed class WindowSpecSqlExpression : SqlExpression { public IReadOnlyList Partitions { get; } public IReadOnlyList Orderings { get; } + public WindowFrameType? FrameType { get; } + public WindowFrameBoundInfo? FrameStart { get; } + public WindowFrameBoundInfo? FrameEnd { get; } public WindowSpecSqlExpression( IReadOnlyList partitions, IReadOnlyList orderings, - RelationalTypeMapping? typeMapping) + RelationalTypeMapping? typeMapping, + WindowFrameType? frameType = null, + WindowFrameBoundInfo? frameStart = null, + WindowFrameBoundInfo? frameEnd = null) : base(typeof(object), typeMapping) { Partitions = partitions; Orderings = orderings; + FrameType = frameType; + FrameStart = frameStart; + FrameEnd = frameEnd; } public WindowSpecSqlExpression WithPartition(SqlExpression partition) { var newPartitions = new List(Partitions) { partition }; - return new WindowSpecSqlExpression(newPartitions, Orderings, TypeMapping); + return new WindowSpecSqlExpression(newPartitions, Orderings, TypeMapping, FrameType, FrameStart, FrameEnd); } public WindowSpecSqlExpression WithOrdering(SqlExpression expression, bool ascending) @@ -37,9 +46,13 @@ public WindowSpecSqlExpression WithOrdering(SqlExpression expression, bool ascen { new(expression, ascending) }; - return new WindowSpecSqlExpression(Partitions, newOrderings, TypeMapping); + return new WindowSpecSqlExpression(Partitions, newOrderings, TypeMapping, FrameType, FrameStart, FrameEnd); } + public WindowSpecSqlExpression WithFrame( + WindowFrameType frameType, WindowFrameBoundInfo start, WindowFrameBoundInfo end) => + new(Partitions, Orderings, TypeMapping, frameType, start, end); + protected override Expression VisitChildren(ExpressionVisitor visitor) { var changed = false; @@ -59,7 +72,7 @@ protected override Expression VisitChildren(ExpressionVisitor visitor) } return changed - ? new WindowSpecSqlExpression(newPartitions, newOrderings, TypeMapping) + ? new WindowSpecSqlExpression(newPartitions, newOrderings, TypeMapping, FrameType, FrameStart, FrameEnd) : this; } @@ -88,6 +101,15 @@ protected override void Print(ExpressionPrinter expressionPrinter) } } + if (FrameType is { } frameType) + { + if (Partitions.Count > 0 || Orderings.Count > 0) expressionPrinter.Append(" "); + expressionPrinter.Append(frameType == WindowFrameType.Rows ? "ROWS BETWEEN " : "RANGE BETWEEN "); + expressionPrinter.Append(FrameStart!.Value.ToSqlFragment()); + expressionPrinter.Append(" AND "); + expressionPrinter.Append(FrameEnd!.Value.ToSqlFragment()); + } + expressionPrinter.Append(")"); } @@ -99,13 +121,19 @@ public override Expression Quote() => public override bool Equals(object? obj) => obj is WindowSpecSqlExpression other && Partitions.SequenceEqual(other.Partitions) - && Orderings.SequenceEqual(other.Orderings); + && Orderings.SequenceEqual(other.Orderings) + && FrameType == other.FrameType + && FrameStart == other.FrameStart + && FrameEnd == other.FrameEnd; public override int GetHashCode() { var hash = new HashCode(); foreach (var p in Partitions) hash.Add(p); foreach (var o in Orderings) hash.Add(o); + hash.Add(FrameType); + hash.Add(FrameStart); + hash.Add(FrameEnd); return hash.ToHashCode(); } } diff --git a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/RelationalExpressivePlugin.cs b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/RelationalExpressivePlugin.cs index b658d67..03cf68e 100644 --- a/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/RelationalExpressivePlugin.cs +++ b/src/ExpressiveSharp.EntityFrameworkCore.RelationalExtensions/RelationalExpressivePlugin.cs @@ -27,6 +27,10 @@ public void ApplyServices(IServiceCollection services) // Register method call translator plugin (scoped — matches EF Core's service lifetimes) services.AddScoped(); + // Register member translator plugin for WindowFrameBound property getters + // (UnboundedPreceding, CurrentRow, UnboundedFollowing) + services.AddScoped(); + // Register evaluatable expression filter services.AddSingleton(); diff --git a/tests/ExpressiveSharp.EntityFrameworkCore.IntegrationTests/Infrastructure/WindowFunctionTestBase.cs b/tests/ExpressiveSharp.EntityFrameworkCore.IntegrationTests/Infrastructure/WindowFunctionTestBase.cs index 2f23d4a..ae0b490 100644 --- a/tests/ExpressiveSharp.EntityFrameworkCore.IntegrationTests/Infrastructure/WindowFunctionTestBase.cs +++ b/tests/ExpressiveSharp.EntityFrameworkCore.IntegrationTests/Infrastructure/WindowFunctionTestBase.cs @@ -255,4 +255,663 @@ public async Task RowNumber_OverExpressiveTotal_UsesExpandedExpression() for (var i = 1; i < results.Count; i++) Assert.IsTrue(results[i].Total >= results[i - 1].Total); } + + // ── Aggregate window function tests ───────────────────────────────── + // + // Aggregate functions (SUM, AVG, COUNT, MIN, MAX) with OVER produce + // results that depend on the frame clause — unlike ranking functions. + // A running total (SUM with ROWS UNBOUNDED PRECEDING TO CURRENT ROW) + // gives a different value per row than a full-partition SUM. + // + // Seed data (ordered by Price ASC): + // Price: 10, 15, 20, 20, 25, 30, 35, 40, 45, 50 + // Running total: 10, 25, 45, 65, 90, 120, 155, 195, 240, 290 + + [TestMethod] + public async Task Sum_WithRowsFrame_ProducesRunningTotal() + { + var query = Context.Orders + .Select(o => new + { + o.Price, + RunningTotal = WindowFunction.Sum(o.Price, + Window.OrderBy(o.Price) + .RowsBetween(WindowFrameBound.UnboundedPreceding, WindowFrameBound.CurrentRow)), + }) + .OrderBy(x => x.RunningTotal); + + var sql = query.ToQueryString(); + StringAssert.Contains(sql, "SUM"); + StringAssert.Contains(sql, "ROWS BETWEEN"); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + // Running total must increase monotonically + for (var i = 1; i < results.Count; i++) + Assert.IsTrue(results[i].RunningTotal >= results[i - 1].RunningTotal, + $"Running total must be non-decreasing: {results[i].RunningTotal} < {results[i - 1].RunningTotal}"); + + // First row = smallest price (10), last row = sum of all (290) + Assert.AreEqual(10.0, results[0].RunningTotal); + Assert.AreEqual(290.0, results[^1].RunningTotal); + } + + [TestMethod] + public async Task Sum_WithPartitionAndFrame_ResetsPerGroup() + { + // Customer 1 prices (ascending): 15, 20, 20, 35, 50 → running totals: 15, 35, 55, 90, 140 + // Customer 2 prices (ascending): 10, 25, 30, 40, 45 → running totals: 10, 35, 65, 105, 150 + var query = Context.Orders + .Select(o => new + { + o.CustomerId, + o.Price, + RunningTotal = WindowFunction.Sum(o.Price, + Window.PartitionBy(o.CustomerId) + .OrderBy(o.Price) + .RowsBetween(WindowFrameBound.UnboundedPreceding, WindowFrameBound.CurrentRow)), + }) + .OrderBy(x => x.CustomerId) + .ThenBy(x => x.RunningTotal); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + var c1 = results.Where(r => r.CustomerId == 1).ToList(); + var c2 = results.Where(r => r.CustomerId == 2).ToList(); + + // Each partition's running total starts at the first price + Assert.AreEqual(c1[0].Price, c1[0].RunningTotal); + Assert.AreEqual(c2[0].Price, c2[0].RunningTotal); + + // Each partition's last running total = sum of all prices in that partition + Assert.AreEqual(140.0, c1[^1].RunningTotal); // 15+20+20+35+50 + Assert.AreEqual(150.0, c2[^1].RunningTotal); // 10+25+30+40+45 + } + + [TestMethod] + public async Task Average_WithSlidingWindow_ProducesMovingAverage() + { + // 3-row sliding window average (1 preceding, current, 1 following) + // differs from a full-partition average + var query = Context.Orders + .Select(o => new + { + o.Price, + MovingAvg = WindowFunction.Average(o.Price, + Window.OrderBy(o.Price) + .RowsBetween(WindowFrameBound.Preceding(1), WindowFrameBound.Following(1))), + }) + .OrderBy(x => x.Price); + + var sql = query.ToQueryString(); + StringAssert.Contains(sql, "AVG"); + StringAssert.Contains(sql, "1 PRECEDING"); + StringAssert.Contains(sql, "1 FOLLOWING"); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + // First row (price=10): avg of [10, 15] = 12.5 (only 1 following, no preceding) + // Middle rows have 3-element windows + // Not all values should be the same (unlike full-partition avg which would be 29.0) + var distinctAvgs = results.Select(r => r.MovingAvg).Distinct().Count(); + Assert.IsTrue(distinctAvgs > 1, "Moving average should produce varying values per row"); + } + + [TestMethod] + public async Task Count_Star_WithPartition_ProducesRunningCount() + { + var query = Context.Orders + .Select(o => new + { + o.Id, + o.CustomerId, + RunningCount = WindowFunction.Count( + Window.PartitionBy(o.CustomerId).OrderBy(o.Id)), + }) + .OrderBy(x => x.CustomerId) + .ThenBy(x => x.Id); + + var sql = query.ToQueryString(); + StringAssert.Contains(sql, "COUNT(*)"); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + // Running count per customer: 1, 2, 3, 4, 5 for each partition + var c1 = results.Where(r => r.CustomerId == 1).Select(r => r.RunningCount).ToList(); + var c2 = results.Where(r => r.CustomerId == 2).Select(r => r.RunningCount).ToList(); + + CollectionAssert.AreEqual(new int[] { 1, 2, 3, 4, 5 }, c1); + CollectionAssert.AreEqual(new int[] { 1, 2, 3, 4, 5 }, c2); + } + + [TestMethod] + public async Task Min_WithPartition_ReturnsRunningMin() + { + // Running MIN per customer, ordered by Id (insertion order). + // Customer 1 prices by Id: 50, 20, 20, 15, 35 → running min: 50, 20, 20, 15, 15 + // Customer 2 prices by Id: 10, 30, 40, 25, 45 → running min: 10, 10, 10, 10, 10 + var query = Context.Orders + .Select(o => new + { + o.Id, + o.CustomerId, + o.Price, + RunningMin = WindowFunction.Min(o.Price, + Window.PartitionBy(o.CustomerId).OrderBy(o.Id)), + }) + .OrderBy(x => x.CustomerId) + .ThenBy(x => x.Id); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + // Running min is non-increasing within each partition + foreach (var group in results.GroupBy(r => r.CustomerId)) + { + var mins = group.OrderBy(r => r.Id).Select(r => r.RunningMin).ToList(); + for (var i = 1; i < mins.Count; i++) + Assert.IsTrue(mins[i] <= mins[i - 1], + $"Running MIN must be non-increasing: {mins[i]} > {mins[i - 1]}"); + } + } + + [TestMethod] + public async Task Max_WithRowsFrame_ReturnsRunningMax() + { + // Ordered by Price ASC → running MAX = current row's price (each new row is the max so far) + var query = Context.Orders + .Select(o => new + { + o.Price, + RunningMax = WindowFunction.Max(o.Price, + Window.OrderBy(o.Price) + .RowsBetween(WindowFrameBound.UnboundedPreceding, WindowFrameBound.CurrentRow)), + }) + .OrderBy(x => x.Price); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + // When ordered ascending, running MAX equals the current price + for (var i = 0; i < results.Count; i++) + Assert.AreEqual(results[i].Price, results[i].RunningMax, + $"Running MAX should equal current price when ordered ascending, at index {i}"); + } + + [TestMethod] + public async Task MultipleAggregates_InSameSelect() + { + var query = Context.Orders + .Select(o => new + { + o.Id, + o.Price, + RunningSum = WindowFunction.Sum(o.Price, + Window.OrderBy(o.Price) + .RowsBetween(WindowFrameBound.UnboundedPreceding, WindowFrameBound.CurrentRow)), + RunningCount = WindowFunction.Count( + Window.OrderBy(o.Price)), + RunningMax = WindowFunction.Max(o.Price, + Window.OrderBy(o.Price) + .RowsBetween(WindowFrameBound.UnboundedPreceding, WindowFrameBound.CurrentRow)), + }) + .OrderBy(x => x.Price); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + // Last row: sum = 290, count = 10, max = 50 + Assert.AreEqual(290.0, results[^1].RunningSum); + Assert.AreEqual(10, results[^1].RunningCount); + Assert.AreEqual(50.0, results[^1].RunningMax); + } + + // ── Navigation function tests (LAG / LEAD) ────────────────────────── + // + // LAG/LEAD access a row at a fixed offset from the current row. + // They do NOT support frame clauses (SQL standard forbids it). + // + // Ordered by Price ASC: 10, 15, 20, 20, 25, 30, 35, 40, 45, 50 + + [TestMethod] + public async Task Lag_Default_ReturnsPreviousRowPrice() + { + // Cast to (double?) so SQL NULL is distinguishable from 0.0 + var query = Context.Orders + .Select(o => new + { + o.Price, + PrevPrice = (double?)WindowFunction.Lag(o.Price, + Window.OrderBy(o.Price)), + }) + .OrderBy(x => x.Price); + + var sql = query.ToQueryString(); + StringAssert.Contains(sql, "LAG"); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + // First row has no previous → null + Assert.IsNull(results[0].PrevPrice, "First row should have no previous value"); + // Second row's LAG = first row's price + Assert.AreEqual(10.0, results[1].PrevPrice); + } + + [TestMethod] + public async Task Lag_WithOffset_ReturnsCorrectRow() + { + var query = Context.Orders + .Select(o => new + { + o.Price, + Prev2 = (double?)WindowFunction.Lag(o.Price, 2, + Window.OrderBy(o.Price)), + }) + .OrderBy(x => x.Price); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + // First two rows have no row 2 back → null + Assert.IsNull(results[0].Prev2); + Assert.IsNull(results[1].Prev2); + // Third row (price=20) looks back 2 → price=10 + Assert.AreEqual(10.0, results[2].Prev2); + } + + [TestMethod] + public async Task Lag_WithDefault_ReturnsDefaultWhenNoRow() + { + var query = Context.Orders + .Select(o => new + { + o.Price, + PrevOrZero = WindowFunction.Lag(o.Price, 1, 0.0, + Window.OrderBy(o.Price)), + }) + .OrderBy(x => x.Price); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + // First row has no previous → default (0.0) + Assert.AreEqual(0.0, results[0].PrevOrZero); + // Second row → previous price (10) + Assert.AreEqual(10.0, results[1].PrevOrZero); + } + + [TestMethod] + public async Task Lead_Default_ReturnsNextRowPrice() + { + var query = Context.Orders + .Select(o => new + { + o.Price, + NextPrice = (double?)WindowFunction.Lead(o.Price, + Window.OrderBy(o.Price)), + }) + .OrderBy(x => x.Price); + + var sql = query.ToQueryString(); + StringAssert.Contains(sql, "LEAD"); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + // Last row has no next → null + Assert.IsNull(results[^1].NextPrice, "Last row should have no next value"); + // First row's LEAD = second row's price (15) + Assert.AreEqual(15.0, results[0].NextPrice); + } + + [TestMethod] + public async Task Lead_WithOffsetAndDefault_ReturnsCorrectValues() + { + var query = Context.Orders + .Select(o => new + { + o.Price, + Next2OrNeg1 = WindowFunction.Lead(o.Price, 2, -1.0, + Window.OrderBy(o.Price)), + }) + .OrderBy(x => x.Price); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + // First row (price=10) looks ahead 2 → price=20 + Assert.AreEqual(20.0, results[0].Next2OrNeg1); + // Last two rows have no row 2 ahead → default (-1.0) + Assert.AreEqual(-1.0, results[^1].Next2OrNeg1); + Assert.AreEqual(-1.0, results[^2].Next2OrNeg1); + } + + [TestMethod] + public async Task Lag_WithPartition_ResetsPerGroup() + { + var query = Context.Orders + .Select(o => new + { + o.CustomerId, + o.Price, + PrevInGroup = (double?)WindowFunction.Lag(o.Price, + Window.PartitionBy(o.CustomerId).OrderBy(o.Price)), + }) + .OrderBy(x => x.CustomerId) + .ThenBy(x => x.Price); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + // First row of each customer partition has no previous → null + var c1 = results.Where(r => r.CustomerId == 1).ToList(); + var c2 = results.Where(r => r.CustomerId == 2).ToList(); + + Assert.IsNull(c1[0].PrevInGroup); + Assert.IsNull(c2[0].PrevInGroup); + // Second row of each partition → first row's price + Assert.AreEqual(c1[0].Price, c1[1].PrevInGroup); + Assert.AreEqual(c2[0].Price, c2[1].PrevInGroup); + } + + // ── FIRST_VALUE / LAST_VALUE ───────────────────────────────────────── + + [TestMethod] + public async Task FirstValue_ReturnsFirstPriceInPartition() + { + // Ordered by Price ASC: 10, 15, 20, 20, 25, 30, 35, 40, 45, 50 + // FIRST_VALUE with default frame → first row's price = 10 for every row + var query = Context.Orders + .Select(o => new + { + o.Price, + First = WindowFunction.FirstValue(o.Price, + Window.OrderBy(o.Price)), + }); + + var sql = query.ToQueryString(); + StringAssert.Contains(sql, "FIRST_VALUE"); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + Assert.IsTrue(results.All(r => Math.Abs(r.First - 10.0) < 1e-9), + "FIRST_VALUE should return the lowest price (10) for all rows"); + } + + [TestMethod] + public async Task LastValue_WithUnboundedFrame_ReturnsLastPriceInPartition() + { + // Without an explicit frame, LAST_VALUE returns the current row (useless). + // With ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING → true last value. + var query = Context.Orders + .Select(o => new + { + o.Price, + Last = WindowFunction.LastValue(o.Price, + Window.OrderBy(o.Price) + .RowsBetween(WindowFrameBound.UnboundedPreceding, WindowFrameBound.UnboundedFollowing)), + }); + + var sql = query.ToQueryString(); + StringAssert.Contains(sql, "LAST_VALUE"); + StringAssert.Contains(sql, "UNBOUNDED FOLLOWING"); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + Assert.IsTrue(results.All(r => Math.Abs(r.Last - 50.0) < 1e-9), + "LAST_VALUE with unbounded frame should return the highest price (50) for all rows"); + } + + [TestMethod] + public async Task FirstValue_WithPartition_ReturnsFirstPerGroup() + { + // Customer 1 prices (ascending): 15, 20, 20, 35, 50 → first = 15 + // Customer 2 prices (ascending): 10, 25, 30, 40, 45 → first = 10 + var query = Context.Orders + .Select(o => new + { + o.CustomerId, + o.Price, + FirstInGroup = WindowFunction.FirstValue(o.Price, + Window.PartitionBy(o.CustomerId).OrderBy(o.Price)), + }) + .OrderBy(x => x.CustomerId) + .ThenBy(x => x.Price); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + var c1 = results.Where(r => r.CustomerId == 1).ToList(); + var c2 = results.Where(r => r.CustomerId == 2).ToList(); + + Assert.IsTrue(c1.All(r => Math.Abs(r.FirstInGroup - 15.0) < 1e-9)); + Assert.IsTrue(c2.All(r => Math.Abs(r.FirstInGroup - 10.0) < 1e-9)); + } + + // ── PERCENT_RANK ───────────────────────────────────────────────────── + + [TestMethod] + public async Task PercentRank_ReturnsBetweenZeroAndOne() + { + var query = Context.Orders + .Select(o => new + { + o.Price, + Pct = WindowFunction.PercentRank( + Window.OrderBy(o.Price)), + }) + .OrderBy(x => x.Price); + + var sql = query.ToQueryString(); + StringAssert.Contains(sql, "PERCENT_RANK"); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + // First row always has PERCENT_RANK = 0.0 + Assert.AreEqual(0.0, results[0].Pct); + // Last row always has PERCENT_RANK = 1.0 (when no ties at the end) + Assert.AreEqual(1.0, results[^1].Pct); + // All values in [0.0, 1.0] + Assert.IsTrue(results.All(r => r.Pct >= 0.0 && r.Pct <= 1.0)); + } + + // ── CUME_DIST ──────────────────────────────────────────────────────── + + [TestMethod] + public async Task CumeDist_ReturnsBetweenZeroAndOne() + { + var query = Context.Orders + .Select(o => new + { + o.Price, + Cume = WindowFunction.CumeDist( + Window.OrderBy(o.Price)), + }) + .OrderBy(x => x.Price); + + var sql = query.ToQueryString(); + StringAssert.Contains(sql, "CUME_DIST"); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + // CUME_DIST: last row = 1.0, all values in (0.0, 1.0] + Assert.AreEqual(1.0, results[^1].Cume); + Assert.IsTrue(results.All(r => r.Cume > 0.0 && r.Cume <= 1.0)); + // Unlike PERCENT_RANK, first row's CUME_DIST > 0.0 + Assert.IsTrue(results[0].Cume > 0.0); + } + + // ── NTH_VALUE ──────────────────────────────────────────────────────── + + [TestMethod] + public async Task NthValue_ReturnsValueAtPosition() + { + // NTH_VALUE is not supported on SQL Server — skip gracefully. + // Ordered by Price ASC: 10, 15, 20, 20, 25, 30, 35, 40, 45, 50 + // NTH_VALUE(Price, 3) with unbounded frame → 3rd value = 20 for all rows + var query = Context.Orders + .Select(o => new + { + o.Price, + Third = WindowFunction.NthValue(o.Price, 3, + Window.OrderBy(o.Price) + .RowsBetween(WindowFrameBound.UnboundedPreceding, WindowFrameBound.UnboundedFollowing)), + }); + + var sql = query.ToQueryString(); + StringAssert.Contains(sql, "NTH_VALUE"); + + try + { + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + Assert.IsTrue(results.All(r => Math.Abs(r.Third - 20.0) < 1e-9), + "NTH_VALUE(Price, 3) should return the 3rd price (20) for all rows with unbounded frame"); + } + catch (Exception ex) when (ex.Message.Contains("NTH_VALUE") && ex.Message.Contains("not a recognized")) + { + Assert.Inconclusive("NTH_VALUE is not supported by this database provider."); + } + } + + // ── Coverage gap tests ─────────────────────────────────────────────── + + [TestMethod] + public async Task Count_Expression_CountsNonNullOnly() + { + // Tag is null for all seeded rows, so COUNT(Tag) should be 0 everywhere, + // while COUNT(*) returns the running count. This proves COUNT(expr) only + // counts non-null values. + var query = Context.Orders + .Select(o => new + { + o.Id, + TagCount = WindowFunction.Count(o.Tag, Window.OrderBy(o.Id)), + StarCount = WindowFunction.Count(Window.OrderBy(o.Id)), + }) + .OrderBy(x => x.Id); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + // COUNT(Tag) = 0 for every row (all Tags are null) + Assert.IsTrue(results.All(r => r.TagCount == 0), + "COUNT(Tag) should be 0 when all Tags are null"); + // COUNT(*) = running count 1..10 + Assert.AreEqual(10, results[^1].StarCount); + } + + [TestMethod] + public async Task Average_IntColumn_ReturnsDouble() + { + // Quantity is int — the int→double overload should be selected. + // Ordered by Quantity ASC: 1,1,1,1,2,2,2,3,3,5 + var query = Context.Orders + .Select(o => new + { + o.Quantity, + AvgQty = WindowFunction.Average(o.Quantity, + Window.OrderBy(o.Quantity) + .RowsBetween(WindowFrameBound.UnboundedPreceding, WindowFrameBound.CurrentRow)), + }) + .OrderBy(x => x.Quantity); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + // First row: AVG of [1] = 1.0 + Assert.AreEqual(1.0, results[0].AvgQty); + // Last row: AVG of all quantities = (1+1+1+1+2+2+2+3+3+5)/10 = 2.1 + Assert.AreEqual(2.1, results[^1].AvgQty, 0.01); + } + + [TestMethod] + public async Task Sum_WithoutFrame_UsesDefaultFrame() + { + // Sum with OrderedWindowDefinition only (no explicit frame). + // SQL default frame with ORDER BY = RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW. + // This is a running total — same behavior as the explicit frame test. + var query = Context.Orders + .Select(o => new + { + o.Price, + RunningTotal = WindowFunction.Sum(o.Price, + Window.OrderBy(o.Price)), + }) + .OrderBy(x => x.Price); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + // First = 10, last = 290 (same as explicit ROWS UNBOUNDED PRECEDING TO CURRENT ROW) + // Note: with RANGE default, tied prices (20, 20) get the same running total + Assert.AreEqual(290.0, results[^1].RunningTotal); + // Running total must be non-decreasing + for (var i = 1; i < results.Count; i++) + Assert.IsTrue(results[i].RunningTotal >= results[i - 1].RunningTotal); + } + + [TestMethod] + public async Task Sum_WithRangeFrame_ProducesRangeBasedTotal() + { + // RANGE frame groups rows by value, not by position. + // With RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, + // tied prices (20, 20) both see the sum up to and including both 20s. + var query = Context.Orders + .Select(o => new + { + o.Price, + RangeTotal = WindowFunction.Sum(o.Price, + Window.OrderBy(o.Price) + .RangeBetween(WindowFrameBound.UnboundedPreceding, WindowFrameBound.CurrentRow)), + }) + .OrderBy(x => x.Price); + + var sql = query.ToQueryString(); + StringAssert.Contains(sql, "RANGE BETWEEN"); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + // The two price-20 rows should have the same RANGE-based running total + // (RANGE includes all rows with the same ORDER BY value) + var price20Totals = results.Where(r => r.Price == 20).Select(r => r.RangeTotal).Distinct().ToList(); + Assert.AreEqual(1, price20Totals.Count, + "RANGE frame should give tied prices the same running total"); + } + + [TestMethod] + public async Task Lead_WithPartition_ResetsPerGroup() + { + var query = Context.Orders + .Select(o => new + { + o.CustomerId, + o.Price, + NextInGroup = (double?)WindowFunction.Lead(o.Price, + Window.PartitionBy(o.CustomerId).OrderBy(o.Price)), + }) + .OrderBy(x => x.CustomerId) + .ThenBy(x => x.Price); + + var results = await query.ToListAsync(); + Assert.AreEqual(10, results.Count); + + // Last row of each customer partition has no next → null + var c1 = results.Where(r => r.CustomerId == 1).ToList(); + var c2 = results.Where(r => r.CustomerId == 2).ToList(); + + Assert.IsNull(c1[^1].NextInGroup, "Last row of customer 1 should have no next"); + Assert.IsNull(c2[^1].NextInGroup, "Last row of customer 2 should have no next"); + // First row of each partition → second row's price + Assert.AreEqual(c1[1].Price, c1[0].NextInGroup); + Assert.AreEqual(c2[1].Price, c2[0].NextInGroup); + } } diff --git a/tests/ExpressiveSharp.Tests/ExpressiveSharp.Tests.csproj b/tests/ExpressiveSharp.Tests/ExpressiveSharp.Tests.csproj index af01910..404e04c 100644 --- a/tests/ExpressiveSharp.Tests/ExpressiveSharp.Tests.csproj +++ b/tests/ExpressiveSharp.Tests/ExpressiveSharp.Tests.csproj @@ -7,6 +7,8 @@ + + diff --git a/tests/ExpressiveSharp.Tests/RelationalExtensions/WindowFrameBoundSqlExpressionTests.cs b/tests/ExpressiveSharp.Tests/RelationalExtensions/WindowFrameBoundSqlExpressionTests.cs new file mode 100644 index 0000000..4761341 --- /dev/null +++ b/tests/ExpressiveSharp.Tests/RelationalExtensions/WindowFrameBoundSqlExpressionTests.cs @@ -0,0 +1,54 @@ +// EF1001 fires on the ".Internal" namespace convention, but these are OUR internals +// (ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Infrastructure.Internal) — +// these tests deliberately exercise them via InternalsVisibleTo. +#pragma warning disable EF1001 + +using ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Infrastructure.Internal; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace ExpressiveSharp.Tests.RelationalExtensions; + +/// +/// Unit tests for the internal intermediate +/// node and the struct's SQL fragment formatting. +/// +[TestClass] +public class WindowFrameBoundSqlExpressionTests +{ + [TestMethod] + public void ToSqlFragment_FormatsEachBoundKindCorrectly() + { + Assert.AreEqual("UNBOUNDED PRECEDING", + new WindowFrameBoundInfo(WindowFrameBoundKind.UnboundedPreceding, null).ToSqlFragment()); + Assert.AreEqual("3 PRECEDING", + new WindowFrameBoundInfo(WindowFrameBoundKind.Preceding, 3).ToSqlFragment()); + Assert.AreEqual("CURRENT ROW", + new WindowFrameBoundInfo(WindowFrameBoundKind.CurrentRow, null).ToSqlFragment()); + Assert.AreEqual("5 FOLLOWING", + new WindowFrameBoundInfo(WindowFrameBoundKind.Following, 5).ToSqlFragment()); + Assert.AreEqual("UNBOUNDED FOLLOWING", + new WindowFrameBoundInfo(WindowFrameBoundKind.UnboundedFollowing, null).ToSqlFragment()); + } + + [TestMethod] + public void Equals_ComparesBoundInfo() + { + var a = new WindowFrameBoundSqlExpression(new WindowFrameBoundInfo(WindowFrameBoundKind.Preceding, 3)); + var b = new WindowFrameBoundSqlExpression(new WindowFrameBoundInfo(WindowFrameBoundKind.Preceding, 3)); + var c = new WindowFrameBoundSqlExpression(new WindowFrameBoundInfo(WindowFrameBoundKind.Preceding, 5)); + var d = new WindowFrameBoundSqlExpression(new WindowFrameBoundInfo(WindowFrameBoundKind.Following, 3)); + + Assert.AreEqual(a, b); + Assert.AreEqual(a.GetHashCode(), b.GetHashCode()); + Assert.AreNotEqual(a, c); + Assert.AreNotEqual(a, d); + } + + [TestMethod] + public void BoundInfo_PropertyRoundtrips() + { + var info = new WindowFrameBoundInfo(WindowFrameBoundKind.Following, 7); + var expr = new WindowFrameBoundSqlExpression(info); + Assert.AreEqual(info, expr.BoundInfo); + } +} diff --git a/tests/ExpressiveSharp.Tests/RelationalExtensions/WindowFunctionSqlExpressionFrameTests.cs b/tests/ExpressiveSharp.Tests/RelationalExtensions/WindowFunctionSqlExpressionFrameTests.cs new file mode 100644 index 0000000..28aa4ed --- /dev/null +++ b/tests/ExpressiveSharp.Tests/RelationalExtensions/WindowFunctionSqlExpressionFrameTests.cs @@ -0,0 +1,182 @@ +// EF1001 fires on the ".Internal" namespace convention, but these are OUR internals — +// these tests deliberately exercise them via InternalsVisibleTo. +#pragma warning disable EF1001 + +using ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Infrastructure.Internal; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace ExpressiveSharp.Tests.RelationalExtensions; + +/// +/// Unit tests for the frame-clause rendering in . +/// Uses aggregate function names (SUM, AVG, COUNT) since frames only apply to aggregates +/// — the SQL standard forbids frames on ranking functions. +/// +[TestClass] +public class WindowFunctionSqlExpressionFrameTests +{ + private static readonly SqlFragmentExpression ColPrice = new("[Price]"); + private static readonly SqlFragmentExpression ColCustomerId = new("[CustomerId]"); + + private static string PrintExpression(WindowFunctionSqlExpression expr) + { + var printer = new ExpressionPrinter(); + printer.Visit(expr); + return printer.ToString(); + } + + [TestMethod] + public void Print_NoFrame_OmitsFrameClause() + { + var expr = new WindowFunctionSqlExpression( + "SUM", + arguments: [ColPrice], + partitions: [], + orderings: [new OrderingExpression(ColPrice, ascending: true)], + type: typeof(double), + typeMapping: null); + + var printed = PrintExpression(expr); + Assert.AreEqual("SUM([Price]) OVER(ORDER BY [Price] ASC)", printed); + } + + [TestMethod] + public void Print_RowsBetweenUnboundedPrecedingAndCurrentRow() + { + var expr = new WindowFunctionSqlExpression( + "SUM", + arguments: [ColPrice], + partitions: [], + orderings: [new OrderingExpression(ColPrice, ascending: true)], + type: typeof(double), + typeMapping: null, + frameType: WindowFrameType.Rows, + frameStart: new WindowFrameBoundInfo(WindowFrameBoundKind.UnboundedPreceding, null), + frameEnd: new WindowFrameBoundInfo(WindowFrameBoundKind.CurrentRow, null)); + + var printed = PrintExpression(expr); + Assert.AreEqual( + "SUM([Price]) OVER(ORDER BY [Price] ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)", + printed); + } + + [TestMethod] + public void Print_RowsBetweenNumericPrecedingAndFollowing() + { + var expr = new WindowFunctionSqlExpression( + "AVG", + arguments: [ColPrice], + partitions: [], + orderings: [new OrderingExpression(ColPrice, ascending: false)], + type: typeof(double), + typeMapping: null, + frameType: WindowFrameType.Rows, + frameStart: new WindowFrameBoundInfo(WindowFrameBoundKind.Preceding, 3), + frameEnd: new WindowFrameBoundInfo(WindowFrameBoundKind.Following, 3)); + + var printed = PrintExpression(expr); + Assert.AreEqual( + "AVG([Price]) OVER(ORDER BY [Price] DESC ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING)", + printed); + } + + [TestMethod] + public void Print_RangeBetweenUnboundedPrecedingAndCurrentRow() + { + var expr = new WindowFunctionSqlExpression( + "SUM", + arguments: [ColPrice], + partitions: [], + orderings: [new OrderingExpression(ColPrice, ascending: true)], + type: typeof(double), + typeMapping: null, + frameType: WindowFrameType.Range, + frameStart: new WindowFrameBoundInfo(WindowFrameBoundKind.UnboundedPreceding, null), + frameEnd: new WindowFrameBoundInfo(WindowFrameBoundKind.CurrentRow, null)); + + var printed = PrintExpression(expr); + Assert.AreEqual( + "SUM([Price]) OVER(ORDER BY [Price] ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)", + printed); + } + + [TestMethod] + public void Print_RowsBetweenUnboundedPrecedingAndUnboundedFollowing() + { + var expr = new WindowFunctionSqlExpression( + "COUNT", + arguments: [new SqlFragmentExpression("*")], + partitions: [], + orderings: [new OrderingExpression(ColPrice, ascending: true)], + type: typeof(long), + typeMapping: null, + frameType: WindowFrameType.Rows, + frameStart: new WindowFrameBoundInfo(WindowFrameBoundKind.UnboundedPreceding, null), + frameEnd: new WindowFrameBoundInfo(WindowFrameBoundKind.UnboundedFollowing, null)); + + var printed = PrintExpression(expr); + Assert.AreEqual( + "COUNT(*) OVER(ORDER BY [Price] ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", + printed); + } + + [TestMethod] + public void Print_WithPartitionOrderAndFrame() + { + var expr = new WindowFunctionSqlExpression( + "SUM", + arguments: [ColPrice], + partitions: [ColCustomerId], + orderings: [new OrderingExpression(ColPrice, ascending: true)], + type: typeof(double), + typeMapping: null, + frameType: WindowFrameType.Rows, + frameStart: new WindowFrameBoundInfo(WindowFrameBoundKind.UnboundedPreceding, null), + frameEnd: new WindowFrameBoundInfo(WindowFrameBoundKind.CurrentRow, null)); + + var printed = PrintExpression(expr); + Assert.AreEqual( + "SUM([Price]) OVER(PARTITION BY [CustomerId] ORDER BY [Price] ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)", + printed); + } + + [TestMethod] + public void Print_MinWithNumericOffsets() + { + var expr = new WindowFunctionSqlExpression( + "MIN", + arguments: [ColPrice], + partitions: [], + orderings: [new OrderingExpression(ColPrice, ascending: true)], + type: typeof(double), + typeMapping: null, + frameType: WindowFrameType.Rows, + frameStart: new WindowFrameBoundInfo(WindowFrameBoundKind.Preceding, 2), + frameEnd: new WindowFrameBoundInfo(WindowFrameBoundKind.Following, 2)); + + var printed = PrintExpression(expr); + Assert.AreEqual( + "MIN([Price]) OVER(ORDER BY [Price] ASC ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING)", + printed); + } + + [TestMethod] + public void Equals_IncludesFrameFields() + { + WindowFunctionSqlExpression MakeExpr(WindowFrameBoundKind startKind, int? startOffset) => + new("SUM", [ColPrice], [], [new OrderingExpression(ColPrice, ascending: true)], typeof(double), null, + WindowFrameType.Rows, + new WindowFrameBoundInfo(startKind, startOffset), + new WindowFrameBoundInfo(WindowFrameBoundKind.CurrentRow, null)); + + var a = MakeExpr(WindowFrameBoundKind.Preceding, 2); + var b = MakeExpr(WindowFrameBoundKind.Preceding, 2); + var c = MakeExpr(WindowFrameBoundKind.Preceding, 3); + + Assert.AreEqual(a, b); + Assert.AreEqual(a.GetHashCode(), b.GetHashCode()); + Assert.AreNotEqual(a, c); + } +} diff --git a/tests/ExpressiveSharp.Tests/RelationalExtensions/WindowSpecSqlExpressionFrameTests.cs b/tests/ExpressiveSharp.Tests/RelationalExtensions/WindowSpecSqlExpressionFrameTests.cs new file mode 100644 index 0000000..d0ca0f3 --- /dev/null +++ b/tests/ExpressiveSharp.Tests/RelationalExtensions/WindowSpecSqlExpressionFrameTests.cs @@ -0,0 +1,122 @@ +// EF1001 fires on the ".Internal" namespace convention, but these are OUR internals — +// these tests deliberately exercise them via InternalsVisibleTo. +#pragma warning disable EF1001 + +using ExpressiveSharp.EntityFrameworkCore.RelationalExtensions.Infrastructure.Internal; +using Microsoft.EntityFrameworkCore.Query.SqlExpressions; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace ExpressiveSharp.Tests.RelationalExtensions; + +/// +/// Unit tests for the frame-related extensions to : +/// the WithFrame builder method and the equality/hash-code behaviour for the +/// new frame fields. +/// +[TestClass] +public class WindowSpecSqlExpressionFrameTests +{ + private static readonly SqlFragmentExpression ColA = new("[a]"); + private static readonly SqlFragmentExpression ColB = new("[b]"); + + [TestMethod] + public void Constructor_DefaultsFrameFieldsToNull() + { + var spec = new WindowSpecSqlExpression([ColA], [new OrderingExpression(ColB, ascending: true)], typeMapping: null); + + Assert.IsNull(spec.FrameType); + Assert.IsNull(spec.FrameStart); + Assert.IsNull(spec.FrameEnd); + } + + [TestMethod] + public void WithFrame_PreservesPartitionsAndOrderings() + { + var spec = new WindowSpecSqlExpression( + [ColA], + [new OrderingExpression(ColB, ascending: false)], + typeMapping: null); + + var framed = spec.WithFrame( + WindowFrameType.Rows, + new WindowFrameBoundInfo(WindowFrameBoundKind.UnboundedPreceding, null), + new WindowFrameBoundInfo(WindowFrameBoundKind.CurrentRow, null)); + + Assert.AreEqual(1, framed.Partitions.Count); + Assert.AreSame(ColA, framed.Partitions[0]); + Assert.AreEqual(1, framed.Orderings.Count); + Assert.AreSame(ColB, framed.Orderings[0].Expression); + Assert.IsFalse(framed.Orderings[0].IsAscending); + Assert.AreEqual(WindowFrameType.Rows, framed.FrameType); + Assert.AreEqual(WindowFrameBoundKind.UnboundedPreceding, framed.FrameStart!.Value.Kind); + Assert.AreEqual(WindowFrameBoundKind.CurrentRow, framed.FrameEnd!.Value.Kind); + } + + [TestMethod] + public void WithFrame_ReturnsNewInstance() + { + var spec = new WindowSpecSqlExpression([], [new OrderingExpression(ColA, ascending: true)], typeMapping: null); + var framed = spec.WithFrame( + WindowFrameType.Range, + new WindowFrameBoundInfo(WindowFrameBoundKind.Preceding, 3), + new WindowFrameBoundInfo(WindowFrameBoundKind.Following, 3)); + + Assert.AreNotSame(spec, framed); + Assert.IsNull(spec.FrameType); + } + + [TestMethod] + public void Equals_IncludesFrameFields() + { + var orderings = new[] { new OrderingExpression(ColA, ascending: true) }; + var a = new WindowSpecSqlExpression([], orderings, typeMapping: null) + .WithFrame(WindowFrameType.Rows, + new WindowFrameBoundInfo(WindowFrameBoundKind.UnboundedPreceding, null), + new WindowFrameBoundInfo(WindowFrameBoundKind.CurrentRow, null)); + var b = new WindowSpecSqlExpression([], orderings, typeMapping: null) + .WithFrame(WindowFrameType.Rows, + new WindowFrameBoundInfo(WindowFrameBoundKind.UnboundedPreceding, null), + new WindowFrameBoundInfo(WindowFrameBoundKind.CurrentRow, null)); + var differentFrameType = new WindowSpecSqlExpression([], orderings, typeMapping: null) + .WithFrame(WindowFrameType.Range, + new WindowFrameBoundInfo(WindowFrameBoundKind.UnboundedPreceding, null), + new WindowFrameBoundInfo(WindowFrameBoundKind.CurrentRow, null)); + var differentBound = new WindowSpecSqlExpression([], orderings, typeMapping: null) + .WithFrame(WindowFrameType.Rows, + new WindowFrameBoundInfo(WindowFrameBoundKind.Preceding, 2), + new WindowFrameBoundInfo(WindowFrameBoundKind.CurrentRow, null)); + + Assert.AreEqual(a, b); + Assert.AreEqual(a.GetHashCode(), b.GetHashCode()); + Assert.AreNotEqual(a, differentFrameType); + Assert.AreNotEqual(a, differentBound); + } + + [TestMethod] + public void WithPartition_PreservesFrame() + { + var spec = new WindowSpecSqlExpression([], [new OrderingExpression(ColA, ascending: true)], typeMapping: null) + .WithFrame(WindowFrameType.Rows, + new WindowFrameBoundInfo(WindowFrameBoundKind.UnboundedPreceding, null), + new WindowFrameBoundInfo(WindowFrameBoundKind.CurrentRow, null)); + + var withPartition = spec.WithPartition(ColB); + + Assert.AreEqual(WindowFrameType.Rows, withPartition.FrameType); + Assert.AreEqual(1, withPartition.Partitions.Count); + } + + [TestMethod] + public void WithOrdering_PreservesFrame() + { + var spec = new WindowSpecSqlExpression([], [new OrderingExpression(ColA, ascending: true)], typeMapping: null) + .WithFrame(WindowFrameType.Range, + new WindowFrameBoundInfo(WindowFrameBoundKind.UnboundedPreceding, null), + new WindowFrameBoundInfo(WindowFrameBoundKind.UnboundedFollowing, null)); + + var withOrdering = spec.WithOrdering(ColB, ascending: false); + + Assert.AreEqual(WindowFrameType.Range, withOrdering.FrameType); + Assert.AreEqual(2, withOrdering.Orderings.Count); + } +}