1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
use std::collections::HashMap;
use std::sync::Arc;

use winnow::combinator::preceded;
use winnow::error::ErrMode;
use winnow::Parser;

use crate::parser::{op_tuple, AsmPResult, PError, Parsable, ParseStream};
use crate::Attr;
use crate::ContextRef;
use crate::ContextWRef;
use crate::Printable;
use crate::TyAssembly;

#[derive(Debug, Clone)]
pub struct Type {
    context: ContextWRef,
    dialect_id: u32,
    type_id: u32,
    attrs: HashMap<String, Attr>,
}

impl Type {
    pub fn new(
        context: ContextRef,
        dialect_id: u32,
        type_id: u32,
        attrs: HashMap<String, Attr>,
    ) -> Self {
        Type {
            context: Arc::downgrade(&context),
            dialect_id,
            type_id,
            attrs,
        }
    }

    pub fn get_context(&self) -> Option<ContextRef> {
        self.context.upgrade()
    }
    pub fn get_dialect_id(&self) -> u32 {
        self.dialect_id
    }

    pub fn get_type_id(&self) -> u32 {
        self.type_id
    }

    pub fn get_attrs(&self) -> &HashMap<String, Attr> {
        &self.attrs
    }

    pub fn isa<T: Ty>(&self) -> bool {
        let context = self.context.upgrade().unwrap();
        let dialect = context.get_dialect_by_name(T::get_dialect_name()).unwrap();
        if dialect.get_id() != self.dialect_id {
            return false;
        }
        if let Some(type_id) = dialect.get_type_id(T::get_type_name()) {
            return type_id == self.type_id;
        };

        false
    }
}

impl Printable for Type {
    fn print(&self, fmt: &mut dyn crate::IRFormatter) {
        let context = self.get_context().unwrap();
        let dialect = context.get_dialect(self.dialect_id).unwrap();

        fmt.write_direct("!");
        if dialect.get_name() != crate::builtin::DIALECT_NAME {
            fmt.write_direct(&format!("{}.", dialect.get_name()));
        }

        let printer = dialect.get_type_printer(self.type_id).unwrap();
        printer(&self.attrs, fmt);
    }
}

impl Parsable<Type> for Type {
    fn parse(input: &mut ParseStream<'_>) -> AsmPResult<Type> {
        let (dialect, ty) = preceded("!", op_tuple).parse_next(input)?;

        let context = input.state.get_context();
        let dialect = context
            .get_dialect_by_name(dialect)
            .ok_or(ErrMode::Cut(PError::UnknownDialect(dialect.to_string())))?;
        let id = dialect
            .get_type_id(ty)
            .ok_or(ErrMode::Cut(PError::UnknownType(ty.to_string())))?;

        // By definition every existing type has a parser, and we just obtained type id from this
        // dialect.
        let mut parser = dialect.get_type_parser(id).unwrap();

        let attrs = parser.parse_next(input)?;

        Ok(Type::new(context, dialect.get_id(), id, attrs))
    }
}

pub trait Ty: TyAssembly {
    fn get_type_name() -> &'static str;
    fn get_dialect_name() -> &'static str;
}

impl PartialEq for Type {
    fn eq(&self, other: &Self) -> bool {
        self.context.as_ptr() == other.context.as_ptr()
            && self.type_id == other.type_id
            && self.dialect_id == other.dialect_id
            && self.attrs == other.attrs
    }
}