diff --git a/lib/cors_plug.ex b/lib/cors_plug.ex index 2484779..ad1b034 100644 --- a/lib/cors_plug.ex +++ b/lib/cors_plug.ex @@ -16,17 +16,18 @@ defmodule CORSPlug do end def init(options) do - Dict.merge defaults, options + Keyword.merge(defaults, options) end def call(conn, options) do - conn = put_in conn.resp_headers, conn.resp_headers ++ headers(conn, options) + conn = put_in(conn.resp_headers, conn.resp_headers ++ headers(conn, options)) case conn.method do - "OPTIONS" -> halt send_resp conn, 204, "" + "OPTIONS" -> conn |> send_resp(204, "") |> halt _method -> conn end end + # headers specific to OPTIONS request defp headers(conn = %Plug.Conn{method: "OPTIONS"}, options) do headers(%{conn | method: nil}, options) ++ [ {"access-control-max-age", "#{options[:max_age]}"}, @@ -35,6 +36,7 @@ defmodule CORSPlug do ] end + # universal headers defp headers(conn, options) do [ {"access-control-allow-origin", origin(options[:origin], conn)}, @@ -43,15 +45,25 @@ defmodule CORSPlug do ] end + # normalize non-list to list defp origin(key, conn) when not is_list(key) do - key |> List.wrap |> origin(conn) + origin(List.wrap(key), conn) end + + # whitelist internal requests defp origin([:self], conn) do get_req_header(conn, "origin") |> List.first || "*" end - defp origin(["*"], _conn), do: "*" + + # return "*" if origin list is ["*"] + defp origin(["*"], _conn) do + "*" + end + + # return request origin if in origin list, otherwise "null" string + # see: https://www.w3.org/TR/cors/#access-control-allow-origin-response-header defp origin(origins, conn) when is_list(origins) do req_origin = get_req_header(conn, "origin") |> List.first - if req_origin in origins, do: req_origin, else: nil + if req_origin in origins, do: req_origin, else: "null" end end diff --git a/test/cors_plug_test.exs b/test/cors_plug_test.exs index 96a15e1..c682796 100644 --- a/test/cors_plug_test.exs +++ b/test/cors_plug_test.exs @@ -52,12 +52,12 @@ defmodule CORSPlugTest do get_resp_header(conn, "access-control-allow-origin") end - test "returns nil when the origin is invalid" do + test "returns null string when the origin is invalid" do opts = CORSPlug.init(origin: ["example1.com"]) conn = conn(:get, "/", nil, headers: [{"origin", "example2.com"}]) conn = CORSPlug.call(conn, opts) - assert [nil] == get_resp_header conn, "access-control-allow-origin" + assert ["null"] == get_resp_header conn, "access-control-allow-origin" end test "returns the request host when origin is :self" do