#include <assert.h>
#include <cosmo.h>
#include <linux/rseq.h>
#include <pthread.h>

#define ITERATIONS 10000000l

struct {
  alignas(64) long x;
} counter[CPU_SETSIZE];

void hit() {
  long cpu_id;
  struct rseq *rseq = __get_rseq();
  // NOTE: this code is NOT production quality
  asm volatile(".pushsection .rodata.rseq,\"a\",@progbits\n"
               "	.balign	32\n"
               "300:	.long	0\n"          // rseq_cs::version
               "	.long	0\n"          // rseq_cs::flags
               "	.quad	301f\n"       // rseq_cs::start_ip
               "	.quad	302f-301f\n"  // rseq_cs::post_commit_offset
               "	.quad	303f\n"       // rseq_cs::abort_ip
               "	.popsection\n"
#ifdef __x86_64__
               "301:	movq	$300b,%0\n"  // give kernel above data
               "	movl	%2,%k1\n"    // fetch cpu_id_start
               "	.pushsection .text.unlikely,\"ax\",@progbits\n"
               "	.byte	0x0f,0xb9,0x3d\n"
               "	.long	0x53053053\n"
               "303:	jmp	301b\n"  // restart on abort
               "	.popsection"
               : "=m"(rseq->rseq_cs), "=r"(cpu_id)
               : "m"(rseq->cpu_id_start)
               : "memory");
#elifdef __aarch64__
               "301:	adrp	x16, 300b\n"
               "	add	x16, x16, :lo12:300b\n"
               "	str	x16, %0\n"  // give kernel above data
               "	ldr	%w1, %2\n"  // fetch cpu_id_start
               "	.pushsection .text.unlikely,\"ax\",@progbits\n"
               "	.long	0xd428bc00\n"
               "303:	b	301b\n"  // restart on abort
               "	.popsection"
               : "=m"(rseq->rseq_cs), "=r"(cpu_id)
               : "m"(rseq->cpu_id_start)
               : "x16", "memory");
#endif
  ++counter[cpu_id].x;
  asm volatile("\n302:");
}

long count(void) {
  long r = 0;
  for (long i = 0; i < CPU_SETSIZE; ++i)
    r += counter[i].x;
  return r;
}

void *worker(void *arg) {
  for (long i = 0; i < ITERATIONS; ++i)
    hit();
  return 0;
}

int main(void) {
  int threads = cosmo_cpu_count();
  pthread_t th[threads];
  for (long i = 0; i < threads; ++i)
    pthread_create(&th[i], 0, worker, 0);
  for (long i = 0; i < threads; ++i)
    pthread_join(th[i], 0);
  assert(count() == threads * ITERATIONS);
}