diff --git a/changelog.md b/changelog.md index d4ecb8039..ffecdc5ac 100644 --- a/changelog.md +++ b/changelog.md @@ -2,6 +2,12 @@ ## Unreleased +* `FIX` Generic class inheritance with type arguments now works correctly (e.g., `class Bar: Foo`) [#1929](https://github.com/LuaLS/lua-language-server/issues/1929) +* `FIX` Method return types on generic classes now resolve correctly (e.g., `Box:getValue()` returns `string`) [#1863](https://github.com/LuaLS/lua-language-server/issues/1863) +* `FIX` Self-referential generic classes no longer cause infinite expansion in hover display [#1853](https://github.com/LuaLS/lua-language-server/issues/1853) +* `FIX` Generic type parameters now work in `@overload` annotations [#723](https://github.com/LuaLS/lua-language-server/issues/723) +* `NEW` Support `fun` syntax for inline generic function types in `@field` and `@type` annotations [#1170](https://github.com/LuaLS/lua-language-server/issues/1170) +* `FIX` Methods with `@generic T` and `@param self T` now correctly resolve return type to the receiver's concrete type (e.g., `List:identity()` returns `List`) [#1000](https://github.com/LuaLS/lua-language-server/issues/1000) ## 3.16.4 `2025-12-25` diff --git a/script/parser/guide.lua b/script/parser/guide.lua index c407eca67..76a2d4fc0 100644 --- a/script/parser/guide.lua +++ b/script/parser/guide.lua @@ -177,7 +177,7 @@ local childMap = { ['doc.generic.object'] = {'generic', 'extends', 'comment'}, ['doc.vararg'] = {'vararg', 'comment'}, ['doc.type.array'] = {'node'}, - ['doc.type.function'] = {'#args', '#returns', 'comment'}, + ['doc.type.function'] = {'#args', '#returns', '#signs', 'comment'}, ['doc.type.table'] = {'#fields', 'comment'}, ['doc.type.literal'] = {'node'}, ['doc.type.arg'] = {'name', 'extends'}, diff --git a/script/parser/luadoc.lua b/script/parser/luadoc.lua index ad22af1d7..f31c54622 100644 --- a/script/parser/luadoc.lua +++ b/script/parser/luadoc.lua @@ -523,6 +523,8 @@ local function parseTypeUnitFunction(parent) args = {}, returns = {}, } + -- Parse optional generic params: fun(...) + typeUnit.signs = parseSigns(typeUnit) if not nextSymbolOrError('(') then return nil end @@ -617,6 +619,51 @@ local function parseTypeUnitFunction(parent) end end typeUnit.finish = getFinish() + -- Bind local generics from fun to type names within this function + if typeUnit.signs then + local generics = {} + for _, sign in ipairs(typeUnit.signs) do + generics[sign[1]] = sign + end + local function bindTypeNames(obj) + if not obj then return end + if obj.type == 'doc.type.name' and generics[obj[1]] then + obj.type = 'doc.generic.name' + obj.generic = generics[obj[1]] + elseif obj.type == 'doc.type' and obj.types then + for _, t in ipairs(obj.types) do + bindTypeNames(t) + end + elseif obj.type == 'doc.type.array' then + bindTypeNames(obj.node) + elseif obj.type == 'doc.type.table' and obj.fields then + for _, field in ipairs(obj.fields) do + bindTypeNames(field.name) + bindTypeNames(field.extends) + end + elseif obj.type == 'doc.type.sign' then + bindTypeNames(obj.node) + if obj.signs then + for _, s in ipairs(obj.signs) do + bindTypeNames(s) + end + end + elseif obj.type == 'doc.type.function' then + for _, arg in ipairs(obj.args) do + bindTypeNames(arg.extends) + end + for _, ret in ipairs(obj.returns) do + bindTypeNames(ret) + end + end + end + for _, arg in ipairs(typeUnit.args) do + bindTypeNames(arg.extends) + end + for _, ret in ipairs(typeUnit.returns) do + bindTypeNames(ret) + end + end return typeUnit end @@ -1030,6 +1077,12 @@ local docSwitch = util.switch() } return result end + if extend.type == 'doc.extends.name' then + local signResult = parseTypeUnitSign(result, extend) + if signResult then + extend = signResult + end + end result.extends[#result.extends+1] = extend result.finish = getFinish() if not checkToken('symbol', ',', 1) then @@ -1850,7 +1903,9 @@ local function bindGeneric(binded) or doc.type == 'doc.return' or doc.type == 'doc.type' or doc.type == 'doc.class' - or doc.type == 'doc.alias' then + or doc.type == 'doc.alias' + or doc.type == 'doc.field' + or doc.type == 'doc.overload' then guide.eachSourceType(doc, 'doc.type.name', function (src) local name = src[1] if generics[name] then diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index af7b7cc69..d10c171f1 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -234,6 +234,107 @@ local function searchLiteralFieldFromTable(source, key, callback) end end +---@param obj parser.object +---@return boolean +local function containsGenericName(obj) + if not obj then + return false + end + if obj.type == 'doc.generic.name' then + return true + end + if obj.type == 'doc.type' and obj.types then + for _, t in ipairs(obj.types) do + if containsGenericName(t) then + return true + end + end + elseif obj.type == 'doc.type.array' then + return containsGenericName(obj.node) + elseif obj.type == 'doc.type.table' and obj.fields then + for _, field in ipairs(obj.fields) do + if containsGenericName(field.name) or containsGenericName(field.extends) then + return true + end + end + elseif obj.type == 'doc.type.sign' then + if obj.signs then + for _, s in ipairs(obj.signs) do + if containsGenericName(s) then + return true + end + end + end + elseif obj.type == 'doc.type.function' then + for _, arg in ipairs(obj.args or {}) do + if containsGenericName(arg.extends) then + return true + end + end + for _, ret in ipairs(obj.returns or {}) do + if containsGenericName(ret) then + return true + end + end + end + return false +end + +---@param uri uri +---@param classGlobal vm.global +---@param signs parser.object[] +---@return table? +local function getClassGenericMap(uri, classGlobal, signs) + for _, set in ipairs(classGlobal:getSets(uri)) do + if set.type == 'doc.class' and set.signs then + local resolved = {} + for i, signName in ipairs(set.signs) do + local signType = signs[i] + if signType and signName[1] then + resolved[signName[1]] = vm.compileNode(signType) + end + end + if next(resolved) then + return resolved + end + break + end + end + return nil +end + +---@param uri uri +---@param classGlobal vm.global +---@param field parser.object +---@param signs parser.object[] +---@return parser.object? +local function resolveGenericField(uri, classGlobal, field, signs) + if field.type ~= 'doc.field' or not field.extends then + return nil + end + if not containsGenericName(field.extends) then + return nil + end + local resolved = getClassGenericMap(uri, classGlobal, signs) + if not resolved then + return nil + end + local newExtends = vm.cloneObject(field.extends, resolved) + if not newExtends then + return nil + end + return { + type = field.type, + start = field.start, + finish = field.finish, + parent = field.parent, + field = field.field, + extends = newExtends, + visible = field.visible, + optional = field.optional, + } +end + local searchFieldSwitch = util.switch() : case 'table' : call(function (_suri, source, key, pushResult) @@ -357,7 +458,16 @@ local searchFieldSwitch = util.switch() if not globalVar then return end - vm.getClassFields(suri, globalVar, key, pushResult) + vm.getClassFields(suri, globalVar, key, function (field, isMark) + if source.signs then + local newField = resolveGenericField(suri, globalVar, field, source.signs) + if newField then + pushResult(newField, isMark) + return + end + end + pushResult(field, isMark) + end) end) : case 'global' : call(function (suri, node, key, pushResult) @@ -565,7 +675,6 @@ function vm.getClassFields(suri, object, key, pushResult) for _, set in ipairs(sets) do if set.type == 'doc.class' then - -- look into extends(if field not found) if not searchedFields[key] and set.extends then for _, extend in ipairs(set.extends) do if extend.type == 'doc.extends.name' then @@ -573,6 +682,14 @@ function vm.getClassFields(suri, object, key, pushResult) if extendType then searchClass(extendType, searchedFields) end + elseif extend.type == 'doc.type.sign' then + searchFieldSwitch(extend.type, suri, extend, key, function (field, isMark) + local fieldKey = guide.getKeyName(field) + if fieldKey and not searchedFields[fieldKey] then + hasFounded[fieldKey] = true + pushResult(field, isMark) + end + end) end end end @@ -1484,16 +1601,162 @@ local function bindReturnOfFunction(source, mfunc, index, args) if not returnObject then return end + + local resolveArgs = args + if source.func and source.func.type == 'getmethod' then + local receiver = source.func.node + if receiver then + resolveArgs = { receiver } + if args then + for i = 2, #args do + resolveArgs[#resolveArgs + 1] = args[i] + end + end + end + end + local returnNode = vm.compileNode(returnObject) + + local selfGenericResolved = nil + if source.func and source.func.type == 'getmethod' and mfunc.type == 'function' and mfunc.bindDocs then + local receiver = source.func.node + if receiver then + local receiverNode = vm.compileNode(receiver) + local selfGenericName = nil + for _, doc in ipairs(mfunc.bindDocs) do + if doc.type == 'doc.param' and doc.param and doc.param[1] == 'self' then + if doc.extends then + for _, typeUnit in ipairs(doc.extends.types or {}) do + if typeUnit.type == 'doc.generic.name' then + selfGenericName = typeUnit[1] + break + end + end + end + break + end + end + if selfGenericName then + local filteredNode = vm.createNode() + for item in receiverNode:eachObject() do + if item.type == 'doc.type.sign' + or (item.type == 'global' and item.cate == 'type') + or item.type == 'doc.type.table' + or item.type == 'doc.type.array' then + filteredNode:merge(item) + end + end + if not filteredNode:isEmpty() then + selfGenericResolved = { [selfGenericName] = filteredNode } + else + selfGenericResolved = { [selfGenericName] = receiverNode } + end + end + end + end + for rnode in returnNode:eachObject() do if rnode.type == 'generic' then - returnNode = rnode:resolve(guide.getUri(source), args) + if selfGenericResolved and rnode.sign then + local resolved = rnode.sign:resolve(guide.getUri(source), resolveArgs) or {} + for k, v in pairs(selfGenericResolved) do + resolved[k] = v + end + local protoNode = vm.compileNode(rnode.proto) + local result = vm.createNode() + for nd in protoNode:eachObject() do + if nd.type == 'global' or nd.type == 'variable' then + result:merge(nd) + else + local clonedObject = vm.cloneObject(nd, resolved) + if clonedObject then + result:merge(vm.compileNode(clonedObject)) + end + end + end + if protoNode:isOptional() then + result:addOptional() + end + returnNode = result + else + returnNode = rnode:resolve(guide.getUri(source), resolveArgs) + end break end end + + if mfunc.type == 'function' then + local hasUnresolvedGeneric = false + for rnode in returnNode:eachObject() do + if rnode.type == 'doc.generic.name' and not rnode._resolved then + hasUnresolvedGeneric = true + break + end + end + if hasUnresolvedGeneric then + local sign = vm.getSign(mfunc) + if sign and resolveArgs and #resolveArgs > 0 then + local resolved = sign:resolve(guide.getUri(source), resolveArgs) + if resolved and next(resolved) then + local newReturnNode = vm.createNode() + for rnode in returnNode:eachObject() do + local cloned = vm.cloneObject(rnode, resolved) + if cloned then + newReturnNode:merge(vm.compileNode(cloned)) + else + newReturnNode:merge(rnode) + end + end + returnNode = newReturnNode + end + end + end + end + + local call = source.parent + if not selfGenericResolved and call and call.type == 'call' then + local callNode = call.node + if callNode and (callNode.type == 'getmethod' or callNode.type == 'getfield') then + local receiver = callNode.node + if receiver then + local receiverNode = vm.compileNode(receiver) + for rn in receiverNode:eachObject() do + if rn.type == 'doc.type.sign' and rn.signs and rn.node and rn.node[1] then + local classGlobal = vm.getGlobal('type', rn.node[1]) + if classGlobal then + local genericMap = getClassGenericMap(guide.getUri(source), classGlobal, rn.signs) + if genericMap and mfunc.bindDocs then + for _, doc in ipairs(mfunc.bindDocs) do + if doc.type == 'doc.return' then + for _, rtn in ipairs(doc.returns) do + if rtn.returnIndex == index then + local newRtn = vm.cloneObject(rtn, genericMap) + if newRtn then + returnNode = vm.compileNode(newRtn) + for rnode in returnNode:eachObject() do + if rnode.type == 'generic' then + returnNode = rnode:resolve(guide.getUri(source), args) + break + end + end + end + break + end + end + break + end + end + end + break + end + end + end + end + end + end + if returnNode then for rnode in returnNode:eachObject() do - -- TODO: narrow type if rnode.type ~= 'doc.generic.name' then vm.setNode(source, rnode) end @@ -2068,7 +2331,11 @@ local compilerSwitch = util.switch() end) : case 'doc.generic.name' : call(function (source) - vm.setNode(source, source) + if source._resolved then + vm.setNode(source, source._resolved) + else + vm.setNode(source, source) + end end) : case 'doc.type.sign' : call(function (source) @@ -2088,6 +2355,9 @@ local compilerSwitch = util.switch() if ext.type == 'doc.type.table' then if vm.getGeneric(ext) then local resolved = vm.getGeneric(ext):resolve(uri, source.signs) + for obj in resolved:eachObject() do + obj.hideView = true + end vm.setNode(source, resolved) end end diff --git a/script/vm/generic.lua b/script/vm/generic.lua index f1eaaf99d..b6cffe5a5 100644 --- a/script/vm/generic.lua +++ b/script/vm/generic.lua @@ -34,6 +34,21 @@ local function cloneObject(source, resolved) end return newName end + if source.type == 'doc.type.name' then + local key = source[1] + if resolved[key] then + local newName = { + type = 'doc.generic.name', + start = source.start, + finish = source.finish, + parent = source.parent, + [1] = source[1], + } + vm.setNode(newName, resolved[key], true) + newName._resolved = resolved[key] + return newName + end + end if source.type == 'doc.type' then local newType = { type = source.type, @@ -106,13 +121,43 @@ local function cloneObject(source, resolved) newDocFunc.args[i] = newObj end for i, ret in ipairs(source.returns) do - local newObj = cloneObject(ret, resolved) + local newObj = cloneObject(ret, resolved) newObj.parent = newDocFunc newObj.optional = ret.optional - newDocFunc.returns[i] = cloneObject(ret, resolved) + newDocFunc.returns[i] = newObj end return newDocFunc end + if source.type == 'doc.type.sign' and source.signs then + local needsClone = false + for _, sign in ipairs(source.signs) do + if sign.type == 'doc.type' then + for _, tp in ipairs(sign.types) do + if tp.type == 'doc.type.name' and resolved[tp[1]] then + needsClone = true + break + end + end + elseif sign.type == 'doc.type.name' and resolved[sign[1]] then + needsClone = true + end + if needsClone then break end + end + if needsClone then + local newSign = { + type = source.type, + start = source.start, + finish = source.finish, + parent = source.parent, + node = source.node, + signs = {}, + } + for i, sign in ipairs(source.signs) do + newSign.signs[i] = cloneObject(sign, resolved) + end + return newSign + end + end return source end @@ -173,3 +218,10 @@ function vm.createGeneric(proto, sign) }, mt) return generic end + +---@param source vm.object? +---@param resolved? table +---@return vm.object? +function vm.cloneObject(source, resolved) + return cloneObject(source, resolved) +end diff --git a/test/type_inference/common.lua b/test/type_inference/common.lua index 7f6d854c7..c54c82ee1 100644 --- a/test/type_inference/common.lua +++ b/test/type_inference/common.lua @@ -4909,3 +4909,263 @@ function f(...args) print() end ]] + +TEST 'integer' [[ +---@class Foo +---@field a T + +---@class Bar: Foo + +---@type Bar +local x +local = x.a +]] + +TEST 'string' [[ +---@class GenericBase +---@field value T + +---@class StringHolder: GenericBase + +---@type StringHolder +local holder +local = holder.value +]] + +TEST 'boolean' [[ +---@class Container +---@field key K +---@field val V + +---@class BoolContainer: Container + +---@type BoolContainer +local c +local = c.val +]] + +TEST 'string' [[ +---@class Container +---@field key K +---@field val V + +---@class BoolContainer: Container + +---@type BoolContainer +local c +local = c.key +]] + +TEST 'string' [[ +---@class Box +---@field value T +local Box = {} + +---@return T +function Box:getValue() + return self.value +end + +---@type Box +local b + +local = b:getValue() +]] + +TEST 'integer' [[ +---@class Wrapper +---@field item V +local Wrapper = {} + +---@return V +function Wrapper:unwrap() + return self.item +end + +---@type Wrapper +local w + +local = w:unwrap() +]] + +-- Issue #1856: Generic class display format +-- Current behavior shows list<>|{...} - the <> indicates an unresolved generic +-- The resolved table type is also shown +TEST 'list<>|{ [integer]: string }' [[ +---@class list: {[integer]:T} + +---@generic T +---@param class `T` +---@return list +local function new_list(class) + return {} +end + +local = new_list('string') +]] + +-- Issue #1853: Recursive expansion on hover of generic type +-- Self-referential generic classes should not expand infinitely +TEST 'store' [[ +---@class store: {set:fun(self:store, key:integer, value:T), get:fun(self:store, key:integer):T} + +local ---@type store +]] + +-- Test composite return types with generics (T[] should resolve to string[]) +TEST 'string[]' [[ +---@class Container +local Container = {} + +---@return T[] +function Container:getAll() + return {} +end + +---@type Container +local c + +local = c:getAll() +]] + +-- Test nested generic return types (Wrapper should resolve to Wrapper) +TEST 'Wrapper' [[ +---@class Wrapper +---@field value V + +---@class Factory +local Factory = {} + +---@return Wrapper +function Factory:wrap() + return {} +end + +---@type Factory +local f + +local = f:wrap() +]] + +-- Issue #723: Generics in @overload +-- @generic should work with @overload annotations +TEST 'string' [[ +---@generic T +---@param x T +---@return T +---@overload fun(x: T): T +local function identity(x) + return x +end + +local = identity("hello") +]] + +-- Issue #723: Multiple generics in @overload +TEST 'integer' [[ +---@generic K, V +---@param k K +---@param v V +---@return V +---@overload fun(k: K, v: V): V +local function getValue(k, v) + return v +end + +local = getValue("key", 42) +]] + +-- Issue #1170: Generics in function type format (fun) +TEST 'string' [[ +---@type fun(x: T): T +local identity + +local = identity("hello") +]] + +-- Issue #1170: Multiple generics in function type +TEST 'boolean' [[ +---@type fun(k: K, v: V): V +local getSecond + +local = getSecond("key", true) +]] + +-- Issue #1170: Generic function in @field +TEST 'integer' [[ +---@class Mapper +---@field transform fun(input: T, fn: fun(x: T): U): U + +---@type Mapper +local m + +local = m.transform("hello", function(x) return #x end) +]] + +-- Issue #1532: Promise-like method chaining +-- Method returning self-type should preserve generic param through chain +TEST 'string' [[ +---@class Promise +---@field value T +local Promise = {} + +---@return Promise +function Promise:next(fn) + return self +end + +---@return T +function Promise:await() + return self.value +end + +---@type Promise +local p + +local = p:next(function() end):await() +]] + +-- Issue #1532: Multiple chained methods +TEST 'number' [[ +---@class Chain +local Chain = {} + +---@return Chain +function Chain:map(fn) + return self +end + +---@return Chain +function Chain:filter(fn) + return self +end + +---@return V +function Chain:first() + return nil +end + +---@type Chain +local c + +local = c:map(function() end):filter(function() end):first() +]] + +-- Issue #1000: Generic self parameter should resolve to concrete type +-- When @generic T and @param self T, calling on List should return List +TEST 'List' [[ +---@class List +local List = {} + +---@generic T +---@param self T +---@return T +function List:identity() + return self +end + +---@type List +local mylist + +local = mylist:identity() +]]