1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
use crate::builtin::DIALECT_NAME;
use crate::parser::{single_block_region, AsmPResult, ParseStream};
use crate::{IRFormatter, Op, OpAssembly, OpImpl, OpRef, Printable, RegionRef, Terminator};
use tir_macros::{op_implements, Op, OpAssembly, OpValidator};
use winnow::Parser;

use crate as tir_core;

#[derive(Op, Debug, OpValidator)]
#[operation(name = "module", dialect = builtin)]
pub struct ModuleOp {
    #[region(single_block, no_args)]
    body: RegionRef,
    r#impl: OpImpl,
}

#[derive(Op, Debug, OpValidator, OpAssembly)]
#[operation(name = "module_end", dialect = builtin)]
pub struct ModuleEndOp {
    r#impl: OpImpl,
}

#[op_implements(dialect = builtin)]
impl Terminator for ModuleEndOp {}

impl OpAssembly for ModuleOp {
    fn parse_assembly(input: &mut ParseStream) -> AsmPResult<OpRef>
    where
        Self: Sized,
    {
        let ops = single_block_region.parse_next(input)?;
        let context = input.state.get_context();
        let module = ModuleOp::builder(&context).build();
        for op in ops {
            module.borrow_mut().get_body().push(&op);
        }
        Ok(module)
    }

    fn print_assembly(&self, fmt: &mut dyn IRFormatter) {
        fmt.start_region();
        let body = self.get_body();
        for op in body.iter() {
            op.borrow().print(fmt);
        }
        fmt.end_region();
    }
}

#[cfg(test)]
mod test {
    use std::any::TypeId;

    use super::*;
    use crate::{
        parse_ir,
        utils::{op_dyn_cast, op_has_trait},
        Context, Printable, StringPrinter,
    };

    #[test]
    fn test_module() {
        assert!(ModuleOp::get_operation_name() == "module");

        let context = Context::new();
        let module = ModuleOp::builder(&context).build();
        module.borrow().get_body_region();
        module.borrow().get_body();
        module.borrow().get_context();
        assert!(!op_has_trait::<dyn Terminator>(module.clone()));
        assert!(op_dyn_cast::<dyn Terminator>(module).is_none());
    }

    #[test]
    fn test_module_end() {
        assert!(ModuleEndOp::get_operation_name() == "module_end");

        let context = Context::new();

        let module_end = ModuleEndOp::builder(&context).build();
        assert!(op_has_trait::<dyn Terminator>(module_end.clone()));
        assert!(op_dyn_cast::<dyn Terminator>(module_end).is_some());
    }

    // TODO replace this test with a snapshot test
    #[test]
    fn test_module_print() {
        let context = Context::new();
        let module = ModuleOp::builder(&context).build();

        let mut printer = StringPrinter::new();

        module.borrow().print(&mut printer);

        let result = printer.get();

        let golden = "module {\n}\n";
        assert_eq!(result, golden);
    }

    #[test]
    fn test_module_parse() {
        let context = Context::new();
        let input = "module {\n}\n";
        let op = parse_ir(context, input).expect("parsed ir");
        assert_eq!(op.borrow().type_id(), TypeId::of::<ModuleOp>());
    }
}