Armed with a text editor

mu's views on program and recipe! design

This code sample checks python ParseTuple or BuildValue format strings, and reports anything it isn't sure is a proper match. Tweak the FILTER_* options, and values in ''fmtfunctions'' to taste. This really should be done with a proper C parser, but I was too far along before that became completely obvious.

Usage: ./fmtcheck.py <dir1> <dir2> ...

#! /usr/bin/env python
#
#    A lame python format string checker...
#    Copyright (C) 2005  Michael Urman
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of version 2 of the GNU General Public License as
#    published by the Free Software Foundation.
#

class NonLiteral(ValueError): pass

FILTER_UNSIGNED = False
FILTER_UNKNOWN = False

fmtfunctions = [
    "PyArg_ParseTuple",
    #"Py_BuildValue"
]

mapfmt = {
    's': 'char',
    'z': 'char',
    'u': 'Py_UNICODE',
    'e': 'char',

    '#': 'int',
    'b': 'char',
    'B': 'unsigned char',
    'h': 'short',
    'H': 'unsigned short',
    'i': 'int',
    'I': 'unsigned int',
    'l': 'long',
    'L': 'PY_LONG_LONG',
    'k': 'unsigned long',
    'K': 'unsigned PY_LONG_LONG',
    'c': 'char',
    'f': 'float',
    'd': 'double',
    'D': 'Py_complex',

    'O': 'PyObject',
    '!': 'PyObject',
    '&': '', # can't check void*
    'N': 'PyObject',
    'S': 'PyObject',
    'U': 'PyObject',
    't': 'char',
    'w': 'char',
}

maptype = {
    'const char': 'char',
    'long int': 'long',
    'unsigned long int': 'unsigned long',
    'short int': 'short',
    'unsigned short int': 'unsigned short',

    'PyTypeObject': 'PyObject',
    'PyUnicodeObject': 'PyObject',
    'PyDateTime_DateType': 'PyObject',
    'PyDateTime_TimeType': 'PyObject',
    'PyDateTime_TZInfoType': 'PyObject',
    'statichere PyTypeObject': 'PyObject',
    'staticforward PyTypeObject': 'PyObject',
    'PyCodeObject': 'PyObject',

    'alcobject': 'PyObject',
    'DBObject': 'PyObject',
    'DBLockObject': 'PyObject',
    'PyCursesWindowObject': 'PyObject',

    '#define status_i': 'int'
}

knownobj = dict.fromkeys("""
    PyCode_Type
    PyDict_Type
    PyString_Type
    PyUnicode_Type
    PyList_Type
    PyTuple_Type
    PyFile_Type
    PyInt_Type
    PySocketModule.Sock_Type

    PyCursesWindow_Type
    self
""".strip().split(), 'PyObject')

knownobj.update(dict(
    Py_FileSystemDefaultEncoding='char',
))


def findfmt(line):
    for func in fmtfunctions:
        i = line.find(func)
        if i >= 0: return i
    return i

def checkfile(name):
    from os.path import basename
    lines = [line.strip() for line in open(name)]
    kill_comments(lines)
    important = [i for (i, line) in enumerate(lines) if findfmt(line)>=0]
    for i in important:
        fmt = ''
        try: fmt, variables = getfmtvars(lines, i)
        except NonLiteral: continue
        except Exception, err: print 'skipping %s:%d %s' % (name, i+1, err)
        if not fmt: continue
        types = getvartypes(lines, i, variables)
        fakefmt = fmt
        for c in '()[]|': fakefmt = fakefmt.replace(c, '')
        if ':' in fakefmt: fakefmt = fakefmt[:fakefmt.find(':')]
        if ';' in fakefmt: fakefmt = fakefmt[:fakefmt.find(';')]
        fakefmt = fakefmt.replace('O&', '&&') # can't check function or void*
        #print i, zip(types, fakefmt, variables)
        for t, f, v in zip(types, fakefmt, variables):
            checkfmt = mapfmt.get(f)
            checktype = maptype.get(t, t)
            if FILTER_UNKNOWN and checktype == 'unknown': continue
            if FILTER_UNSIGNED:
                if checkfmt.startswith('unsigned '): checkfmt = checkfmt[9:]
                if checktype.startswith('unsigned '): checktype = checktype[9:]
            if checkfmt != '' and checkfmt != checktype:
                print "%-20s%40s >< %s" % (
                    "%s:%d" % (basename(name), i+1),
                    " `%s' type `%s'" % (v.strip('&'), t),
                    "%s: %s" % (f, checkfmt))

def kill_comments(lines):
    in_comment = False
    for i, line in enumerate(lines):
        line.replace('\\','')
        if in_comment:
            be = line.find('*/')
            if be >= 0:
                line = line[be+2:]
                in_comment = False
            else:
                line = ''
            lines[i] = line

        while True:
            bs = line.find('/*')
            be = line.find('*/')
            ls = line.find('//')

            if bs >= 0:
                if ls >= 0 and ls < bs:
                    line = line[:ls]
                elif be >= 0:
                    line = line[:bs] + line[be+2:]
                else:
                    line = line[:bs]
                    in_comment = True
            elif ls >= 0:
                line = line[:ls]
            else:
                break
            lines[i] = line

def getfmtvars(lines, i):
    content = lines[i][findfmt(lines[i]):]
    left = content.find('(')
    right = left + 1
    count = 1
    nextline = i+1
    while count > 0:
        if len(content) == right:
            content += ' ' + lines[nextline]
            nextline += 1
        if content[right] == '(': count += 1
        elif content[right] == ')': count -= 1
        right += 1
    content = content[:right]
    args = map(str.strip, content[left+1:right-1].split(','))

    if args[0].startswith('"'):
        fmt = args[0]
        args = args[1:]

    elif args[1].startswith('"'):
        fmt = args[1]
        args = args[2:]

    elif len(args) > 2 and args[2].startswith('"'):
        fmt = args[2]
        args = args[4:]

    else:
        raise NonLiteral("unrecognized format string")

    if ' ' in fmt and fmt.find('#') > fmt.find(' '): fmt = fmt[:fmt.find(' ')] + '"'
    while not fmt.endswith('"'):
        fmt += "," + args.pop(0)
    fmt = fmt[1:-1]

    if ':' in fmt: fmt = fmt[:fmt.find(':')]
    if ';' in fmt: fmt = fmt[:fmt.find(';')]
    return fmt, args

def getvartypes(lines, i, variables):
    types = {}
    typelist = []
    for var in variables:
        var = var.strip()
        if var.startswith('('):
            types[var] = var[1:var.find(')')].strip('*').strip()
        elif var.startswith('"'):
            types[var] = 'char'
        elif var.strip('&') in knownobj:
            types[var] = knownobj[var.strip('&')]
        elif '.' in var or '->' in var:
            types[var] = 'unknown'
        elif var.startswith('ntohs(') or var.startswith('htons('):
            types[var] = 'short'
        elif var.startswith('ntohl(') or var.startswith('htonl('):
            types[var] = 'long'
        else:
            if var.startswith('&'): var = var[1:]
            if '[' in var: var = var[:var.find('[')]
            j = 1
            for j in range(1,i+1):
                line = lines[i-j]
                if '[]' not in line and var in line:
                    fakeline = ' ' + line + ' '
                    for c in '()[],;+-*/=': fakeline = fakeline.replace(c, ' ')
                    if (' ' + var + ' ') in fakeline:
                        if '(' in line and '=' not in line:
                            line = line[line.find('('):]
                        for c in '=+-/^&!~()[]': line = line.replace(c, ',')
                        pieces = [piece.strip() for piece in line.split(",")]
                        decl = pieces[0].split()
                        if len(decl) > 1:
                            maybe = " ".join(decl[:-1]).strip('*')
                            if maybe != 'return' and '"' not in maybe:
                                types[var] = maybe
                                break
            else:
                types[var] = 'unknown'
        typelist.append(types.get(var, 'unknown'))
        if typelist[-1].startswith('static '): typelist[-1] = types[var][7:]
    return typelist


if __name__ == '__main__':
    import sys, os
    for basedir in sys.argv[1:]:
        for dirpath, dirnames, filenames in os.walk(basedir):
            for name in filenames:
                if name.endswith('.c') or name.endswith('.h'):
                    checkfile(os.path.join(dirpath, name))