ソフトウェア開発 C++

700行数式パーサー

戻る


約700行で数式を字句解析・構文解析するC++コードは以下の通り。

#include <string>
#include <iostream>
#include <map>
 
#include <cstring>
 
#define MAX_BUF 64

inline bool ISSPACE(char ch)
{
    return ch == ' ' || ch == '\t';
}
inline bool ISUPPER(char ch)
{
    return 'A' <= ch && ch <= 'Z';
}
inline bool ISLOWER(char ch)
{
    return 'a' <= ch && ch <= 'z';
}
inline bool ISALPHA(char ch)
{
    return ISUPPER(ch) || ISLOWER(ch);
}
inline bool ISDIGIT(char ch)
{
    return '0' <= ch && ch <= '9';
}
inline bool ISALNUM(char ch)
{
    return ISALPHA(ch) || ISDIGIT(ch);
}
 
typedef double DOUBLE;
typedef std::string STRING;
typedef int INT;
 
std::map<STRING, DOUBLE> g_varmap;
 
enum NODE_TYPE
{
    NT_INVALID,
    NT_NUMBER,
    NT_IDENT,
    NT_LPAREN,
    NT_RPAREN,
    NT_OP,
    NT_BOP,
    NT_UOP,
    NT_FUNC,
    NT_VAR
};
 
struct NODE
{
    NODE_TYPE type;
    NODE *left, *right;
    STRING str;
    DOUBLE value;
 
    NODE()
    {
        left = right = NULL;
        value = 0.0;
    }
    NODE(NODE_TYPE type_, const STRING& str_, DOUBLE value_ = 0.0)
    {
        type = type_;
        left = right = NULL;
        str = str_;
        value = value_;
    }
    NODE(DOUBLE value_)
    {
        type = NT_NUMBER;
        left = right = NULL;
        value = value_;
    }
    NODE(const NODE& node)
    {
        type = node.type;
        left = right = NULL;
        str = node.str;
        value = node.value;
    }
};
 
void tree_free(NODE *tree)
{
    if (tree != NULL)
    {
        if (tree->left != NULL) tree_free(tree->left);
        if (tree->right != NULL) tree_free(tree->right);
        delete tree;
    }
}
 
void list_free(NODE *list)
{
    NODE *node = list, *next;
    while(node != NULL)
    {
        next = node->right;
        delete node;
        node = next;
    }
}
 
NODE *list_last(NODE *list)
{
    NODE *node = list, *next;
    while(node != NULL)
    {
        next = node->right;
        if (next == NULL) break;
        node = next;
    }
    return node;
}
 
struct LIST
{
    NODE *head, *tail;
    INT count;
 
    LIST()
    {
        head = tail = NULL;
        count = 0;
    }
    LIST(NODE *begin, NODE *end = NULL);
    ~LIST()
    {
        free();
    }
    void add(NODE *node)
    {
        if (head == NULL)
        {
            head = tail = node;
            return;
        }
        tail->right = node;
        node->left = tail;
        tail = node;
        count++;
    }
    void add(NODE_TYPE type_, const STRING& str_, DOUBLE value_ = 0.0)
    {
        add(new NODE(type_, str_, value_));
    }
    void add(DOUBLE value_)
    {
        add(new NODE(value_));
    }
    void free()
    {
        list_free(head);
        head = tail = NULL;
    }
};
 
LIST::LIST(NODE *begin, NODE *end/* = NULL*/)
{
    head = tail = NULL;
    count = 0;
    NODE *node = begin;
    while(node != end)
    {
        add(new NODE(*node));
        node = node->right;
    }
}
 
void list_free(LIST *list)
{
    list_free(list->head);
}
 
NODE *list_last(LIST *list)
{
    return list->tail;
}
 
LIST *lex(const char *str)
{
    const char *p;
    INT count;
    DOUBLE value;
    char buf[MAX_BUF];
    LIST *list = new LIST;
 
    while(*str)
    {
        while(ISSPACE(*str)) str++;
        if (*str == 0)
            break;
 
        if (*str == '(')
        {
            list->add(NT_LPAREN, STRING());
            str++;
        }
        else if (*str == ')')
        {
            list->add(NT_RPAREN, STRING());
            str++;
        }
        else if (ISDIGIT(*str) || *str == '.')
        {
            p = str + 1;
            while(ISDIGIT(*p) || *p == '.') p++;
            count = p - str;
            if (count >= MAX_BUF)
            {
                std::cerr << "ERROR: too long number" << std::endl;
                delete list;
                return NULL;
            }
            memcpy(buf, str, count);
            buf[count] = 0;
            value = std::atof(buf);
            list->add(value);
            str = p;
        }
        else if (ISALPHA(*str) || *str == '_')
        {
            p = str + 1;
            while(ISALNUM(*p) || *p == '_') p++;
            count = p - str;
            if (count >= MAX_BUF)
            {
                std::cerr << "ERROR: too long ident" << std::endl;
                delete list;
                return NULL;
            }
            memcpy(buf, str, count);
            buf[count] = 0;
            list->add(NT_IDENT, STRING(buf));
            str = p;
        }
        else
        {
            switch(*str)
            {
            case '+': case '-': case '*': case '/': case '^':
                buf[0] = *str;
                buf[1] = 0;
                list->add(NT_OP, buf);
                str++;
                break;
 
            default:
                std::cerr << "ERROR: invalid character" << std::endl;
                delete list;
                return NULL;
            }
        }
    }
 
    INT level = 0, flag = 0;
    NODE *node = list->head;
    NODE_TYPE type, old_type = NT_INVALID;
    while(node != NULL)
    {
        type = node->type;
        if (type == NT_LPAREN)
        {
            level++;
            flag = 0;
            if (old_type == NT_RPAREN || type == NT_NUMBER)
            {
                std::cerr << "ERROR: syntax error (1)" << std::endl;
                delete list;
                return NULL;
            }
        }
        else if (type == NT_RPAREN)
        {
            level--;
            if (level < 0)
            {
                std::cerr << "ERROR: mismatch parentheses" << std::endl;
                delete list;
                return NULL;
            }
            flag = 1;
            if (old_type == NT_LPAREN)
            {
                std::cerr << "ERROR: syntax error (2)" << std::endl;
                delete list;
                return NULL;
            }
        }
        else if (type == NT_OP)
        {
            if (node->str == "+")
            {
                if (flag == 0)
                {
                    node->type = NT_UOP;
                }
                else
                {
                    node->type = NT_BOP;
                    flag = 0;
                }
            }
            else if (node->str == "-")
            {
                if (flag == 0)
                {
                    node->type = NT_UOP;
                }
                else
                {
                    node->type = NT_BOP;
                    flag = 0;
                }
            }
            else
            {
                if (node->str == "*" || node->str == "/" || node->str == "^")
                    node->type = NT_BOP;
                flag = 0;
            }
        }
        else
        {
            if (type == NT_NUMBER || type == NT_IDENT)
            {
                if (old_type == NT_NUMBER || old_type == NT_IDENT ||
                    old_type == NT_RPAREN)
                {
                    std::cerr << "ERROR: syntax error (4)" << std::endl;
                    delete list;
                    return NULL;
                }
                flag = 1;
            }
        }
        old_type = type;
        node = node->right;
    }
 
    if (level != 0)
    {
        std::cerr << "ERROR: mismatch parentheses" << std::endl;
        delete list;
        return NULL;
    }
 
#if 0
    node = list->head;
    while(node != NULL)
    {
        std::cerr << node->type << std::endl;
        node = node->right;
    }
#endif

    return list;
}
 
NODE *parse(NODE *begin, NODE *end)
{
    NODE *node;
    INT level;
 
    NODE *rbegin = (end == NULL ? list_last(begin) : end->left);
    NODE *rend = begin->left;
 
    level = 0;
    node = rbegin;
    while(node != rend)
    {
        if (node->type == NT_RPAREN)
        {
            level++;
        }
        else if (node->type == NT_LPAREN)
        {
            level--;
        }
        else if (level == 0)
        {
            if (node->type == NT_BOP)
            {
                if (node->str == "+" || node->str == "-")
                {
                    if (node->right != NULL)
                    {
                        NODE *left = parse(begin, node);
                        NODE *right = parse(node->right, end);
                        NODE *tree = new NODE(NT_BOP, node->str);
                        if (left != NULL && right != NULL)
                        {
                            tree->left = left;
                            tree->right = right;
                        }
                        else
                        {
                            tree_free(left);
                            tree_free(right);
                            delete tree;
                            tree = NULL;
                        }
                        return tree;
                    }
                    else
                    {
                        std::cerr << "ERROR: syntax error (5)" << std::endl;
                        return NULL;
                    }
                }
            }
        }
        node = node->left;
    }
 
    level = 0;
    node = rbegin;
    while(node != rend)
    {
        if (node->type == NT_RPAREN)
        {
            level++;
        }
        else if (node->type == NT_LPAREN)
        {
            level--;
        }
        else if (level == 0)
        {
            if (node->type == NT_BOP)
            {
                if (node->str == "*" || node->str == "/")
                {
                    if (node->right != NULL)
                    {
                        NODE *left = parse(begin, node);
                        NODE *right = parse(node->right, end);
                        NODE *tree = new NODE(NT_BOP, node->str);
                        if (left != NULL && right != NULL)
                        {
                            tree->left = left;
                            tree->right = right;
                        }
                        else
                        {
                            tree_free(left);
                            tree_free(right);
                            delete tree;
                            tree = NULL;
                        }
                        return tree;
                    }
                    else
                    {
                        std::cerr << "ERROR: syntax error (6)" << std::endl;
                        return NULL;
                    }
                }
            }
        }
        node = node->left;
    }
 
    level = 0;
    node = rbegin;
    while(node != rend)
    {
        if (node->type == NT_RPAREN)
        {
            level++;
        }
        else if (node->type == NT_LPAREN)
        {
            level--;
        }
        else if (level == 0)
        {
            if (node->type == NT_BOP)
            {
                if (node->str == "^")
                {
                    if (node->right != NULL)
                    {
                        NODE *left = parse(begin, node);
                        NODE *right = parse(node->right, end);
                        NODE *tree = new NODE(NT_BOP, node->str);
                        if (left != NULL && right != NULL)
                        {
                            tree->left = left;
                            tree->right = right;
                        }
                        else
                        {
                            tree_free(left);
                            tree_free(right);
                            delete tree;
                            tree = NULL;
                        }
                        return tree;
                    }
                    else
                    {
                        std::cerr << "ERROR: syntax error (7)" << std::endl;
                        return NULL;
                    }
                }
            }
        }
        node = node->left;
    }
 
    node = begin;
    if (node->type == NT_UOP)
    {
        NODE *right = parse(node->right, end);
        NODE *tree = new NODE(NT_UOP, node->str);
        if (right != NULL)
        {
            tree->right = right;
        }
        else
        {
            tree_free(right);
            delete tree;
            tree = NULL;
        }
        return tree;
    }
 
    if (begin->type == NT_LPAREN)
    {
        level = 0;
        node = begin;
        INT flag = 1;
        while(node != end)
        {
            if (node->type == NT_LPAREN)
            {
                level++;
            }
            else if (node->type == NT_RPAREN)
            {
                level--;
                if (level == 0 && node->right != end)
                {
                    flag = 0;
                    break;
                }
            }
            node = node->right;
        }
        if (flag)
            return parse(begin->right, rbegin);
    }
 
    if (begin->type == NT_IDENT)
    {
        node = begin->right;
        if (node != NULL && node->type == NT_LPAREN)
        {
            level = 0;
            INT flag = 1;
            while(node != end)
            {
                if (node->type == NT_LPAREN)
                {
                    level++;
                }
                else if (node->type == NT_RPAREN)
                {
                    level--;
                    if (level == 0 && node->right != end)
                    {
                        flag = 0;
                        break;
                    }
                }
                node = node->right;
            }
            if (flag)
            {
                NODE *tree = new NODE(NT_FUNC, begin->str);
                tree->left = NULL;
                tree->right = parse(begin->right, end);
                if (tree->right == NULL)
                {
                    delete tree;
                    tree = NULL;
                }
                return tree;
            }
        }
    }
 
    node = begin;
    if (node->right == end)
    {
        if (node->type == NT_IDENT)
            return new NODE(NT_VAR, node->str);
        if (node->type == NT_NUMBER)
            return new NODE(node->value);
    }
 
    if (node->type == NT_UOP)
    {
        NODE *tree = new NODE(NT_UOP, node->str);
        tree->right = parse(node->right, end);
        if (tree->right == NULL)
        {
            delete tree;
            tree = NULL;
        }
        return tree;
    }
 
    std::cerr << "ERROR: syntax error (8)" << std::endl;
    return NULL;
}
 
NODE *parse(LIST *list)
{
    return parse(list->head, NULL);
}
 
NODE *parse(const char *str)
{
    LIST *list = lex(str);
    if (list != NULL)
    {
        NODE *tree = parse(list);
        delete list;
        return tree;
    }
    return NULL;
}
 
DOUBLE eval(NODE *tree)
{
    if (tree == NULL) return 0;
 
    switch(tree->type)
    {
    case NT_NUMBER:
        //std::cerr << "number " << tree->value << std::endl;
        return tree->value;
 
    case NT_VAR:
        //std::cerr << "variable " << tree->str << std::endl;
        if (g_varmap.find(tree->str) == g_varmap.end())
        {
            std::cerr << "ERROR: variable `" << tree->str <<
                "' was not found" << std::endl;
            return 0;
        }
        return g_varmap.find(tree->str)->second;
 
    case NT_BOP:
        //std::cerr << "b-op " << tree->str << std::endl;
        if (tree->str == "+")
            return eval(tree->left) + eval(tree->right);
        if (tree->str == "-")
            return eval(tree->left) - eval(tree->right);
        if (tree->str == "*")
            return eval(tree->left) * eval(tree->right);
        if (tree->str == "/")
            return eval(tree->left) / eval(tree->right);
        if (tree->str == "^")
            return std::pow(eval(tree->left), eval(tree->right));
        std::cerr << "ERROR: internal error" << std::endl;
        return 0;
 
    case NT_UOP:
        //std::cerr << "u-op " << tree->str << std::endl;
        if (tree->str == "+")
            return +eval(tree->right);
        if (tree->str == "-")
            return -eval(tree->right);
        std::cerr << "ERROR: internal error" << std::endl;
        return 0;
 
    case NT_FUNC:
        //std::cerr << "function " << tree->str << std::endl;
        if (tree->str == "sin")
            return std::sin(eval(tree->right));
        if (tree->str == "cos")
            return std::cos(eval(tree->right));
        if (tree->str == "tan")
            return std::tan(eval(tree->right));
        std::cerr << "ERROR: function `" << tree->str <<
            "' was not found" << std::endl;
        return 0;
 
    default:
        std::cerr << "ERROR: internal error" << std::endl;
        return 0;
    }
}
 
DOUBLE eval(const char *str)
{
    NODE *tree = parse(str);
    if (tree != NULL)
    {
        DOUBLE value = eval(tree);
        tree_free(tree);
        return value;
    }
    return 0;
}
 
int main(void)
{
    g_varmap["x"] = 100;
    g_varmap["y"] = 0.1;
    std::cout << eval("x^2 - 10 * y + 100") << std::endl;
    return 0;
}

普通は、bison, lex, yaccなどを使うんですがね。。。もっと短くできるかな? メールください。

ソース: parser700.zip


国内格安航空券サイトe航空券.com

戻る

©片山博文MZ
katayama.hirofumi.mz@gmail.com

inserted by FC2 system