www

Unnamed repository; edit this file 'description' to name the repository.
Log | Files | Refs | README

infer.rkt (13771B)


      1 #lang s-exp macrotypes/typecheck
      2 (extends "ext-stlc.rkt"
      3          #:except define #%app λ → + - void = zero? sub1 add1 not
      4          #:rename [~→ ~ext-stlc:→])
      5 (reuse cons [head hd] [tail tl] nil [isnil nil?] List list
      6        #:from "stlc+cons.rkt")
      7 (reuse tup × proj
      8        #:from "stlc+tup.rkt")
      9 (require (only-in "sysf.rkt" ∀ ~∀ ∀? Λ))
     10 (require (for-syntax "../type-constraints.rkt"))
     11 
     12 ;; a language with local type inference using bidirectional type checking
     13 
     14 (provide →
     15         (typed-out [+ : (→ Int Int Int)]
     16                    [- : (→ Int Int Int)]
     17                    [void : (→ Unit)]
     18                    [= : (→ Int Int Bool)]
     19                    [zero? : (→ Int Bool)]
     20                    [sub1 : (→ Int Int)]
     21                    [add1 : (→ Int Int)]
     22                    [not : (→ Bool Bool)]
     23                    [abs : (→ Int Int)])
     24         define λ #%app)
     25 
     26 (define-syntax → ; wrapping →
     27   (syntax-parser
     28     [(_ (~and Xs {X:id ...}) . rst)
     29      #:when (brace? #'Xs)
     30      (add-orig #'(∀ (X ...) (ext-stlc:→ . rst)) (get-orig this-syntax))]
     31     [(_ . rst) (add-orig #'(∀ () (ext-stlc:→ . rst)) (get-orig this-syntax))]))
     32 
     33 (begin-for-syntax
     34 ;; redefine all inferX functions to use 'env prop --------------------
     35   (define (infer es #:ctx [ctx null] #:tvctx [tvctx null])
     36     (syntax-parse ctx #:datum-literals (:)
     37       [([x : τ] ...) ; dont expand yet bc τ may have references to tvs
     38        #:with ([tv (~seq sep:id tvk) ...] ...) tvctx
     39        #:with (e ...) es
     40        #:with
     41        ((~literal #%plain-lambda) tvs+
     42         ((~literal let-values) () ((~literal let-values) ()
     43          ((~literal #%expression)
     44           ((~literal #%plain-lambda) xs+
     45            ((~literal let-values) () ((~literal let-values) ()
     46             ((~literal #%expression) e+) ... (~literal void))))))))
     47        (expand/df
     48         #`(λ- (tv ...)
     49             (let-syntax ([tv (make-rename-transformer
     50                               (set-stx-prop/preserved
     51                                (for/fold ([tv-id #'tv])
     52                                          ([s (in-list (list 'sep ...))]
     53                                           [k (in-list (list #'tvk ...))])
     54                                  (attach tv-id s k))
     55                                'tyvar #t))] ...)
     56               (λ- (x ...)
     57                 (let-syntax 
     58                   ([x 
     59                     (syntax-parser 
     60                      [i:id
     61                       (if (and (identifier? #'τ) (free-identifier=? #'x #'τ))
     62                           (if (get-expected-type #'i)
     63                               (add-env 
     64                                 (assign-type #'x (get-expected-type #'i)) 
     65                                 #`((x #,(get-expected-type #'i))))
     66                               (raise
     67                                (exn:fail:type:infer
     68                                  (format "~a (~a:~a): could not infer type of ~a; add annotation(s)"
     69                                          (syntax-source #'x) (syntax-line #'x) (syntax-column #'x)
     70                                          (syntax->datum #'x))
     71                                  (current-continuation-marks))))
     72                           (assign-type #'x #'τ))]
     73                      [(o . rst) ; handle if x used in fn position
     74                       #:fail-when (and (identifier? #'τ) (free-identifier=? #'x #'τ))
     75                       (raise (exn:fail:type:infer
     76                                  (format "~a (~a:~a): could not infer type of function ~a; add annotation(s)"
     77                                          (syntax-source #'o) (syntax-line #'o) (syntax-column #'o)
     78                                          (syntax->datum #'o))
     79                                (current-continuation-marks)))
     80                       #:with app (datum->syntax #'o '#%app)
     81                       (datum->syntax this-syntax
     82                                      (list* #'app (assign-type #'x #'τ) #'rst)
     83                                      this-syntax)])] ...)
     84                   (#%expression e) ... void)))))
     85        (list #'tvs+ #'xs+ #'(e+ ...)
     86               (stx-map typeof #'(e+ ...)))]
     87       [([x τ] ...) (infer es #:ctx #'([x : τ] ...) #:tvctx tvctx)]))
     88 
     89   (define (infer/ctx+erase ctx e)
     90     (syntax-parse (infer (list e) #:ctx ctx)
     91       [(_ xs (e+) (τ)) (list #'xs #'e+ #'τ)]))
     92   (define (infers/ctx+erase ctx es)
     93     (stx-cdr (infer es #:ctx ctx)))
     94   ; tyctx = kind env for bound type vars in term e
     95   (define (infer/tyctx+erase ctx e)
     96     (syntax-parse (infer (list e) #:tvctx ctx)
     97       [(tvs _ (e+) (τ)) (list #'tvs #'e+ #'τ)]))
     98   (define (infers/tyctx+erase ctx es)
     99     (syntax-parse (infer es #:tvctx ctx)
    100       [(tvs+ _ es+ tys) (list #'tvs+ #'es+ #'tys)]))
    101 
    102   ;; find-free-Xs : (Stx-Listof Id) Type -> (Listof Id)
    103   ;; finds the free Xs in the type
    104   (define (find-free-Xs Xs ty)
    105     (for/list ([X (in-list (stx->list Xs))]
    106                #:when (stx-contains-id? ty X))
    107       X))
    108 
    109   ;; solve : (Stx-Listof Id) (Stx-Listof Stx) (Stx-Listof Type-Stx)
    110   ;;         -> (List Constraints (Listof (Stx-List Stx Type-Stx)))
    111   ;; Solves for the Xs by inferring the type of each arg and unifying it against
    112   ;; each corresponding expected-τ (which could have free Xs in them).
    113   ;; It returns list of 2 values if successful, else throws a type error
    114   ;;  - the constraints for substituting the types
    115   ;;  - a list containing of all the arguments paired with their types
    116   (define (solve Xs args expected-τs)
    117     (let-values
    118         ([(cs e+τs)
    119           (for/fold ([cs '()] [e+τs #'()])
    120                     ([e_arg (syntax->list args)]
    121                      [τ_inX (syntax->list expected-τs)])
    122             (define τ_in (inst-type/cs Xs cs τ_inX))
    123             (define/with-syntax [e τ]
    124               (infer+erase (if (empty? (find-free-Xs Xs τ_in))
    125                                (add-expected-ty e_arg τ_in)
    126                                e_arg)))
    127             ;             (displayln #'(e τ))
    128             (define cs* (add-constraints Xs cs #`([#,τ_in τ])))
    129             (values cs* (cons #'[e τ] e+τs)))])
    130       (list cs (reverse (stx->list e+τs))))))
    131 
    132 (define-typed-syntax define
    133   [(_ x:id e)
    134    #:with (e- τ) (infer+erase #'e)
    135    #:with y (generate-temporary)
    136    #'(begin-
    137        (define-syntax x (make-rename-transformer (⊢ y : τ)))
    138        (define- y e-))]
    139   [(_ (~and Xs {X:id ...}) (f:id [x:id (~datum :) τ] ... (~datum →) τ_out) e)
    140    #:when (brace? #'Xs)
    141    #:with g (generate-temporary #'f)
    142    #:with e_ann #'(add-expected e τ_out)
    143    #'(begin-
    144        (define-syntax f
    145          (make-rename-transformer
    146           (⊢ g : #,(add-orig #'(∀ (X ...) (ext-stlc:→ τ ... τ_out))
    147                              #'(→ τ ... τ_out)))))
    148        (define- g (Λ (X ...) (ext-stlc:λ ([x : τ] ...) e_ann))))]
    149   [(_ (f:id [x:id (~datum :) τ] ... (~datum →) τ_out) e)
    150    #:with g (generate-temporary #'f)
    151    #:with e_ann #'(add-expected e τ_out)
    152    #'(begin-
    153        (define-syntax f (make-rename-transformer (⊢ g : (→ τ ... τ_out))))
    154        (define- g (ext-stlc:λ ([x : τ] ...) e_ann)))])
    155 
    156 ; all λs have type (∀ (X ...) (→ τ_in ... τ_out))
    157 (define-typed-syntax λ #:datum-literals (:)
    158   [(_ (x:id ...) e) ; no annotations, try to infer from outer ctx, ie an app
    159    #:with given-τ-args (syntax-property stx 'given-τ-args)
    160    #:fail-unless (syntax-e #'given-τ-args) ; cant infer type and no annotations
    161                  (format
    162                   "input types for ~a could not be inferred; add annotations"
    163                   (syntax->datum stx))
    164    #:with (τ_arg ...) #'given-τ-args
    165    #:with [fn- τ_fn] (infer+erase #'(ext-stlc:λ ([x : τ_arg] ...) e))
    166    (⊢ fn- : #,(add-orig #'(∀ () τ_fn) (get-orig #'τ_fn)))]
    167   [(_ (x:id ...) ~! e) ; no annotations, couldnt infer from ctx (eg, unapplied lam), try to infer from body
    168    #:with (xs- e- τ_res) (infer/ctx+erase #'([x : x] ...) #'e)
    169    #:with env (get-env #'e-)
    170    #:fail-unless (syntax-e #'env)
    171      (format
    172       "input types for ~a could not be inferred; add annotations"
    173       (syntax->datum stx))
    174    #:with (τ_arg ...) (stx-map (λ (y) (lookup y #'env)) #'xs-)
    175    #:fail-unless (stx-andmap syntax-e #'(τ_arg ...))
    176      (format
    177       "some input types for ~a could not be inferred; add annotations"
    178       (syntax->datum stx))
    179    ;; propagate up inferred types of variables
    180    #:with res (add-env #'(λ- xs- e-) #'env)
    181 ;   #:with [fn- τ_fn] (infer+erase #'(ext-stlc:λ ([x : x] ...) e))
    182    (⊢ res : #,(add-orig #'(∀ () (ext-stlc:→ τ_arg ... τ_res))
    183                         #`(→ #,@(stx-map get-orig #'(τ_arg ... τ_res)))))]
    184    ;(⊢ (λ- xs- e-) : (∀ () (ext-stlc:→ τ_arg ... τ_res)))]
    185   [(_ . rst)
    186    #:with [fn- τ_fn] (infer+erase #'(ext-stlc:λ . rst))
    187    (⊢ fn- : #,(add-orig #'(∀ () τ_fn) (get-orig #'τ_fn)))])
    188 
    189 (define-typed-syntax #%app
    190   [(_ e_fn e_arg ...) ; infer args first
    191  ;  #:when (printf "args first ~a\n" (syntax->datum stx))
    192    #:with maybe-inferred-τs (with-handlers ([exn:fail:type:infer? (λ _ #f)])
    193                                  (infers+erase #'(e_arg ...)))
    194    #:when (syntax-e #'maybe-inferred-τs)
    195    #:with ([e_arg- τ_arg] ...) #'maybe-inferred-τs
    196    #:with e_fn_anno (syntax-property #'e_fn 'given-τ-args #'(τ_arg ...))
    197 ;   #:with [e_fn- (τ_in ... τ_out)] (⇑ e_fn_anno as →)
    198    #:with [e_fn- ((X ...) ((~ext-stlc:→ τ_inX ... τ_outX)))] (⇑ e_fn_anno as ∀)
    199    #:fail-unless (stx-length=? #'(τ_inX ...) #'(e_arg ...)) ; check arity
    200                  (type-error #:src stx
    201                   #:msg (string-append
    202                   (format "~a (~a:~a) Wrong number of arguments given to function ~a.\n"
    203                           (syntax-source stx) (syntax-line stx) (syntax-column stx)
    204                           (syntax->datum #'e_fn))
    205                   (format "Expected: ~a arguments with types: "
    206                           (stx-length #'(τ_inX ...)))
    207                   (string-join (stx-map type->str #'(τ_inX ...)) ", " #:after-last "\n")
    208                   "Given:\n"
    209                   (string-join
    210                    (map (λ (e t) (format "  ~a : ~a" e t)) ; indent each line
    211                         (syntax->datum #'(e_arg ...))
    212                         (stx-map type->str #'(τ_arg ...)))
    213                    "\n")))
    214    #:with cs (add-constraints #'(X ...) '() #'([τ_inX τ_arg] ...))
    215    #:with (τ_in ... τ_out) (inst-types/cs #'(X ...) #'cs #'(τ_inX ... τ_outX))
    216    ; some code duplication
    217    #:fail-unless (typechecks? #'(τ_arg ...) #'(τ_in ...))
    218                  (type-error #:src stx
    219                   #:msg (string-append
    220                   (format "~a (~a:~a) Arguments to function ~a have wrong type(s).\n"
    221                           (syntax-source stx) (syntax-line stx) (syntax-column stx)
    222                           (syntax->datum #'e_fn))
    223                   "Given:\n"
    224                   (string-join
    225                    (map (λ (e t) (format "  ~a : ~a" e t)) ; indent each line
    226                         (syntax->datum #'(e_arg ...))
    227                         (stx-map type->str #'(τ_arg ...)))
    228                    "\n" #:after-last "\n")
    229                   (format "Expected: ~a arguments with type(s): "
    230                           (stx-length #'(τ_in ...)))
    231                   (string-join (stx-map type->str #'(τ_in ...)) ", ")))
    232    ; propagate inferred types for variables up
    233    #:with env (stx-flatten (filter (λ (x) x) (stx-map get-env #'(e_arg- ...))))
    234    #:with result-app (add-env #'(#%app- e_fn- e_arg- ...) #'env)
    235    ;(⊢ (#%app- e_fn- e_arg- ...) : τ_out)]
    236    (⊢ result-app : τ_out)]
    237   [(_ e_fn e_arg ...) ; infer fn first ------------------------- ; TODO: remove code dup
    238 ;   #:when (printf "fn first ~a\n" (syntax->datum stx))
    239    #:with [e_fn- ((X ...) ((~ext-stlc:→ τ_inX ... τ_outX)))] (⇑ e_fn as ∀)
    240    #:fail-unless (stx-length=? #'(τ_inX ...) #'(e_arg ...)) ; check arity
    241                  (type-error #:src stx
    242                   #:msg (string-append
    243                   (format "~a (~a:~a) Wrong number of arguments given to function ~a.\n"
    244                           (syntax-source stx) (syntax-line stx) (syntax-column stx)
    245                           (syntax->datum #'e_fn))
    246                   (format "Expected: ~a arguments with types: "
    247                           (stx-length #'(τ_inX ...)))
    248                   (string-join (stx-map type->str #'(τ_inX ...)) ", " #:after-last "\n")
    249                   "Given args: "
    250                   (string-join (map ~a (syntax->datum #'(e_arg ...))) ", ")))
    251 ;   #:with ([e_arg- τ_arg] ...) #'(infers+erase #'(e_arg ...))
    252    #:with (cs ([e_arg- τ_arg] ...))
    253           (solve #'(X ...) #'(e_arg ...) #'(τ_inX ...))
    254    #:with env (stx-flatten (filter (λ (x) x) (stx-map get-env #'(e_arg- ...))))
    255    #:with (τ_in ... τ_out) (inst-types/cs #'(X ...) #'cs #'(τ_inX ... τ_outX))
    256    ; some code duplication
    257    #:fail-unless (typechecks? #'(τ_arg ...) #'(τ_in ...))
    258                  (string-append
    259                   (format "~a (~a:~a) Arguments to function ~a have wrong type(s).\n"
    260                           (syntax-source stx) (syntax-line stx) (syntax-column stx)
    261                           (syntax->datum #'e_fn))
    262                   "Given:\n"
    263                   (string-join
    264                    (map (λ (e t) (format "  ~a : ~a" e t)) ; indent each line
    265                         (syntax->datum #'(e_arg ...))
    266                         (stx-map type->str #'(τ_arg ...)))
    267                    "\n" #:after-last "\n")
    268                   (format "Expected: ~a arguments with type(s): "
    269                           (stx-length #'(τ_in ...)))
    270                   (string-join (stx-map type->str #'(τ_in ...)) ", "))
    271   #:with result-app (add-env #'(#%app- e_fn- e_arg- ...) #'env)
    272   ;(⊢ (#%app- e_fn- e_arg- ...) : τ_out)])
    273   (⊢ result-app : τ_out)])