infer.rkt (13771B)
1 #lang s-exp macrotypes/typecheck 2 (extends "ext-stlc.rkt" 3 #:except define #%app λ → + - void = zero? sub1 add1 not 4 #:rename [~→ ~ext-stlc:→]) 5 (reuse cons [head hd] [tail tl] nil [isnil nil?] List list 6 #:from "stlc+cons.rkt") 7 (reuse tup × proj 8 #:from "stlc+tup.rkt") 9 (require (only-in "sysf.rkt" ∀ ~∀ ∀? Λ)) 10 (require (for-syntax "../type-constraints.rkt")) 11 12 ;; a language with local type inference using bidirectional type checking 13 14 (provide → 15 (typed-out [+ : (→ Int Int Int)] 16 [- : (→ Int Int Int)] 17 [void : (→ Unit)] 18 [= : (→ Int Int Bool)] 19 [zero? : (→ Int Bool)] 20 [sub1 : (→ Int Int)] 21 [add1 : (→ Int Int)] 22 [not : (→ Bool Bool)] 23 [abs : (→ Int Int)]) 24 define λ #%app) 25 26 (define-syntax → ; wrapping → 27 (syntax-parser 28 [(_ (~and Xs {X:id ...}) . rst) 29 #:when (brace? #'Xs) 30 (add-orig #'(∀ (X ...) (ext-stlc:→ . rst)) (get-orig this-syntax))] 31 [(_ . rst) (add-orig #'(∀ () (ext-stlc:→ . rst)) (get-orig this-syntax))])) 32 33 (begin-for-syntax 34 ;; redefine all inferX functions to use 'env prop -------------------- 35 (define (infer es #:ctx [ctx null] #:tvctx [tvctx null]) 36 (syntax-parse ctx #:datum-literals (:) 37 [([x : τ] ...) ; dont expand yet bc τ may have references to tvs 38 #:with ([tv (~seq sep:id tvk) ...] ...) tvctx 39 #:with (e ...) es 40 #:with 41 ((~literal #%plain-lambda) tvs+ 42 ((~literal let-values) () ((~literal let-values) () 43 ((~literal #%expression) 44 ((~literal #%plain-lambda) xs+ 45 ((~literal let-values) () ((~literal let-values) () 46 ((~literal #%expression) e+) ... (~literal void)))))))) 47 (expand/df 48 #`(λ- (tv ...) 49 (let-syntax ([tv (make-rename-transformer 50 (set-stx-prop/preserved 51 (for/fold ([tv-id #'tv]) 52 ([s (in-list (list 'sep ...))] 53 [k (in-list (list #'tvk ...))]) 54 (attach tv-id s k)) 55 'tyvar #t))] ...) 56 (λ- (x ...) 57 (let-syntax 58 ([x 59 (syntax-parser 60 [i:id 61 (if (and (identifier? #'τ) (free-identifier=? #'x #'τ)) 62 (if (get-expected-type #'i) 63 (add-env 64 (assign-type #'x (get-expected-type #'i)) 65 #`((x #,(get-expected-type #'i)))) 66 (raise 67 (exn:fail:type:infer 68 (format "~a (~a:~a): could not infer type of ~a; add annotation(s)" 69 (syntax-source #'x) (syntax-line #'x) (syntax-column #'x) 70 (syntax->datum #'x)) 71 (current-continuation-marks)))) 72 (assign-type #'x #'τ))] 73 [(o . rst) ; handle if x used in fn position 74 #:fail-when (and (identifier? #'τ) (free-identifier=? #'x #'τ)) 75 (raise (exn:fail:type:infer 76 (format "~a (~a:~a): could not infer type of function ~a; add annotation(s)" 77 (syntax-source #'o) (syntax-line #'o) (syntax-column #'o) 78 (syntax->datum #'o)) 79 (current-continuation-marks))) 80 #:with app (datum->syntax #'o '#%app) 81 (datum->syntax this-syntax 82 (list* #'app (assign-type #'x #'τ) #'rst) 83 this-syntax)])] ...) 84 (#%expression e) ... void))))) 85 (list #'tvs+ #'xs+ #'(e+ ...) 86 (stx-map typeof #'(e+ ...)))] 87 [([x τ] ...) (infer es #:ctx #'([x : τ] ...) #:tvctx tvctx)])) 88 89 (define (infer/ctx+erase ctx e) 90 (syntax-parse (infer (list e) #:ctx ctx) 91 [(_ xs (e+) (τ)) (list #'xs #'e+ #'τ)])) 92 (define (infers/ctx+erase ctx es) 93 (stx-cdr (infer es #:ctx ctx))) 94 ; tyctx = kind env for bound type vars in term e 95 (define (infer/tyctx+erase ctx e) 96 (syntax-parse (infer (list e) #:tvctx ctx) 97 [(tvs _ (e+) (τ)) (list #'tvs #'e+ #'τ)])) 98 (define (infers/tyctx+erase ctx es) 99 (syntax-parse (infer es #:tvctx ctx) 100 [(tvs+ _ es+ tys) (list #'tvs+ #'es+ #'tys)])) 101 102 ;; find-free-Xs : (Stx-Listof Id) Type -> (Listof Id) 103 ;; finds the free Xs in the type 104 (define (find-free-Xs Xs ty) 105 (for/list ([X (in-list (stx->list Xs))] 106 #:when (stx-contains-id? ty X)) 107 X)) 108 109 ;; solve : (Stx-Listof Id) (Stx-Listof Stx) (Stx-Listof Type-Stx) 110 ;; -> (List Constraints (Listof (Stx-List Stx Type-Stx))) 111 ;; Solves for the Xs by inferring the type of each arg and unifying it against 112 ;; each corresponding expected-τ (which could have free Xs in them). 113 ;; It returns list of 2 values if successful, else throws a type error 114 ;; - the constraints for substituting the types 115 ;; - a list containing of all the arguments paired with their types 116 (define (solve Xs args expected-τs) 117 (let-values 118 ([(cs e+τs) 119 (for/fold ([cs '()] [e+τs #'()]) 120 ([e_arg (syntax->list args)] 121 [τ_inX (syntax->list expected-τs)]) 122 (define τ_in (inst-type/cs Xs cs τ_inX)) 123 (define/with-syntax [e τ] 124 (infer+erase (if (empty? (find-free-Xs Xs τ_in)) 125 (add-expected-ty e_arg τ_in) 126 e_arg))) 127 ; (displayln #'(e τ)) 128 (define cs* (add-constraints Xs cs #`([#,τ_in τ]))) 129 (values cs* (cons #'[e τ] e+τs)))]) 130 (list cs (reverse (stx->list e+τs)))))) 131 132 (define-typed-syntax define 133 [(_ x:id e) 134 #:with (e- τ) (infer+erase #'e) 135 #:with y (generate-temporary) 136 #'(begin- 137 (define-syntax x (make-rename-transformer (⊢ y : τ))) 138 (define- y e-))] 139 [(_ (~and Xs {X:id ...}) (f:id [x:id (~datum :) τ] ... (~datum →) τ_out) e) 140 #:when (brace? #'Xs) 141 #:with g (generate-temporary #'f) 142 #:with e_ann #'(add-expected e τ_out) 143 #'(begin- 144 (define-syntax f 145 (make-rename-transformer 146 (⊢ g : #,(add-orig #'(∀ (X ...) (ext-stlc:→ τ ... τ_out)) 147 #'(→ τ ... τ_out))))) 148 (define- g (Λ (X ...) (ext-stlc:λ ([x : τ] ...) e_ann))))] 149 [(_ (f:id [x:id (~datum :) τ] ... (~datum →) τ_out) e) 150 #:with g (generate-temporary #'f) 151 #:with e_ann #'(add-expected e τ_out) 152 #'(begin- 153 (define-syntax f (make-rename-transformer (⊢ g : (→ τ ... τ_out)))) 154 (define- g (ext-stlc:λ ([x : τ] ...) e_ann)))]) 155 156 ; all λs have type (∀ (X ...) (→ τ_in ... τ_out)) 157 (define-typed-syntax λ #:datum-literals (:) 158 [(_ (x:id ...) e) ; no annotations, try to infer from outer ctx, ie an app 159 #:with given-τ-args (syntax-property stx 'given-τ-args) 160 #:fail-unless (syntax-e #'given-τ-args) ; cant infer type and no annotations 161 (format 162 "input types for ~a could not be inferred; add annotations" 163 (syntax->datum stx)) 164 #:with (τ_arg ...) #'given-τ-args 165 #:with [fn- τ_fn] (infer+erase #'(ext-stlc:λ ([x : τ_arg] ...) e)) 166 (⊢ fn- : #,(add-orig #'(∀ () τ_fn) (get-orig #'τ_fn)))] 167 [(_ (x:id ...) ~! e) ; no annotations, couldnt infer from ctx (eg, unapplied lam), try to infer from body 168 #:with (xs- e- τ_res) (infer/ctx+erase #'([x : x] ...) #'e) 169 #:with env (get-env #'e-) 170 #:fail-unless (syntax-e #'env) 171 (format 172 "input types for ~a could not be inferred; add annotations" 173 (syntax->datum stx)) 174 #:with (τ_arg ...) (stx-map (λ (y) (lookup y #'env)) #'xs-) 175 #:fail-unless (stx-andmap syntax-e #'(τ_arg ...)) 176 (format 177 "some input types for ~a could not be inferred; add annotations" 178 (syntax->datum stx)) 179 ;; propagate up inferred types of variables 180 #:with res (add-env #'(λ- xs- e-) #'env) 181 ; #:with [fn- τ_fn] (infer+erase #'(ext-stlc:λ ([x : x] ...) e)) 182 (⊢ res : #,(add-orig #'(∀ () (ext-stlc:→ τ_arg ... τ_res)) 183 #`(→ #,@(stx-map get-orig #'(τ_arg ... τ_res)))))] 184 ;(⊢ (λ- xs- e-) : (∀ () (ext-stlc:→ τ_arg ... τ_res)))] 185 [(_ . rst) 186 #:with [fn- τ_fn] (infer+erase #'(ext-stlc:λ . rst)) 187 (⊢ fn- : #,(add-orig #'(∀ () τ_fn) (get-orig #'τ_fn)))]) 188 189 (define-typed-syntax #%app 190 [(_ e_fn e_arg ...) ; infer args first 191 ; #:when (printf "args first ~a\n" (syntax->datum stx)) 192 #:with maybe-inferred-τs (with-handlers ([exn:fail:type:infer? (λ _ #f)]) 193 (infers+erase #'(e_arg ...))) 194 #:when (syntax-e #'maybe-inferred-τs) 195 #:with ([e_arg- τ_arg] ...) #'maybe-inferred-τs 196 #:with e_fn_anno (syntax-property #'e_fn 'given-τ-args #'(τ_arg ...)) 197 ; #:with [e_fn- (τ_in ... τ_out)] (⇑ e_fn_anno as →) 198 #:with [e_fn- ((X ...) ((~ext-stlc:→ τ_inX ... τ_outX)))] (⇑ e_fn_anno as ∀) 199 #:fail-unless (stx-length=? #'(τ_inX ...) #'(e_arg ...)) ; check arity 200 (type-error #:src stx 201 #:msg (string-append 202 (format "~a (~a:~a) Wrong number of arguments given to function ~a.\n" 203 (syntax-source stx) (syntax-line stx) (syntax-column stx) 204 (syntax->datum #'e_fn)) 205 (format "Expected: ~a arguments with types: " 206 (stx-length #'(τ_inX ...))) 207 (string-join (stx-map type->str #'(τ_inX ...)) ", " #:after-last "\n") 208 "Given:\n" 209 (string-join 210 (map (λ (e t) (format " ~a : ~a" e t)) ; indent each line 211 (syntax->datum #'(e_arg ...)) 212 (stx-map type->str #'(τ_arg ...))) 213 "\n"))) 214 #:with cs (add-constraints #'(X ...) '() #'([τ_inX τ_arg] ...)) 215 #:with (τ_in ... τ_out) (inst-types/cs #'(X ...) #'cs #'(τ_inX ... τ_outX)) 216 ; some code duplication 217 #:fail-unless (typechecks? #'(τ_arg ...) #'(τ_in ...)) 218 (type-error #:src stx 219 #:msg (string-append 220 (format "~a (~a:~a) Arguments to function ~a have wrong type(s).\n" 221 (syntax-source stx) (syntax-line stx) (syntax-column stx) 222 (syntax->datum #'e_fn)) 223 "Given:\n" 224 (string-join 225 (map (λ (e t) (format " ~a : ~a" e t)) ; indent each line 226 (syntax->datum #'(e_arg ...)) 227 (stx-map type->str #'(τ_arg ...))) 228 "\n" #:after-last "\n") 229 (format "Expected: ~a arguments with type(s): " 230 (stx-length #'(τ_in ...))) 231 (string-join (stx-map type->str #'(τ_in ...)) ", "))) 232 ; propagate inferred types for variables up 233 #:with env (stx-flatten (filter (λ (x) x) (stx-map get-env #'(e_arg- ...)))) 234 #:with result-app (add-env #'(#%app- e_fn- e_arg- ...) #'env) 235 ;(⊢ (#%app- e_fn- e_arg- ...) : τ_out)] 236 (⊢ result-app : τ_out)] 237 [(_ e_fn e_arg ...) ; infer fn first ------------------------- ; TODO: remove code dup 238 ; #:when (printf "fn first ~a\n" (syntax->datum stx)) 239 #:with [e_fn- ((X ...) ((~ext-stlc:→ τ_inX ... τ_outX)))] (⇑ e_fn as ∀) 240 #:fail-unless (stx-length=? #'(τ_inX ...) #'(e_arg ...)) ; check arity 241 (type-error #:src stx 242 #:msg (string-append 243 (format "~a (~a:~a) Wrong number of arguments given to function ~a.\n" 244 (syntax-source stx) (syntax-line stx) (syntax-column stx) 245 (syntax->datum #'e_fn)) 246 (format "Expected: ~a arguments with types: " 247 (stx-length #'(τ_inX ...))) 248 (string-join (stx-map type->str #'(τ_inX ...)) ", " #:after-last "\n") 249 "Given args: " 250 (string-join (map ~a (syntax->datum #'(e_arg ...))) ", "))) 251 ; #:with ([e_arg- τ_arg] ...) #'(infers+erase #'(e_arg ...)) 252 #:with (cs ([e_arg- τ_arg] ...)) 253 (solve #'(X ...) #'(e_arg ...) #'(τ_inX ...)) 254 #:with env (stx-flatten (filter (λ (x) x) (stx-map get-env #'(e_arg- ...)))) 255 #:with (τ_in ... τ_out) (inst-types/cs #'(X ...) #'cs #'(τ_inX ... τ_outX)) 256 ; some code duplication 257 #:fail-unless (typechecks? #'(τ_arg ...) #'(τ_in ...)) 258 (string-append 259 (format "~a (~a:~a) Arguments to function ~a have wrong type(s).\n" 260 (syntax-source stx) (syntax-line stx) (syntax-column stx) 261 (syntax->datum #'e_fn)) 262 "Given:\n" 263 (string-join 264 (map (λ (e t) (format " ~a : ~a" e t)) ; indent each line 265 (syntax->datum #'(e_arg ...)) 266 (stx-map type->str #'(τ_arg ...))) 267 "\n" #:after-last "\n") 268 (format "Expected: ~a arguments with type(s): " 269 (stx-length #'(τ_in ...))) 270 (string-join (stx-map type->str #'(τ_in ...)) ", ")) 271 #:with result-app (add-env #'(#%app- e_fn- e_arg- ...) #'env) 272 ;(⊢ (#%app- e_fn- e_arg- ...) : τ_out)]) 273 (⊢ result-app : τ_out)])