1. #include <unordered_map>
    
  2. #include <stdio.h>
    
  3. #include <string>
    
  4. #include <vector>
    
  5. #include <cstring>
    
  6. #include <sys/mman.h>
    
  7. #include <unistd.h>
    
  8. #include <malloc.h>
    
  9. #include "mm.h"
    
  10. #include <cassert>
    
  11. 
    
  12. using std::unordered_map;
    
  13. using std::string;
    
  14. using std::vector;
    
  15. 
    
  16. typedef unsigned char byte;
    
  17. typedef vector<byte> Bytes;
    
  18. 
    
  19. const int TAPE_LEN = 30*1000;
    
  20. struct Jumps {
    
  21.     // forward maps from a forward brace index
    
  22.     // to the matching closing brace.
    
  23.     unordered_map<int, int> forward;
    
  24.     // backward maps from closing brace index
    
  25.     // to the matching opening brace
    
  26.     unordered_map<int, int> backward;
    
  27. };
    
  28. 
    
  29. int fatal(const char* msg) {
    
  30.     puts(msg);
    
  31.     exit(1);
    
  32.     return 0;
    
  33. }
    
  34.     
    
  35. Jumps build_jumps(const string& prog) {
    
  36.     /*
    
  37.     build tables to lookup the matching
    
  38.     brace for a program.
    
  39.     */
    
  40.     Jumps result;
    
  41.     vector<int> stack;
    
  42.     int i = 0;
    
  43.     while(i < prog.size()) {
    
  44.         if (prog[i] == '[') {
    
  45.             stack.push_back(i);
    
  46.         } else if (prog[i] == ']') {
    
  47.             if (stack.size() == 0) {
    
  48.                 puts("no [ on stack.");
    
  49.                 exit(1);
    
  50.             }
    
  51.             int match = stack.back(); stack.pop_back();
    
  52.             result.backward[i] = match;
    
  53.             result.forward[match] = i;
    
  54.         }
    
  55.         i += 1;
    
  56.     }
    
  57.     if (stack.size() > 0) {
    
  58.         puts("jump stack not empty");
    
  59.         exit(1);
    
  60.     }
    
  61.     return result;
    
  62. }
    
  63. 
    
  64. int jump(const unordered_map<int, int>& jumps,
    
  65.     int index) {
    
  66.     auto it = jumps.find(index);
    
  67.     if (it != jumps.end()) {
    
  68.         return it->second;
    
  69.     } else {
    
  70.         printf("no jumps for %d\n", index);
    
  71.         exit(1);
    
  72.     }
    
  73. }
    
  74. 
    
  75. extern "C" {
    
  76.     char tape[TAPE_LEN];
    
  77.     int i = 0;
    
  78.     int pc = 0;
    
  79.     void plus(void);
    
  80.     void plus_end(void);
    
  81.     void minus(void);
    
  82.     void minus_end(void);
    
  83.     void right(void);
    
  84.     void right_end(void);
    
  85.     void left(void);
    
  86.     void left_end(void);
    
  87.     void dot(void);
    
  88.     void dot_end(void);
    
  89.     void prologue(void);
    
  90.     void prologue_end(void);
    
  91.     // branches. left and right brace characters.
    
  92.     void bleft(void);
    
  93.     void bleft_end(void);
    
  94.     void bright(void);
    
  95.     void bright_end(void);
    
  96. }
    
  97. 
    
  98. // Assembled native code for each instruction.
    
  99. typedef unordered_map<char, Bytes> Instructions;
    
  100. Instructions instructions;
    
  101. Bytes& instruction(char instc) {
    
  102.     auto it = instructions.find(instc);
    
  103.     if (it != instructions.end()) {
    
  104.         return it->second;
    
  105.     } else {
    
  106.         printf("unknown instruction: %c\n", instc);
    
  107.         exit(1);
    
  108.     }
    
  109. }
    
  110. 
    
  111. Bytes copy_range(void (*start)(void),void (*end)(void)) {
    
  112.     int len = (byte*)end - (byte*)start;
    
  113.     Bytes result;
    
  114.     result.resize(len);
    
  115.     memcpy(result.data(), (void*)start, len);
    
  116.     for(int i = 0; i < result.size(); i++) {
    
  117.         printf("%x ", result[i]);
    
  118.     }
    
  119.     puts("");
    
  120.     return result;
    
  121. }
    
  122. void init_instructions() {
    
  123.     instructions['+'] = copy_range(plus, plus_end);
    
  124.     instructions['-'] = copy_range(minus, minus_end);
    
  125.     instructions['<'] = copy_range(left, left_end);
    
  126.     instructions['>'] = copy_range(right, right_end);
    
  127.     instructions['.'] = copy_range(dot, dot_end);
    
  128.     instructions['['] = copy_range(bleft, bleft_end);
    
  129.     instructions[']'] = copy_range(bright, bright_end);
    
  130.     // end, prologue + ret used at end of block
    
  131.     instructions['e'] = copy_range(prologue, prologue_end);
    
  132. }
    
  133. void jit_init() {
    
  134.     init_instructions();
    
  135.     printf("+ len: %x\n", instruction('+').size());
    
  136.     printf("- len: %x\n", instruction('-').size());
    
  137.     printf("< len: %x\n", instruction('<').size());
    
  138.     printf("> len: %x\n", instruction('>').size());
    
  139.     printf(". len: %x\n", instruction('.').size());
    
  140.     printf("e len: %x\n", instruction('e').size());
    
  141.     printf("pagesize: %d\n", getpagesize());
    
  142. 
    
  143.     puts("jit setup complete");
    
  144. }
    
  145. 
    
  146. struct Block {
    
  147.     void *code;
    
  148.     // instruction count
    
  149.     int count;
    
  150. };
    
  151. 
    
  152. // mapping from pc to compiled blocks. 
    
  153. unordered_map<int, Block*> block_cache;
    
  154. 
    
  155. static int alloc_bytes = 0;
    
  156. static MM memory;
    
  157. 
    
  158. void run_block(Block* data) {
    
  159.     void (*block)(void);
    
  160.     block = reinterpret_cast<void(*)(void)>(data->code);
    
  161.     block();
    
  162.     pc += data->count;
    
  163. }
    
  164. 
    
  165. // Rewrite the last 4 bytes at code_output  to the jump offset.
    
  166. void fixup_last_branch(
    
  167.     int pc_dest,
    
  168.     byte* code_output,
    
  169.     int code_length,
    
  170.     const unordered_map<int, int>& jump_offsets) {
    
  171.     // x86_64 always has eip set to the following instruction. 
    
  172.     // to and from are relative to the start of the block.
    
  173.     auto dst = jump_offsets.find(pc_dest);
    
  174.     int to = dst != jump_offsets.end() 
    
  175.         ? dst->second 
    
  176.         : fatal("no jump");
    
  177.     int from = code_length;
    
  178.     int32_t delta = to - from;
    
  179.     memcpy(code_output - 4, &delta, sizeof(delta));
    
  180. }
    
  181. 
    
  182.     
    
  183. Block* compile_block(const string& prog, int pc, const Jumps& jumps) {
    
  184.     int start = pc;
    
  185.     int bytes = 0;
    
  186.     Block* block = block_cache[pc] = new Block;
    
  187.     assert(pc == 0);
    
  188. 
    
  189.     // pc -> byte offset in block of the following instruction
    
  190.     unordered_map <int, int> jump_offsets; 
    
  191.   
    
  192.     while (pc < prog.length()) {
    
  193.         char inst = prog[pc];
    
  194.         bytes += instruction(inst).size();
    
  195.         jump_offsets[pc] = bytes;
    
  196.         pc++;
    
  197.     }
    
  198.     bytes += instruction('e').size();
    
  199.     block->count = pc - start;
    
  200. 
    
  201.     byte *code;
    
  202.     memory.alloc((void**)&code, bytes);
    
  203.     alloc_bytes += bytes;
    
  204.     byte *write = code;
    
  205.     // body
    
  206.     int offset = 0;
    
  207.     while (start < pc) {
    
  208.         char instc = prog[start];
    
  209.         Bytes &instn = instruction(instc);
    
  210.         memcpy(write + offset, instn.data(), instn.size());
    
  211. 
    
  212.         offset += instn.size();
    
  213.         if (instc == '[') {
    
  214.             fixup_last_branch(jump(jumps.forward, start),
    
  215.                 write+offset, offset, jump_offsets);
    
  216.         } else if (instc == ']') {
    
  217.             fixup_last_branch(jump(jumps.backward, start),
    
  218.                 write+offset, offset, jump_offsets);
    
  219.         }
    
  220.         start++;
    
  221.     }
    
  222.     // prologue
    
  223.     {
    
  224.         Bytes &instn = instruction('e');
    
  225.         memcpy(write + offset, instn.data(), instn.size());
    
  226.     }
    
  227.     // printf("block at %x, len=%d\n", code, bytes);
    
  228. 
    
  229.     block->code = reinterpret_cast<void*>(code);
    
  230.     
    
  231.     return block;
    
  232. }
    
  233.     
    
  234. void bf(const string& prog) {
    
  235.     memset(tape, TAPE_LEN, 0);
    
  236.     Jumps jumps = build_jumps(prog);
    
  237.     int len = prog.length();
    
  238.     // One block now represents the entire program.
    
  239.     // The ret at the end represents the end of program.
    
  240.     Block* block = compile_block(prog, pc, jumps);
    
  241.     run_block(block);
    
  242. }
    
  243. 
    
  244. string slurp(const string& path) {
    
  245.     string output;
    
  246.     FILE* f = fopen(path.c_str(), "rb");
    
  247.     if (!f) {
    
  248.         perror("failed to open file");
    
  249.         exit(1);
    
  250.     }
    
  251.     char buffer[1024];
    
  252.     int goal = sizeof(buffer)-1;
    
  253.     while(true) {
    
  254.         size_t n = fread(buffer, 1, goal, f);
    
  255.         if (n < 0) {
    
  256.             perror("file read error");
    
  257.             exit(1);
    
  258.         } 
    
  259.         buffer[n] = 0;
    
  260.         output.append(buffer, n);
    
  261.         if (n < goal)
    
  262.             break;
    
  263.     }
    
  264.     return output;
    
  265. }
    
  266. 
    
  267. string clean(const string& input) {
    
  268.     string program;
    
  269.     program.reserve(input.length());
    
  270.     int i = 0;
    
  271.     while (i < input.length()) {
    
  272.         char c = input[i];
    
  273.         if (c == ';') {
    
  274.             // drop until newline
    
  275.             while (i < input.length() && 
    
  276.                 input[i] != '\n') {
    
  277.                 i++;
    
  278.             }
    
  279.             continue;
    
  280.         }
    
  281. 
    
  282.         switch(c) {
    
  283.             case '\n':
    
  284.                 [[fallthrough]];
    
  285.             case ' ':
    
  286.                 break;
    
  287.             default:
    
  288.                 program.push_back(c);
    
  289.         }
    
  290.         i++;
    
  291.     }
    
  292.     return program;
    
  293. }
    
  294.     
    
  295. int main(int argc, char** argv) {
    
  296.     if (argc < 2) {
    
  297.         puts("usage: ./bf file.bf");
    
  298.         return 0;
    
  299.     }
    
  300.     jit_init();
    
  301.     string program = slurp(argv[1]);
    
  302.     program = clean(program);
    
  303.     bf(program);
    
  304.     puts("done");
    
  305.     printf("jit stats: \n");
    
  306.     printf("alloc_bytes: %d\n", alloc_bytes);
    
  307.     printf("num blocks: %d\n", block_cache.size());
    
  308.     return 0;
    
  309. }