Skip to content

Instantly share code, notes, and snippets.

@iracigt
Created December 9, 2019 18:10
Show Gist options
  • Select an option

  • Save iracigt/32148aecd5145807acd6562af42c412d to your computer and use it in GitHub Desktop.

Select an option

Save iracigt/32148aecd5145807acd6562af42c412d to your computer and use it in GitHub Desktop.
A parser combinator library for parsing (and unparsing) binary data
import shapeless.Generic
case class Complex(re: Int, im: Int)
case class Foo(a: Int, b: Int, c: Int)
object ParserExample extends App {
def parseFoo: Parser[Foo] = UIntParser(2) ~ UIntParser(1) ~ UIntParser(1) ^^ Generic[Foo]
def parseComplex: Parser[Complex] = UIntParser(8) ~ UIntParser(8) ^^ Generic[Complex]
val cplx = Complex(123, 45)
val bits = parseComplex.unparse(cplx)
println(bits.map(_.asBinary).getOrElse("None"))
println(bits.map(parseComplex(_)).getOrElse("None"))
}
import shapeless._
import shapeless.ops.hlist.{Prepend, Split, Length}
import shapeless.ops.coproduct.Align
import scala.language.implicitConversions
import scala.collection.mutable.Buffer
object structs {
class ParserException(msg: String) extends Exception(msg, null)
trait ParserInput {
def take(n: Int): Option[(Seq[Boolean], ParserInput)]
def has(n: Int): Boolean
def consumed: Int
}
trait SeqInput extends ParserInput {
def toSeq: Seq[Boolean]
def ++(right: SeqInput): SeqInput
}
case class SeqBackedInput(seq: Seq[Boolean], consumed: Int = 0) extends SeqInput {
def has(n: Int) = seq.length >= n
def take(n: Int) = ((l: Seq[Boolean]) => if (l.length == n) Some((l, SeqBackedInput(seq.drop(n), consumed + n))) else None)(seq.take(n))
def toSeq = seq
def ++(right: SeqInput) = SeqBackedInput(seq ++ right.toSeq)
def ++(right: Seq[Boolean]) = SeqBackedInput(seq ++ right)
}
object ByteInput {
def apply(seq: Seq[Int]): SeqBackedInput = SeqBackedInput(seq.map(x => Array.tabulate(8)(i => (x & (1 << i)) != 0).reverse).flatten)
}
implicit class SeqInputExtensions(val s: SeqInput) extends AnyVal {
def asBytes = s.toSeq.sliding(8,8).map(_.foldRight(0)((bit, acc) => (acc >> 1) | (if (bit) 1 << 7 else 0))).toSeq
def asBinary = s.toSeq.map(if(_) "1" else "0").mkString("")
def asHex = s.asBytes.map(x => f"$x%02X").mkString(" ")
}
case class WrapperInput(in: ParserInput, consumed: Int) extends ParserInput {
def take(n: Int) = in.take(n).map({ case (x, rem) => (x, WrapperInput(rem, consumed + n))})
def has(n: Int) = in.has(n)
}
sealed trait ParseResult[+A] {
def map[B](f: A => B): ParseResult[B]
def flatMapWithNext[B](f: A => ParserInput => ParseResult[B]): ParseResult[B]
def merge[C >: A, B <: C](r: Unit => ParseResult[B]): ParseResult[C]
}
object ParseResult {
final case object TooShortException extends ParserException("Input too short")
final case class Success[A](value: A, input: ParserInput) extends ParseResult[A] {
def map[B](f: A => B): ParseResult[B] = Success(f(value), input)
def flatMapWithNext[B](f: A => ParserInput => ParseResult[B]): ParseResult[B] = f(value)(input)
def merge[C >: A, B <: C](r: Unit => ParseResult[B]): ParseResult[C] = this
}
final case class Failure[A](reason: Exception, input: ParserInput) extends ParseResult[A] {
def map[B](f: A => B): ParseResult[B] = Failure(reason, input)
def flatMapWithNext[B](f: A => ParserInput => ParseResult[B]): ParseResult[B] = Failure(reason, input)
def merge[C >: A, B <: C](r: Unit => ParseResult[B]): ParseResult[C] = r(())
}
def TooShort[A](buffer: ParserInput) = Failure[A](TooShortException, buffer)
def Fail[A](msg: String, input: ParserInput) = Failure[A](new ParserException(msg), input)
}
trait Parser[T] {
def parse(in: ParserInput): ParseResult[T]
def unparse(x: T): Option[SeqInput]
def apply(in: ParserInput): ParseResult[T] = parse(in)
}
trait HParser[L <: HList] extends Parser[L] {
def ~[R <: HList, T <: HList, N <: Nat](right: HParser[R])
(implicit prep: Prepend[L, R]{type Out = T}, split: Split.Aux[T, N, L, R],
len: Length[L]{type Out = N}) = new HParser[T] {
def parse(in: ParserInput) = HParser.this.parse(in).
flatMapWithNext(l => (rem => right.parse(rem).map(r => prep(l, r))))
def unparse(x: T) = split(x) match {
case (l, r) => HParser.this.unparse(l).flatMap(a => right.unparse(r).map(b => a ++ b))
}
}
def ^^[O](t: Generic[O]{type Repr = L}): Parser[O] = new Parser[O] {
def parse(in: ParserInput) = HParser.this.parse(in).map(t.from(_))
def unparse(x: O) = HParser.this.unparse(t.to(x))
}
}
case class WrappedHParser[T](par: Parser[T]) extends HParser[T :: HNil] {
def parse(in: ParserInput) = par.parse(in).map(_ :: HNil)
def unparse(x: T :: HNil) = par.unparse(x.head)
}
implicit def parser2HParser[T](par: Parser[T]): HParser[T :: HNil] = WrappedHParser(par)
trait CParser[L <: Coproduct] {
def cparse(buffer: ParserInput): ParseResult[L]
def cunparse(x: L): Option[SeqInput]
def |[U](right: Parser[U]): CParser[U :+: L] = new CParser[U :+: L] {
def cparse(buffer: ParserInput) = CParser.this.cparse(buffer).
map[:+:[U,L]](Inr[U, L](_)).merge(_ => right.parse(buffer).map(Inl[U, L](_)))
def cunparse(x: (U :+: L)) = x match {
case Inl(l) => right.unparse(l)
case Inr(r) => CParser.this.cunparse(r)
}
}
def |>[M <: Coproduct, O](g: Generic[O] { type Repr = M })
(implicit align: Align[L, M], align2: Align[M, L]) = new Parser[O] {
def parse(buffer: ParserInput) = CParser.this.cparse(buffer).map(x => g.from(align(x)))
def unparse(x: O) = CParser.this.cunparse(align2(g.to(x)))
}
}
case class WrappedCParser[T](parser: Parser[T]) extends CParser[T :+: CNil] {
def cparse(buffer: ParserInput) = parser.parse(buffer).map(Inl[T, CNil](_))
def cunparse(x: (T :+: CNil)) = x match {
case Inl(l) => parser.unparse(l)
case Inr(_) => ??? // This should be equivalent to Nothing
}
}
implicit def parser2CParser[T](parser: Parser[T]): CParser[T :+: CNil] = WrappedCParser(parser)
/**
* A parser whose parse result is ignored. Useful for fixed fields.
*
*/
final case class IgnoreParser(parser: Parser[Unit]) extends HParser[HNil] {
override def parse(buffer: ParserInput): ParseResult[HNil] = parser.parse(buffer).map(_ => HNil)
override def unparse(x: HNil): Option[SeqInput] = parser.unparse(())
}
/**
* A parser that consumes a fixed field. Useful for matching specific values.
*
*/
final case class FixedParser(pat: Seq[Boolean]) extends Parser[Unit] {
def parse(in: ParserInput): ParseResult[Unit] = in.take(pat.length).map({ case (x, rem) =>
if (x.sameElements(pat)) ParseResult.Success((), rem)
else ParseResult.Fail("Mismatch", in)
}).getOrElse(ParseResult.TooShort(in))
def unparse(x: Unit) = Some(SeqBackedInput(pat))
}
/**
* Convienence method for using bitmasks
*/
def FixedParser(len: Int, pat: Int): FixedParser = FixedParser(Array.tabulate(len)(i => (pat & (1 << i)) != 0).reverse)
/**
* A parser that parses any len bits
*
*/
final case class DontCareParser(len: Int) extends Parser[Unit] {
def parse(in: ParserInput): ParseResult[Unit] = in.take(len).map(
{ case (_, rem) => ParseResult.Success((), rem) }
).getOrElse(ParseResult.TooShort(in))
def unparse(x: Unit) = Some(SeqBackedInput(Array.fill(len)(false)))
}
/**
* A parser that always fails
*
*/
final case class CantParse[T]() extends Parser[T] {
def parse(buffer: ParserInput) = ParseResult.Fail("CantParse", buffer)
def unparse(x: T) = None
}
/**
* A parser that parses 0 bits to a constant
*
*/
final case class ConstParse[T](c: T) extends Parser[T] {
def parse(buffer: ParserInput) = ParseResult.Success(c, buffer)
def unparse(x: T) = Some(SeqBackedInput(Seq()))
}
/**
* Boolean (1-bit) parser
*/
final case object BoolParser extends Parser[Boolean] {
def parse(in: ParserInput): ParseResult[Boolean] = in.take(1).map(
{ case(x, rem) => ParseResult.Success(x(0), rem) }
).getOrElse(ParseResult.TooShort(in))
def unparse(x: Boolean) = Some(SeqBackedInput(Array(x)))
}
/** CURRENTLY UNUSED */
sealed trait BitOrder
final case object MSBFirst extends BitOrder
final case object LSBFirst extends BitOrder
/**
* Big-Endian (MSB first) unsigned integer parser
*
* @Author: Grant Iraci
*/
final case class UIntParser(len: Int) extends Parser[Int] {
def parse(in: ParserInput): ParseResult[Int] = in.take(len).map({
case (x, rem) => ParseResult.Success(x.foldLeft(0)((acc, bit) => (acc << 1) | (if (bit) 1 else 0)), rem)
}).getOrElse(ParseResult.TooShort(in))
def unparse(x: Int) = {
if (x < 0 || (x & ~((1 << len) - 1)) != 0) {
None
} else {
Some(SeqBackedInput(Array.tabulate(len)(i => (x & (1 << i)) != 0).reverse))
}
}
}
/**
* Big-Endian (MSB first) signed integer parser
*
* @Author: Grant Iraci
*/
final case class SIntParser(len: Int) extends Parser[Int] {
def parse(in: ParserInput): ParseResult[Int] = in.take(len).map({
case (x, rem) => ParseResult.Success(x.foldLeft(if (x(0)) -1 else 0)((acc, bit) => (acc << 1) | (if (bit) 1 else 0)), rem)
}).getOrElse(ParseResult.TooShort(in))
def unparse(x: Int) = {
if ((x & ~((1 << len) - 1)) != 0 && (x | ((1 << len) - 1)) != -1) {
None
} else {
Some(SeqBackedInput(Array.tabulate(len)(i => (x & (1 << i)) != 0).reverse))
}
}
}
/**
* A parser that handles checksums. Requires a checksum function that maps bits to bits
*
*/
final case class ChecksumParser[T <: HList](parser: HParser[T],
chksum: Seq[Boolean] => Seq[Boolean]) extends HParser[T] {
def parse(in: ParserInput) = parser.parse(WrapperInput(in, 0)).flatMapWithNext(
x => wrap => in.take(wrap.consumed).map({
case (buf, rem) => FixedParser(chksum(buf))(rem).map(_ => x)
}).getOrElse(ParseResult.Fail("Checksum internal error. Mutating ParserInput?", in)))
def unparse(x: T) = parser.unparse(x).map(a => a ++ SeqBackedInput(chksum(a.toSeq)))
}
object ChecksumParser {
private[this] def convertChksum(chksum: Seq[Int] => Int, width: Int, blkSize: Int = 8)(data: Seq[Boolean]) =
((x: Int) => Array.tabulate(width)(i => (x & (1 << i)) != 0).reverse)(
chksum(data.sliding(blkSize,blkSize).map(_.foldRight(0)((bit, acc) => (acc >> 1) | (if (bit) 1 << (blkSize - 1) else 0))).toSeq)
)
/**
* Convenience method. Requires a checksum function that maps bytes to an Int.
*
*/
def apply[T <: HList](chksum: Seq[Int] => Int, width: Int, blkSize: Int = 8)(parser: HParser[T]): ChecksumParser[T] =
ChecksumParser(parser, convertChksum(chksum, width, blkSize)(_))
}
/**
* A parser that parses a fixed number of *bytes*
*
*/
final case class FixedLenParser(len: Int) extends Parser[Array[Int]] {
def parse(in: ParserInput) = Array.fill(len)(UIntParser(8)).foldLeft(
(in: ParserInput) => ParseResult.Success(Array.empty[Int], in): ParseResult[Array[Int]]
)(
(a,b) => (in: ParserInput) => a(in).flatMapWithNext((x: Array[Int]) => (rem: ParserInput) => b.parse(rem).map(x :+ _))
)(in)
def unparse(x: Array[Int]) = x.map(UIntParser(8).unparse(_)).foldRight(Option(Array.empty[SeqInput]))(
(a, b) => a.flatMap(c => b.map(c +: _))).map(_.reduceLeft(_ ++ _))
}
/**
* A parser that parses a variable number of *bytes*
*
*/
final case class VarLenParser(len: Parser[Int]) extends Parser[Array[Int]] {
def parse(in: ParserInput) = len.parse(in).flatMapWithNext(len => in => FixedLenParser(len).parse(in))
def unparse(x: Array[Int]) = len.unparse(x.length).flatMap(a => FixedLenParser(x.length).unparse(x).map(a ++ _))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment