GCJのためのMulti-thread Framework
ということで、とりあえず書いた。これがなかなか面倒だった。
まあ、設計に問題があったりもするかもしれないが、とりあえず現在動いているので、この設計方針をメモ。ソースコード自体は最後に貼り付けておく。
まず、やりたかったこととしては、「今までのプログラムが(ほとんど)そのまま動く」ということ。マルチスレッドを意識しながら描くのは非常にめんどうなので、そんなことは考えないようなことをやりたかった。
ちなみに今までは、以下のようなものを使っていた。
import java.util.*; import java.math.*; class Main{ public static void main(String [] argv){ Scanner sc = new Scanner(System.in); int n = sc.nextInt(); for(int i = 0; i < n; i++){ System.out.printf("Case #%d: ", i+1); solve(sc); } } public static void solve(Scanner sc){ } }
この中で、solveの中で好き勝手にやればよい、というのが今までのやり方だった。これをマルチスレッド化するとどうなるか、と言うと、上記のfor分のところを適当に並列化すればよい、と思うのだが、実はそれだけではうまくいかない。どこがうまくいかないかと言えば、IOの部分だ。
まず、Inputの部分としては、これは全てScannerクラスから読み出すわけだが、このScannerの使用に関してはlockをかけるなり、synchronizedにするなり、ということをしなくてはならない。が、synchronizedを使うには「必ず読み込みをメソッド切り分けしなくてはならない」という制約が加わる。わざわざ一次変数で値を受け取って、それから・・・というのは非常にやりにくいこともある。このような自由度を削るのは美しくない。したがって、lockを使うことにした。今回はReentrantLockを使ったが、別に何でもいい。で、読み込みが終わったところでlockをunlockする。このunlock処理は絶対に呼ばなくてはならない、というのが美しくないが、これは我慢することにしよう。
次にOutputの部分だ。まず、やりたいこととしては、スレッド並列で実行する以上、各タスクの終了はどのような順番になるかわからないのでそれをちゃんとシリアライズする、ということだ。で、それだけならいいのだが、できることならば実行がどこまで進んだかをチェックしたい。したがって、「printする時には一旦標準エラー出力に出しておいて、すべての実行が終わった後で改めて標準出力に出す」と言う方針。このためには、print, println, printfのそれぞれと同じインターフェースを持つ関数を自分で定義しなくてはならない。仕方がないので、これは適当に書いてしまった。そのせいでソースが長くなってしまっている。特に、Javaのprintとprintlnの仕様が各primitive型について定義されているので、面倒である。
これらのことを、全部まとめて書いてしまったので、今後使う時には、
System.out.println(hogehoge)
の代わりに
myout.println(hogehoge);
と書けばよい。
ちなみに、Runnnableを継承させたクラスをMainではなくてTaskPoolという別クラスにしているが、これは少しでもMainで定義されている変数を減らしたかったという意図でしかない。それでも、LockやらScannerやらがMainクラスで定義されているのは、Mainから呼ぶ場合にTaskPool.scなどとわざわざ書くのがめんどくさかった、と言うだけのことだ。
import java.util.*; import java.math.*; import java.util.concurrent.locks.ReentrantLock; class TaskPool implements Runnable{ public static int tasksize; public static int now = 0; public void run(){ while(true){ Main.lock.lock(); if(now >= tasksize) break; //exec int problem_num = now; now++; Main main = new Main(); main.solve(problem_num, Main.sc); } Main.lock.unlock(); // important } } class Main extends Thread{ final static int NUM_THREADS = 8; public static ReentrantLock lock; static Scanner sc; public Myout myout; static{ lock = new ReentrantLock(); } public static void main(String argv){ sc = new Scanner(System.in); int n = sc.nextInt(); TaskPool.tasksize = n; Thread threads = new Thread[NUM_THREADS]; for(int i = 0; i < NUM_THREADS; i++){ threads[i] = new Thread(new TaskPool()); threads[i].start(); } for(int i = 0; i < NUM_THREADS; i++){ try{ threads[i].join(); }catch(InterruptedException e){ System.err.println(e); } } Myout.flush(); } public void solve(int problem_number, Scanner sc){ this.myout = Myout.get(problem_number); //read phase //read ends lock.unlock(); } } class Myout implements Comparable{ public int num; private String str; private boolean state; private static LinkedList list; static{ list = new LinkedList (); } private Myout(int problem_number){ this.num = problem_number; this.state = false; this.str = ""; } public static Myout get(int problem_number){ Myout out = new Myout(problem_number); list.add(out); return out; } public void printFirst(){ if(!state){ System.err.printf("Case #%d: ", num + 1); } state = true; } public void print(boolean x){ printMain("" + x); } public void print(char x){ printMain("" + x); } public void print(char x){ printMain("" + x); } public void print(double x){ printMain("" + x); } public void print(float x){ printMain("" + x); } public void print(int x){ printMain("" + x); } public void print(long x){ printMain("" + x); } public void print(Object x){ printMain("" + x); } public void println(){ printlnMain(""); } public void println(boolean x){ printlnMain("" + x); } public void println(char x){ printlnMain("" + x); } public void println(char x){ printlnMain("" + x); } public void println(double x){ printlnMain("" + x); } public void println(float x){ printlnMain("" + x); } public void println(int x){ printlnMain("" + x); } public void println(long x){ printlnMain("" + x); } public void println(Object x){ printlnMain("" + x); } private void printlnMain(String str){ printFirst(); this.str = this.str + str; System.err.println(str); } public void printf(String format, Object... args){ printFirst(); String f = String.format(format, args); this.str = this.str + f; System.err.print(f); } private void printMain(String str){ printFirst(); this.str = this.str + str; System.err.print(str); } public static void flush(){ Collections.sort(list); for(Myout m: list){ System.out.printf("Case #%d: ", m.num + 1); System.out.println(m.str); } } public int compareTo(Myout m){ return this.num - m.num; } }