From 5d44a980506c377523afb4b7e9511d0aee0cf6aa Mon Sep 17 00:00:00 2001 From: Ian Pascoe Date: Sun, 11 Jan 2026 07:54:39 -0500 Subject: [PATCH 1/7] fix: support generic class inheritance with type arguments Fixes #1929 - Classes can now extend generic classes with specific type parameters and fields are properly resolved. Changes: - Parser: Support doc.type.sign for class extends (e.g., Bar: Foo) - Parser: Include doc.field in bindGeneric() to convert generic type names - VM: Add resolveGenericField() to resolve generic parameters in field types - VM: Handle doc.type.sign extends in class field inheritance search - VM: Export vm.cloneObject() for generic type resolution Example that now works: ---@class Foo ---@field a T ---@class Bar: Foo local x ---@type Bar local what = x.a -- Now infers as 'integer' instead of 'unknown' --- script/parser/luadoc.lua | 9 ++++- script/vm/compiler.lua | 70 +++++++++++++++++++++++++++++++++- script/vm/generic.lua | 7 ++++ test/type_inference/common.lua | 46 ++++++++++++++++++++++ 4 files changed, 129 insertions(+), 3 deletions(-) diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua index ad22af1d7..79764dcad 100644 --- a/script/parser/luadoc.lua +++ b/script/parser/luadoc.lua @@ -1030,6 +1030,12 @@ local docSwitch = util.switch() } return result end + if extend.type == 'doc.extends.name' then + local signResult = parseTypeUnitSign(result, extend) + if signResult then + extend = signResult + end + end result.extends[#result.extends+1] = extend result.finish = getFinish() if not checkToken('symbol', ',', 1) then @@ -1850,7 +1856,8 @@ local function bindGeneric(binded) or doc.type == 'doc.return' or doc.type == 'doc.type' or doc.type == 'doc.class' - or doc.type == 'doc.alias' then + or doc.type == 'doc.alias' + or doc.type == 'doc.field' then guide.eachSourceType(doc, 'doc.type.name', function (src) local name = src[1] if generics[name] then diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index af7b7cc69..61891a22a 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -234,6 +234,56 @@ local function searchLiteralFieldFromTable(source, key, callback) end end +---@param uri uri +---@param classGlobal vm.global +---@param field parser.object +---@param signs parser.object[] +---@return parser.object? +local function resolveGenericField(uri, classGlobal, field, signs) + if field.type ~= 'doc.field' then + return nil + end + if not field.extends then + return nil + end + local hasGeneric = false + guide.eachSourceType(field.extends, 'doc.generic.name', function () + hasGeneric = true + end) + if not hasGeneric then + return nil + end + for _, set in ipairs(classGlobal:getSets(uri)) do + if set.type == 'doc.class' and set.signs then + local resolved = {} + for i, signName in ipairs(set.signs) do + local signType = signs[i] + if signType and signName[1] then + local signNode = vm.compileNode(signType) + resolved[signName[1]] = signNode + end + end + if next(resolved) then + local newExtends = vm.cloneObject(field.extends, resolved) + if newExtends then + return { + type = field.type, + start = field.start, + finish = field.finish, + parent = field.parent, + field = field.field, + extends = newExtends, + visible = field.visible, + optional = field.optional, + } + end + end + break + end + end + return nil +end + local searchFieldSwitch = util.switch() : case 'table' : call(function (_suri, source, key, pushResult) @@ -357,7 +407,16 @@ local searchFieldSwitch = util.switch() if not globalVar then return end - vm.getClassFields(suri, globalVar, key, pushResult) + vm.getClassFields(suri, globalVar, key, function (field, isMark) + if source.signs then + local newField = resolveGenericField(suri, globalVar, field, source.signs) + if newField then + pushResult(newField, isMark) + return + end + end + pushResult(field, isMark) + end) end) : case 'global' : call(function (suri, node, key, pushResult) @@ -565,7 +624,6 @@ function vm.getClassFields(suri, object, key, pushResult) for _, set in ipairs(sets) do if set.type == 'doc.class' then - -- look into extends(if field not found) if not searchedFields[key] and set.extends then for _, extend in ipairs(set.extends) do if extend.type == 'doc.extends.name' then @@ -573,6 +631,14 @@ function vm.getClassFields(suri, object, key, pushResult) if extendType then searchClass(extendType, searchedFields) end + elseif extend.type == 'doc.type.sign' then + searchFieldSwitch(extend.type, suri, extend, key, function (field, isMark) + local fieldKey = guide.getKeyName(field) + if fieldKey and not searchedFields[fieldKey] then + hasFounded[fieldKey] = true + pushResult(field, isMark) + end + end) end end end diff --git a/script/vm/generic.lua b/script/vm/generic.lua index f1eaaf99d..69c9553d5 100644 --- a/script/vm/generic.lua +++ b/script/vm/generic.lua @@ -173,3 +173,10 @@ function vm.createGeneric(proto, sign) }, mt) return generic end + +---@param source vm.object? +---@param resolved? table +---@return vm.object? +function vm.cloneObject(source, resolved) + return cloneObject(source, resolved) +end diff --git a/test/type_inference/common.lua b/test/type_inference/common.lua index 7f6d854c7..614d7c667 100644 --- a/test/type_inference/common.lua +++ b/test/type_inference/common.lua @@ -4909,3 +4909,49 @@ function f(...args) print() end ]] + +TEST 'integer' [[ +---@class Foo +---@field a T + +---@class Bar: Foo + +---@type Bar +local x +local = x.a +]] + +TEST 'string' [[ +---@class GenericBase +---@field value T + +---@class StringHolder: GenericBase + +---@type StringHolder +local holder +local = holder.value +]] + +TEST 'boolean' [[ +---@class Container +---@field key K +---@field val V + +---@class BoolContainer: Container + +---@type BoolContainer +local c +local = c.val +]] + +TEST 'string' [[ +---@class Container +---@field key K +---@field val V + +---@class BoolContainer: Container + +---@type BoolContainer +local c +local = c.key +]] From c4ce6f9e484667d07af7af6a54ac800e26da01d1 Mon Sep 17 00:00:00 2001 From: Ian Pascoe Date: Sun, 11 Jan 2026 18:34:05 -0500 Subject: [PATCH 2/7] fix: resolve generic method return types and prevent recursive hover expansion - Fix #1863: Method return types on generic classes now resolve correctly When calling methods like Box:getValue(), the return type T is now properly resolved to the concrete type (string) from the class instance. - Fix #1853: Self-referential generic classes no longer cause infinite expansion in hover display. The resolved extends clause is now hidden from type display while still being available for type checking. - Add test cases for both fixes --- script/vm/compiler.lua | 63 ++++++++++++++++++++++++++++++++++ test/type_inference/common.lua | 56 ++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+) diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 61891a22a..f032c7b2a 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -1542,6 +1542,10 @@ local function compileLocal(source) end ---@param source parser.object +---Resolves generic type names from a class's generic parameters +---@param uri uri +---@param classGlobal vm.global +---@param typeName string ---@param mfunc parser.object ---@param index integer ---@param args parser.object[] @@ -1557,6 +1561,62 @@ local function bindReturnOfFunction(source, mfunc, index, args) break end end + + -- Handle method calls on generic class instances + -- When calling b:getValue() where b is Box, resolve T to string + local call = source.parent + if call and call.type == 'call' then + local callNode = call.node + if callNode and (callNode.type == 'getmethod' or callNode.type == 'getfield') then + local receiver = callNode.node + if receiver then + local receiverNode = vm.compileNode(receiver) + for rn in receiverNode:eachObject() do + if rn.type == 'doc.type.sign' and rn.signs and rn.node then + local classGlobal = vm.getGlobal('type', rn.node[1]) + if classGlobal then + -- Build a map of class generic param names to their concrete types + local genericMap = {} + for _, set in ipairs(classGlobal:getSets(guide.getUri(source))) do + if set.type == 'doc.class' and set.signs then + for i, signName in ipairs(set.signs) do + if signName[1] and rn.signs[i] then + genericMap[signName[1]] = vm.compileNode(rn.signs[i]) + end + end + break + end + end + + if next(genericMap) then + -- Check the return node for global type references that match generic params + local newReturnNode = vm.createNode() + local hasReplacement = false + for retNode in returnNode:eachObject() do + if retNode.type == 'global' and retNode.cate == 'type' then + local resolvedNode = genericMap[retNode.name] + if resolvedNode then + newReturnNode:merge(resolvedNode) + hasReplacement = true + else + newReturnNode:merge(retNode) + end + else + newReturnNode:merge(retNode) + end + end + if hasReplacement then + returnNode = newReturnNode + end + end + break + end + end + end + end + end + end + if returnNode then for rnode in returnNode:eachObject() do -- TODO: narrow type @@ -2154,6 +2214,9 @@ local compilerSwitch = util.switch() if ext.type == 'doc.type.table' then if vm.getGeneric(ext) then local resolved = vm.getGeneric(ext):resolve(uri, source.signs) + for obj in resolved:eachObject() do + obj.hideView = true + end vm.setNode(source, resolved) end end diff --git a/test/type_inference/common.lua b/test/type_inference/common.lua index 614d7c667..3d7a84525 100644 --- a/test/type_inference/common.lua +++ b/test/type_inference/common.lua @@ -4955,3 +4955,59 @@ TEST 'string' [[ local c local = c.key ]] + +TEST 'string' [[ +---@class Box +---@field value T +local Box = {} + +---@return T +function Box:getValue() + return self.value +end + +---@type Box +local b + +local = b:getValue() +]] + +TEST 'integer' [[ +---@class Wrapper +---@field item V +local Wrapper = {} + +---@return V +function Wrapper:unwrap() + return self.item +end + +---@type Wrapper +local w + +local = w:unwrap() +]] + +-- Issue #1856: Generic class display format +-- Current behavior shows list<>|{...} - the <> indicates an unresolved generic +-- The resolved table type is also shown +TEST 'list<>|{ [integer]: string }' [[ +---@class list: {[integer]:T} + +---@generic T +---@param class `T` +---@return list +local function new_list(class) + return {} +end + +local = new_list('string') +]] + +-- Issue #1853: Recursive expansion on hover of generic type +-- Self-referential generic classes should not expand infinitely +TEST 'store' [[ +---@class store: {set:fun(self:store, key:integer, value:T), get:fun(self:store, key:integer):T} + +local ---@type store +]] From 962293f482ce3ef57a099c6e26add9677572ca84 Mon Sep 17 00:00:00 2001 From: Ian Pascoe Date: Sun, 11 Jan 2026 19:02:09 -0500 Subject: [PATCH 3/7] chore: fix stale doc comments and add defensive guard --- script/vm/compiler.lua | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index f032c7b2a..fd118232b 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -1542,10 +1542,6 @@ local function compileLocal(source) end ---@param source parser.object ----Resolves generic type names from a class's generic parameters ----@param uri uri ----@param classGlobal vm.global ----@param typeName string ---@param mfunc parser.object ---@param index integer ---@param args parser.object[] @@ -1572,7 +1568,7 @@ local function bindReturnOfFunction(source, mfunc, index, args) if receiver then local receiverNode = vm.compileNode(receiver) for rn in receiverNode:eachObject() do - if rn.type == 'doc.type.sign' and rn.signs and rn.node then + if rn.type == 'doc.type.sign' and rn.signs and rn.node and rn.node[1] then local classGlobal = vm.getGlobal('type', rn.node[1]) if classGlobal then -- Build a map of class generic param names to their concrete types From e2b37b13fa9a301f300c63455db629abc6758c35 Mon Sep 17 00:00:00 2001 From: Ian Pascoe Date: Sun, 11 Jan 2026 19:25:10 -0500 Subject: [PATCH 4/7] fix: handle composite generic return types (T[], Wrapper) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address PR review comment: The previous implementation only handled direct T → string substitution but failed for composite types like T[] or Wrapper. Changes: - Use vm.cloneObject on the actual return annotation (doc.return) instead of iterating through compiled node objects - Extend vm.cloneObject to handle doc.type.name when name matches a generic param (converts to doc.generic.name with _resolved set) - Extend vm.cloneObject to handle doc.type.sign for nested generic types, but only when signs contain doc.type.name needing resolution (preserves function-level generics like Callback<>) Added test cases for: - T[] resolving to string[] - Wrapper resolving to Wrapper --- script/vm/compiler.lua | 35 +++++++++++++------------- script/vm/generic.lua | 45 ++++++++++++++++++++++++++++++++++ test/type_inference/common.lua | 35 ++++++++++++++++++++++++++ 3 files changed, 98 insertions(+), 17 deletions(-) diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index fd118232b..893993f5a 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -1584,26 +1584,27 @@ local function bindReturnOfFunction(source, mfunc, index, args) end end - if next(genericMap) then - -- Check the return node for global type references that match generic params - local newReturnNode = vm.createNode() - local hasReplacement = false - for retNode in returnNode:eachObject() do - if retNode.type == 'global' and retNode.cate == 'type' then - local resolvedNode = genericMap[retNode.name] - if resolvedNode then - newReturnNode:merge(resolvedNode) - hasReplacement = true - else - newReturnNode:merge(retNode) + if next(genericMap) and mfunc.bindDocs then + for _, doc in ipairs(mfunc.bindDocs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + if rtn.returnIndex == index then + local newRtn = vm.cloneObject(rtn, genericMap) + if newRtn then + returnNode = vm.compileNode(newRtn) + for rnode in returnNode:eachObject() do + if rnode.type == 'generic' then + returnNode = rnode:resolve(guide.getUri(source), args) + break + end + end + end + break + end end - else - newReturnNode:merge(retNode) + break end end - if hasReplacement then - returnNode = newReturnNode - end end break end diff --git a/script/vm/generic.lua b/script/vm/generic.lua index 69c9553d5..789d4dbb3 100644 --- a/script/vm/generic.lua +++ b/script/vm/generic.lua @@ -34,6 +34,21 @@ local function cloneObject(source, resolved) end return newName end + if source.type == 'doc.type.name' then + local key = source[1] + if resolved[key] then + local newName = { + type = 'doc.generic.name', + start = source.start, + finish = source.finish, + parent = source.parent, + [1] = source[1], + } + vm.setNode(newName, resolved[key], true) + newName._resolved = resolved[key] + return newName + end + end if source.type == 'doc.type' then local newType = { type = source.type, @@ -113,6 +128,36 @@ local function cloneObject(source, resolved) end return newDocFunc end + if source.type == 'doc.type.sign' and source.signs then + local needsClone = false + for _, sign in ipairs(source.signs) do + if sign.type == 'doc.type' then + for _, tp in ipairs(sign.types) do + if tp.type == 'doc.type.name' and resolved[tp[1]] then + needsClone = true + break + end + end + elseif sign.type == 'doc.type.name' and resolved[sign[1]] then + needsClone = true + end + if needsClone then break end + end + if needsClone then + local newSign = { + type = source.type, + start = source.start, + finish = source.finish, + parent = source.parent, + node = source.node, + signs = {}, + } + for i, sign in ipairs(source.signs) do + newSign.signs[i] = cloneObject(sign, resolved) + end + return newSign + end + end return source end diff --git a/test/type_inference/common.lua b/test/type_inference/common.lua index 3d7a84525..ef237830f 100644 --- a/test/type_inference/common.lua +++ b/test/type_inference/common.lua @@ -5011,3 +5011,38 @@ TEST 'store' [[ local ---@type store ]] + +-- Test composite return types with generics (T[] should resolve to string[]) +TEST 'string[]' [[ +---@class Container +local Container = {} + +---@return T[] +function Container:getAll() + return {} +end + +---@type Container +local c + +local = c:getAll() +]] + +-- Test nested generic return types (Wrapper should resolve to Wrapper) +TEST 'Wrapper' [[ +---@class Wrapper +---@field value V + +---@class Factory +local Factory = {} + +---@return Wrapper +function Factory:wrap() + return {} +end + +---@type Factory +local f + +local = f:wrap() +]] From 1d32d41a1d028e2d0b1c6b244b8a94148c6fa0f7 Mon Sep 17 00:00:00 2001 From: Ian Pascoe Date: Sun, 11 Jan 2026 19:34:45 -0500 Subject: [PATCH 5/7] docs: add changelog entries for generic class fixes --- changelog.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/changelog.md b/changelog.md index d4ecb8039..2bd458dc6 100644 --- a/changelog.md +++ b/changelog.md @@ -2,6 +2,9 @@ ## Unreleased +* `FIX` Generic class inheritance with type arguments now works correctly (e.g., `class Bar: Foo`) [#1929](https://github.com/LuaLS/lua-language-server/issues/1929) +* `FIX` Method return types on generic classes now resolve correctly (e.g., `Box:getValue()` returns `string`) [#1863](https://github.com/LuaLS/lua-language-server/issues/1863) +* `FIX` Self-referential generic classes no longer cause infinite expansion in hover display [#1853](https://github.com/LuaLS/lua-language-server/issues/1853) ## 3.16.4 `2025-12-25` From 2bb62ab90b1dd73c97c153148c5b0510f53aa328 Mon Sep 17 00:00:00 2001 From: Ian Pascoe Date: Mon, 12 Jan 2026 06:56:42 -0500 Subject: [PATCH 6/7] fix: resolve generics in overloads, fun syntax, and @param self T patterns - Add support for @generic T in @overload annotations (#723) - Parse fun(x: T): T syntax in @field and @type annotations (#1170) - Fix methods with @generic T and @param self T to resolve return type to receiver's concrete type (e.g., List:identity() returns List instead of unknown) (#1000) - Add comprehensive test cases for all generic patterns --- changelog.md | 3 + script/parser/guide.lua | 2 +- script/parser/luadoc.lua | 50 +++++++++++++- script/vm/compiler.lua | 118 +++++++++++++++++++++++++++++-- test/type_inference/common.lua | 123 +++++++++++++++++++++++++++++++++ 5 files changed, 287 insertions(+), 9 deletions(-) diff --git a/changelog.md b/changelog.md index 2bd458dc6..ffecdc5ac 100644 --- a/changelog.md +++ b/changelog.md @@ -5,6 +5,9 @@ * `FIX` Generic class inheritance with type arguments now works correctly (e.g., `class Bar: Foo`) [#1929](https://github.com/LuaLS/lua-language-server/issues/1929) * `FIX` Method return types on generic classes now resolve correctly (e.g., `Box:getValue()` returns `string`) [#1863](https://github.com/LuaLS/lua-language-server/issues/1863) * `FIX` Self-referential generic classes no longer cause infinite expansion in hover display [#1853](https://github.com/LuaLS/lua-language-server/issues/1853) +* `FIX` Generic type parameters now work in `@overload` annotations [#723](https://github.com/LuaLS/lua-language-server/issues/723) +* `NEW` Support `fun` syntax for inline generic function types in `@field` and `@type` annotations [#1170](https://github.com/LuaLS/lua-language-server/issues/1170) +* `FIX` Methods with `@generic T` and `@param self T` now correctly resolve return type to the receiver's concrete type (e.g., `List:identity()` returns `List`) [#1000](https://github.com/LuaLS/lua-language-server/issues/1000) ## 3.16.4 `2025-12-25` diff --git a/script/parser/guide.lua b/script/parser/guide.lua index c407eca67..76a2d4fc0 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -177,7 +177,7 @@ local childMap = { ['doc.generic.object'] = {'generic', 'extends', 'comment'}, ['doc.vararg'] = {'vararg', 'comment'}, ['doc.type.array'] = {'node'}, - ['doc.type.function'] = {'#args', '#returns', 'comment'}, + ['doc.type.function'] = {'#args', '#returns', '#signs', 'comment'}, ['doc.type.table'] = {'#fields', 'comment'}, ['doc.type.literal'] = {'node'}, ['doc.type.arg'] = {'name', 'extends'}, diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua index 79764dcad..f31c54622 100644 --- a/script/parser/luadoc.lua +++ b/script/parser/luadoc.lua @@ -523,6 +523,8 @@ local function parseTypeUnitFunction(parent) args = {}, returns = {}, } + -- Parse optional generic params: fun(...) + typeUnit.signs = parseSigns(typeUnit) if not nextSymbolOrError('(') then return nil end @@ -617,6 +619,51 @@ local function parseTypeUnitFunction(parent) end end typeUnit.finish = getFinish() + -- Bind local generics from fun to type names within this function + if typeUnit.signs then + local generics = {} + for _, sign in ipairs(typeUnit.signs) do + generics[sign[1]] = sign + end + local function bindTypeNames(obj) + if not obj then return end + if obj.type == 'doc.type.name' and generics[obj[1]] then + obj.type = 'doc.generic.name' + obj.generic = generics[obj[1]] + elseif obj.type == 'doc.type' and obj.types then + for _, t in ipairs(obj.types) do + bindTypeNames(t) + end + elseif obj.type == 'doc.type.array' then + bindTypeNames(obj.node) + elseif obj.type == 'doc.type.table' and obj.fields then + for _, field in ipairs(obj.fields) do + bindTypeNames(field.name) + bindTypeNames(field.extends) + end + elseif obj.type == 'doc.type.sign' then + bindTypeNames(obj.node) + if obj.signs then + for _, s in ipairs(obj.signs) do + bindTypeNames(s) + end + end + elseif obj.type == 'doc.type.function' then + for _, arg in ipairs(obj.args) do + bindTypeNames(arg.extends) + end + for _, ret in ipairs(obj.returns) do + bindTypeNames(ret) + end + end + end + for _, arg in ipairs(typeUnit.args) do + bindTypeNames(arg.extends) + end + for _, ret in ipairs(typeUnit.returns) do + bindTypeNames(ret) + end + end return typeUnit end @@ -1857,7 +1904,8 @@ local function bindGeneric(binded) or doc.type == 'doc.type' or doc.type == 'doc.class' or doc.type == 'doc.alias' - or doc.type == 'doc.field' then + or doc.type == 'doc.field' + or doc.type == 'doc.overload' then guide.eachSourceType(doc, 'doc.type.name', function (src) local name = src[1] if generics[name] then diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 893993f5a..9e8ef4aaf 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -1550,18 +1550,120 @@ local function bindReturnOfFunction(source, mfunc, index, args) if not returnObject then return end + + local resolveArgs = args + if source.func and source.func.type == 'getmethod' then + local receiver = source.func.node + if receiver then + resolveArgs = { receiver } + if args then + for i = 2, #args do + resolveArgs[#resolveArgs + 1] = args[i] + end + end + end + end + local returnNode = vm.compileNode(returnObject) + + local selfGenericResolved = nil + if source.func and source.func.type == 'getmethod' and mfunc.type == 'function' and mfunc.bindDocs then + local receiver = source.func.node + if receiver then + local receiverNode = vm.compileNode(receiver) + local selfGenericName = nil + for _, doc in ipairs(mfunc.bindDocs) do + if doc.type == 'doc.param' and doc.param and doc.param[1] == 'self' then + if doc.extends then + for _, typeUnit in ipairs(doc.extends.types or {}) do + if typeUnit.type == 'doc.generic.name' then + selfGenericName = typeUnit[1] + break + end + end + end + break + end + end + if selfGenericName then + local filteredNode = vm.createNode() + for item in receiverNode:eachObject() do + if item.type == 'doc.type.sign' + or (item.type == 'global' and item.cate == 'type') + or item.type == 'doc.type.table' + or item.type == 'doc.type.array' then + filteredNode:merge(item) + end + end + if not filteredNode:isEmpty() then + selfGenericResolved = { [selfGenericName] = filteredNode } + else + selfGenericResolved = { [selfGenericName] = receiverNode } + end + end + end + end + for rnode in returnNode:eachObject() do if rnode.type == 'generic' then - returnNode = rnode:resolve(guide.getUri(source), args) + if selfGenericResolved and rnode.sign then + local resolved = rnode.sign:resolve(guide.getUri(source), resolveArgs) or {} + for k, v in pairs(selfGenericResolved) do + resolved[k] = v + end + local protoNode = vm.compileNode(rnode.proto) + local result = vm.createNode() + for nd in protoNode:eachObject() do + if nd.type == 'global' or nd.type == 'variable' then + result:merge(nd) + else + local clonedObject = vm.cloneObject(nd, resolved) + if clonedObject then + result:merge(vm.compileNode(clonedObject)) + end + end + end + if protoNode:isOptional() then + result:addOptional() + end + returnNode = result + else + returnNode = rnode:resolve(guide.getUri(source), resolveArgs) + end break end end - -- Handle method calls on generic class instances - -- When calling b:getValue() where b is Box, resolve T to string + if mfunc.type == 'function' then + local hasUnresolvedGeneric = false + for rnode in returnNode:eachObject() do + if rnode.type == 'doc.generic.name' and not rnode._resolved then + hasUnresolvedGeneric = true + break + end + end + if hasUnresolvedGeneric then + local sign = vm.getSign(mfunc) + if sign and resolveArgs and #resolveArgs > 0 then + local resolved = sign:resolve(guide.getUri(source), resolveArgs) + if resolved and next(resolved) then + local newReturnNode = vm.createNode() + for rnode in returnNode:eachObject() do + local cloned = vm.cloneObject(rnode, resolved) + if cloned then + newReturnNode:merge(vm.compileNode(cloned)) + else + newReturnNode:merge(rnode) + end + end + returnNode = newReturnNode + end + end + end + end + local call = source.parent - if call and call.type == 'call' then + if not selfGenericResolved and call and call.type == 'call' then local callNode = call.node if callNode and (callNode.type == 'getmethod' or callNode.type == 'getfield') then local receiver = callNode.node @@ -1571,7 +1673,6 @@ local function bindReturnOfFunction(source, mfunc, index, args) if rn.type == 'doc.type.sign' and rn.signs and rn.node and rn.node[1] then local classGlobal = vm.getGlobal('type', rn.node[1]) if classGlobal then - -- Build a map of class generic param names to their concrete types local genericMap = {} for _, set in ipairs(classGlobal:getSets(guide.getUri(source))) do if set.type == 'doc.class' and set.signs then @@ -1616,7 +1717,6 @@ local function bindReturnOfFunction(source, mfunc, index, args) if returnNode then for rnode in returnNode:eachObject() do - -- TODO: narrow type if rnode.type ~= 'doc.generic.name' then vm.setNode(source, rnode) end @@ -2191,7 +2291,11 @@ local compilerSwitch = util.switch() end) : case 'doc.generic.name' : call(function (source) - vm.setNode(source, source) + if source._resolved then + vm.setNode(source, source._resolved) + else + vm.setNode(source, source) + end end) : case 'doc.type.sign' : call(function (source) diff --git a/test/type_inference/common.lua b/test/type_inference/common.lua index ef237830f..c54c82ee1 100644 --- a/test/type_inference/common.lua +++ b/test/type_inference/common.lua @@ -5046,3 +5046,126 @@ local f local = f:wrap() ]] + +-- Issue #723: Generics in @overload +-- @generic should work with @overload annotations +TEST 'string' [[ +---@generic T +---@param x T +---@return T +---@overload fun(x: T): T +local function identity(x) + return x +end + +local = identity("hello") +]] + +-- Issue #723: Multiple generics in @overload +TEST 'integer' [[ +---@generic K, V +---@param k K +---@param v V +---@return V +---@overload fun(k: K, v: V): V +local function getValue(k, v) + return v +end + +local = getValue("key", 42) +]] + +-- Issue #1170: Generics in function type format (fun) +TEST 'string' [[ +---@type fun(x: T): T +local identity + +local = identity("hello") +]] + +-- Issue #1170: Multiple generics in function type +TEST 'boolean' [[ +---@type fun(k: K, v: V): V +local getSecond + +local = getSecond("key", true) +]] + +-- Issue #1170: Generic function in @field +TEST 'integer' [[ +---@class Mapper +---@field transform fun(input: T, fn: fun(x: T): U): U + +---@type Mapper +local m + +local = m.transform("hello", function(x) return #x end) +]] + +-- Issue #1532: Promise-like method chaining +-- Method returning self-type should preserve generic param through chain +TEST 'string' [[ +---@class Promise +---@field value T +local Promise = {} + +---@return Promise +function Promise:next(fn) + return self +end + +---@return T +function Promise:await() + return self.value +end + +---@type Promise +local p + +local = p:next(function() end):await() +]] + +-- Issue #1532: Multiple chained methods +TEST 'number' [[ +---@class Chain +local Chain = {} + +---@return Chain +function Chain:map(fn) + return self +end + +---@return Chain +function Chain:filter(fn) + return self +end + +---@return V +function Chain:first() + return nil +end + +---@type Chain +local c + +local = c:map(function() end):filter(function() end):first() +]] + +-- Issue #1000: Generic self parameter should resolve to concrete type +-- When @generic T and @param self T, calling on List should return List +TEST 'List' [[ +---@class List +local List = {} + +---@generic T +---@param self T +---@return T +function List:identity() + return self +end + +---@type List +local mylist + +local = mylist:identity() +]] From c2fb295283b68e0664e8b97e53dee04e2357c0dc Mon Sep 17 00:00:00 2001 From: Ian Pascoe Date: Mon, 12 Jan 2026 07:11:31 -0500 Subject: [PATCH 7/7] refactor: simplify generic resolution with helper functions - Fix double-clone bug in cloneObject for doc.type.function returns - Add containsGenericName() for early-exit generic detection - Add getClassGenericMap() to eliminate code duplication - Refactor resolveGenericField() to use new helpers --- script/vm/compiler.lua | 128 +++++++++++++++++++++++++++-------------- script/vm/generic.lua | 4 +- 2 files changed, 86 insertions(+), 46 deletions(-) diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index 9e8ef4aaf..d10c171f1 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -234,49 +234,68 @@ local function searchLiteralFieldFromTable(source, key, callback) end end ----@param uri uri ----@param classGlobal vm.global ----@param field parser.object ----@param signs parser.object[] ----@return parser.object? -local function resolveGenericField(uri, classGlobal, field, signs) - if field.type ~= 'doc.field' then - return nil +---@param obj parser.object +---@return boolean +local function containsGenericName(obj) + if not obj then + return false end - if not field.extends then - return nil + if obj.type == 'doc.generic.name' then + return true end - local hasGeneric = false - guide.eachSourceType(field.extends, 'doc.generic.name', function () - hasGeneric = true - end) - if not hasGeneric then - return nil + if obj.type == 'doc.type' and obj.types then + for _, t in ipairs(obj.types) do + if containsGenericName(t) then + return true + end + end + elseif obj.type == 'doc.type.array' then + return containsGenericName(obj.node) + elseif obj.type == 'doc.type.table' and obj.fields then + for _, field in ipairs(obj.fields) do + if containsGenericName(field.name) or containsGenericName(field.extends) then + return true + end + end + elseif obj.type == 'doc.type.sign' then + if obj.signs then + for _, s in ipairs(obj.signs) do + if containsGenericName(s) then + return true + end + end + end + elseif obj.type == 'doc.type.function' then + for _, arg in ipairs(obj.args or {}) do + if containsGenericName(arg.extends) then + return true + end + end + for _, ret in ipairs(obj.returns or {}) do + if containsGenericName(ret) then + return true + end + end end + return false +end + +---@param uri uri +---@param classGlobal vm.global +---@param signs parser.object[] +---@return table? +local function getClassGenericMap(uri, classGlobal, signs) for _, set in ipairs(classGlobal:getSets(uri)) do if set.type == 'doc.class' and set.signs then local resolved = {} for i, signName in ipairs(set.signs) do local signType = signs[i] if signType and signName[1] then - local signNode = vm.compileNode(signType) - resolved[signName[1]] = signNode + resolved[signName[1]] = vm.compileNode(signType) end end if next(resolved) then - local newExtends = vm.cloneObject(field.extends, resolved) - if newExtends then - return { - type = field.type, - start = field.start, - finish = field.finish, - parent = field.parent, - field = field.field, - extends = newExtends, - visible = field.visible, - optional = field.optional, - } - end + return resolved end break end @@ -284,6 +303,38 @@ local function resolveGenericField(uri, classGlobal, field, signs) return nil end +---@param uri uri +---@param classGlobal vm.global +---@param field parser.object +---@param signs parser.object[] +---@return parser.object? +local function resolveGenericField(uri, classGlobal, field, signs) + if field.type ~= 'doc.field' or not field.extends then + return nil + end + if not containsGenericName(field.extends) then + return nil + end + local resolved = getClassGenericMap(uri, classGlobal, signs) + if not resolved then + return nil + end + local newExtends = vm.cloneObject(field.extends, resolved) + if not newExtends then + return nil + end + return { + type = field.type, + start = field.start, + finish = field.finish, + parent = field.parent, + field = field.field, + extends = newExtends, + visible = field.visible, + optional = field.optional, + } +end + local searchFieldSwitch = util.switch() : case 'table' : call(function (_suri, source, key, pushResult) @@ -1673,19 +1724,8 @@ local function bindReturnOfFunction(source, mfunc, index, args) if rn.type == 'doc.type.sign' and rn.signs and rn.node and rn.node[1] then local classGlobal = vm.getGlobal('type', rn.node[1]) if classGlobal then - local genericMap = {} - for _, set in ipairs(classGlobal:getSets(guide.getUri(source))) do - if set.type == 'doc.class' and set.signs then - for i, signName in ipairs(set.signs) do - if signName[1] and rn.signs[i] then - genericMap[signName[1]] = vm.compileNode(rn.signs[i]) - end - end - break - end - end - - if next(genericMap) and mfunc.bindDocs then + local genericMap = getClassGenericMap(guide.getUri(source), classGlobal, rn.signs) + if genericMap and mfunc.bindDocs then for _, doc in ipairs(mfunc.bindDocs) do if doc.type == 'doc.return' then for _, rtn in ipairs(doc.returns) do diff --git a/script/vm/generic.lua b/script/vm/generic.lua index 789d4dbb3..b6cffe5a5 100644 --- a/script/vm/generic.lua +++ b/script/vm/generic.lua @@ -121,10 +121,10 @@ local function cloneObject(source, resolved) newDocFunc.args[i] = newObj end for i, ret in ipairs(source.returns) do - local newObj = cloneObject(ret, resolved) + local newObj = cloneObject(ret, resolved) newObj.parent = newDocFunc newObj.optional = ret.optional - newDocFunc.returns[i] = cloneObject(ret, resolved) + newDocFunc.returns[i] = newObj end return newDocFunc end