Commit ce55c1e8 authored by Romain Reuillon's avatar Romain Reuillon
Browse files

Merge branch 'fixabc' into 'master'

Fixabc

See merge request !21
parents 10b8cbae 26d003cc
......@@ -6,50 +6,52 @@ import org.openmole.core.dsl.extension._
object ABCHook {
def apply(abc: DSLContainer[ABC.ABCParameters], dir: FromContext[File], frequency: OptionalArgument[Long] = None)(implicit name: sourcecode.Name, definitionScope: DefinitionScope) =
def apply(abc: DSLContainer[ABC.ABCParameters], dir: FromContext[File], frequency: Long = 1)(implicit name: sourcecode.Name, definitionScope: DefinitionScope) =
Hook("ABCHook") { p
import p._
import org.openmole.plugin.tool.csv._
context(abc.data.state) match {
case MonAPMC.Empty() ()
case MonAPMC.State(_, s)
val step = context(abc.data.step)
val filePath = dir / s"step${step}.csv"
val file = filePath.from(context)
val size = s.thetas.size
val dim = s.thetas(0).size
val paramNames = abc.data.prior.map { x x.v.name }
val header =
(Vector("epsilon,pAcc,t,ts,rhos,weight") ++
paramNames).mkString(",")
val data =
(Vector.fill(size)(s.epsilon) zip
Vector.fill(size)(s.pAcc) zip
Vector.fill(size)(s.t) zip
s.ts zip
s.rhos zip
s.weights zip
s.thetas).map {
case ((((((epsilon, pAcc), t), ti), rhoi), wi), thetai)
epsilon.formatted("%.12f") ++ "," ++
pAcc.formatted("%.12f") ++ "," ++
t.formatted("%d") ++ "," ++
ti.formatted("%d") ++ "," ++
rhoi.formatted("%.12f") ++ "," ++
wi.formatted("%.12f") ++ "," ++
thetai.map { _.formatted("%.12f") }.mkString(",")
}.mkString("\n")
file.createParentDir
file.content = header ++ "\n" ++ data
if (context(abc.data.step) % frequency == 0) {
context(abc.data.state) match {
case MonAPMC.Empty() ()
case MonAPMC.State(_, s)
val step = context(abc.data.step)
val filePath = dir / s"step${step}.csv"
val file = filePath.from(context)
val size = s.thetas.size
val dim = s.thetas(0).size
val paramNames = abc.data.prior.map { x x.v.name }
val header =
(Vector("epsilon,pAcc,t,ts,rhos,weight") ++
paramNames).mkString(",")
val data =
(Vector.fill(size)(s.epsilon) zip
Vector.fill(size)(s.pAcc) zip
Vector.fill(size)(s.t) zip
s.ts zip
s.rhos zip
s.weights zip
s.thetas).map {
case ((((((epsilon, pAcc), t), ti), rhoi), wi), thetai)
epsilon.formatted("%.12f") ++ "," ++
pAcc.formatted("%.12f") ++ "," ++
t.formatted("%d") ++ "," ++
ti.formatted("%d") ++ "," ++
rhoi.formatted("%.12f") ++ "," ++
wi.formatted("%.12f") ++ "," ++
thetai.map { _.formatted("%.12f") }.mkString(",")
}.mkString("\n")
file.createParentDir
file.content = header ++ "\n" ++ data
}
}
context
......
......@@ -113,7 +113,7 @@ package object abc {
val masterState = Val[MonAPMC.MonState]("masterState", abcNamespace)
val islandState = state
val step = Val[Int]("step", abcNamespace)
val step = Val[Int]("masterStep", abcNamespace)
val stop = Val[Boolean]
val n = sample + generated
......@@ -127,7 +127,8 @@ package object abc {
val master =
MoleTask(appendSplit -- terminationTask) set (
exploredOutputs += islandState.array
exploredOutputs += islandState.array,
step := 0
)
val slave =
......@@ -158,9 +159,9 @@ package object abc {
}
implicit class ABCContainer(dsl: DSLContainer[ABCParameters]) extends DSLContainerHook(dsl) {
def hook(directory: FromContext[File]): DSLContainer[ABC.ABCParameters] = {
def hook(directory: FromContext[File], frequency: Long = 1): DSLContainer[ABC.ABCParameters] = {
implicit val defScope = dsl.scope
dsl hook ABCHook(dsl, directory)
dsl hook ABCHook(dsl, directory, frequency)
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment