Skip to content

Instantly share code, notes, and snippets.

@H2CO3
Created January 31, 2014 10:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save H2CO3/8729897 to your computer and use it in GitHub Desktop.
Save H2CO3/8729897 to your computer and use it in GitHub Desktop.
//
// derive.c
// just because
// (just because I wanted to demonstrate what the Sparkling API is good for)
//
// created by H2CO3 on 31/01/2014
// use for good, not for evil
//
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <assert.h>
#include <spn/api.h>
#include <spn/parser.h>
#include <spn/ast.h>
// assumptions:
// - the argument of the function is called 'x'
// - only basic arithmetic (+, -, *, /) is used
// - the only functions called are sin, cos, tan, asin, acos, atan, exp, log, sinh, cosh
// - literal numbers and any identifier that is not 'x' is assumed to be a constant
static void real_dump(SpnAST *ast, int parens)
{
static const char ops[] = "+-*/";
switch (ast->node) {
case SPN_NODE_ADD:
case SPN_NODE_SUB:
case SPN_NODE_MUL:
case SPN_NODE_DIV:
if (parens)
printf("(");
real_dump(ast->left, 1);
printf(" %c ", ops[ast->node - SPN_NODE_ADD]);
real_dump(ast->right, 1);
if (parens)
printf(")");
break;
case SPN_NODE_UNPLUS:
real_dump(ast->left, 1);
break;
case SPN_NODE_UNMINUS:
printf("-");
real_dump(ast->left, 1);
break;
case SPN_NODE_LITERAL:
assert(ast->value.t == SPN_TYPE_NUMBER);
if (ast->value.f & SPN_TFLG_FLOAT)
printf("%g", ast->value.v.fltv);
else
printf("%ld", ast->value.v.intv);
break;
case SPN_NODE_FUNCCALL:
assert(ast->left->node == SPN_NODE_IDENT);
assert(ast->right->left == NULL); // only ONE argument, s'il vous plait
printf("%s(", ast->left->name->cstr);
real_dump(ast->right->right, 0);
printf(")");
break;
case SPN_NODE_IDENT:
printf("%s", ast->name->cstr);
break;
default:
printf("\n\nError: unrecognized node/operation: %d\n", ast->node);
exit(-1);
break;
}
}
static void dump(SpnAST *ast)
{
real_dump(ast, 0);
printf("\n\n");
}
static SpnAST *copy_ast(SpnAST *orig)
{
SpnAST *ast = spn_ast_new(orig->node, orig->lineno);
spn_value_retain(&orig->value);
ast->value = orig->value;
if (orig->name) {
spn_object_retain(orig->name);
ast->name = orig->name;
}
if (orig->left)
ast->left = copy_ast(orig->left);
if (orig->right)
ast->right = copy_ast(orig->right);
return ast;
}
static int is_constant(SpnAST *ast)
{
if (ast->node == SPN_NODE_LITERAL)
return 1;
if (ast->node == SPN_NODE_IDENT)
if (strcmp(ast->name->cstr, "x") != 0)
return 1;
return 0;
}
static SpnAST *make_literal_zero(unsigned long lineno)
{
SpnAST *ast = spn_ast_new(SPN_NODE_LITERAL, lineno);
ast->value = (SpnValue){ .t = SPN_TYPE_NUMBER, .f = 0, .v.intv = 0 };
return ast;
}
static SpnAST *make_literal_one(unsigned long lineno)
{
SpnAST *ast = spn_ast_new(SPN_NODE_LITERAL, lineno);
ast->value = (SpnValue){ .t = SPN_TYPE_NUMBER, .f = 0, .v.intv = 1 };
return ast;
}
static SpnAST *derivative_func(SpnAST *ast)
{
assert(ast->node == SPN_NODE_IDENT);
static const struct {
const char *f;
const char *fprime;
} dict[] = {
{ "sin", "cos" },
{ "cos", "-sin" },
{ "tan", "1 / cos^2" },
{ "exp", "exp" },
{ "ln", "1 / " },
{ "sinh", "cosh" },
{ "cosh", "sinh" }
};
for (size_t i = 0; i < sizeof dict / sizeof dict[0]; i++) {
if (!strcmp(ast->name->cstr, dict[i].f)) {
SpnAST *ret = spn_ast_new(SPN_NODE_IDENT, ast->lineno);
ret->name = spn_string_new_nocopy(dict[i].fprime, 0);
return ret;
}
}
printf("\n\nUnrecognized function: %s\n", ast->name->cstr);
exit(-1);
return NULL;
}
static SpnAST *derivative(SpnAST *ast)
{
switch (ast->node) {
case SPN_NODE_ADD:
case SPN_NODE_SUB: {
// constant optimization
if (is_constant(ast->left))
return derivative(ast->right);
if (is_constant(ast->right))
return derivative(ast->left);
SpnAST *ret = spn_ast_new(ast->node, ast->lineno);
ret->left = derivative(ast->left);
ret->right = derivative(ast->right);
return ret;
}
case SPN_NODE_MUL: {
if (is_constant(ast->left)) {
if (is_constant(ast->right))
return make_literal_zero(ast->lineno);
SpnAST *ret = spn_ast_new(SPN_NODE_MUL, ast->lineno);
ret->left = copy_ast(ast->left);
ret->right = derivative(ast->right);
return ret;
}
if (is_constant(ast->right)) {
if (is_constant(ast->left))
return make_literal_zero(ast->lineno);
SpnAST *ret = spn_ast_new(SPN_NODE_MUL, ast->lineno);
ret->left = derivative(ast->left);
ret->right = copy_ast(ast->right);
return ret;
}
SpnAST *fder_g = spn_ast_new(SPN_NODE_MUL, ast->lineno);
SpnAST *f_gder = spn_ast_new(SPN_NODE_MUL, ast->lineno);
fder_g->left = derivative(ast->left);
fder_g->right = copy_ast(ast->right);
f_gder->left = copy_ast(ast->left);
f_gder->right = derivative(ast->right);
SpnAST *ret = spn_ast_new(SPN_NODE_ADD, ast->lineno);
ret->left = fder_g;
ret->right = f_gder;
return ret;
}
case SPN_NODE_DIV:
if (is_constant(ast->right)) {
if (is_constant(ast->left)) {
return make_literal_zero(ast->lineno);
}
SpnAST *ret = spn_ast_new(SPN_NODE_DIV, ast->lineno);
ret->left = derivative(ast->left);
ret->right = copy_ast(ast->right);
return ret;
}
SpnAST *fder_g = spn_ast_new(SPN_NODE_MUL, ast->lineno);
fder_g->left = derivative(ast->left);
fder_g->right = copy_ast(ast->right);
SpnAST *f_gder = spn_ast_new(SPN_NODE_MUL, ast->lineno);
f_gder->left = copy_ast(ast->left);
f_gder->right = derivative(ast->right);
SpnAST *diff = spn_ast_new(SPN_NODE_SUB, ast->lineno);
diff->left = fder_g;
diff->right = f_gder;
SpnAST *g_squared = spn_ast_new(SPN_NODE_MUL, ast->lineno);
g_squared->left = copy_ast(ast->right);
g_squared->right = copy_ast(ast->right);
SpnAST *ret = spn_ast_new(SPN_NODE_DIV, ast->lineno);
ret->left = diff;
ret->right = g_squared;
return ret;
case SPN_NODE_UNPLUS:
return derivative(ast->left);
case SPN_NODE_UNMINUS: {
SpnAST *ret = spn_ast_new(SPN_NODE_UNMINUS, ast->lineno);
ret->left = derivative(ast->left);
return ret;
}
case SPN_NODE_LITERAL:
return make_literal_zero(ast->lineno);
case SPN_NODE_IDENT:
if (is_constant(ast)) // not 'x'
return make_literal_zero(ast->lineno);
else // it is 'x' --> dx/dx = 1
return make_literal_one(ast->lineno);
case SPN_NODE_FUNCCALL: {
assert(ast->left->node == SPN_NODE_IDENT);
assert(ast->right->left == NULL); // only ONE argument, s'il vous plait
if (is_constant(ast->right->right)) {
return make_literal_zero(ast->lineno);
}
// optimization: let's NOT treat f'(x) as f'(x) * x' = f'(x) * 1
if (ast->right->right->node == SPN_NODE_IDENT) {
// not a constant but an identifier --> it can only be 'x'
SpnAST *ret = spn_ast_new(SPN_NODE_FUNCCALL, ast->lineno);
ret->left = derivative_func(ast->left);
ret->right = copy_ast(ast->right);
return ret;
}
// else it's a function composition
SpnAST *fder_g = spn_ast_new(SPN_NODE_FUNCCALL, ast->lineno);
fder_g->left = derivative_func(ast->left);
fder_g->right = copy_ast(ast->right);
SpnAST *gder = derivative(ast->right->right);
SpnAST *ret = spn_ast_new(SPN_NODE_MUL, ast->lineno);
ret->left = fder_g;
ret->right = gder;
return ret;
}
default:
printf("\n\nError: unrecognized node/operation: %d\n", ast->node);
exit(-1);
return NULL;
}
}
int main(int argc, char *argv[])
{
char *expr = strdup(argv[1]);
expr = realloc(expr, strlen(argv[1]) + 1 + 1);
expr[strlen(argv[1])] = ';';
expr[strlen(argv[1]) + 1] = 0;
SpnParser *parser = spn_parser_new();
SpnAST *ast = spn_parser_parse(parser, expr);
spn_parser_free(parser);
free(expr);
printf("f(x) = ");
dump(ast->left);
SpnAST *der = derivative(ast->left);
spn_ast_free(ast);
printf("f'(x) = ");
dump(der);
spn_ast_free(der);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment