Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer function parameter types when overriding the same-named class function in an instance of that class #2859

Merged
merged 4 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Unreleased
<!-- Add all new changes here. They will be moved under a version at release -->
* `NEW` Added support for Japanese locale
* `NEW` Infer function parameter types when overriding the same-named class function in an instance of that class [#2158](https://github.com/LuaLS/lua-language-server/issues/2158)
* `FIX` Eliminate floating point error in test benchmark output
* `FIX` Remove luamake install from make scripts

Expand Down
6 changes: 6 additions & 0 deletions script/core/diagnostics/duplicate-set-field.lua
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ return function (uri, callback)
if not defValue or defValue.type ~= 'function' then
goto CONTINUE
end
if vm.getDefinedClass(guide.getUri(def), def.node)
and not vm.getDefinedClass(guide.getUri(src), src.node)
then
-- allow type variable to override function defined in class variable
goto CONTINUE
end
callback {
start = src.start,
finish = src.finish,
Expand Down
60 changes: 45 additions & 15 deletions script/vm/compiler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1117,26 +1117,56 @@ local function compileFunctionParam(func, source)
end
---@cast aindex integer

-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
local funcNode = vm.compileNode(func)
local found = false
for n in funcNode:eachObject() do
if n.type == 'doc.type.function' and n.args[aindex] then
local argNode = vm.compileNode(n.args[aindex])
for an in argNode:eachObject() do
if an.type ~= 'doc.generic.name' then
vm.setNode(source, an)
if func.parent.type == 'callargs' then
-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
for n in funcNode:eachObject() do
if n.type == 'doc.type.function' and n.args[aindex] then
local argNode = vm.compileNode(n.args[aindex])
for an in argNode:eachObject() do
if an.type ~= 'doc.generic.name' then
vm.setNode(source, an)
end
end
end
-- NOTE: keep existing behavior for local call which only set type based on the 1st match
if func.parent.type == 'callargs' then
-- NOTE: keep existing behavior for function as argument which only set type based on the 1st match
return true
end
found = true
end
end
if found then
return true
else
-- function declaration: use info from all `fun()`, also from the base function when overriding
--[[
---@type fun(x: string)|fun(x: number)
local function f1(x) end --> x -> string|number

---@overload fun(x: string)
---@overload fun(x: number)
local function f2(x) end --> x -> string|number

---@class A
local A = {}
---@param x number
function A:f(x) end --> x -> number
---@type A
local a = {}
function a:f(x) end --> x -> number
]]
local found = false
for n in funcNode:eachObject() do
if (n.type == 'doc.type.function' or n.type == 'function')
and n.args[aindex] and n.args[aindex] ~= source
then
local argNode = vm.compileNode(n.args[aindex])
for an in argNode:eachObject() do
if an.type ~= 'doc.generic.name' then
vm.setNode(source, an)
end
end
found = true
end
end
if found then
return true
end
end

local derviationParam = config.get(guide.getUri(func), 'Lua.type.inferParamType')
Expand Down
26 changes: 26 additions & 0 deletions test/diagnostics/duplicate-set-field.lua
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,29 @@ else
function X.f() end
end
]]

TEST [[
---@class A
X = {}

function X:f() end

---@type x
local x

function x:f() end
]]

TEST [[
---@class A
X = {}

function X:f() end

---@type x
local x

function <!x:f!>() end

function <!x:f!>() end
]]
12 changes: 12 additions & 0 deletions test/type_inference/common.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4441,3 +4441,15 @@ local B = {}

function B:func(<?x?>) end
]]

TEST 'number' [[
---@class A
local A = {}

---@param x number
function A:func(x) end

---@type A
local a = {}
function a:func(<?x?>) end
]]
Loading