[patches] [AVRO 3/3] Allow to preserve extra fields in AST and fingerprint

AKhatskevich avkhatskevich at tarantool.org
Tue Feb 20 11:26:33 MSK 2018


From: "AKhatskevich avkhatskevich at tarantool.org" <avkhatskevich at gmail.com>

Add two options (tables of attr names) to `avro.create`:
 - preserve_in_ast: do not delete from AST
 - preserve_in_fingerprint: do not delete while calculating fingerprint

This feature may be useful in case of creating frameworks which works
over AVRO.

Closes #31
---
 CMakeLists.txt              |  2 +-
 avro_schema/fingerprint.lua | 29 ++++++++-----
 avro_schema/frontend.lua    | 55 ++++++++++++++++---------
 avro_schema/init.lua        | 57 ++++++++++++++++++++++----
 avro_schema/utils.lua       | 12 ++++++
 test/api_tests.lua          | 99 ++++++++++++++++++++++++++++++++++++++++++++-
 6 files changed, 217 insertions(+), 37 deletions(-)
 create mode 100644 avro_schema/utils.lua

diff --git a/CMakeLists.txt b/CMakeLists.txt
index b7a80da..5633fab 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -85,7 +85,7 @@ add_custom_target(postprocess_lua ALL DEPENDS
 # Install module
 install(FILES avro_schema/init.lua avro_schema/compiler.lua
               avro_schema/frontend.lua avro_schema/runtime.lua
-              avro_schema/fingerprint.lua
+              avro_schema/fingerprint.lua avro_schema/utils.lua
         DESTINATION ${TARANTOOL_INSTALL_LUADIR}/avro_schema)
 
 install(FILES ${CMAKE_BINARY_DIR}/il.lua
diff --git a/avro_schema/fingerprint.lua b/avro_schema/fingerprint.lua
index 0391835..4c8ab4e 100644
--- a/avro_schema/fingerprint.lua
+++ b/avro_schema/fingerprint.lua
@@ -20,21 +20,29 @@ local function is_primitive_type(xtype)
     return false
 end
 
-local function avro_json_array(data)
+local function avro_json_array(data, extra_fields)
     local res = {}
     for _,item in ipairs(data) do
-        table.insert(res,avro_json(item))
+        table.insert(res,avro_json(item, extra_fields))
     end
     return string.format("[%s]", table.concat(res, ","))
 end
 
-local function avro_json_object(data)
+local function avro_json_object(data, extra_fields)
     local res = {}
     local necessary_order = {"name", "type", "fields", "symbols", "items", "values", "size"}
+    --
+    -- There are a cases in which it is necessary to extend a schema.
+    -- The source below provides method to add those attrs in sustainable way.
+    --
+    for _, val in ipairs(extra_fields) do
+        table.insert(necessary_order, val)
+    end
+
     for _,name in ipairs(necessary_order) do
         local item = data[name]
         if item ~= nil then
-            local inner = avro_json(item)
+            local inner = avro_json(item, extra_fields)
             inner = string.format([[%s:%s]], json.encode(name), inner)
             table.insert(res, inner)
         end
@@ -44,7 +52,10 @@ end
 
 -- Takes normalized avro schema and produces normalized schema representation
 -- encoded in json format.
-avro_json = function (data)
+avro_json = function (data, extra_fields)
+    extra_fields = extra_fields or {}
+    -- should be sorted for consistency
+    table.sort(extra_fields)
     local xtype = type(data)
     if is_primitive_type(xtype) then
         return json.encode(data)
@@ -54,17 +65,17 @@ avro_json = function (data)
     end
     -- array
     if #data > 0 then
-        return avro_json_array(data)
+        return avro_json_array(data, extra_fields)
     end
     -- object (dict)
-    return avro_json_object(data)
+    return avro_json_object(data, extra_fields)
 end
 
-local function get_fingerprint(schema, algo, size)
+local function get_fingerprint(schema, algo, size, options)
     if digest[algo] == nil or type(digest[algo]) ~= "function" then
         raise_error("The hash function %s is not supported", algo)
     end
-    local fp = digest[algo](avro_json(schema))
+    local fp = digest[algo](avro_json(schema, options.preserve_in_fingerprint))
     return fp:sub(1, size)
 end
 
diff --git a/avro_schema/frontend.lua b/avro_schema/frontend.lua
index 646b448..2a6cc34 100644
--- a/avro_schema/frontend.lua
+++ b/avro_schema/frontend.lua
@@ -67,7 +67,7 @@ local floor = math.floor
 local clear = require('table.clear')
 local next, type = next, type
 
-function deepcopy(orig)
+local function deepcopy(orig)
     local orig_type = type(orig)
     local copy
     if orig_type == 'table' then
@@ -131,6 +131,7 @@ local function type_tag(t)
     return (type(t) == 'string' and t) or t.name or t.type
 end
 
+local copy_schema
 local copy_schema_error
 local copy_schema_location_info
 
@@ -200,13 +201,22 @@ local dcache = setmetatable({}, { __mode = 'k' })
 
 local copy_field_default
 
+local function copy_fields(from, to, fields)
+    for _,field in ipairs(fields) do
+        if from[field] ~= nil then
+            to[field] = deepcopy(from[field])
+        end
+    end
+end
 -- create a private copy and sanitize recursively;
 -- [ns]       current ns (or nil)
 -- [scope]    a dictionary of named types (ocasionally used for unnamed too)
 -- [open_rec] a set consisting of the current record + parent records;
 --            it is used to reject records containing themselves
-copy_schema = function(schema, ns, scope, open_rec)
-    local res, ptr -- we depend on these being locals #5 and #6
+-- [options]  options table, contains:
+--             - preserve_in_ast: names of attrs which should not be deleted
+copy_schema = function(schema, ns, scope, open_rec, options)
+    local res, ptr -- we depend on these being locals #6 and #7
     if type(schema) == 'table' then
         if scope[schema] then
             -- this check is necessary for unnamed complex types (union, array map)
@@ -219,7 +229,7 @@ copy_schema = function(schema, ns, scope, open_rec)
             res = {}
             for branchno, xbranch in ipairs(schema) do
                 ptr = branchno
-                local branch = copy_schema(xbranch, ns, scope)
+                local branch = copy_schema(xbranch, ns, scope, nil, options)
                 local bxtype, bxname
                 if type(branch) == 'table' and not branch.type then
                     copy_schema_error('Union may not immediately contain other unions')
@@ -248,13 +258,21 @@ copy_schema = function(schema, ns, scope, open_rec)
             nullable, xtype = extract_nullable(xtype)
 
             if primitive_type[xtype] then
+                -- Preserve fields which are asked to be in ast.
+                res = {}
+                copy_fields(schema, res, options.preserve_in_ast)
                 -- primitive type normalization
-                if nullable == nil then
+                if nullable == nil and not next(res) then
                     return xtype
                 end
-                return {type = xtype, nullable = nullable}
+                res.type = xtype
+                res.nullable = nullable
+                return res
             elseif xtype == 'record' then
-                res = { type = 'record' }
+                -- Preserve fields which are asked to be in ast.
+                res = {}
+                copy_fields(schema, res, options.preserve_in_ast)
+                res.type = 'record'
                 res.nullable = nullable
                 local name, ns = checkname(schema, ns, scope)
                 scope[name] = res
@@ -298,19 +316,19 @@ copy_schema = function(schema, ns, scope, open_rec)
                     if not xtype then
                         copy_schema_error('Record field must have a "type"')
                     end
-                    field.type = copy_schema(xtype, ns, scope, open_rec)
+                    field.type = copy_schema(xtype, ns, scope, open_rec, options)
                     if open_rec[field.type] then
                         local path, n = {}
                         for i = 1, 1000000 do
-                            local _, res = debug.getlocal(i, 5)
+                            local _, res = debug.getlocal(i, 6)
                             if res == field.type then
                                 n = i
                                 break
                             end
                         end
                         for i = n, 1, -1 do
-                            local _, res = debug.getlocal(i, 5)
-                            local _, ptr = debug.getlocal(i, 6)
+                            local _, res = debug.getlocal(i, 6)
+                            local _, ptr = debug.getlocal(i, 7)
                             insert(path, res.fields[ptr].name)
                         end
                         error(format('Record %s contains itself via %s',
@@ -389,7 +407,7 @@ copy_schema = function(schema, ns, scope, open_rec)
                 if not xitems then
                     copy_schema_error('Array type must have "items"')
                 end
-                res.items = copy_schema(xitems, ns, scope)
+                res.items = copy_schema(xitems, ns, scope, nil, options)
                 scope[schema] = nil
                 return res
             elseif xtype == 'map' then
@@ -399,7 +417,7 @@ copy_schema = function(schema, ns, scope, open_rec)
                 if not xvalues then
                     copy_schema_error('Map type must have "values"')
                 end
-                res.values = copy_schema(xvalues, ns, scope)
+                res.values = copy_schema(xvalues, ns, scope, nil, options)
                 scope[schema] = nil
                 return res
             elseif xtype == 'fixed' then
@@ -478,13 +496,14 @@ copy_schema_location_info = function()
     local top, bottom = find_frames(copy_schema)
     local res = {}
     for i = bottom, top, -1 do
-        local _, node = debug.getlocal(i, 5)
-        local _, ptr  = debug.getlocal(i, 6)
+        -- 6 and 7 are res and ptr vars from copy func
+        local _, node = debug.getlocal(i, 6)
+        local _, ptr  = debug.getlocal(i, 7)
         if type(node) == 'table' then
             if node.type == nil then -- union
                 insert(res, '<union>')
                 if i <= top + 1 then
-                    local _, next_node = debug.getlocal(i - 1, 6)
+                    local _, next_node = debug.getlocal(i - 1, 7)
                     if i == top or (i == top + 1 and
                                     not (next_node and next_node.name)) then
                         insert(res, format('<branch-%d>', ptr))
@@ -525,8 +544,8 @@ copy_schema_error = function(fmt, ...)
 end
 
 -- validate schema definition (creates a copy)
-local function create_schema(schema)
-    return copy_schema(schema, nil, {})
+local function create_schema(schema, options)
+    return copy_schema(schema, nil, {}, nil, options)
 end
 
 -- get a mapping from a (string) type tag -> union branch id
diff --git a/avro_schema/init.lua b/avro_schema/init.lua
index 621030d..c3989de 100644
--- a/avro_schema/init.lua
+++ b/avro_schema/init.lua
@@ -5,6 +5,7 @@ local il          = require('avro_schema.il')
 local backend_lua = require('avro_schema.backend')
 local rt          = require('avro_schema.runtime')
 local fingerprint = require('avro_schema.fingerprint')
+local utils       = require('avro_schema.utils')
 
 local format, find, sub = string.format, string.find, string.sub
 local insert, remove, concat = table.insert, table.remove, table.concat
@@ -22,6 +23,7 @@ local rt_universal_decode = rt.universal_decode
 local install_lua_backend = backend_lua.install
 
 -- We give away a handle but we never expose schema data.
+-- {schema=schema, options=options}
 local schema_by_handle = setmetatable( {}, { __mode = 'k' } )
 
 local function get_schema(handle)
@@ -29,7 +31,7 @@ local function get_schema(handle)
     if not schema then
         error(format('Not a schema: %s', handle), 0)
     end
-    return schema
+    return schema.schema
 end
 
 local function is_schema(schema_handle)
@@ -62,7 +64,7 @@ local function get_ir(from_schema, to_schema, inverse)
 end
 
 local function schema_to_string(handle)
-    local schema = schema_by_handle[handle]
+    local schema = get_schema(handle)
     return format('Schema (%s)',
                   handle[1] or (type(schema) ~= 'table' and schema) or
                   schema.name or schema.type or 'union')
@@ -119,16 +121,53 @@ augment_defaults = function(schema, visited)
     end
 end
 
+local function create_options_validate(options)
+    options = options or {}
+    options = table.deepcopy(options)
+    if type(options) ~= 'table' then
+        return false, "Options should be a table"
+    end
+    if type(options.preserve_in_ast) ~= 'table' then
+        options.preserve_in_ast = {}
+    end
+    for _, f_ast in ipairs(options.preserve_in_ast) do
+        if type(f_ast) ~= 'string' then
+            return false, "preserve fields should be of string type"
+        end
+    end
+    if type(options.preserve_in_fingerprint) ~= 'table' then
+        options.preserve_in_fingerprint = {}
+    end
+    -- preserve_in_fingerprint should not contain fields which are not
+    -- presented in preserve_in_ast
+    for _, f_f in ipairs(options.preserve_in_fingerprint) do
+        if type(f_f) ~= 'string' then
+            return false, "preserve fields should be of string type"
+        end
+        if not utils.table_contains(options.preserve_in_ast, f_f) then
+            return false, "fingerprint should contain only fields from AST"
+        end
+    end
+    return true, options
+end
+
 local function create(raw_schema, options)
-    local ok, schema = pcall(f_create_schema, raw_schema)
+    local ok
+    ok, options = create_options_validate(options)
+    if ok == false then
+        return false, options
+    end
+    local schema
+    ok, schema = pcall(f_create_schema, raw_schema, options)
     if not ok then
         return false, schema
     end
-    if type(options) == 'table' and options.defaults == 'auto' then
+    if options.defaults == 'auto' then
         augment_defaults(schema, {})
     end
     local schema_handle = setmetatable({}, schema_handle_mt)
-    schema_by_handle[schema_handle] = schema
+    schema_by_handle[schema_handle] = {schema = schema,
+                                       options = options}
     return true, schema_handle
 end
 
@@ -511,10 +550,12 @@ end
 local function export(schema_h)
     return export_helper(get_schema(schema_h), {})
 end
-local function get_fingerprint(schema_h, algo, size)
-    if algo == nil then algo = "sha256" end
+local function get_fingerprint(schema_h, hash, size)
+    if hash == nil then hash = "sha256" end
     if size == nil then size = 8 end
-    return fingerprint.get_fingerprint(get_schema(schema_h), algo, size)
+    local schema = schema_by_handle[schema_h]
+    return fingerprint.get_fingerprint(schema.schema, hash,
+                                       size, schema.options)
 end
 local function to_json(schema_h)
     return fingerprint.avro_json(get_schema(schema_h))
diff --git a/avro_schema/utils.lua b/avro_schema/utils.lua
new file mode 100644
index 0000000..da10c25
--- /dev/null
+++ b/avro_schema/utils.lua
@@ -0,0 +1,12 @@
+local function table_contains(t, xval)
+    for k, val in ipairs(t) do
+        if type(k) == "number" and val == xval then
+            return true
+        end
+    end
+    return false
+end
+
+return {
+    table_contains = table_contains
+}
\ No newline at end of file
diff --git a/test/api_tests.lua b/test/api_tests.lua
index 42fc7f4..02be147 100644
--- a/test/api_tests.lua
+++ b/test/api_tests.lua
@@ -5,7 +5,7 @@ local msgpack = require('msgpack')
 
 local test = tap.test('api-tests')
 
-test:plan(54)
+test:plan(64)
 
 test:is_deeply({schema.create()}, {false, 'Unknown Avro type: nil'},
                'error unknown type')
@@ -280,5 +280,102 @@ for i, testcase in ipairs(fingerprint_testcases) do
     test:is(string.lower(string.tohex(fingerprint)), testcase.fingerprint, "Fingerprint testcase "..i)
 end
 
+local schema_preserve_fields_testcases = {
+    {
+        name = "1",
+        schema = {
+            type="int",
+            extra_field="extra_field"
+        },
+        options = {},
+        ast = "int"
+    },
+    {
+        name = "2",
+        schema = {
+            type="int",
+            extra_field="extra_field"
+        },
+        options = {preserve_in_ast={"extra_field"}},
+        ast = {
+            type="int",
+            extra_field="extra_field"
+        }
+    },
+    {
+        name = "3-complex",
+        schema = {
+            type="int",
+            extra_field={extra_field={"extra_field"}}
+        },
+        options = {preserve_in_ast={"extra_field"}},
+        ast = {
+            type="int",
+            extra_field={extra_field={"extra_field"}}
+        }
+    }
+}
+
+for _, testcase in ipairs(schema_preserve_fields_testcases) do
+    res = {schema.create(testcase.schema, testcase.options)}
+    test:is_deeply(schema.export(res[2]), testcase.ast, 'schema extra fields '..testcase.name)
+end
+
+test:is_deeply(
+        {schema.create("int", {
+                                preserve_in_ast={},
+                                preserve_in_fingerprint={"extra_field"},
+                             })},
+        {false, "fingerprint should contain only fields from AST"},
+        'preserve_in_fingerprint contains more fields than AST')
+
+local fingerprint
+res = {schema.create(
+        {
+            type = "record",
+            name = "test",
+            extra_field = "extra_field",
+            fields = {
+                { name = "bar", type = "null", default = msgpack.NULL, extra_field = "extra" },
+                { name = "foo", type = {"null", "int"}, default = msgpack.NULL },
+            }
+        }, nil)}
+fingerprint = schema.fingerprint(res[2], "sha256", 32)
+test:is(string.lower(string.tohex(fingerprint)),
+        "a64098ee437e9020923c6005db88f37a234ed60daae23b26e33d8ae1bf643356",
+        "Fingerprint extra fields 1")
+
+res = {schema.create(
+        {
+            type = "record",
+            name = "test",
+            extra_field = "extra_field",
+            fields = {
+                { name = "bar", type = "null", default = msgpack.NULL, extra_field = "extra" },
+                { name = "foo", type = {"null", "int"}, default = msgpack.NULL },
+            }
+        }, {preserve_in_ast={"extra_field"}, preserve_in_fingerprint={"extra_field"}})}
+fingerprint = schema.fingerprint(res[2], "sha256", 32)
+test:is(string.lower(string.tohex(fingerprint)),
+        "70bd295335daafff0a4512cadc39a4298cd81c460defec530c7372bdd1ec6f44",
+        "Fingerprint extra fields 2")
+
+res = {schema.create(
+        {
+            type = "int",
+            extra_field = "extra_field",
+        }, {preserve_in_ast={"extra_field"}})}
+fingerprint = schema.fingerprint(res[2], "sha256", 32)
+test:is_deeply(schema.export(res[2]), {type = "int", extra_field = "extra_field"},
+        "Prevent primitive type collapse by extra field")
+
+-- avro_json is used for fingerprint
+fingerprint = require("avro_schema.fingerprint")
+test:is(fingerprint.avro_json({field1="1"}), "{}", "avro_json 1")
+test:is(fingerprint.avro_json({field1="1"}, {"field1"}), '{"field1":"1"}', "avro_json 2")
+test:is(fingerprint.avro_json({field2="1", field1="1"}, {"field2", "field1"}),
+        '{"field1":"1","field2":"1"}', "avro_json 3 order")
+
+
 test:check()
 os.exit(test.planned == test.total and test.failed == 0 and 0 or -1)
-- 
2.14.1




More information about the Tarantool-patches mailing list