kagamihogeの日記

kagamihogeの日記です。

spring-batchのTaskExecutorによるstepマルチスレッド化

https://docs.spring.io/spring-batch/docs/current/reference/html/scalability.html#scalability のMulti-threaded Stepを試す。

plugins {
    id 'org.springframework.boot' version '2.2.2.RELEASE'
    id 'io.spring.dependency-management' version '1.0.8.RELEASE'
    id 'java'
}

sourceCompatibility = '11'
targetCompatibility = '11'

repositories {
    mavenCentral()
}

dependencies {
    implementation 'org.springframework.boot:spring-boot-starter-batch'
    runtimeOnly 'com.h2database:h2'
    testImplementation('org.springframework.boot:spring-boot-starter-test') {
        exclude group: 'org.junit.vintage', module: 'junit-vintage-engine'
    }
    testImplementation 'org.springframework.batch:spring-batch-test'
}

test {
    useJUnitPlatform()
}

taskExecutorを指定することでそのstepはマルチスレッドで実行される。以下のコードは1-1000の合計数を出すサンプルコードだが、意図通りには動作しない。その点は後述。

@SpringBootApplication
@EnableBatchProcessing
public class App {
    
    @Bean
    public Job job(JobBuilderFactory jobs, @Qualifier("s1") Step s1) {
        return jobs
                .get("myJob")
                .incrementer(new RunIdIncrementer())
                .start(s1)
                .listener(new JobExecutionListener() {
                    @Override
                    public void beforeJob(JobExecution jobExecution) {}
                    
                    @Override
                    public void afterJob(JobExecution jobExecution) {
                        System.out.println(sum);
                    }
                })
                .build();
    }
    
    int sum = 0;
    
    @Bean(name = "s1")
    public Step step1(StepBuilderFactory steps) {
        List<Integer> list = IntStream.range(1, 1001).boxed().collect(Collectors.toList());
        ItemReader<Integer> reader = new ListItemReader<Integer>(list);
        
        TaskletStep step = steps
                .get("step1")
                .<Integer ,Integer>chunk(10)
                .reader(reader)
                .writer(l -> {
                    l.forEach(i -> {
                        System.out.print(i+",");
                        sum += i;
                    });
                    System.out.println();
                })
                .taskExecutor(taskExecutor())
                .build();
        
        return step;
    }
    
    @Bean
    public TaskExecutor taskExecutor(){
        return new SimpleAsyncTaskExecutor("spring_batch");
    }

    public static void main(String[] args) {
        new SpringApplicationBuilder(App.class).run(args);
    }

}

これの実行時の様子は以下の通り。

4,9,17,22,24,29,31,34,36,1,3,9,14,17,22,24,29,31,34,36,
2,5,6,7,8,10,12,15,19,21,23,27,37,
28,30,32,
(省略)
sum=501421

期待される合計値は500500だがそうはならない。これの原因は、マルチスレッドでstepを実行する際に複数スレッドでreader-processor-writerインスタンスを共有するため。プログラム内容に応じて適切な同期化を施さないと期待通りには動かない。上記の通り、readの同期化が無いために9が2回出てきちゃったりしている。また、上記実行状況の通り処理順序がバラバラになるため、順序が重要な場合にはこの方法は適さない。

同期化で改善するには、まず、合計を保持する変数をAtomicIntegerにする。

次に、ItemReaderを同期化対応版にする。spring-batchが提供する各種実装の大半はスレッドセーフではない。javadocに何らか回避策について記述が有る場合もあれば無い場合もある。非スレッドセーフの場合SynchronizedItemStreamReaderでラッピングしたり、自前のラッピングクラス作ったりで対応する。今回はとりあえず自前のクラス作る方向にしてみる。

public class MyItemReader implements ItemReader<Integer> {
    ItemReader<Integer> reader;
    public MyItemReader(ItemReader<Integer> reader) {
        this.reader = reader;
    }
    
    public synchronized Integer read() throws UnexpectedInputException, ParseException, NonTransientResourceException, Exception {
        return reader.read();
    }

}

単純にListItemReaderをラッピングしてreadsynchronizedで行う。

改善後のソースコードは以下の通り。

@SpringBootApplication
@EnableBatchProcessing
public class App {
    
    @Bean
    public Job job(JobBuilderFactory jobs, @Qualifier("s1") Step s1) {/*省略*/}
    
    static AtomicInteger sum = new AtomicInteger(0);
    
    @Bean(name = "s1")
    public Step step1(StepBuilderFactory steps) {
        List<Integer> list = IntStream.range(1, 1001).boxed().collect(Collectors.toList());
        ItemReader<Integer> reader = new MyItemReader(new ListItemReader<Integer>(list));
        
        TaskletStep step = steps
                .get("step1")
                .<Integer ,Integer>chunk(10)
                .reader(reader)
                .writer(l -> {
                    l.forEach(i -> {
                        sum.addAndGet(i);
                        System.out.print(i+",");
                    });
                    System.out.println();
                })
                .taskExecutor(taskExecutor())
                .build();
        
        return step;
    }
    
    @Bean
    public TaskExecutor taskExecutor(){
        return new SimpleAsyncTaskExecutor("spring_batch");
    }

    public static void main(String[] args) {
        new SpringApplicationBuilder(App.class).run(args);
    }

}

感想とか

お手軽にstepをマルチスレッド化できる。ただし、処理順序であるとかreader-processor-writerに同期化の考慮が必要とか、若干のクセがあるのでそこは注意が必要。