Skip to content

Commit

Permalink
Fix non-exhaustive warnings with the Ast.expr.Call node
Browse files Browse the repository at this point in the history
Because function pointers is forbidden in the expression language from now parser
won't allow anymore following constructions:
- 42()
- "string"()
- true()
- (...)()
- []()
  • Loading branch information
Mingun committed Mar 8, 2024
1 parent d5b4b8e commit ea3fcf0
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ class ExpressionsSpec extends AnyFunSpec {
)
}

// Attribute / method call
// Attributes
it("parses 123.to_s") {
Expressions.parse("123.to_s") should be (Attribute(IntNum(123),identifier("to_s")))
}
Expand All @@ -404,6 +404,63 @@ class ExpressionsSpec extends AnyFunSpec {
Expressions.parse("foo.bar") should be (Attribute(Name(identifier("foo")),identifier("bar")))
}

// Method calls
describe("parses method") {
it("without parameters") {
Expressions.parse("foo.bar()") should be (Call(Name(identifier("foo")),identifier("bar"), Seq()))
}

it("with parameters") {
Expressions.parse("foo.bar(42)") should be (Call(Name(identifier("foo")),identifier("bar"), Seq(IntNum(42))))
}

it("on strings") {
Expressions.parse("\"foo\".bar(42)") should be (Call(Str("foo"),identifier("bar"), Seq(IntNum(42))))
Expressions.parse("'foo'.bar(42)") should be (Call(Str("foo"),identifier("bar"), Seq(IntNum(42))))
}

it("on booleans") {
Expressions.parse("true.bar(42)") should be (Call(Bool(true), identifier("bar"), Seq(IntNum(42))))
Expressions.parse("false.bar(42)") should be (Call(Bool(false),identifier("bar"), Seq(IntNum(42))))
}

it("on integer") {
Expressions.parse("42.bar(42)") should be (Call(IntNum(42), identifier("bar"), Seq(IntNum(42))))
}

it("on float") {
Expressions.parse("42.0.bar(42)") should be (Call(FloatNum(42.0), identifier("bar"), Seq(IntNum(42))))
}

it("on array") {
Expressions.parse("[].bar(42)") should be (Call(List(Nil), identifier("bar"), Seq(IntNum(42))))
}

it("on slice") {
Expressions.parse("foo[1].bar(42)") should be (
Call(
Subscript(Name(identifier("foo")), IntNum(1)),
identifier("bar"),
Seq(IntNum(42))
)
)
}

it("on group") {
Expressions.parse("(42).bar(42)") should be (Call(IntNum(42), identifier("bar"), Seq(IntNum(42))))
}

it("on expression") {
Expressions.parse("(1+2).bar(42)") should be (
Call(
BinOp(IntNum(1), Add, IntNum(2)),
identifier("bar"),
Seq(IntNum(42))
)
)
}
}

describe("f-strings") {
it("parses f-string with just a string") {
Expressions.parse("f\"abc\"") should be(InterpolatedStr(Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,8 @@ class GraphvizClassCompiler(classSpecs: ClassSpecs, topClass: ClassSpec) extends
case _ =>
affectedVars(value)
}
case Ast.expr.Call(func, args) =>
val fromFunc = func match {
case Ast.expr.Attribute(obj: Ast.expr, methodName: Ast.identifier) => affectedVars(obj)
}
fromFunc ::: affectedVars(Ast.expr.List(args))
case Ast.expr.Call(value, _, args) =>
affectedVars(value) ::: affectedVars(Ast.expr.List(args))
case Ast.expr.Subscript(value, idx) =>
affectedVars(value) ++ affectedVars(idx)
case SwitchType.ELSE_CONST =>
Expand Down
12 changes: 11 additions & 1 deletion shared/src/main/scala/io/kaitai/struct/exprlang/Ast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,17 @@ object Ast {
// case class Dict(keys: Seq[expr], values: Seq[expr]) extends expr
/** Represents `X < Y`, `X > Y` and so on. */
case class Compare(left: expr, ops: cmpop, right: expr) extends expr
case class Call(func: expr, args: Seq[expr]) extends expr
/**
* Represents function call on some expression:
* ```
* <obj>.<methodName>(<args>)
* ```
*
* @param obj expression on which method is called
* @param methodName method to call
* @param args method arguments
*/
case class Call(obj: expr, methodName: identifier, args: Seq[expr]) extends expr
case class IntNum(n: BigInt) extends expr
case class FloatNum(n: BigDecimal) extends expr
case class Str(s: String) extends expr
Expand Down
13 changes: 10 additions & 3 deletions shared/src/main/scala/io/kaitai/struct/exprlang/Expressions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,20 @@ object Expressions {
def list_contents[$: P] = P( test.rep(1, ",") ~ ",".? )
def list[$: P] = P( list_contents ).map(Ast.expr.List(_))

def call[$: P] = P("(" ~ arglist ~ ")").map { case (args) => (lhs: Ast.expr) => Ast.expr.Call(lhs, args)}
def call[$: P] = P("(" ~ arglist ~ ")")
def slice[$: P] = P("[" ~ test ~ "]").map { case (args) => (lhs: Ast.expr) => Ast.expr.Subscript(lhs, args)}
def cast[$: P] = P( "." ~ "as" ~ "<" ~ TYPE_NAME ~ ">" ).map(
typeName => (lhs: Ast.expr) => Ast.expr.CastToType(lhs, typeName)
)
def attr[$: P] = P("." ~ NAME).map(id => (lhs: Ast.expr) => Ast.expr.Attribute(lhs, id))
def trailer[$: P]: P[Ast.expr => Ast.expr] = P( call | slice | cast | attr )
// Returns function that accept lsh expression and returns Attribute or Call
// node depending on existence of parameters
def attr[$: P] = P("." ~ NAME ~ call.?).map{
case (id, args) => (lhs: Ast.expr) => args match {
case Some(args) => Ast.expr.Call(lhs, id, args)
case None => Ast.expr.Attribute(lhs, id)
}
}
def trailer[$: P]: P[Ast.expr => Ast.expr] = P( slice | cast | attr )

def exprlist[$: P]: P[Seq[Ast.expr]] = P( expr.rep(1, sep = ",") ~ ",".? )
def testlist[$: P]: P[Seq[Ast.expr]] = P( test.rep(1, sep = ",") ~ ",".? )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,27 +87,23 @@ abstract trait CommonMethods[T] extends TypeDetector {
* @return result of translation as [[T]]
*/
def translateCall(call: Ast.expr.Call): T = {
val func = call.func
val obj = call.obj
val args = call.args

func match {
case Ast.expr.Attribute(obj: Ast.expr, methodName: Ast.identifier) =>
val objType = detectType(obj)
(objType, methodName.name) match {
// TODO: check argument quantity
case (_: StrType, "substring") => strSubstring(obj, args(0), args(1))
case (_: StrType, "to_i") => strToInt(obj, args(0))
case (_: BytesType, "to_s") =>
args match {
case Seq(Ast.expr.Str(encoding)) =>
bytesToStr(obj, encoding)
case Seq(x) =>
throw new TypeMismatchError(s"to_s: argument #0: expected string literal, got $x")
case _ =>
throw new TypeMismatchError(s"to_s: expected 1 argument, got ${args.length}")
}
case _ => throw new TypeMismatchError(s"don't know how to call method '$methodName' of object type '$objType'")
val objType = detectType(obj)
(objType, call.methodName.name) match {
// TODO: check argument quantity
case (_: StrType, "substring") => strSubstring(obj, args(0), args(1))
case (_: StrType, "to_i") => strToInt(obj, args(0))
case (_: BytesType, "to_s") =>
args match {
case Seq(Ast.expr.Str(encoding)) =>
bytesToStr(obj, encoding)
case Seq(x) =>
throw new TypeMismatchError(s"to_s: argument #0: expected string literal, got $x")
case _ =>
throw new TypeMismatchError(s"to_s: expected 1 argument, got ${args.length}")
}
case _ => throw new TypeMismatchError(s"don't know how to call method '${call.methodName}' of object type '$objType'")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,23 +237,19 @@ class TypeDetector(provider: TypeProvider) {
/**
* Detects resulting data type of a given function call expression. Typical function
* call expression in KSY is `foo.bar(arg1, arg2)`, which is represented in AST as
* `Call(Attribute(foo, bar), Seq(arg1, arg2))`.
* `Call(foo, bar, Seq(arg1, arg2))`.
* @note Must be kept in sync with [[CommonMethods.translateCall]]
* @param call function call expression
* @return data type
*/
def detectCallType(call: Ast.expr.Call): DataType = {
call.func match {
case Ast.expr.Attribute(obj: Ast.expr, methodName: Ast.identifier) =>
val objType = detectType(obj)
// TODO: check number and type of arguments in `call.args`
(objType, methodName.name) match {
case (_: StrType, "substring") => CalcStrType
case (_: StrType, "to_i") => CalcIntType
case (_: BytesType, "to_s") => CalcStrType
case _ =>
throw new MethodNotFoundError(methodName.name, objType)
}
val objType = detectType(call.obj)
// TODO: check number and type of arguments in `call.args`
(objType, call.methodName.name) match {
case (_: StrType, "substring") => CalcStrType
case (_: StrType, "to_i") => CalcIntType
case (_: BytesType, "to_s") => CalcStrType
case _ => throw new MethodNotFoundError(call.methodName.name, objType)
}
}

Expand Down

0 comments on commit ea3fcf0

Please sign in to comment.