Commit 26d003cc authored by Guillaume Chérel's avatar Guillaume Chérel
Browse files

[Plugin] enh: Use frequency in abc hook.

parent 00420b67
......@@ -6,50 +6,52 @@ import org.openmole.core.dsl.extension._
object ABCHook {
def apply(abc: DSLContainer[ABC.ABCParameters], dir: FromContext[File], frequency: OptionalArgument[Long] = None)(implicit name: sourcecode.Name, definitionScope: DefinitionScope) =
def apply(abc: DSLContainer[ABC.ABCParameters], dir: FromContext[File], frequency: Long = 1)(implicit name: sourcecode.Name, definitionScope: DefinitionScope) =
Hook("ABCHook") { p
import p._
import org.openmole.plugin.tool.csv._
context(abc.data.state) match {
case MonAPMC.Empty() ()
case MonAPMC.State(_, s)
val step = context(abc.data.step)
val filePath = dir / s"step${step}.csv"
val file = filePath.from(context)
val size = s.thetas.size
val dim = s.thetas(0).size
val paramNames = abc.data.prior.map { x x.v.name }
val header =
(Vector("epsilon,pAcc,t,ts,rhos,weight") ++
paramNames).mkString(",")
val data =
(Vector.fill(size)(s.epsilon) zip
Vector.fill(size)(s.pAcc) zip
Vector.fill(size)(s.t) zip
s.ts zip
s.rhos zip
s.weights zip
s.thetas).map {
case ((((((epsilon, pAcc), t), ti), rhoi), wi), thetai)
epsilon.formatted("%.12f") ++ "," ++
pAcc.formatted("%.12f") ++ "," ++
t.formatted("%d") ++ "," ++
ti.formatted("%d") ++ "," ++
rhoi.formatted("%.12f") ++ "," ++
wi.formatted("%.12f") ++ "," ++
thetai.map { _.formatted("%.12f") }.mkString(",")
}.mkString("\n")
file.createParentDir
file.content = header ++ "\n" ++ data
if (context(abc.data.step) % frequency == 0) {
context(abc.data.state) match {
case MonAPMC.Empty() ()
case MonAPMC.State(_, s)
val step = context(abc.data.step)
val filePath = dir / s"step${step}.csv"
val file = filePath.from(context)
val size = s.thetas.size
val dim = s.thetas(0).size
val paramNames = abc.data.prior.map { x x.v.name }
val header =
(Vector("epsilon,pAcc,t,ts,rhos,weight") ++
paramNames).mkString(",")
val data =
(Vector.fill(size)(s.epsilon) zip
Vector.fill(size)(s.pAcc) zip
Vector.fill(size)(s.t) zip
s.ts zip
s.rhos zip
s.weights zip
s.thetas).map {
case ((((((epsilon, pAcc), t), ti), rhoi), wi), thetai)
epsilon.formatted("%.12f") ++ "," ++
pAcc.formatted("%.12f") ++ "," ++
t.formatted("%d") ++ "," ++
ti.formatted("%d") ++ "," ++
rhoi.formatted("%.12f") ++ "," ++
wi.formatted("%.12f") ++ "," ++
thetai.map { _.formatted("%.12f") }.mkString(",")
}.mkString("\n")
file.createParentDir
file.content = header ++ "\n" ++ data
}
}
context
......
......@@ -159,9 +159,9 @@ package object abc {
}
implicit class ABCContainer(dsl: DSLContainer[ABCParameters]) extends DSLContainerHook(dsl) {
def hook(directory: FromContext[File]): DSLContainer[ABC.ABCParameters] = {
def hook(directory: FromContext[File], frequency: Long = 1): DSLContainer[ABC.ABCParameters] = {
implicit val defScope = dsl.scope
dsl hook ABCHook(dsl, directory)
dsl hook ABCHook(dsl, directory, frequency)
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment