Commit 8ebec709 authored by Romain Reuillon's avatar Romain Reuillon

[Core] fix+enh: refactor csv output format

parent 147302ed
......@@ -42,11 +42,12 @@ package object csv {
case s: ScalarData List(s.v)
}
sealed trait CSVData
case class ScalarData(v: Any) extends CSVData
case class ArrayData(v: List[Any]) extends CSVData
}
sealed trait CSVData
case class ScalarData(v: Any) extends CSVData
case class ArrayData(v: List[Any]) extends CSVData
import CSVData._
def valuesToData(values: Seq[Any]) =
values.map {
......@@ -55,73 +56,60 @@ package object csv {
case v ScalarData(v)
}.toList
def header(prototypes: Seq[Val[_]], values: Seq[Any], arraysOnSingleRow: Boolean = false) = {
val lists =
values.map {
case v: Array[_] v.toList
case l: List[_] l
case v List(v)
}.toList
(prototypes zip lists).flatMap {
case (p, l)
if (arraysOnSingleRow && moreThanOneElement(l))
(0 until l.size).map(i s"${p.name}$i")
else List(p.name)
}.mkString(",")
}
def header(prototypes: Seq[Val[_]]) = prototypes.map(_.name).mkString(",")
def writeVariablesToCSV(
output: PrintStream,
header: Option[String] = None,
values: Seq[Any],
arraysOnSingleRow: Boolean = false): Unit = {
header.foreach(h output.appendLine { h })
output: PrintStream,
header: Option[String] = None,
values: Seq[Any],
unrollArray: Boolean = false): Unit = {
// TODO add option to flatten multidim arrays here be be written on multiple lines
header.foreach(h output.appendLine { h })
def flatAny(o: Any): List[Any] = o match {
case o: List[_] o
case _ List(o)
}
def quote(v: Any): String =
v match {
case v: Array[_] s""""${format(v)}""""
case v: Seq[_] s""""${format(v)}""""
case v v.prettify()
}
def writeData(data: List[CSVData]): Unit = {
val scalars = data.collect { case x: ScalarData x }
if (scalars.size == data.size) writeLine(scalars.map(_.v))
else if (arraysOnSingleRow) {
val lists = data.map(CSVData.toList)
writeLine(lists.flatten(flatAny))
def format(v: Any): String =
v match {
case v: Array[_] s"[${v.map(format).mkString(",")}]"
case v: Seq[_] s"[${v.map(format).mkString(",")}]"
case v v.prettify()
}
else writeArrayData(data)
}
@tailrec def writeArrayData(data: List[CSVData]): Unit = {
if (data.collect { case l: ArrayData l }.forall(_.v.isEmpty)) Unit
else {
val lists = data.map(CSVData.toList)
writeLine(lists.map { _.headOption.getOrElse("") })
def tail(d: CSVData) =
d match {
case a @ ArrayData(Nil) a
case a: ArrayData a.copy(a.v.tail)
case s: ScalarData s
def csvLine(v: Seq[Any]): String = v.map(quote).mkString(",")
def unroll(v: Seq[Any]) = {
def writeLines(lists: Seq[List[Any]]): Unit = {
output.appendLine(csvLine(lists.map(_.head)))
val lastLine = lists.forall(_.tail.isEmpty)
if (!lastLine) {
val skipHead = lists.map {
case h :: Nil h :: Nil
case _ :: t t
case Nil Nil
}
writeArrayData(data.map(tail))
writeLines(skipHead)
}
}
}
def writeLine[T](list: List[T]) = {
output.appendLine(list.map(l {
val prettified = l.prettify()
def shouldBeQuoted = prettified.contains(',') || prettified.contains('"')
def quote(s: String) = '"' + s.replaceAll("\"", "\"\"") + '"'
if (shouldBeQuoted) quote(prettified) else prettified
}).mkString(","))
def lists: Seq[List[Any]] =
v map {
case v: Array[_] v.toList
case v: Seq[_] v.toList
case v List(v)
}
writeLines(lists)
}
writeData(valuesToData(values))
if (unrollArray) unroll(values)
else output.appendLine(csvLine(values))
}
/**
......
package org.openmole.core.csv
import java.io.PrintStream
import org.openmole.tool.stream.StringOutputStream
import org.scalatest._
import org.scalatest.junit._
class CSVSpec extends FlatSpec with Matchers {
def result(f: PrintStream Unit): String = {
val result = new StringOutputStream()
val printStream = new PrintStream(result)
f(printStream)
printStream.close()
result.builder.toString
}
"Function" should "produce conform csv" in {
assert(
result(writeVariablesToCSV(_, None, Seq(42, 56, Array(89, 89)))) ===
"""42,56,"[89,89]"
|""".stripMargin
)
assert(
result(writeVariablesToCSV(_, None, Seq(42, 56, Array(89, 101)), unrollArrays = true)) ===
"""42,56,89
|42,56,101
|""".stripMargin
)
}
}
......@@ -14,14 +14,14 @@ object CSVHook {
apply(output, values.toVector)
def apply(
output: WritableOutput,
values: Seq[Val[_]] = Vector.empty,
exclude: Seq[Val[_]] = Vector.empty,
header: OptionalArgument[FromContext[String]] = None,
arrayOnRow: Boolean = false,
overwrite: Boolean = false)(implicit name: sourcecode.Name, definitionScope: DefinitionScope): mole.FromContextHook =
output: WritableOutput,
values: Seq[Val[_]] = Vector.empty,
exclude: Seq[Val[_]] = Vector.empty,
header: OptionalArgument[FromContext[String]] = None,
unrollArray: Boolean = true,
overwrite: Boolean = false)(implicit name: sourcecode.Name, definitionScope: DefinitionScope): mole.FromContextHook =
FormattedFileHook(
format = CSVOutputFormat(header = header, arrayOnRow = arrayOnRow, append = !overwrite),
format = CSVOutputFormat(header = header, unrollArray = unrollArray, append = !overwrite),
output = output,
values = values,
exclude = exclude,
......@@ -34,7 +34,7 @@ object CSVHook {
override def write(format: CSVOutputFormat, output: WritableOutput, variables: Seq[Variable[_]]): FromContext[Unit] = FromContext { p
import p._
def headerLine = format.header.map(_.from(context)) getOrElse csv.header(variables.map(_.prototype), variables, format.arrayOnRow)
def headerLine = format.header.map(_.from(context)) getOrElse csv.header(variables.map(_.prototype))
output match {
case WritableOutput.FileValue(file)
......@@ -43,13 +43,13 @@ object CSVHook {
val h = if (f.isEmpty) Some(headerLine) else None
if (create) f.atomicWithPrintStream { ps csv.writeVariablesToCSV(ps, h, variables.map(_.value), format.arrayOnRow) }
else f.withPrintStream(append = true, create = true) { ps csv.writeVariablesToCSV(ps, h, variables.map(_.value), format.arrayOnRow) }
if (create) f.atomicWithPrintStream { ps csv.writeVariablesToCSV(ps, h, variables.map(_.value), unrollArray = format.unrollArray) }
else f.withPrintStream(append = true, create = true) { ps csv.writeVariablesToCSV(ps, h, variables.map(_.value), unrollArray = format.unrollArray) }
case WritableOutput.StreamValue(ps, prelude)
prelude.foreach(ps.print)
val header = Some(headerLine)
csv.writeVariablesToCSV(ps, header, variables, format.arrayOnRow)
csv.writeVariablesToCSV(ps, header, variables, unrollArray = format.unrollArray)
}
}
......@@ -63,8 +63,8 @@ object CSVHook {
}
case class CSVOutputFormat(
header: OptionalArgument[FromContext[String]] = None,
arrayOnRow: Boolean = false,
append: Boolean = false)
header: OptionalArgument[FromContext[String]] = None,
unrollArray: Boolean = false,
append: Boolean = false)
}
......@@ -38,7 +38,7 @@ package object directsampling {
def hook[T: OutputFormat](
output: WritableOutput,
values: Seq[Val[_]] = Vector.empty,
format: T = CSVOutputFormat()): DSLContainer[DirectSampling] = {
format: T = CSVOutputFormat(append = true)): DSLContainer[DirectSampling] = {
implicit val defScope = dsl.scope
dsl hook FormattedFileHook(output = output, values = values, format = format)
}
......@@ -49,7 +49,7 @@ package object directsampling {
output: WritableOutput,
values: Seq[Val[_]] = Vector.empty,
includeSeed: Boolean = false,
format: T = CSVOutputFormat()): DSLContainer[Replication] = {
format: T = CSVOutputFormat(append = true)): DSLContainer[Replication] = {
implicit val defScope = dsl.scope
val exclude = if (!includeSeed) Seq(dsl.data.seed) else Seq()
dsl hook FormattedFileHook(output = output, values = values, exclude = exclude, format = format)
......
......@@ -63,7 +63,7 @@ object SavePopulationHook {
}
def apply[T, F: OutputFormat](algorithm: T, output: WritableOutput, frequency: OptionalArgument[Long] = None, last: Boolean = false, format: F = CSVOutputFormat())(implicit wfi: WorkflowIntegration[T], name: sourcecode.Name, definitionScope: DefinitionScope) = {
def apply[T, F: OutputFormat](algorithm: T, output: WritableOutput, frequency: OptionalArgument[Long] = None, last: Boolean = false, format: F = CSVOutputFormat(unrollArray = true))(implicit wfi: WorkflowIntegration[T], name: sourcecode.Name, definitionScope: DefinitionScope) = {
val t = wfi(algorithm)
hook(t, output, frequency.option, last = last, format = format)
}
......@@ -72,24 +72,14 @@ object SavePopulationHook {
object SaveLastPopulationHook {
def apply[T](algorithm: T, file: FromContext[File])(implicit wfi: WorkflowIntegration[T], name: sourcecode.Name, definitionScope: DefinitionScope) = {
def apply[T, F](algorithm: T, output: WritableOutput, format: F = CSVOutputFormat(unrollArray = true))(implicit wfi: WorkflowIntegration[T], name: sourcecode.Name, definitionScope: DefinitionScope, outputFormat: OutputFormat[F]) = {
val t = wfi(algorithm)
Hook("SaveLastPopulationHook") { p
import p._
import org.openmole.core.csv
val values = SavePopulationHook.resultVariables(t).from(context).map(_.value)
def headerLine = csv.header(SavePopulationHook.resultVariables(t).from(context).map(_.prototype.array), values)
file.from(context).withPrintStream(create = true) { ps
csv.writeVariablesToCSV(
ps,
Some(headerLine),
values
)
}
outputFormat.write(format, output, SavePopulationHook.resultVariables(t).from(context)).from(context)
context
} set (inputs += (t.populationPrototype, t.statePrototype))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment